clementine_core/task/
manager.rs

1use super::status_monitor::{TaskStatusMonitorTask, TASK_STATUS_MONITOR_POLL_DELAY};
2use super::{IntoTask, Task, TaskExt, TaskVariant};
3use crate::errors::BridgeError;
4use crate::rpc::clementine::StoppedTasks;
5use crate::utils::timed_request;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{oneshot, RwLock};
10use tokio::task::{AbortHandle, JoinHandle};
11use tokio::time::sleep;
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub enum TaskStatus {
15    Running,
16    NotRunning(String),
17}
18
19pub type TaskRegistry =
20    HashMap<TaskVariant, (TaskStatus, AbortHandle, Option<oneshot::Sender<()>>)>;
21
22const TASK_STATUS_FETCH_TIMEOUT: Duration = Duration::from_secs(60);
23
24/// A background task manager that can hold and manage multiple tasks. When
25/// dropped, it will abort all tasks. Graceful shutdown can be performed with
26/// `graceful_shutdown`
27#[derive(Debug)]
28pub struct BackgroundTaskManager {
29    task_registry: Arc<RwLock<TaskRegistry>>,
30}
31
32impl Default for BackgroundTaskManager {
33    fn default() -> Self {
34        Self {
35            task_registry: Arc::new(RwLock::new(HashMap::new())),
36        }
37    }
38}
39
40impl BackgroundTaskManager {
41    /// Monitors the spawned task. If any task stops running, logs the reason
42    /// why and updates the task registry to register the task as not running.
43    fn monitor_spawned_task(
44        &self,
45        handle: JoinHandle<Result<(), BridgeError>>,
46        task_variant: TaskVariant,
47    ) {
48        let task_registry = Arc::downgrade(&self.task_registry);
49
50        tokio::spawn(async move {
51            let exit_reason = match handle.await {
52                Ok(Ok(_)) => {
53                    // Task completed successfully
54                    tracing::debug!("Task {task_variant:?} completed successfully");
55                    "Completed successfully".to_owned()
56                }
57                Ok(Err(e)) => {
58                    // Task returned an error
59                    tracing::error!("Task {task_variant:?} failed with error: {e:?}");
60                    format!("Failed due to error: {e:?}")
61                }
62                Err(e) => {
63                    if e.is_cancelled() {
64                        // Task was cancelled, which is expected during cleanup
65                        tracing::debug!("Task {task_variant:?} was cancelled");
66                        "Cancelled".to_owned()
67                    } else {
68                        // Task panicked or was aborted
69                        tracing::error!("Task {task_variant:?} panicked: {e:?}");
70                        format!("Panicked due to {e:?}")
71                    }
72                }
73            };
74
75            let Some(task_registry) = task_registry.upgrade() else {
76                tracing::debug!(
77                    "Task manager has been dropped, task {:?} no longer monitored",
78                    task_variant
79                );
80                return;
81            };
82
83            let mut task_registry = task_registry.write().await;
84
85            if !task_registry.contains_key(&task_variant) {
86                tracing::error!(
87                    "Invariant violated: Monitored task {:?} not registered in the task registry",
88                    task_variant
89                );
90                return;
91            }
92
93            task_registry
94                .entry(task_variant)
95                .and_modify(|(status, _, _)| {
96                    *status = TaskStatus::NotRunning(exit_reason);
97                });
98        });
99    }
100
101    /// Checks if a task is running by checking the task registry
102    async fn is_task_running(&self, variant: TaskVariant) -> bool {
103        self.task_registry
104            .read()
105            .await
106            .get(&variant)
107            .map(|(status, _, _)| status == &TaskStatus::Running)
108            .unwrap_or(false)
109    }
110
111    /// Gets all tasks that are not running
112    /// Returns an error if the task status fetch takes too long
113    pub async fn get_stopped_tasks(&self) -> Result<StoppedTasks, BridgeError> {
114        timed_request(TASK_STATUS_FETCH_TIMEOUT, "get_stopped_tasks", async {
115            let mut stopped_tasks = vec![];
116            let task_registry = self.task_registry.read().await;
117            for (variant, (status, _, _)) in task_registry.iter() {
118                match status {
119                    TaskStatus::Running => {}
120                    TaskStatus::NotRunning(reason) => {
121                        stopped_tasks.push(format!("{variant:?}: {reason}"));
122                    }
123                }
124            }
125            Ok(StoppedTasks { stopped_tasks })
126        })
127        .await
128    }
129
130    /// Gets the status of a single task by checking the task registry
131    pub async fn get_task_status(&self, variant: TaskVariant) -> Option<TaskStatus> {
132        self.task_registry
133            .read()
134            .await
135            .get(&variant)
136            .map(|(status, _, _)| status.clone())
137    }
138
139    /// Wraps the task in a cancelable loop and spawns it, registers it in the
140    /// task registry. If a task with the same TaskVariant is already running,
141    /// it will not be started.
142    pub async fn ensure_task_looping<S, U: IntoTask<Task = S>>(&self, task: U)
143    where
144        S: Task + Sized + std::fmt::Debug,
145        <S as Task>::Output: Into<bool>,
146    {
147        self.ensure_monitor_running().await;
148
149        let variant = S::VARIANT;
150
151        // do not start the same task if it is already running
152        if self.is_task_running(variant).await {
153            tracing::debug!("Task {:?} is already running, skipping", variant);
154            return;
155        }
156
157        let task = task.into_task();
158        let (task, cancel_tx) = task.cancelable_loop();
159
160        let join_handle = task.into_bg();
161        let abort_handle = join_handle.abort_handle();
162
163        self.task_registry.write().await.insert(
164            variant,
165            (TaskStatus::Running, abort_handle, Some(cancel_tx)),
166        );
167
168        self.monitor_spawned_task(join_handle, variant);
169    }
170
171    async fn ensure_monitor_running(&self) {
172        if !self.is_task_running(TaskVariant::TaskStatusMonitor).await {
173            let task = TaskStatusMonitorTask::new(self.task_registry.clone())
174                .with_delay(TASK_STATUS_MONITOR_POLL_DELAY);
175
176            let variant = TaskVariant::TaskStatusMonitor;
177            let (task, cancel_tx) = task.cancelable_loop();
178            let bg_task = task.into_bg();
179            let abort_handle = bg_task.abort_handle();
180
181            self.task_registry.write().await.insert(
182                variant,
183                (TaskStatus::Running, abort_handle, Some(cancel_tx)),
184            );
185
186            self.monitor_spawned_task(bg_task, variant);
187        }
188    }
189
190    /// Sends cancel signals to all tasks that have a cancel_tx
191    async fn send_cancel_signals(&self) {
192        let mut task_registry = self.task_registry.write().await;
193        for (_, (_, _, cancel_tx)) in task_registry.iter_mut() {
194            let oneshot_tx = cancel_tx.take();
195            if let Some(oneshot_tx) = oneshot_tx {
196                // send can fail, but if it fails the task is dropped.
197                let _ = oneshot_tx.send(());
198            }
199        }
200    }
201
202    /// Abort all tasks by dropping their cancellation senders
203    pub fn abort_all(&mut self) {
204        tracing::info!("Aborting all tasks");
205
206        // only one thread must have &mut self, so lock should be able to be acquired
207        if let Ok(task_registry) = self.task_registry.try_read() {
208            for (_, (_, abort_handle, _)) in task_registry.iter() {
209                abort_handle.abort();
210            }
211        }
212    }
213
214    /// Graceful shutdown of all tasks
215    ///
216    /// This function does not have any timeout, please use
217    /// `graceful_shutdown_with_timeout` instead for cases where you need a
218    /// timeout. The function polls tasks until they are finished with a 100ms
219    /// poll interval.
220    pub async fn graceful_shutdown(&mut self) {
221        tracing::info!("Gracefully shutting down all tasks");
222
223        self.send_cancel_signals().await;
224
225        loop {
226            let mut all_finished = true;
227            let task_registry = self.task_registry.read().await;
228
229            for (_, (_, abort_handle, _)) in task_registry.iter() {
230                if !abort_handle.is_finished() {
231                    all_finished = false;
232                    break;
233                }
234            }
235
236            if all_finished {
237                break;
238            }
239
240            sleep(Duration::from_millis(100)).await;
241        }
242    }
243
244    /// Graceful shutdown of all tasks with a timeout. All tasks will be aborted
245    /// if the timeout is reached.
246    ///
247    /// # Arguments
248    ///
249    /// * `timeout` - The timeout duration for the graceful shutdown. Since the
250    ///   `graceful_shutdown` function polls tasks until they are finished with a
251    ///   100ms poll interval, the timeout should be at least 100ms for the
252    ///   timeout to be effective.
253    pub async fn graceful_shutdown_with_timeout(&mut self, timeout: Duration) {
254        let timeout_handle = tokio::time::timeout(timeout, self.graceful_shutdown());
255
256        if timeout_handle.await.is_err() {
257            self.abort_all();
258        }
259    }
260}
261
262impl Drop for BackgroundTaskManager {
263    fn drop(&mut self) {
264        tracing::info!("Dropping BackgroundTaskManager, aborting all tasks");
265
266        self.abort_all();
267    }
268}