Browse Source

task manager

Thomas 1 year ago
parent
commit
3e9f04f7e5
2 changed files with 53 additions and 54 deletions
  1. 5 9
      src/lib.rs
  2. 48 45
      src/runn.rs

+ 5 - 9
src/lib.rs

@@ -51,21 +51,18 @@ mod tests {
                 )
             .with_bam("/data/longreads_basic_pipe/CAMARA/diag/CAMARA_diag_hs1.bam").unwrap();
 
-        let mut prog = BasicProgress { length: 1, step: 0 };
-
-        let task_id = task_manager
-            .spawn_progress(|sender, mut progress: BasicProgress| async move {
+        let _ = task_manager
+            .spawn(|sender| async move {
                 cramino.run()?;
-                progress.inc(1);
                 if let Err(_) = sender.send(cramino) {
                     return Err(anyhow!("the receiver dropped"));
                 }
                 Ok(())
-            }, prog.clone())
+            })
             .await;
 
         loop {
-            let task_finished = task_manager.is_finished(task_id).await.unwrap();
+            let task_finished = task_manager.is_finished().await.unwrap();
 
             if task_finished {
                 break;
@@ -75,9 +72,8 @@ mod tests {
             tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
         }
 
-        if let Some(r) = task_manager.try_recv(task_id).await {
+        if let Some(r) = task_manager.try_recv().await {
             println!("{r:?}");
-            println!("{prog:?}");
         }
 
     }

+ 48 - 45
src/runn.rs

@@ -44,64 +44,63 @@ impl IndProgress {
 
 #[derive(Clone, Debug)]
 pub struct TaskManager<R> {
-    tasks: Arc<Mutex<HashMap<usize, task::JoinHandle<anyhow::Result<()>>>>>,
-    results_channels: Arc<Mutex<HashMap<usize, Receiver<R>>>>,
-    next_id: Arc<Mutex<usize>>,
+    task: Arc<Mutex<Option<task::JoinHandle<anyhow::Result<()>>>>>,
+    results_channel: Arc<Mutex<Option<Receiver<R>>>>,
 }
 
 impl<R> TaskManager<R> {
     pub fn new() -> Self {
         TaskManager {
-            tasks: Arc::new(Mutex::new(HashMap::new())),
-            results_channels: Arc::new(Mutex::new(HashMap::new())),
-            next_id: Arc::new(Mutex::new(0)),
+            task: Arc::new(Mutex::new(None)),
+            results_channel: Arc::new(Mutex::new(None)),
         }
     }
 
-    pub async fn spawn<F, T>(&self, f: F) -> usize
+    pub async fn spawn<F, T>(&self, f: F) 
     where
         F: FnOnce(Sender<R>) -> T + Send + 'static,
         // T: Future<Output = BoxFuture<Result<(), ()>>> + Send + 'static,
         T: Future<Output = anyhow::Result<()>> + Send + 'static,
     {
-        let mut next_id = self.next_id.lock().await;
-        let task_id = *next_id;
-        *next_id += 1;
+        // let mut next_id = self.next_id.lock().await;
+        // let task_id = *next_id;
+        // *next_id += 1;
 
         let (s, r) = tokio::sync::oneshot::channel::<R>();
-        self.results_channels.lock().await.insert(task_id, r);
+        *self.results_channel.lock().await = Some(r);
 
         let handle = task::spawn(Box::pin(f(s)));
 
-        self.tasks.lock().await.insert(task_id, handle);
+        *self.task.lock().await = Some(handle);
 
-        task_id
     }
 
 
-    pub async fn spawn_progress<F, T, I>(&self, f: F, progress: I) -> usize
-    where
-        F: FnOnce(Sender<R>, I) -> T + Send + 'static,
-        I: Inc,
-        // T: Future<Output = BoxFuture<Result<(), ()>>> + Send + 'static,
-        T: Future<Output = anyhow::Result<()>> + Send + 'static,
-    {
-        let mut next_id = self.next_id.lock().await;
-        let task_id = *next_id;
-        *next_id += 1;
-
-        let (s, r) = tokio::sync::oneshot::channel::<R>();
-        self.results_channels.lock().await.insert(task_id, r);
-
-        let handle = task::spawn(Box::pin(f(s, progress)));
-
-        self.tasks.lock().await.insert(task_id, handle);
-
-        task_id
-    }
-
-    pub async fn try_recv(&self, task_id: usize) -> Option<R> {
-        if let Some(r) = self.results_channels.lock().await.get_mut(&task_id) {
+    // pub async fn spawn_progress<F, T, I>(&self, f: F, progress: I) -> usize
+    // where
+    //     F: FnOnce(Sender<R>, I) -> T + Send + 'static,
+    //     I: Inc,
+    //     // T: Future<Output = BoxFuture<Result<(), ()>>> + Send + 'static,
+    //     T: Future<Output = anyhow::Result<()>> + Send + 'static,
+    // {
+    //     let mut next_id = self.next_id.lock().await;
+    //     let task_id = *next_id;
+    //     *next_id += 1;
+    //
+    //     let (s, r) = tokio::sync::oneshot::channel::<R>();
+    //     self.results_channels.lock().await.insert(task_id, r);
+    //
+    //     let handle = task::spawn(Box::pin(f(s, progress)));
+    //
+    //     self.tasks.lock().await.insert(task_id, handle);
+    //
+    //     task_id
+    // }
+
+    pub async fn try_recv(&self) -> Option<R> {
+        let mut chan = self.results_channel.lock().await;
+        
+        if let Some(r) = chan.as_mut() {
             match r.try_recv() {
                 Ok(res) => Some(res),
                 Err(err) => match err {
@@ -114,18 +113,22 @@ impl<R> TaskManager<R> {
         }
     }
 
-    pub async fn is_finished(&self, task_id: usize) -> Option<bool> {
-        let tasks = self.tasks.lock().await;
-        tasks.get(&task_id).map(|handle| handle.is_finished())
-    }
-
-    pub async fn get_result(&self, task_id: usize) -> Option<Result<anyhow::Result<()>, tokio::task::JoinError>> {
-        let mut tasks = self.tasks.lock().await;
-        if let Some(handle) = tasks.remove(&task_id) {
-            Some(handle.await)
+    pub async fn is_finished(&self) -> Option<bool> {
+        let task = self.task.lock().await;
+        if let Some(t) = task.as_ref() {
+            Some(t.is_finished())
         } else {
             None
         }
     }
+
+    // pub async fn get_result(&self, task_id: usize) -> Option<Result<anyhow::Result<()>, tokio::task::JoinError>> {
+    //     let mut tasks = self.tasks.lock().await;
+    //     if let Some(handle) = tasks.remove(&task_id) {
+    //         Some(handle.await)
+    //     } else {
+    //         None
+    //     }
+    // }
 }