Browse Source

task manager

Thomas 1 year ago
parent
commit
3fcc33b98e
5 changed files with 547 additions and 3 deletions
  1. 0 1
      Cargo.toml
  2. 53 2
      src/lib.rs
  3. 158 0
      src/rr.rs
  4. 70 0
      src/runn.rs
  5. 266 0
      src/runners.rs

+ 0 - 1
Cargo.toml

@@ -12,4 +12,3 @@ indicatif = "0.17.8"
 tokio = { version = "1", features = ["full"] }
 futures = "0.3.30"
 regex = "1.10.4"
-

+ 53 - 2
src/lib.rs

@@ -1,12 +1,19 @@
 pub mod progs;
+// pub mod runners;
+pub mod runn;
 pub mod utils;
 
 #[cfg(test)]
 mod tests {
-    use env_logger::Env;
+    use std::time;
+
     use crate::utils::Run;
+    use anyhow::anyhow;
+    use env_logger::Env;
+    use log::info;
+    use tokio::sync::mpsc::Sender;
 
-    use self::progs::cramino::Cramino;
+    use self::{progs::cramino::Cramino, runn::TaskManager};
 
     use super::*;
     fn init() {
@@ -27,4 +34,48 @@ mod tests {
         assert!(cramino.results.unwrap().is_woman()?);
         Ok(())
     }
+
+    #[tokio::test]
+    async fn run_detached() {
+        init();
+        struct Cr {
+            value: i32,
+        }
+
+        let task_manager: TaskManager<Cramino> = TaskManager::new();
+
+        let task_id = task_manager
+            .spawn(|sender| async {
+                let mut cramino = Cramino::default()
+                    .with_threads(150)
+                    .with_result_path(
+                        "/data/longreads_basic_pipe/CAMARA/diag/CAMARA_diag_hs1_cramino.txt",
+                    )
+                    .with_bam("/data/longreads_basic_pipe/CAMARA/diag/CAMARA_diag_hs1.bam")?;
+                cramino.run()?;
+
+                if let Err(_) = sender.send(cramino) {
+                    return Err(anyhow!("the receiver dropped"));
+                }
+
+                Ok(())
+            })
+            .await;
+
+        loop {
+            let task_finished = task_manager.is_finished(task_id).await.unwrap();
+
+            if task_finished {
+                break;
+            }
+
+            info!("Waiting task to finish.");
+            tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
+        }
+
+        if let Some(r) = task_manager.try_recv(task_id).await {
+            println!("{r:?}");
+        }
+
+    }
 }

+ 158 - 0
src/rr.rs

@@ -0,0 +1,158 @@
+use crate::utils::Run;
+use futures::stream::{AbortHandle, Abortable};
+use tokio::sync::mpsc::Sender;
+use std::collections::HashMap;
+use std::future::Future;
+use std::sync::Arc;
+use tokio::sync::{mpsc, Mutex};
+use tokio::task::JoinHandle;
+pub struct BasicProgress {
+    pub length: usize,
+    pub step: usize,
+}
+
+pub trait Inc {
+    fn inc(&mut self, delta: usize);
+}
+
+impl Inc for BasicProgress {
+    fn inc(&mut self, delta: usize) {
+        if self.step < self.length {
+            self.step += delta;
+        }
+    }
+}
+
+pub struct IndProgress {
+    inner: indicatif::ProgressBar,
+}
+
+impl Inc for IndProgress {
+    fn inc(&mut self, delta: usize) {
+        self.inner.inc(delta as u64)
+    }
+}
+
+impl IndProgress {
+    pub fn new(len: u64) -> Self {
+        Self {
+            inner: indicatif::ProgressBar::new(len),
+        }
+    }
+}
+
+struct Actor<R: Run> {
+    task_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
+    abort_handle: Option<AbortHandle>,
+    result_sender: Sender<anyhow::Result<R>>,
+    progress: Arc<Mutex<Box<dyn Inc + Send>>>,
+}
+
+impl<R: 'static + Run> Actor<R> {
+    pub fn new<P: 'static + Inc + Send>(progress: P) -> Self {
+        let (result_sender, _) = mpsc::channel(1); // Only one message can be in flight
+        Actor {
+            task_handle: Arc::new(Mutex::new(None)),
+            abort_handle: None,
+            result_sender,
+            progress: Arc::new(Mutex::new(Box::new(progress))),
+        }
+    }
+
+    async fn spawn<F, Fut>(&mut self, task: F)
+    where
+        F: FnOnce(Arc<Mutex<Box<dyn Inc + Send>>>) -> (),
+        Fut: Future<Output = anyhow::Result<R>>,
+    {
+        let (abort_handle, abort_registration) = AbortHandle::new_pair();
+        let result_sender = self.result_sender.clone();
+        let progress = self.progress.clone();
+
+        let task = Abortable::new(
+            tokio::spawn(async move {
+                task(progress);
+            }),
+            abort_registration,
+        );
+
+        let task_handle = self.task_handle.clone();
+        let handle = tokio::spawn(async move {
+            match task.await {
+                Ok(_) => println!("Task completed successfully"),
+                Err(_) => println!("Task aborted"),
+            }
+            *task_handle.lock().await = None;
+        });
+
+        *self.task_handle.lock().await = Some(handle);
+        self.abort_handle = Some(abort_handle);
+    }
+
+    async fn has_finished(&self) -> bool {
+        self.task_handle.lock().await.is_none()
+    }
+
+    fn abort(&mut self) {
+        if let Some(abort_handle) = self.abort_handle.take() {
+            abort_handle.abort();
+        }
+        *self.task_handle.lock().unwrap() = None;
+    }
+
+    fn get_result_receiver(&self) -> mpsc::Receiver<Result<R, String>> {
+        self.result_sender.clone().into_receiver()
+    }
+}
+
+// Define the Runners struct to manage multiple actors
+struct Runners<R: Run> {
+    actors: HashMap<String, Actor<R>>,
+}
+
+impl<R: 'static + Run> Runners<R> {
+    fn new() -> Self {
+        Runners {
+            actors: HashMap::new(),
+        }
+    }
+
+    fn add_actor(&mut self, id: String, actor: Actor<R>) {
+        self.actors.insert(id, actor);
+    }
+
+    fn spawn_task<F, Fut>(&mut self, id: &str, task: F)
+    where
+        F: FnOnce() -> Fut + Send + 'static,
+        Fut: Future<Output = Result<R, String>> + 'static,
+    {
+        if let Some(actor) = self.actors.get_mut(id) {
+            actor.spawn(task);
+        }
+    }
+
+    async fn has_finished(&self, id: &str) -> bool {
+        if let Some(actor) = self.actors.get(id) {
+            return actor.has_finished().await;
+        }
+        false
+    }
+
+    fn abort_task(&mut self, id: &str) {
+        if let Some(actor) = self.actors.get_mut(id) {
+            actor.abort();
+        }
+    }
+
+    fn get_result_receiver(&self, id: &str) -> Option<mpsc::Receiver<Result<R, String>>> {
+        self.actors.get(id).map(|actor| actor.get_result_receiver())
+    }
+
+    async fn receive_result(&mut self, id: &str) -> Option<Result<R, String>> {
+        if let Some(actor) = self.actors.get_mut(id) {
+            let receiver = actor.get_result_receiver();
+            receiver.recv().await
+        } else {
+            None
+        }
+    }
+}

