clementine_core/task/
manager.rs1use super::status_monitor::{TaskStatusMonitorTask, TASK_STATUS_MONITOR_POLL_DELAY};
2use super::{IntoTask, Task, TaskExt, TaskVariant};
3use crate::rpc::clementine::StoppedTasks;
4use crate::utils::timed_request;
5use clementine_errors::BridgeError;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{oneshot, RwLock};
10use tokio::task::{AbortHandle, JoinHandle};
11#[cfg(test)]
12use tokio::time::sleep;
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub enum TaskStatus {
16 Running,
17 NotRunning(String),
18}
19
20pub type TaskRegistry =
21 HashMap<TaskVariant, (TaskStatus, AbortHandle, Option<oneshot::Sender<()>>)>;
22
23const TASK_STATUS_FETCH_TIMEOUT: Duration = Duration::from_secs(60);
24
25#[derive(Debug)]
29pub struct BackgroundTaskManager {
30 task_registry: Arc<RwLock<TaskRegistry>>,
31}
32
33impl Default for BackgroundTaskManager {
34 fn default() -> Self {
35 Self {
36 task_registry: Arc::new(RwLock::new(HashMap::new())),
37 }
38 }
39}
40
41impl BackgroundTaskManager {
42 fn monitor_spawned_task(
45 &self,
46 handle: JoinHandle<Result<(), BridgeError>>,
47 task_variant: TaskVariant,
48 ) {
49 let task_registry = Arc::downgrade(&self.task_registry);
50
51 tokio::spawn(async move {
52 let exit_reason = match handle.await {
53 Ok(Ok(_)) => {
54 tracing::debug!("Task {task_variant:?} completed successfully");
56 "Completed successfully".to_owned()
57 }
58 Ok(Err(e)) => {
59 tracing::error!("Task {task_variant:?} failed with error: {e:?}");
61 format!("Failed due to error: {e:?}")
62 }
63 Err(e) => {
64 if e.is_cancelled() {
65 tracing::debug!("Task {task_variant:?} was cancelled");
67 "Cancelled".to_owned()
68 } else {
69 tracing::error!("Task {task_variant:?} panicked: {e:?}");
71 format!("Panicked due to {e:?}")
72 }
73 }
74 };
75
76 let Some(task_registry) = task_registry.upgrade() else {
77 tracing::debug!(
78 "Task manager has been dropped, task {:?} no longer monitored",
79 task_variant
80 );
81 return;
82 };
83
84 let mut task_registry = task_registry.write().await;
85
86 if !task_registry.contains_key(&task_variant) {
87 tracing::error!(
88 "Invariant violated: Monitored task {:?} not registered in the task registry",
89 task_variant
90 );
91 return;
92 }
93
94 task_registry
95 .entry(task_variant)
96 .and_modify(|(status, _, _)| {
97 *status = TaskStatus::NotRunning(exit_reason);
98 });
99 });
100 }
101
102 async fn is_task_running(&self, variant: TaskVariant) -> bool {
104 self.task_registry
105 .read()
106 .await
107 .get(&variant)
108 .map(|(status, _, _)| status == &TaskStatus::Running)
109 .unwrap_or(false)
110 }
111
112 pub async fn get_stopped_tasks(&self) -> Result<StoppedTasks, BridgeError> {
115 timed_request(TASK_STATUS_FETCH_TIMEOUT, "get_stopped_tasks", async {
116 let mut stopped_tasks = vec![];
117 let task_registry = self.task_registry.read().await;
118 for (variant, (status, _, _)) in task_registry.iter() {
119 match status {
120 TaskStatus::Running => {}
121 TaskStatus::NotRunning(reason) => {
122 stopped_tasks.push(format!("{variant:?}: {reason}"));
123 }
124 }
125 }
126 Ok(StoppedTasks { stopped_tasks })
127 })
128 .await
129 }
130
131 #[cfg(test)]
133 pub async fn get_task_status(&self, variant: TaskVariant) -> Option<TaskStatus> {
134 self.task_registry
135 .read()
136 .await
137 .get(&variant)
138 .map(|(status, _, _)| status.clone())
139 }
140
141 pub async fn ensure_task_looping<S, U: IntoTask<Task = S>>(&self, task: U)
145 where
146 S: Task + Sized + std::fmt::Debug,
147 <S as Task>::Output: Into<bool>,
148 {
149 self.ensure_monitor_running().await;
150
151 let variant = S::VARIANT;
152
153 if self.is_task_running(variant).await {
155 tracing::debug!("Task {:?} is already running, skipping", variant);
156 return;
157 }
158
159 let task = task.into_task();
160 let (task, cancel_tx) = task.cancelable_loop();
161
162 let join_handle = task.into_bg();
163 let abort_handle = join_handle.abort_handle();
164
165 self.task_registry.write().await.insert(
166 variant,
167 (TaskStatus::Running, abort_handle, Some(cancel_tx)),
168 );
169
170 self.monitor_spawned_task(join_handle, variant);
171 }
172
173 async fn ensure_monitor_running(&self) {
174 if !self.is_task_running(TaskVariant::TaskStatusMonitor).await {
175 let task = TaskStatusMonitorTask::new(self.task_registry.clone())
176 .with_delay(TASK_STATUS_MONITOR_POLL_DELAY);
177
178 let variant = TaskVariant::TaskStatusMonitor;
179 let (task, cancel_tx) = task.cancelable_loop();
180 let bg_task = task.into_bg();
181 let abort_handle = bg_task.abort_handle();
182
183 self.task_registry.write().await.insert(
184 variant,
185 (TaskStatus::Running, abort_handle, Some(cancel_tx)),
186 );
187
188 self.monitor_spawned_task(bg_task, variant);
189 }
190 }
191
192 #[cfg(test)]
194 async fn send_cancel_signals(&self) {
195 let mut task_registry = self.task_registry.write().await;
196 for (_, (_, _, cancel_tx)) in task_registry.iter_mut() {
197 let oneshot_tx = cancel_tx.take();
198 if let Some(oneshot_tx) = oneshot_tx {
199 let _ = oneshot_tx.send(());
201 }
202 }
203 }
204
205 pub fn abort_all(&mut self) {
207 tracing::info!("Aborting all tasks");
208
209 if let Ok(task_registry) = self.task_registry.try_read() {
211 for (_, (_, abort_handle, _)) in task_registry.iter() {
212 abort_handle.abort();
213 }
214 }
215 }
216
217 #[cfg(test)]
224 pub async fn graceful_shutdown(&mut self) {
225 tracing::info!("Gracefully shutting down all tasks");
226
227 self.send_cancel_signals().await;
228
229 loop {
230 let mut all_finished = true;
231 let task_registry = self.task_registry.read().await;
232
233 for (_, (_, abort_handle, _)) in task_registry.iter() {
234 if !abort_handle.is_finished() {
235 all_finished = false;
236 break;
237 }
238 }
239
240 if all_finished {
241 break;
242 }
243
244 sleep(Duration::from_millis(100)).await;
245 }
246 }
247
248 #[cfg(test)]
258 pub async fn graceful_shutdown_with_timeout(&mut self, timeout: Duration) {
259 let timeout_handle = tokio::time::timeout(timeout, self.graceful_shutdown());
260
261 if timeout_handle.await.is_err() {
262 self.abort_all();
263 }
264 }
265}
266
267impl Drop for BackgroundTaskManager {
268 fn drop(&mut self) {
269 tracing::info!("Dropping BackgroundTaskManager, aborting all tasks");
270
271 self.abort_all();
272 }
273}