clementine_core/task/
manager.rs

1use super::status_monitor::{TaskStatusMonitorTask, TASK_STATUS_MONITOR_POLL_DELAY};
2use super::{IntoTask, Task, TaskExt, TaskVariant};
3use crate::rpc::clementine::StoppedTasks;
4use crate::utils::timed_request;
5use clementine_errors::BridgeError;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{oneshot, RwLock};
10use tokio::task::{AbortHandle, JoinHandle};
11#[cfg(test)]
12use tokio::time::sleep;
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub enum TaskStatus {
16    Running,
17    NotRunning(String),
18}
19
20pub type TaskRegistry =
21    HashMap<TaskVariant, (TaskStatus, AbortHandle, Option<oneshot::Sender<()>>)>;
22
23const TASK_STATUS_FETCH_TIMEOUT: Duration = Duration::from_secs(60);
24
25/// A background task manager that can hold and manage multiple tasks. When
26/// dropped, it will abort all tasks. Graceful shutdown can be performed with
27/// `graceful_shutdown`
28#[derive(Debug)]
29pub struct BackgroundTaskManager {
30    task_registry: Arc<RwLock<TaskRegistry>>,
31}
32
33impl Default for BackgroundTaskManager {
34    fn default() -> Self {
35        Self {
36            task_registry: Arc::new(RwLock::new(HashMap::new())),
37        }
38    }
39}
40
41impl BackgroundTaskManager {
42    /// Monitors the spawned task. If any task stops running, logs the reason
43    /// why and updates the task registry to register the task as not running.
44    fn monitor_spawned_task(
45        &self,
46        handle: JoinHandle<Result<(), BridgeError>>,
47        task_variant: TaskVariant,
48    ) {
49        let task_registry = Arc::downgrade(&self.task_registry);
50
51        tokio::spawn(async move {
52            let exit_reason = match handle.await {
53                Ok(Ok(_)) => {
54                    // Task completed successfully
55                    tracing::debug!("Task {task_variant:?} completed successfully");
56                    "Completed successfully".to_owned()
57                }
58                Ok(Err(e)) => {
59                    // Task returned an error
60                    tracing::error!("Task {task_variant:?} failed with error: {e:?}");
61                    format!("Failed due to error: {e:?}")
62                }
63                Err(e) => {
64                    if e.is_cancelled() {
65                        // Task was cancelled, which is expected during cleanup
66                        tracing::debug!("Task {task_variant:?} was cancelled");
67                        "Cancelled".to_owned()
68                    } else {
69                        // Task panicked or was aborted
70                        tracing::error!("Task {task_variant:?} panicked: {e:?}");
71                        format!("Panicked due to {e:?}")
72                    }
73                }
74            };
75
76            let Some(task_registry) = task_registry.upgrade() else {
77                tracing::debug!(
78                    "Task manager has been dropped, task {:?} no longer monitored",
79                    task_variant
80                );
81                return;
82            };
83
84            let mut task_registry = task_registry.write().await;
85
86            if !task_registry.contains_key(&task_variant) {
87                tracing::error!(
88                    "Invariant violated: Monitored task {:?} not registered in the task registry",
89                    task_variant
90                );
91                return;
92            }
93
94            task_registry
95                .entry(task_variant)
96                .and_modify(|(status, _, _)| {
97                    *status = TaskStatus::NotRunning(exit_reason);
98                });
99        });
100    }
101
102    /// Checks if a task is running by checking the task registry
103    async fn is_task_running(&self, variant: TaskVariant) -> bool {
104        self.task_registry
105            .read()
106            .await
107            .get(&variant)
108            .map(|(status, _, _)| status == &TaskStatus::Running)
109            .unwrap_or(false)
110    }
111
112    /// Gets all tasks that are not running
113    /// Returns an error if the task status fetch takes too long
114    pub async fn get_stopped_tasks(&self) -> Result<StoppedTasks, BridgeError> {
115        timed_request(TASK_STATUS_FETCH_TIMEOUT, "get_stopped_tasks", async {
116            let mut stopped_tasks = vec![];
117            let task_registry = self.task_registry.read().await;
118            for (variant, (status, _, _)) in task_registry.iter() {
119                match status {
120                    TaskStatus::Running => {}
121                    TaskStatus::NotRunning(reason) => {
122                        stopped_tasks.push(format!("{variant:?}: {reason}"));
123                    }
124                }
125            }
126            Ok(StoppedTasks { stopped_tasks })
127        })
128        .await
129    }
130
131    /// Gets the status of a single task by checking the task registry
132    #[cfg(test)]
133    pub async fn get_task_status(&self, variant: TaskVariant) -> Option<TaskStatus> {
134        self.task_registry
135            .read()
136            .await
137            .get(&variant)
138            .map(|(status, _, _)| status.clone())
139    }
140
141    /// Wraps the task in a cancelable loop and spawns it, registers it in the
142    /// task registry. If a task with the same TaskVariant is already running,
143    /// it will not be started.
144    pub async fn ensure_task_looping<S, U: IntoTask<Task = S>>(&self, task: U)
145    where
146        S: Task + Sized + std::fmt::Debug,
147        <S as Task>::Output: Into<bool>,
148    {
149        self.ensure_monitor_running().await;
150
151        let variant = S::VARIANT;
152
153        // do not start the same task if it is already running
154        if self.is_task_running(variant).await {
155            tracing::debug!("Task {:?} is already running, skipping", variant);
156            return;
157        }
158
159        let task = task.into_task();
160        let (task, cancel_tx) = task.cancelable_loop();
161
162        let join_handle = task.into_bg();
163        let abort_handle = join_handle.abort_handle();
164
165        self.task_registry.write().await.insert(
166            variant,
167            (TaskStatus::Running, abort_handle, Some(cancel_tx)),
168        );
169
170        self.monitor_spawned_task(join_handle, variant);
171    }
172
173    async fn ensure_monitor_running(&self) {
174        if !self.is_task_running(TaskVariant::TaskStatusMonitor).await {
175            let task = TaskStatusMonitorTask::new(self.task_registry.clone())
176                .with_delay(TASK_STATUS_MONITOR_POLL_DELAY);
177
178            let variant = TaskVariant::TaskStatusMonitor;
179            let (task, cancel_tx) = task.cancelable_loop();
180            let bg_task = task.into_bg();
181            let abort_handle = bg_task.abort_handle();
182
183            self.task_registry.write().await.insert(
184                variant,
185                (TaskStatus::Running, abort_handle, Some(cancel_tx)),
186            );
187
188            self.monitor_spawned_task(bg_task, variant);
189        }
190    }
191
192    /// Sends cancel signals to all tasks that have a cancel_tx
193    #[cfg(test)]
194    async fn send_cancel_signals(&self) {
195        let mut task_registry = self.task_registry.write().await;
196        for (_, (_, _, cancel_tx)) in task_registry.iter_mut() {
197            let oneshot_tx = cancel_tx.take();
198            if let Some(oneshot_tx) = oneshot_tx {
199                // send can fail, but if it fails the task is dropped.
200                let _ = oneshot_tx.send(());
201            }
202        }
203    }
204
205    /// Abort all tasks by dropping their cancellation senders
206    pub fn abort_all(&mut self) {
207        tracing::info!("Aborting all tasks");
208
209        // only one thread must have &mut self, so lock should be able to be acquired
210        if let Ok(task_registry) = self.task_registry.try_read() {
211            for (_, (_, abort_handle, _)) in task_registry.iter() {
212                abort_handle.abort();
213            }
214        }
215    }
216
217    /// Graceful shutdown of all tasks
218    ///
219    /// This function does not have any timeout, please use
220    /// `graceful_shutdown_with_timeout` instead for cases where you need a
221    /// timeout. The function polls tasks until they are finished with a 100ms
222    /// poll interval.
223    #[cfg(test)]
224    pub async fn graceful_shutdown(&mut self) {
225        tracing::info!("Gracefully shutting down all tasks");
226
227        self.send_cancel_signals().await;
228
229        loop {
230            let mut all_finished = true;
231            let task_registry = self.task_registry.read().await;
232
233            for (_, (_, abort_handle, _)) in task_registry.iter() {
234                if !abort_handle.is_finished() {
235                    all_finished = false;
236                    break;
237                }
238            }
239
240            if all_finished {
241                break;
242            }
243
244            sleep(Duration::from_millis(100)).await;
245        }
246    }
247
248    /// Graceful shutdown of all tasks with a timeout. All tasks will be aborted
249    /// if the timeout is reached.
250    ///
251    /// # Arguments
252    ///
253    /// * `timeout` - The timeout duration for the graceful shutdown. Since the
254    ///   `graceful_shutdown` function polls tasks until they are finished with a
255    ///   100ms poll interval, the timeout should be at least 100ms for the
256    ///   timeout to be effective.
257    #[cfg(test)]
258    pub async fn graceful_shutdown_with_timeout(&mut self, timeout: Duration) {
259        let timeout_handle = tokio::time::timeout(timeout, self.graceful_shutdown());
260
261        if timeout_handle.await.is_err() {
262            self.abort_all();
263        }
264    }
265}
266
267impl Drop for BackgroundTaskManager {
268    fn drop(&mut self) {
269        tracing::info!("Dropping BackgroundTaskManager, aborting all tasks");
270
271        self.abort_all();
272    }
273}