+ 70 - 0
src/runn.rs

@@ -0,0 +1,70 @@
+use std::{collections::HashMap, sync::Arc};
+use tokio::{sync::{oneshot::{Receiver, Sender}, Mutex}, task};
+use std::future::Future;
+
+// type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
+
+pub struct TaskManager<R> {
+    tasks: Arc<Mutex<HashMap<usize, task::JoinHandle<anyhow::Result<()>>>>>,
+    results_channels: Arc<Mutex<HashMap<usize, Receiver<R>>>>,
+    next_id: Arc<Mutex<usize>>,
+}
+
+impl<R> TaskManager<R> {
+    pub fn new() -> Self {
+        TaskManager {
+            tasks: Arc::new(Mutex::new(HashMap::new())),
+            results_channels: Arc::new(Mutex::new(HashMap::new())),
+            next_id: Arc::new(Mutex::new(0)),
+        }
+    }
+
+    pub async fn spawn<F, T>(&self, f: F) -> usize
+    where
+        F: FnOnce(Sender<R>) -> T + Send + 'static,
+        // T: Future<Output = BoxFuture<Result<(), ()>>> + Send + 'static,
+        T: Future<Output = anyhow::Result<()>> + Send + 'static,
+    {
+        let mut next_id = self.next_id.lock().await;
+        let task_id = *next_id;
+        *next_id += 1;
+
+        let (s, r) = tokio::sync::oneshot::channel::<R>();
+        self.results_channels.lock().await.insert(task_id, r);
+
+        let handle = task::spawn(Box::pin(f(s)));
+
+        self.tasks.lock().await.insert(task_id, handle);
+
+        task_id
+    }
+
+    pub async fn try_recv(&self, task_id: usize) -> Option<R> {
+        if let Some(r) = self.results_channels.lock().await.get_mut(&task_id) {
+            match r.try_recv() {
+                Ok(res) => Some(res),
+                Err(err) => match err {
+                    tokio::sync::oneshot::error::TryRecvError::Empty => None,
+                    tokio::sync::oneshot::error::TryRecvError::Closed => None,
+                },
+            }
+        } else {
+            None
+        }
+    }
+
+    pub async fn is_finished(&self, task_id: usize) -> Option<bool> {
+        let tasks = self.tasks.lock().await;
+        tasks.get(&task_id).map(|handle| handle.is_finished())
+    }
+
+    pub async fn get_result(&self, task_id: usize) -> Option<Result<anyhow::Result<()>, tokio::task::JoinError>> {
+        let mut tasks = self.tasks.lock().await;
+        if let Some(handle) = tasks.remove(&task_id) {
+            Some(handle.await)
+        } else {
+            None
+        }
+    }
+}
+

