< 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960
  1. use anyhow::Context;
  2. use log::{info, warn};
  3. use ordered_float::Float;
  4. use pandora_lib_graph::cytoband::{svg_chromosome, AdditionalRect, RectPosition};
  5. use plotly::{color::Rgb, common::Marker, layout::BarMode, Bar, Layout, Plot, Scatter};
  6. use rand::{thread_rng, Rng};
  7. use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
  8. use serde::{
  9. de::{self, Visitor},
  10. Deserialize, Deserializer, Serialize,
  11. };
  12. use statrs::{
  13. distribution::{Continuous, Discrete},
  14. statistics::Statistics,
  15. };
  16. use std::{
  17. collections::{BTreeMap, HashMap, HashSet},
  18. f64, fmt,
  19. fs::File,
  20. io::{BufRead, BufReader, Write},
  21. str::FromStr,
  22. };
  23. #[derive(Debug, Clone)]
  24. pub struct Count {
  25. pub position: CountRange,
  26. pub n_reads: u32,
  27. pub n_low_mapq: u32,
  28. pub frac_sa: f32,
  29. pub sa_outlier: bool,
  30. pub frac_se: f32,
  31. pub se_outlier: bool,
  32. pub annotation: Vec<CountAnnotation>,
  33. }
  34. impl fmt::Display for Count {
  35. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  36. write!(
  37. f,
  38. "{}\t{}\t{}\t{:.6}\t{}\t{:.6}\t{}",
  39. self.position,
  40. self.n_reads,
  41. self.n_low_mapq,
  42. self.frac_sa,
  43. self.sa_outlier,
  44. self.frac_se,
  45. self.se_outlier
  46. )
  47. }
  48. }
  49. // inclusive 0 based
  50. #[derive(Debug, Clone)]
  51. pub struct CountRange {
  52. pub contig: String,
  53. pub start: u32,
  54. pub end: u32,
  55. }
  56. impl fmt::Display for CountRange {
  57. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  58. write!(f, "{}:{}-{}", self.contig, self.start, self.end)
  59. }
  60. }
  61. #[derive(Debug, Clone, Hash, Eq, PartialEq)]
  62. pub enum CountAnnotation {
  63. MaskedLowMRD,
  64. MaskedQuality,
  65. }
  66. impl<'de> Deserialize<'de> for Count {
  67. fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
  68. where
  69. D: Deserializer<'de>,
  70. {
  71. struct CountVisitor;
  72. impl<'de> Visitor<'de> for CountVisitor {
  73. type Value = Count;
  74. fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
  75. formatter.write_str("a string in the format 'chr:start-end n_reads n_low_mapq frac_sa sa_outlier frac_se se_outlier'")
  76. }
  77. fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
  78. where
  79. E: de::Error,
  80. {
  81. let parts: Vec<&str> = s.split_whitespace().collect();
  82. if parts.len() != 7 {
  83. return Err(E::custom("incorrect number of fields"));
  84. }
  85. let position_parts: Vec<&str> = parts[0].split(&[':', '-'][..]).collect();
  86. if position_parts.len() != 3 {
  87. return Err(E::custom("incorrect position format"));
  88. }
  89. Ok(Count {
  90. position: CountRange {
  91. contig: position_parts[0].to_string(),
  92. start: u32::from_str(position_parts[1]).map_err(E::custom)?,
  93. end: u32::from_str(position_parts[2]).map_err(E::custom)?,
  94. },
  95. n_reads: u32::from_str(parts[1]).map_err(E::custom)?,
  96. n_low_mapq: u32::from_str(parts[2]).map_err(E::custom)?,
  97. frac_sa: f32::from_str(parts[3]).map_err(E::custom)?,
  98. sa_outlier: bool::from_str(parts[4]).map_err(E::custom)?,
  99. frac_se: f32::from_str(parts[5]).map_err(E::custom)?,
  100. se_outlier: bool::from_str(parts[6]).map_err(E::custom)?,
  101. annotation: Vec::new(),
  102. })
  103. }
  104. }
  105. deserializer.deserialize_str(CountVisitor)
  106. }
  107. }
  108. pub fn read_counts_from_file(filename: &str) -> anyhow::Result<Vec<Count>> {
  109. let file = File::open(filename)?;
  110. let reader = BufReader::new(file);
  111. let mut counts = Vec::new();
  112. for line in reader.lines() {
  113. let line = line?;
  114. let count: Count = serde_json::from_str(&format!("\"{}\"", escape_control_chars(&line)))?;
  115. counts.push(count);
  116. }
  117. Ok(counts)
  118. }
  119. fn escape_control_chars(s: &str) -> String {
  120. s.chars()
  121. .map(|c| {
  122. if c.is_control() {
  123. format!("\\u{:04x}", c as u32)
  124. } else {
  125. c.to_string()
  126. }
  127. })
  128. .collect()
  129. }
  130. #[derive(Debug)]
  131. pub struct Counts {
  132. pub data: HashMap<String, Vec<Count>>,
  133. pub mrd: HashMap<String, Vec<Count>>,
  134. }
  135. impl Counts {
  136. pub fn from_files(paths: Vec<String>) -> Self {
  137. let counts: Vec<Vec<Count>> = paths
  138. .par_iter()
  139. .map(|path| match read_counts_from_file(path) {
  140. Ok(c) => c,
  141. Err(e) => {
  142. warn!("Couldnt load {path}: {e}");
  143. Vec::new()
  144. }
  145. })
  146. .filter(|v| !v.is_empty())
  147. .collect();
  148. let mut data = HashMap::new();
  149. for count in counts {
  150. let contig = count.first().unwrap().position.contig.clone();
  151. data.insert(contig, count);
  152. }
  153. Counts {
  154. data,
  155. mrd: HashMap::new(),
  156. }
  157. }
  158. pub fn mrd_from_files(&mut self, paths: Vec<String>) {
  159. let counts: Vec<Vec<Count>> = paths
  160. .par_iter()
  161. .map(|path| match read_counts_from_file(path) {
  162. Ok(c) => c,
  163. Err(e) => {
  164. warn!("Couldnt load {path}: {e}");
  165. Vec::new()
  166. }
  167. })
  168. .filter(|v| !v.is_empty())
  169. .collect();
  170. let mut data = HashMap::new();
  171. for count in counts {
  172. let contig = count.first().unwrap().position.contig.clone();
  173. data.insert(contig, count);
  174. }
  175. self.mrd = data;
  176. }
  177. pub fn mask_low_mrd(&mut self, contig: &str, min_reads: u32) -> anyhow::Result<()> {
  178. if let (Some(mrd), Some(diag)) = (self.mrd.get(contig), self.data.get_mut(contig)) {
  179. for (m, d) in mrd.iter().zip(diag) {
  180. if m.n_reads < min_reads {
  181. d.annotation.push(CountAnnotation::MaskedLowMRD);
  182. }
  183. }
  184. Ok(())
  185. } else {
  186. anyhow::bail!("No {contig} in both mrd and diag.")
  187. }
  188. }
  189. pub fn mask_low_quality(&mut self, contig: &str, max_ratio: f64) -> anyhow::Result<()> {
  190. if let Some(diag) = self.data.get_mut(contig) {
  191. for d in diag.iter_mut() {
  192. if (d.n_low_mapq as f64 / (d.n_reads + d.n_low_mapq) as f64) > max_ratio {
  193. d.annotation.push(CountAnnotation::MaskedQuality);
  194. }
  195. }
  196. Ok(())
  197. } else {
  198. anyhow::bail!("No {contig} in both mrd and diag.")
  199. }
  200. }
  201. pub fn frequencies(&self, contig: &str) -> anyhow::Result<Vec<(f64, f64)>> {
  202. let data = self.get(contig)?;
  203. let mut frequencies = HashMap::new();
  204. for count in data.iter() {
  205. *frequencies.entry(*count).or_insert(0) += 1;
  206. }
  207. let mut frequencies: Vec<(u32, f64)> =
  208. frequencies.iter().map(|(x, y)| (*x, *y as f64)).collect();
  209. frequencies.sort_by_key(|v| v.0);
  210. Ok(frequencies.iter().map(|(x, y)| (*x as f64, *y)).collect())
  211. }
  212. pub fn percentile(&self, contig: &str, percentile: f64) -> anyhow::Result<u32> {
  213. let mut data = self.get(contig)?;
  214. data.sort_unstable();
  215. let total_count = data.len();
  216. let index = |percentile: f64| -> usize {
  217. (percentile / 100.0 * (total_count - 1) as f64).round() as usize
  218. };
  219. Ok(*data.get(index(percentile)).context("Error in percentile")?)
  220. }
  221. pub fn save_contig(
  222. &mut self,
  223. contig: &str,
  224. prefix: &str,
  225. breaks: Vec<u32>,
  226. ) -> anyhow::Result<CountsStats> {
  227. self.mask_low_mrd(contig, 6)?;
  228. self.mask_low_quality(contig, 0.1)?;
  229. let data: Vec<f64> = self.get(contig)?.iter().map(|v| *v as f64).collect();
  230. let n_final = data.len();
  231. let frequencies = self.frequencies(contig)?;
  232. let percentile_99 = self.percentile(contig, 99.0)?;
  233. let mut data_x = Vec::new();
  234. let mut data_y = Vec::new();
  235. frequencies.iter().for_each(|(x, y)| {
  236. if *x <= percentile_99 as f64 {
  237. data_x.push(*x);
  238. data_y.push(*y / n_final as f64);
  239. }
  240. });
  241. // Distribution plot
  242. let distribution_path = format!("{prefix}_{contig}_distrib.svg");
  243. info!("Saving graph: {distribution_path}");
  244. let mut plot = Plot::new();
  245. let colors: Vec<Rgb> = data_x
  246. .iter()
  247. .map(|&x| match x {
  248. x if x < 2.0 => Rgb::new(193, 18, 31),
  249. x if x < 6.0 => Rgb::new(243, 114, 44),
  250. x if x < 15.0 => Rgb::new(255, 202, 58),
  251. _ => Rgb::new(138, 201, 38),
  252. })
  253. .collect();
  254. let bars = Bar::new(data_x.clone(), data_y.clone())
  255. .show_legend(false)
  256. .marker(Marker::new().color_array(colors));
  257. plot.add_trace(bars);
  258. let sum: f64 = data.iter().sum();
  259. let mean = (&data).mean();
  260. let count = data.len() as f64;
  261. let std_dev = (&data).std_dev();
  262. // Normal
  263. let normal = statrs::distribution::Normal::new(mean, std_dev)?;
  264. let data_y: Vec<f64> = data_x.iter().map(|x| normal.pdf(*x)).collect();
  265. let trace = Scatter::new(data_x.clone(), data_y).name("Normal");
  266. plot.add_trace(trace);
  267. // // Gamma
  268. // let shape = mean * mean / variance;
  269. // let rate = mean / variance;
  270. //
  271. // let gamma = statrs::distribution::Gamma::new(shape, rate).unwrap();
  272. // let data_y: Vec<f64> = data_x.iter().map(|x| gamma.pdf(*x)).collect();
  273. // let trace = Scatter::new(data_x.clone(), data_y).name("Gamma");
  274. // plot.add_trace(trace);
  275. // Poisson
  276. let lambda = sum / count;
  277. let poisson = statrs::distribution::Poisson::new(lambda)?;
  278. let data_y = data_x.iter().map(|x| poisson.pmf(*x as u64)).collect();
  279. let trace = Scatter::new(data_x.clone(), data_y).name("Poisson");
  280. plot.add_trace(trace);
  281. plot.write_image(distribution_path, plotly::ImageFormat::SVG, 800, 600, 1.0);
  282. // Fractions
  283. let mut breaks_values = Vec::new();
  284. for (i, b) in breaks.iter().enumerate() {
  285. if i == 0 {
  286. let total: f64 = frequencies
  287. .iter()
  288. .filter(|(x, _)| *x < *b as f64)
  289. .map(|(_, y)| *y / count)
  290. .sum();
  291. breaks_values.push((format!("< {b}"), total));
  292. } else {
  293. let last = breaks[i - 1];
  294. let total: f64 = frequencies
  295. .iter()
  296. .filter(|(x, _)| *x < *b as f64 && *x >= last as f64)
  297. .map(|(_, y)| *y / count)
  298. .sum();
  299. breaks_values.push((format!("[{last} - {b}["), total));
  300. }
  301. }
  302. let last = *breaks.last().unwrap();
  303. let total: f64 = frequencies
  304. .iter()
  305. .filter(|(x, _)| *x >= last as f64)
  306. .map(|(_, y)| *y / count)
  307. .sum();
  308. breaks_values.push((format!(">= {last}"), total));
  309. // Chromosome
  310. let tol = 25;
  311. let chromosome_path = format!("{prefix}_{contig}_chromosome.svg");
  312. info!("Saving graph: {chromosome_path}");
  313. let target_annotations: HashSet<CountAnnotation> = vec![
  314. CountAnnotation::MaskedLowMRD,
  315. CountAnnotation::MaskedQuality,
  316. ]
  317. .into_iter()
  318. .collect();
  319. let d: Vec<u32> = self
  320. .data
  321. .get(contig)
  322. .unwrap()
  323. .iter()
  324. .map(|c| {
  325. if c.annotation
  326. .iter()
  327. .any(|ann| target_annotations.contains(ann))
  328. {
  329. 10_000u32
  330. } else {
  331. c.n_reads
  332. }
  333. })
  334. .collect();
  335. let hm = self.counts_annotations(contig)?;
  336. let len = d.len();
  337. let mut masked: Vec<(String, f64)> = hm
  338. .iter()
  339. .map(|(k, v)| (format!("{:?}", k), *v as f64 / len as f64))
  340. .collect();
  341. masked.push(("Un masked".to_string(), n_final as f64 / len as f64));
  342. let under_6_rects: Vec<AdditionalRect> = ranges_under(&d, 5, tol)
  343. .iter()
  344. .filter(|(s, e)| e > s)
  345. .map(|(start, end)| AdditionalRect {
  346. start: *start as u32 * 1000,
  347. end: *end as u32 * 1000,
  348. color: String::from("red"),
  349. position: RectPosition::Below(1),
  350. })
  351. .collect();
  352. let over_6_rects: Vec<AdditionalRect> = ranges_between(&d, 6, 9999, tol)
  353. .iter()
  354. .filter(|(s, e)| e > s)
  355. .map(|(start, end)| AdditionalRect {
  356. start: *start as u32 * 1000,
  357. end: *end as u32 * 1000,
  358. color: String::from("green"),
  359. position: RectPosition::Below(2),
  360. })
  361. .collect();
  362. let masked_rec: Vec<AdditionalRect> = ranges_over(&d, 10000, tol)
  363. .iter()
  364. .filter(|(s, e)| e > s)
  365. .map(|(start, end)| AdditionalRect {
  366. start: *start as u32 * 1000,
  367. end: *end as u32 * 1000,
  368. color: String::from("grey"),
  369. position: RectPosition::Below(0),
  370. })
  371. .collect();
  372. let mut all = Vec::new();
  373. all.extend(under_6_rects);
  374. all.extend(over_6_rects);
  375. // all.extend(over15);
  376. all.extend(masked_rec);
  377. svg_chromosome(
  378. contig,
  379. 1000,
  380. 50,
  381. "/data/ref/hs1/cytoBandMapped.bed",
  382. &chromosome_path,
  383. &all,
  384. &Vec::new(),
  385. )
  386. .unwrap();
  387. let stats = CountsStats {
  388. sum,
  389. mean,
  390. std_dev,
  391. breaks_values,
  392. masked,
  393. };
  394. // Save stats
  395. let json_path = format!("{prefix}_{contig}_stats.json");
  396. info!("Saving stats into: {json_path}");
  397. let json = serde_json::to_string_pretty(&stats)?;
  398. let mut file = File::create(json_path)?;
  399. file.write_all(json.as_bytes())?;
  400. Ok(stats)
  401. }
  402. pub fn save_contigs(
  403. &mut self,
  404. contigs: &Vec<String>,
  405. prefix: &str,
  406. breaks: Vec<u32>,
  407. ) -> anyhow::Result<()> {
  408. let mut stats = Vec::new();
  409. let mut proportions = HashMap::new();
  410. for contig in contigs {
  411. let stat = self.save_contig(contig, prefix, breaks.clone())?;
  412. let un_masked: Vec<&(String, f64)> = stat.masked.iter().filter(|(s, _)| s == "Un masked").collect();
  413. let un_masked = un_masked.first().unwrap().1;
  414. let masked = 1.0 - un_masked;
  415. let mut props: Vec<(String, f64)> = stat.breaks_values.iter().map(|(s, v)| (s.to_string(), *v * un_masked)).collect();
  416. props.push(("masked".to_string(), masked));
  417. props.iter().for_each(|(s, v)| {
  418. proportions.entry(s.to_string()).or_insert(vec![]).push(*v);
  419. });
  420. stats.push(stat);
  421. }
  422. let mut plot = Plot::new();
  423. let layout = Layout::new().bar_mode(BarMode::Stack);
  424. for (k, v) in proportions {
  425. plot.add_trace(Bar::new(contigs.clone(), v.to_vec()).name(k));
  426. }
  427. println!("{:?}", contigs);
  428. plot.set_layout(layout);
  429. plot.write_image(path, plotly::ImageFormat::SVG, 800, 600, 1.0);
  430. Ok(())
  431. }
  432. pub fn counts_annotations(
  433. &self,
  434. contig: &str,
  435. ) -> anyhow::Result<HashMap<CountAnnotation, u64>> {
  436. if let Some(d) = self.data.get(contig) {
  437. let mut counts = HashMap::new();
  438. for c in d {
  439. for a in &c.annotation {
  440. *counts.entry(a.clone()).or_insert(0) += 1;
  441. }
  442. }
  443. Ok(counts)
  444. } else {
  445. anyhow::bail!("No {contig} in counts.")
  446. }
  447. }
  448. pub fn get(&self, contig: &str) -> anyhow::Result<Vec<u32>> {
  449. if let Some(ccounts) = self.data.get(contig) {
  450. let target_annotations: HashSet<CountAnnotation> = vec![
  451. CountAnnotation::MaskedLowMRD,
  452. CountAnnotation::MaskedQuality,
  453. ]
  454. .into_iter()
  455. .collect();
  456. Ok(ccounts
  457. .iter()
  458. .filter(|count| {
  459. !count
  460. .annotation
  461. .iter()
  462. .any(|ann| target_annotations.contains(ann))
  463. })
  464. .map(|c| c.n_reads)
  465. .collect())
  466. } else {
  467. anyhow::bail!("No {contig} in counts.")
  468. }
  469. }
  470. pub fn mrd(&self, contig: &str) -> anyhow::Result<Vec<u32>> {
  471. if let Some(ccounts) = self.mrd.get(contig) {
  472. Ok(ccounts.iter().map(|c| c.n_reads).collect())
  473. } else {
  474. anyhow::bail!("No {contig} in counts.")
  475. }
  476. }
  477. pub fn calculate_percentiles(
  478. &self,
  479. contig: &str,
  480. percentiles: &[f64],
  481. ) -> anyhow::Result<Vec<f64>> {
  482. if let Some(ccounts) = self.data.get(contig) {
  483. let mut n_reads: Vec<u32> = ccounts.iter().map(|c| c.n_reads).collect();
  484. n_reads.sort_unstable();
  485. let cdf = ND::new(n_reads.clone());
  486. // println!("CDF at 13: {:?}", cdf.cdf(13));
  487. println!("Percentile at 99: {:?}", cdf.percentile(99.0));
  488. println!("above 15X: {:?}", cdf.proportion_above(15));
  489. // println!("above 15.1X: {:?}", cdf.fitted_proportion_above(&15.1));
  490. println!("under 6X: {:?}", cdf.proportion_under(6));
  491. Ok(percentiles
  492. .iter()
  493. .map(|&p| {
  494. let index = (p * (n_reads.len() - 1) as f64).round() as usize;
  495. n_reads[index] as f64
  496. })
  497. .collect())
  498. } else {
  499. anyhow::bail!("No {contig} in counts.")
  500. }
  501. }
  502. pub fn nd_reads(&self, contig: &str) -> anyhow::Result<ND> {
  503. if let Some(ccounts) = self.data.get(contig) {
  504. Ok(ND::new(ccounts.iter().map(|c| c.n_reads).collect()))
  505. } else {
  506. anyhow::bail!("No {contig} in counts")
  507. }
  508. }
  509. pub fn distribution(&self, contig: &str) -> anyhow::Result<ND> {
  510. Ok(ND::new(self.get(contig)?))
  511. }
  512. pub fn save_stats(&self) -> anyhow::Result<()> {
  513. Ok(())
  514. }
  515. pub fn save_global_proportions_graph(
  516. &self,
  517. path: &str,
  518. contigs: &Vec<String>,
  519. breaks: Vec<u32>,
  520. ) {
  521. let mut breaks_str = Vec::new();
  522. for (i, b) in breaks.iter().enumerate() {
  523. if i == 0 {
  524. breaks_str.push(format!("< {b}"));
  525. } else {
  526. let last = breaks[i - 1];
  527. breaks_str.push(format!("[{last} - {b}["))
  528. }
  529. }
  530. breaks_str.push(format!(">= {}", breaks.last().unwrap()));
  531. let mut proportions = Vec::new();
  532. for contig in contigs.iter() {
  533. let d = self.get(contig).unwrap();
  534. let nd = ND::new(d);
  535. proportions.push(nd.frequencies(&breaks));
  536. }
  537. let mut plot = Plot::new();
  538. let layout = Layout::new().bar_mode(BarMode::Stack);
  539. for (i, v) in transpose(proportions).iter().enumerate() {
  540. plot.add_trace(Bar::new(contigs.clone(), v.to_vec()).name(&breaks_str[i]));
  541. }
  542. println!("{:?}", contigs);
  543. plot.set_layout(layout);
  544. plot.write_image(path, plotly::ImageFormat::SVG, 800, 600, 1.0);
  545. }
  546. pub fn save_global_distribution_graph(&self, path: &str, contigs: &Vec<String>) {
  547. let d: Vec<u32> = contigs.iter().flat_map(|c| self.get(c).unwrap()).collect();
  548. let mut data_sorted = d.clone();
  549. data_sorted.sort_unstable();
  550. let nd = ND::new(d.clone());
  551. let mut plot = Plot::new();
  552. let bar_x: Vec<u32> = (1..=nd.percentile(99.0).unwrap()).collect();
  553. let colors: Vec<plotly::color::Rgb> = bar_x
  554. .iter()
  555. .map(|&x| {
  556. if x <= 2 {
  557. plotly::color::Rgb::new(193, 18, 31)
  558. } else if x >= 15 {
  559. plotly::color::Rgb::new(138, 201, 38)
  560. } else if x <= 6 {
  561. plotly::color::Rgb::new(243, 114, 44)
  562. } else {
  563. plotly::color::Rgb::new(255, 202, 58)
  564. }
  565. })
  566. .collect();
  567. let data: Vec<u32> = d.iter().filter(|x| **x >= 1).copied().collect();
  568. // frequencies
  569. let mut frequencies = HashMap::new();
  570. for &value in &data {
  571. *frequencies.entry(value).or_insert(0) += 1;
  572. }
  573. let bars = Bar::new(
  574. bar_x.clone(),
  575. bar_x
  576. .iter()
  577. .map(|x| *frequencies.get(x).unwrap_or(&0) as f64 / data.len() as f64)
  578. .collect(),
  579. )
  580. .show_legend(false)
  581. .marker(Marker::new().color_array(colors));
  582. plot.add_trace(bars);
  583. let data_x = generate_range(0.0, nd.percentile(99.0).unwrap().into(), 100);
  584. let data_y: Vec<f64> = data_x.iter().map(|x| nd.fitted_normal.pdf(x)).collect();
  585. let trace = Scatter::new(data_x.clone(), data_y).name("Gaussian");
  586. plot.add_trace(trace);
  587. // Gamma
  588. let data: Vec<f64> = d.iter().map(|x| *x as f64).collect();
  589. let sum: f64 = data.iter().sum();
  590. let mean = (&data).mean();
  591. let variance = (&data).variance();
  592. let count = d.len() as f64;
  593. let shape = mean * mean / variance;
  594. let rate = mean / variance;
  595. let gamma = statrs::distribution::Gamma::new(shape, rate).unwrap();
  596. let data_y: Vec<f64> = data_x.iter().map(|x| gamma.pdf(*x)).collect();
  597. let trace = Scatter::new(data_x.clone(), data_y).name("Gamma");
  598. plot.add_trace(trace);
  599. // Poisson
  600. let lambda = sum / count;
  601. let poisson = statrs::distribution::Poisson::new(lambda).unwrap();
  602. let data_y = data_x.iter().map(|x| poisson.pmf(*x as u64)).collect();
  603. let trace = Scatter::new(data_x.clone(), data_y).name("Poisson");
  604. plot.add_trace(trace);
  605. plot.write_image(path, plotly::ImageFormat::SVG, 800, 600, 1.0);
  606. println!("> 15x: {:?}", nd.frequencies(&vec![1, 6, 15]));
  607. }
  608. }
  609. pub struct ND {
  610. pub data: Vec<u32>,
  611. pub distribution: BTreeMap<u32, f64>,
  612. pub total_count: usize,
  613. pub frequency: HashMap<u32, usize>,
  614. pub fitted_normal: UvNormal,
  615. }
  616. use rstat::{fitting::MLE, normal::UvNormal, ContinuousDistribution};
  617. impl ND {
  618. fn new(mut data: Vec<u32>) -> Self {
  619. data.sort_unstable();
  620. let n = data.len();
  621. info!("n values {n}");
  622. let mut distribution = BTreeMap::new();
  623. let mut frequency = HashMap::new();
  624. for &value in &data {
  625. *frequency.entry(value).or_insert(0) += 1;
  626. }
  627. let mut cumulative_count = 0;
  628. for (&value, &count) in &frequency {
  629. cumulative_count += count;
  630. let cumulative_prob = cumulative_count as f64 / n as f64;
  631. distribution.insert(value, cumulative_prob);
  632. }
  633. // Fit normal distribution
  634. let fitted_normal = rstat::univariate::normal::Normal::fit_mle(
  635. &data
  636. .iter()
  637. .filter(|x| *x >= &1u32)
  638. .map(|x| *x as f64)
  639. .collect::<Vec<f64>>(),
  640. )
  641. .unwrap();
  642. Self {
  643. data,
  644. distribution,
  645. frequency,
  646. total_count: n,
  647. fitted_normal,
  648. }
  649. }
  650. pub fn frequency(&self, x: u32) -> usize {
  651. *self.frequency.get(&x).unwrap_or(&0)
  652. }
  653. pub fn frequencies(&self, breaks: &Vec<u32>) -> Vec<f64> {
  654. let mut last_prop_under = 0.0;
  655. let mut res = Vec::new();
  656. for brk in breaks {
  657. let v = self.proportion_under(*brk) - last_prop_under;
  658. res.push(v);
  659. last_prop_under += v;
  660. }
  661. let per99 = self.percentile(99.0).unwrap();
  662. res.push(self.proportion_under(per99) - last_prop_under);
  663. res
  664. }
  665. pub fn percentile(&self, percentile: f64) -> Option<u32> {
  666. if !(0.0..=100.0).contains(&percentile) {
  667. return None;
  668. }
  669. let index = (percentile / 100.0 * (self.total_count - 1) as f64).round() as usize;
  670. self.data.get(index).cloned()
  671. }
  672. pub fn proportion_under(&self, x: u32) -> f64 {
  673. let count = self
  674. .frequency
  675. .iter()
  676. .filter(|(&value, _)| value < x)
  677. .map(|(_, &count)| count)
  678. .sum::<usize>();
  679. count as f64 / self.total_count as f64
  680. }
  681. pub fn proportion_above(&self, x: u32) -> f64 {
  682. let count = self
  683. .frequency
  684. .iter()
  685. .filter(|(&value, _)| value > x)
  686. .map(|(_, &count)| count)
  687. .sum::<usize>();
  688. count as f64 / self.total_count as f64
  689. }
  690. }
  691. pub fn generate_range(start: f64, end: f64, steps: usize) -> Vec<f64> {
  692. if steps == 0 {
  693. return vec![];
  694. }
  695. if steps == 1 {
  696. return vec![start];
  697. }
  698. let step_size = (end - start) / (steps - 1) as f64;
  699. (0..steps).map(|i| start + i as f64 * step_size).collect()
  700. }
  701. use rayon::prelude::*;
  702. pub fn ranges_under(vec: &[u32], x: u32, tolerance: usize) -> Vec<(usize, usize)> {
  703. get_ranges_parallel(vec, x, tolerance, |val, threshold| val <= threshold)
  704. }
  705. pub fn ranges_over(vec: &[u32], x: u32, tolerance: usize) -> Vec<(usize, usize)> {
  706. get_ranges_parallel(vec, x, tolerance, |val, threshold| val >= threshold)
  707. }
  708. pub fn ranges_between(
  709. vec: &[u32],
  710. lower: u32,
  711. upper: u32,
  712. tolerance: usize,
  713. ) -> Vec<(usize, usize)> {
  714. get_ranges_parallel(vec, (lower, upper), tolerance, |val, (l, u)| {
  715. val >= l && val <= u
  716. })
  717. }
  718. pub fn get_ranges_parallel<T, F>(
  719. vec: &[u32],
  720. threshold: T,
  721. tolerance: usize,
  722. compare: F,
  723. ) -> Vec<(usize, usize)>
  724. where
  725. F: Fn(u32, T) -> bool + Sync,
  726. T: Copy + Sync,
  727. {
  728. if vec.is_empty() {
  729. return Vec::new();
  730. }
  731. let chunk_size = (vec.len() / rayon::current_num_threads()).max(1);
  732. vec.par_chunks(chunk_size)
  733. .enumerate()
  734. .flat_map(|(chunk_index, chunk)| {
  735. let mut local_ranges = Vec::new();
  736. let mut current_range: Option<(usize, usize)> = None;
  737. let offset = chunk_index * chunk_size;
  738. for (i, &val) in chunk.iter().enumerate() {
  739. let global_index = offset + i;
  740. if compare(val, threshold) {
  741. match current_range {
  742. Some((start, end)) if global_index <= end + tolerance + 1 => {
  743. current_range = Some((start, global_index));
  744. }
  745. Some((start, end)) => {
  746. local_ranges.push((start, end));
  747. current_range = Some((global_index, global_index));
  748. }
  749. None => {
  750. current_range = Some((global_index, global_index));
  751. }
  752. }
  753. } else if let Some((start, end)) = current_range {
  754. if global_index > end + tolerance + 1 {
  755. local_ranges.push((start, end));
  756. current_range = None;
  757. }
  758. }
  759. }
  760. if let Some(range) = current_range {
  761. local_ranges.push(range);
  762. }
  763. local_ranges
  764. })
  765. .collect()
  766. }
  767. pub fn transpose(v: Vec<Vec<f64>>) -> Vec<Vec<f64>> {
  768. assert!(!v.is_empty());
  769. let len = v[0].len();
  770. let mut result = vec![Vec::with_capacity(v.len()); len];
  771. for row in v {
  772. for (i, val) in row.into_iter().enumerate() {
  773. result[i].push(val);
  774. }
  775. }
  776. result
  777. }
  778. #[derive(Debug, Serialize)]
  779. pub struct CountsStats {
  780. pub sum: f64,
  781. pub mean: f64,
  782. pub std_dev: f64,
  783. pub breaks_values: Vec<(String, f64)>,
  784. pub masked: Vec<(String, f64)>,
  785. }
  786. // pub fn save_barplota
  787. // data: Vec<f64>,
  788. // data_x: Vec<f64>,
  789. // data_y: Vec<f64>,
  790. // path: &str,
  791. // ) -> anyhow::Result<CountsStats> {
  792. // let mut plot = Plot::new();
  793. //
  794. // let colors: Vec<plotly::color::Rgb> = data_x
  795. // .iter()
  796. // .map(|&x| {
  797. // if x <= 2.0 {
  798. // plotly::color::Rgb::new(193, 18, 31)
  799. // } else if x >= 15.0 {
  800. // plotly::color::Rgb::new(138, 201, 38)
  801. // } else if x <= 6.0 {
  802. // plotly::color::Rgb::new(243, 114, 44)
  803. // } else {
  804. // plotly::color::Rgb::new(255, 202, 58)
  805. // }
  806. // })
  807. // .collect();
  808. //
  809. // let bars = Bar::new(data_x.clone(), data_y.clone())
  810. // .show_legend(false)
  811. // .marker(Marker::new().color_array(colors));
  812. //
  813. // plot.add_trace(bars);
  814. //
  815. // let sum: f64 = data.iter().sum();
  816. // let mean = (&data).mean();
  817. // let count = data.len() as f64;
  818. // let std_dev = (&data).std_dev();
  819. // println!("mean {mean}");
  820. //
  821. // // Normal
  822. // let normal = statrs::distribution::Normal::new(mean, std_dev)?;
  823. // let data_y: Vec<f64> = data_x.iter().map(|x| normal.pdf(*x)).collect();
  824. // let trace = Scatter::new(data_x.clone(), data_y).name("Normal");
  825. // plot.add_trace(trace);
  826. //
  827. // // // Gamma
  828. // // let shape = mean * mean / variance;
  829. // // let rate = mean / variance;
  830. // //
  831. // // let gamma = statrs::distribution::Gamma::new(shape, rate).unwrap();
  832. // // let data_y: Vec<f64> = data_x.iter().map(|x| gamma.pdf(*x)).collect();
  833. // // let trace = Scatter::new(data_x.clone(), data_y).name("Gamma");
  834. // // plot.add_trace(trace);
  835. //
  836. // // Poisson
  837. // let lambda = sum / count;
  838. // let poisson = statrs::distribution::Poisson::new(lambda)?;
  839. // let data_y = data_x.iter().map(|x| poisson.pmf(*x as u64)).collect();
  840. // let trace = Scatter::new(data_x.clone(), data_y).name("Poisson");
  841. // plot.add_trace(trace);
  842. //
  843. // plot.use_local_plotly();
  844. // plot.write_image(path, plotly::ImageFormat::SVG, 800, 600, 1.0);
  845. // Ok(CountsStats { sum, mean, std_dev })
  846. // }