Browse Source

ClairS upate

Thomas 1 week ago
parent
commit
4ac495035c
4 changed files with 318 additions and 83 deletions
  1. 229 63
      src/callers/clairs.rs
  2. 2 3
      src/callers/deep_variant.rs
  3. 1 17
      src/collection/bam_stats.rs
  4. 86 0
      src/helpers.rs

+ 229 - 63
src/callers/clairs.rs

@@ -2,12 +2,13 @@ use crate::{
     annotation::{Annotation, Annotations, Caller, CallerCat, Sample},
     collection::vcf::Vcf,
     commands::{
-        bcftools::{BcftoolsConcat, BcftoolsKeepPass},
-        CapturedOutput, Command as JobCommand, Runner as LocalRunner, SbatchRunner, SlurmParams,
-        SlurmRunner,
+        CapturedOutput, Command as JobCommand, Runner as LocalRunner, SbatchRunner, SlurmParams, SlurmRunner, bcftools::{BcftoolsConcat, BcftoolsKeepPass}, run_many_sbatch
     },
     config::Config,
-    helpers::{is_file_older, remove_dir_if_exists, temp_file_path},
+    helpers::{
+        get_genome_sizes, is_file_older, remove_dir_if_exists, split_genome_into_n_regions,
+        temp_file_path,
+    },
     io::vcf::read_vcf,
     pipes::{Initialize, ShouldRun, Version},
     runners::Run,
@@ -21,8 +22,11 @@ use anyhow::Context;
 use log::{debug, info, warn};
 use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
 use regex::Regex;
+use rust_htslib::bam::{self, Read};
 use std::{
-    fmt, fs, path::Path, process::{Command as ProcessCommand, Stdio}
+    fmt, fs,
+    path::Path,
+    process::{Command as ProcessCommand, Stdio},
 };
 
 /// A pipeline runner for executing ClairS on paired tumor and normal samples.
@@ -41,6 +45,12 @@ pub struct ClairS {
     /// Optional list of regions passed as repeated `-r REGION` args.
     /// When empty, ClairS runs genome-wide.
     pub regions: Vec<String>,
+
+    /// Optional part index for chunked parallel runs (1-indexed).
+    ///
+    /// When `Some(n)`, output files go into a `part{n}` subdirectory and
+    /// PASS VCFs are per-part, later merged into the canonical VCF.
+    pub part_index: Option<usize>,
 }
 
 impl fmt::Display for ClairS {
@@ -55,6 +65,12 @@ impl fmt::Display for ClairS {
                 format!("{} regions", self.regions.len())
             }
         )?;
+        writeln!(
+            f,
+            "  Part      : {}",
+            self.part_index
+                .map_or("full".into(), |n| format!("part{n}"))
+        )?;
         writeln!(f, "  Log dir   : {}", self.log_dir)
     }
 }