+ 266 - 0
src/runners.rs

@@ -0,0 +1,266 @@
+use anyhow::{Result, anyhow };
+use futures::{
+    stream::{AbortHandle, Abortable},
+    Future,
+};
+use std::{
+    collections::HashMap,
+    sync::{Arc, Mutex},
+};
+use tokio::{sync::mpsc::{self, Sender}, task::JoinHandle};
+
+use crate::utils::Run;
+
+pub struct Runner<R: Run> {
+    task_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
+    abort_handle: Option<AbortHandle>,
+    progress: Arc<Mutex<Box<dyn Inc + Send>>>,
+    result_sender: mpsc::Sender<anyhow::Result<R>>,
+    // result: Arc<Mutex<Box<Option<anyhow::Result<R>>>>>,
+}
+
+pub struct BasicProgress {
+    pub length: usize,
+    pub step: usize,
+}
+
+pub trait Inc {
+    fn inc(&mut self, delta: usize);
+}
+
+impl Inc for BasicProgress {
+    fn inc(&mut self, delta: usize) {
+        if self.step < self.length {
+            self.step += delta;
+        }
+    }
+}
+
+pub struct IndProgress {
+    inner: indicatif::ProgressBar,
+}
+
+impl Inc for IndProgress {
+    fn inc(&mut self, delta: usize) {
+        self.inner.inc(delta as u64)
+    }
+}
+
+impl IndProgress {
+    pub fn new(len: u64) -> Self {
+        Self {
+            inner: indicatif::ProgressBar::new(len),
+        }
+    }
+}
+
+impl<R: 'static + Run> Runner<R> {
+    pub fn new<P: 'static + Inc + Send>(progress: P) -> Self {
+        let (result_sender, _) = mpsc::channel(1);
+        Runner {
+            task_handle: Arc::new(Mutex::new(None)),
+            abort_handle: None,
+            progress: Arc::new(Mutex::new(Box::new(progress))),
+            result_sender,
+            // result: Arc::new(Mutex::new(Box::new(None))),
+        }
+    }
+
+    pub fn spawn<F, Fut>(&mut self, task: F)
+    where
+        F: FnOnce(Arc<Mutex<Box<dyn Inc + Send>>>, Sender<Result<R>>) ,
+        Fut: Future<Output = ()> + Send + 'static,
+    {
+        let (abort_handle, abort_registration) = AbortHandle::new_pair();
+        let progress = self.progress.clone();
+        let result_sender = self.result_sender.clone();
+
+        let task = Abortable::new(
+            tokio::task::spawn_blocking( move || {
+                let result_sender = self.result_sender.clone();
+                task(progress, result_sender);
+            }),
+            abort_registration,
+        );
+
+        let task_handle = self.task_handle.clone();
+        let handle = tokio::spawn(async move {
+            match task.await {
+                Ok(_) => println!("Task completed successfully"),
+                Err(_) => println!("Task aborted"),
+            }
+            *task_handle.lock().unwrap() = None;
+        });
+
+        *self.task_handle.lock().unwrap() = Some(handle);
+        self.abort_handle = Some(abort_handle);
+    }
+
+    pub async fn has_finished(&self) -> bool {
+        let mut guard = self.task_handle.lock().unwrap();
+        if let Some(handle) = guard.take() {
+            if handle.await.is_ok() {
+                return true;
+            }
+        }
+        false
+    }
+
+    pub fn abort(&mut self) {
+        if let Some(abort_handle) = self.abort_handle.take() {
+            abort_handle.abort();
+        }
+        *self.task_handle.lock().unwrap() = None;
+    }
+
+    pub fn get_progress(&self) -> Arc<Mutex<Box<dyn Inc + Send>>> {
+        self.progress.clone()
+    }
+
+    fn get_result(&self) -> anyhow::Result<R> {
+        if let Some(r) = Arc::into_inner(self.result.clone()) {
+            if let Ok(r) = r.into_inner() {
+                if let Some(r) = r {
+                    return r;
+                }
+            }
+        }
+        Err(anyhow!("Error while getting the results."))
+    }
+}
+
+pub struct Runners<R: Run> {
+    runners: HashMap<String, Runner<R>>,
+}
+
+impl<R: 'static + Run + Send> Runners<R> {
+    pub fn new() -> Self {
+        Runners {
+            runners: HashMap::new(),
+        }
+    }
+
+    pub fn add_runner(&mut self, id: &str, runner: Runner<R>) {
+        self.runners.insert(id.to_string(), runner);
+    }
+
+    pub fn spawn_task<F, Fut>(&mut self, id: &str, task: F)
+    where
+        F: FnOnce(Arc<Mutex<Box<dyn Inc + Send>>>) -> Fut + Send + 'static,
+        Fut: Future<Output = anyhow::Result<R>> + Send + 'static,
+    {
+        if let Some(actor) = self.runners.get_mut(id) {
+            actor.spawn(task);
+        }
+    }
+
+    pub async fn has_finished(&self, id: &str) -> bool {
+        if let Some(runner) = self.runners.get(id) {
+            return runner.has_finished().await;
+        }
+        false
+    }
+
+    pub fn abort_task(&mut self, id: &str) {
+        if let Some(runner) = self.runners.get_mut(id) {
+            runner.abort();
+        }
+    }
+
+    pub fn get_progress(&self, id: &str) -> Option<Arc<Mutex<Box<dyn Inc + Send>>>> {
+        self.runners.get(id).map(|runner| runner.get_progress())
+    }
+
+    pub async fn remove_finished_tasks(&mut self) {
+        let mut finished_ids = Vec::new();
+
+        for (id, runner) in &self.runners {
+            if runner.has_finished().await {
+                finished_ids.push(id.clone());
+            }
+        }
+
+        for id in finished_ids {
+            self.runners.remove(&id);
+        }
+    }
+
+    pub fn get_result(&self, id: &str) -> anyhow::Result<R> {
+        if let Some(r) = self.runners.get(id).map(|actor| actor.get_result()) {
+            return r;
+        } else {
+            return Err(anyhow!("Error while getting results."));
+        }
+    }
+}
+
+// #[tokio::main]
+// async fn main() {
+//     let mut runners = Runners::new();
+//
+//     let progress1 = Box::new(Progress { length: 10, step: 0 });
+//     let actor1 = Runner::new(progress1);
+//     runners.add_actor("actor1".to_string(), actor1);
+//
+//     let progress2 = Box::new(Counter { count: 0, max: 10 });
+//     let actor2 = Runner::new(progress2);
+//     runners.add_actor("actor2".to_string(), actor2);
+//
+//     runners.spawn_task("actor1", |progress| {
+//         println!("Task for actor1 started");
+//         for _ in 0..10 {
+//             thread::sleep(Duration::from_secs(1));
+//             let mut progress = progress.lock().unwrap();
+//             progress.inc();
+//             if let Some(p) = progress.as_any().downcast_ref::<Progress>() {
+//                 println!("actor1 Progress: {}/{}", p.step, p.length);
+//             }
+//         }
+//         println!("Task for actor1 finished");
+//     });
+//
+//     runners.spawn_task("actor2", |progress| {
+//         println!("Task for actor2 started");
+//         for _ in 0..10 {
+//             thread::sleep(Duration::from_secs(1));
+//             let mut progress = progress.lock().unwrap();
+//             progress.inc();
+//             if let Some(p) = progress.as_any().downcast_ref::<Counter>() {
+//                 println!("actor2 Counter: {}/{}", p.count, p.max);
+//             }
+//         }
+//         println!("Task for actor2 finished");
+//     });
+//
+//     tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
+//
+//     if let Some(progress) = runners.get_progress("actor1") {
+//         let progress = progress.lock().unwrap();
+//         if let Some(p) = progress.as_any().downcast_ref::<Progress>() {
+//             println!("Current progress for actor1: {}/{}", p.step, p.length);
+//         }
+//     }
+//
+//     if let Some(progress) = runners.get_progress("actor2") {
+//         let progress = progress.lock().unwrap();
+//         if let Some(p) = progress.as_any().downcast_ref::<Counter>()) {
+//             println!("Current progress for actor2: {}/{}", p.count, p.max);
+//         }
+//     }
+//
+//     if runners.has_finished("actor1").await {
+//         println!("Task for actor1 has already finished");
+//     } else {
+//         println!("Task for actor1 is still running");
+//         runners.abort_task("actor1");
+//         println!("Task for actor1 aborted");
+//     }
+//
+//     if runners.has_finished("actor2").await {
+//         println!("Task for actor2 has already finished");
+//     } else {
+//         println!("Task for actor2 is still running");
+//         runners.abort_task("actor2");
+//         println!("Task for actor2 aborted");
+//     }
+// }