Browse Source

loader params

Thomas 1 year ago
parent
commit
f76e6d196a
1 changed files with 46 additions and 17 deletions
  1. 46 17
      src/cases.rs

+ 46 - 17
src/cases.rs

@@ -10,7 +10,9 @@ use log::{info, warn};
 use pandora_lib_pileup::get_hts_nt_pileup;
 use pandora_lib_stats::chi_square_test_for_proportions;
 use pandora_lib_variants::variants::*;
+use rand::seq::IteratorRandom;
 use rayon::prelude::*;
+use rust_htslib::bam;
 use serde::Serialize;
 use std::{
     collections::HashMap,
@@ -19,8 +21,6 @@ use std::{
     path::Path,
     sync::{Arc, Mutex},
 };
-use rand::seq::IteratorRandom;
-use rust_htslib::bam;
 
 #[derive(Debug, Clone, Serialize)]
 pub struct Cases {
@@ -29,7 +29,12 @@ pub struct Cases {
 }
 
 impl Cases {
-    pub fn load(mp: MultiProgress, skip_ids: Option<Vec<String>>) -> Result<Self> {
+    pub fn load(
+        mp: MultiProgress,
+        skip_ids: Option<Vec<String>>,
+        check_snp: bool,
+        load_variants: bool,
+    ) -> Result<Self> {
         let ids_to_skip = skip_ids.unwrap_or(vec![]);
         let config = Config::get()?;
         let mut diag_bams = HashMap::new();
@@ -69,27 +74,47 @@ impl Cases {
         }
 
         // Check SNP AF differences between diag and mrd
-        let mut cases = Vec::new();
-        let diff_snp = DiffSnp::init(&config.commun_snp)?;
+        let mut retained_cases = Vec::new();
+        let diff_snp_opt = if check_snp {
+            Some(DiffSnp::init(&config.commun_snp)?)
+        } else {
+            None
+        };
         for (id, diag_bam) in diag_bams {
             if ids_to_skip.contains(&id) {
                 continue;
             }
             if let Some(mrd_bam) = mrd_bams.get(&id) {
                 // verify if both samples match commun snps
-                let diff = diff_snp.diff_prop(diag_bam.path.to_str().unwrap(), mrd_bam.path.to_str().unwrap())?;
+                if let Some(diff_snp) = &diff_snp_opt {
+                    let diff = diff_snp.diff_prop(
+                        diag_bam.path.to_str().unwrap(),
+                        mrd_bam.path.to_str().unwrap(),
+                    )?;
 
-                if diff < config.max_snp_diff_prop {
-                    let mut case = Case::new(id.clone(), diag_bam.clone(), mrd_bam.clone())?;
-                    let dir = diag_bam.path.parent().context("")?.to_str().context("")?;
-                    case.load_variants(&format!("{dir}/{id}_variants.bytes.gz"), &mp)?;
-                    cases.push(case);
-                } else {
-                    warn!("{id} diag and mrd seems to have been sequenced from two patients.");
+                    if diff > config.max_snp_diff_prop {
+                        warn!("{id} diag and mrd seems to have been sequenced from two patients.");
+                        continue;
+                    }
                 }
+                let case = Case::new(id.clone(), diag_bam.clone(), mrd_bam.clone())?;
+                retained_cases.push(case);
             }
         }
-        Ok(Self { cases, config })
+
+        if load_variants {
+            for case in retained_cases.iter_mut() {
+                let dir = case
+                    .diag_bam
+                    .path
+                    .parent()
+                    .context("")?
+                    .to_str()
+                    .context("")?;
+                case.load_variants(&format!("{dir}/{}_variants.bytes.gz", case.id), &mp)?;
+            }
+        }
+        Ok(Self { cases: retained_cases, config })
     }
 
     pub fn stats(&self) {
@@ -179,7 +204,7 @@ impl Case {
                 } else {
                     None
                 }
-            },
+            }
             None => None,
         }
     }
@@ -188,7 +213,7 @@ impl Case {
 #[derive(Debug, Clone, Serialize)]
 pub struct CaseStats {
     id: String,
-    stats: Vec<Stat>
+    stats: Vec<Stat>,
 }
 
 pub struct DiffSnp {
@@ -210,7 +235,11 @@ impl DiffSnp {
 
     pub fn diff_prop(&self, diag_bam_path: &str, mrd_bam_path: &str) -> Result<f64> {
         let mut rng = rand::thread_rng();
-        let lines = self.lines.clone().into_iter().choose_multiple(&mut rng, 10_000);
+        let lines = self
+            .lines
+            .clone()
+            .into_iter()
+            .choose_multiple(&mut rng, 10_000);
         let max_p_val = 0.0001;
         let diff = Arc::new(Mutex::new(0u64));
         let eq = Arc::new(Mutex::new(0u64));