@@ -71,6 +87,7 @@ impl Initialize for ClairS {
             log_dir,
             config: config.clone(),
             regions: Vec::new(),
+            part_index: None,
         };
 
         if clairs.config.clairs_force {
@@ -99,7 +116,7 @@ impl ShouldRun for ClairS {
 
 impl JobCommand for ClairS {
     fn init(&mut self) -> anyhow::Result<()> {
-        let output_dir = self.config.clairs_output_dir(&self.id);
+        let output_dir = self.part_output_dir();
 
         fs::create_dir_all(&output_dir)
             .with_context(|| format!("Failed create dir: {output_dir}"))?;
@@ -111,7 +128,7 @@ impl JobCommand for ClairS {
     }
 
     fn cmd(&self) -> String {
-        let output_dir = self.config.clairs_output_dir(&self.id);
+        let output_dir = self.part_output_dir();
 
         // Build repeated -r REGION args if any regions were set (for batched runs)
         let region_args = if self.regions.is_empty() {
@@ -190,38 +207,56 @@ impl Run for ClairS {
 
 impl ClairS {
     fn postprocess_local(&self) -> anyhow::Result<()> {
-        // Germline PASS
-        let clair3_germline_passed = self.config.clairs_germline_passed_vcf(&self.id);
-        if !Path::new(&clair3_germline_passed).exists() {
-            let clair3_germline_normal = self.config.clairs_germline_normal_vcf(&self.id);
-
-            let mut cmd = BcftoolsKeepPass::from_config(
-                &self.config,
-                clair3_germline_normal,
-                clair3_germline_passed.clone(),
-            );
-            let report = <BcftoolsKeepPass as LocalRunner>::run(&mut cmd).with_context(|| {
-                format!(
-                    "Failed to run `bcftools keep PASS` for {}.",
-                    clair3_germline_passed
-                )
-            })?;
-
-            let log_file = format!("{}/bcftools_germline_pass_", self.log_dir);
-            report
-                .save_to_file(&log_file)
-                .with_context(|| format!("Error while writing logs into {log_file}"))?;
-        } else {
-            debug!(
-                "ClairS Germline PASSED VCF already exists for {}, skipping.",
-                self.id
-            );
+        // Germline PASS only once (full run, not per-part)
+        if self.part_index.is_none() {
+            let clair3_germline_passed = self.config.clairs_germline_passed_vcf(&self.id);
+            if !Path::new(&clair3_germline_passed).exists() {
+                let clair3_germline_normal = self.config.clairs_germline_normal_vcf(&self.id);
+
+                let mut cmd = BcftoolsKeepPass::from_config(
+                    &self.config,
+                    clair3_germline_normal,
+                    clair3_germline_passed.clone(),
+                );
+                let report =
+                    <BcftoolsKeepPass as LocalRunner>::run(&mut cmd).with_context(|| {
+                        format!(
+                            "Failed to run `bcftools keep PASS` for {}.",
+                            clair3_germline_passed
+                        )
+                    })?;
+
+                let log_file = format!("{}/bcftools_germline_pass_", self.log_dir);
+                report
+                    .save_to_file(&log_file)
+                    .with_context(|| format!("Error while writing logs into {log_file}"))?;
+            } else {
+                debug!(
+                    "ClairS Germline PASSED VCF already exists for {}, skipping.",
+                    self.id
+                );
+            }
         }
 
-        // Somatic concat + PASS
-        let passed_vcf = self.config.clairs_passed_vcf(&self.id);
+        // Somatic concat + PASS (per-part or full)
+        let passed_vcf = self.somatic_passed_vcf_path();
         if !Path::new(&passed_vcf).exists() {
             let (output_vcf, output_indels_vcf) = self.config.clairs_output_vcfs(&self.id);
+            let output_dir = self.part_output_dir();
+            let output_vcf = format!(
+                "{output_dir}/{}",
+                Path::new(&output_vcf)
+                    .file_name()
+                    .unwrap()
+                    .to_string_lossy()
+            );
+            let output_indels_vcf = format!(
+                "{output_dir}/{}",
+                Path::new(&output_indels_vcf)
+                    .file_name()
+                    .unwrap()
+                    .to_string_lossy()
+            );
 
             let tmp_file = temp_file_path(".vcf.gz")?.to_str().unwrap().to_string();
 
@@ -258,8 +293,8 @@ impl ClairS {
                 .with_context(|| format!("Failed to remove temporary file {tmp_file}"))?;
         } else {
             debug!(
-                "ClairS PASSED VCF already exists for {}, skipping.",
-                self.id
+                "ClairS PASSED VCF already exists for {}, part {:?}, skipping.",
+                self.id, self.part_index
             );
         }
 
@@ -267,34 +302,51 @@ impl ClairS {
     }
 
     fn postprocess_sbatch(&self) -> anyhow::Result<()> {
-        // Germline PASS via Slurm
-        let clair3_germline_passed = self.config.clairs_germline_passed_vcf(&self.id);
-        if !Path::new(&clair3_germline_passed).exists() {
-            let clair3_germline_normal = self.config.clairs_germline_normal_vcf(&self.id);
-
-            let mut cmd = BcftoolsKeepPass::from_config(
-                &self.config,
-                clair3_germline_normal,
-                clair3_germline_passed.clone(),
-            );
-            let report = SlurmRunner::run(&mut cmd)
-                .context("Failed to run `bcftools keep PASS` on Slurm")?;
-
-            let log_file = format!("{}/bcftools_germline_pass_", self.log_dir);
-            report
-                .save_to_file(&log_file)
-                .context("Error while writing logs")?;
-        } else {
-            debug!(
-                "ClairS Germline PASSED VCF already exists for {}, skipping.",
-                self.id
-            );
+        // Germline PASS only once
+        if self.part_index.is_none() {
+            let clair3_germline_passed = self.config.clairs_germline_passed_vcf(&self.id);
+            if !Path::new(&clair3_germline_passed).exists() {
+                let clair3_germline_normal = self.config.clairs_germline_normal_vcf(&self.id);
+
+                let mut cmd = BcftoolsKeepPass::from_config(
+                    &self.config,
+                    clair3_germline_normal,
+                    clair3_germline_passed.clone(),
+                );
+                let report = SlurmRunner::run(&mut cmd)
+                    .context("Failed to run `bcftools keep PASS` on Slurm")?;
+
+                let log_file = format!("{}/bcftools_germline_pass_", self.log_dir);
+                report
+                    .save_to_file(&log_file)
+                    .context("Error while writing logs")?;
+            } else {
+                debug!(
+                    "ClairS Germline PASSED VCF already exists for {}, skipping.",
+                    self.id
+                );
+            }
         }
 
-        // Somatic concat + PASS via Slurm
-        let passed_vcf = self.config.clairs_passed_vcf(&self.id);
+        // Somatic concat + PASS (per-part or full)
+        let passed_vcf = self.somatic_passed_vcf_path();
         if !Path::new(&passed_vcf).exists() {
             let (output_vcf, output_indels_vcf) = self.config.clairs_output_vcfs(&self.id);
+            let output_dir = self.part_output_dir();
+            let output_vcf = format!(
+                "{output_dir}/{}",
+                Path::new(&output_vcf)
+                    .file_name()
+                    .unwrap()
+                    .to_string_lossy()
+            );
+            let output_indels_vcf = format!(
+                "{output_dir}/{}",
+                Path::new(&output_indels_vcf)
+                    .file_name()
+                    .unwrap()
+                    .to_string_lossy()
+            );
 
             let tmp_file = temp_file_path(".vcf.gz")?.to_str().unwrap().to_string();
 
@@ -324,8 +376,8 @@ impl ClairS {
             fs::remove_file(&tmp_file).context("Failed to remove temporary merged VCF")?;
         } else {
             debug!(
-                "ClairS PASSED VCF already exists for {}, skipping.",
-                self.id
+                "ClairS PASSED VCF already exists for {}, part {:?}, skipping.",
+                self.id, self.part_index
             );
         }
 
@@ -365,6 +417,34 @@ impl ClairS {
         self.postprocess_sbatch()?;
         Ok(out)
     }
+
+    /// Per-part output directory.
+    ///
+    /// For chunked runs, this is `{clairs_output_dir(id)}/part{idx}`.
+    /// For full-genome runs, just `clairs_output_dir(id)`.
+    fn part_output_dir(&self) -> String {
+        let base_dir = self.config.clairs_output_dir(&self.id);
+        match self.part_index {
+            Some(idx) => format!("{base_dir}/part{idx}"),
+            None => base_dir,
+        }
+    }
+
+    /// Somatic PASS VCF path for this run.
+    ///
+    /// - When `part_index.is_some()`: per-part intermediate PASS VCF
+    ///   (inside the part dir), later merged.
+    /// - When `None`: canonical final path from `Config::clairs_passed_vcf`.
+    fn somatic_passed_vcf_path(&self) -> String {
+        match self.part_index {
+            Some(idx) => {
+                // Example: {clairs_output_dir(id)}/part{idx}/clairs.part{idx}.pass.vcf.gz
+                let outdir = self.part_output_dir();
+                format!("{outdir}/clairs.part{idx}.pass.vcf.gz")
+            }
+            None => self.config.clairs_passed_vcf(&self.id),
+        }
+    }
 }
 
 /* ---------------- Variant / Label / Version impls ------------------------ */
@@ -503,6 +583,92 @@ impl Version for ClairS {
     }
 }
 
+/// Merge N chunked ClairS PASS VCFs into the final clairs_passed_vcf().
+fn merge_clairs_parts(base: &ClairS, n_parts: usize) -> anyhow::Result<()> {
+    use std::path::PathBuf;
+
+    let mut part_pass_paths: Vec<PathBuf> = Vec::with_capacity(n_parts);
+
+    for i in 1..=n_parts {
+        let mut part = base.clone();
+        part.part_index = Some(i);
+        let part_pass = part.somatic_passed_vcf_path();
+
+        anyhow::ensure!(
+            Path::new(&part_pass).exists(),
+            "Missing ClairS part {i} PASS VCF: {part_pass}"
+        );
+
+        part_pass_paths.push(PathBuf::from(part_pass));
+    }
+
+    let final_passed_vcf = base.config.clairs_passed_vcf(&base.id);
+    let final_tmp = format!("{final_passed_vcf}.tmp");
+
+    if let Some(parent) = Path::new(&final_passed_vcf).parent() {
+        fs::create_dir_all(parent)?;
+    }
+
+    info!(
+        "Concatenating {} ClairS part VCFs into {}",
+        n_parts, final_passed_vcf
+    );
+
+    let mut concat = BcftoolsConcat::from_config(&base.config, part_pass_paths, &final_tmp);
+    SlurmRunner::run(&mut concat).context("Failed to run bcftools concat for ClairS parts")?;
+
+    fs::rename(&final_tmp, &final_passed_vcf).context("Failed to rename merged ClairS PASS VCF")?;
+
+    info!(
+        "Successfully merged {} ClairS parts into {}",
+        n_parts, final_passed_vcf
+    );
+
+    Ok(())
+}
+
+pub fn run_clairs_chunked_sbatch_with_merge(
+    id: &str,
+    config: &Config,
+    n_parts: usize,
+) -> anyhow::Result<Vec<CapturedOutput>> {
+    let base = ClairS::initialize(id, config)?;
+
+    // If final VCF already up-to-date, skip (uses full run ShouldRun logic)
+    if !base.should_run() {
+        debug!("ClairS PASS VCF already up-to-date for {id}, skipping.");
+        return Ok(Vec::new());
+    }
+
+    // Genome sizes from normal BAM header
+    let normal_bam = config.normal_bam(id);
+    let reader =
+        bam::Reader::from_path(&normal_bam).with_context(|| format!("Opening BAM {normal_bam}"))?;
+    let header = bam::Header::from_template(reader.header());
+    let genome_sizes = get_genome_sizes(&header)?;
+    let region_chunks = split_genome_into_n_regions(&genome_sizes, n_parts);
+    let n_parts = region_chunks.len();
+
+    // Build jobs
+    let mut jobs = Vec::with_capacity(n_parts);
+    for (i, regions) in region_chunks.into_iter().enumerate() {
+        let mut job = base.clone();
+        job.part_index = Some(i + 1);
+        job.regions = regions;
+        job.log_dir = format!("{}/part{}", base.log_dir, i + 1);
+        info!("Planned ClairS job:\n{job}");
+        jobs.push(job);
+    }
+
+    // Run all parts via Slurm
+    let outputs = run_many_sbatch(jobs)?;
+
+    // Merge somatic PASS VCFs into final clairs_passed_vcf()
+    merge_clairs_parts(&base, n_parts)?;
+
+    Ok(outputs)
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;

+ 2 - 3
src/callers/deep_variant.rs

@@ -247,11 +247,10 @@ impl JobCommand for DeepVariant {
             --regions='{regions}' \
             {haploid_flag} \
             --par_regions_bed={par_bed} \
-            --haploid_contigs='chrX,chrY' \
             --output_vcf={output_vcf} \
             --num_shards={threads} \
             --vcf_stats_report=true \
-            --postprocess_cpus={postprocess_cpus},
+            --postprocess_cpus={postprocess_cpus} \
             --logging_dir=/output/{log_dir} \
             --dry_run=false \
             --sample_name={sample_name}",
@@ -303,7 +302,7 @@ impl SbatchRunner for DeepVariant {
         };
         SlurmParams {
             job_name: Some(format!("deepvariant_{}_{}", self.id, self.time_point)),
-            cpus_per_task: Some(10),
+            cpus_per_task: Some(self.config.deepvariant_threads.into()),
             mem: Some("60G".into()),
             partition: Some("gpgpuq".into()),
             gres: Some(format!("gpu:{gpu}:1")),

+ 1 - 17
src/collection/bam_stats.rs

@@ -15,7 +15,7 @@ use rust_htslib::{
 use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
 use serde::{Deserialize, Serialize};
 
-use crate::config::Config;
+use crate::{config::Config, helpers::get_genome_sizes};
 
 /// Flags to skip: unmapped, secondary, QC fail, supplementary
 const SKIP_FLAGS: u16 = (BAM_FUNMAP | BAM_FSECONDARY | BAM_FQCFAIL | BAM_FSUPPLEMENTARY) as u16;
@@ -894,22 +894,6 @@ pub fn n50_from_hist(hist: &BTreeMap<u64, u64>, mapped_yield: u64) -> u64 {
     0
 }
 
-/// Extracts genome sizes from BAM header.
-fn get_genome_sizes(header: &rust_htslib::bam::Header) -> anyhow::Result<FxHashMap<String, u64>> {
-    let mut sizes = FxHashMap::default();
-
-    for (_, records) in header.to_hashmap() {
-        for record in records {
-            if let (Some(sn), Some(ln)) = (record.get("SN"), record.get("LN")) {
-                if let Ok(len) = ln.parse::<u64>() {
-                    sizes.insert(sn.clone(), len);
-                }
-            }
-        }
-    }
-
-    Ok(sizes)
-}
 
 // =============================================================================
 // Display Implementations

+ 86 - 0
src/helpers.rs

@@ -2,6 +2,7 @@ use anyhow::Context;
 use bitcode::{Decode, Encode};
 use glob::glob;
 use log::{debug, error, warn};
+use rustc_hash::FxHashMap;
 use serde::{Deserialize, Serialize};
 use std::{
     cmp::Ordering,
@@ -773,3 +774,88 @@ impl Drop for TempDirGuard {
         }
     }
 }
+
+
+/// Extracts genome sizes from BAM header.
+pub fn get_genome_sizes(header: &rust_htslib::bam::Header) -> anyhow::Result<FxHashMap<String, u64>> {
+    let mut sizes = FxHashMap::default();
+
+    for (_, records) in header.to_hashmap() {
+        for record in records {
+            if let (Some(sn), Some(ln)) = (record.get("SN"), record.get("LN")) {
+                if let Ok(len) = ln.parse::<u64>() {
+                    sizes.insert(sn.clone(), len);
+                }
+            }
+        }
+    }
+
+    Ok(sizes)
+}
+
+/// Split genome into ~n_parts equal-sized chunks (by number of bases),
+/// returning for each chunk a list of regions in `ctg:start-end` form.
+pub fn split_genome_into_n_regions(
+    genome_sizes: &FxHashMap<String, u64>,
+    n_parts: usize,
+) -> Vec<Vec<String>> {
+    if n_parts == 0 || genome_sizes.is_empty() {
+        return Vec::new();
+    }
+
+    // Deterministic contig order
+    let mut contigs: Vec<(String, u64)> = genome_sizes
+        .iter()
+        .map(|(ctg, len)| (ctg.clone(), *len))
+        .collect();
+    contigs.sort_by(|a, b| a.0.cmp(&b.0));
+
+    let total_bases: u64 = contigs.iter().map(|(_, len)| *len).sum();
+    if total_bases == 0 {
+        return Vec::new();
+    }
+
+    let target_chunk_size: u64 = total_bases.div_ceil(n_parts as u64); // ceil
+
+    let mut chunks: Vec<Vec<String>> = vec![Vec::new(); n_parts];
+    let mut current_part = 0usize;
+    let mut remaining_in_part = target_chunk_size;
+
+    for (ctg, len) in contigs {
+        let mut remaining_ctg = len;
+        let mut start: u64 = 1;
+
+        while remaining_ctg > 0 && current_part < n_parts {
+            let take = remaining_in_part.min(remaining_ctg);
+            let end = start + take - 1;
+
+            chunks[current_part].push(format!("{ctg}:{start}-{end}"));
+
+            remaining_ctg -= take;
+            start = end + 1;
+            remaining_in_part -= take;
+
+            if remaining_in_part == 0 {
+                current_part += 1;
+                if current_part >= n_parts {
+                    break;
+                }
+                remaining_in_part = target_chunk_size;
+            }
+        }
+
+        // If we ran out of parts but contig still has bases, dump the rest into the last part
+        if remaining_ctg > 0 && current_part >= n_parts {
+            let end = start + remaining_ctg - 1;
+            chunks[n_parts - 1].push(format!("{ctg}:{start}-{end}"));
+            break;
+        }
+    }
+
+    // Remove empty chunks at the end (e.g. n_parts > total_bases)
+    while chunks.last().is_some_and(|c| c.is_empty()) {
+        chunks.pop();
+    }
+
+    chunks
+}