clementine_core/states/
task.rs

1use crate::{
2    bitcoin_syncer::{BlockHandler, FinalizedBlockFetcherTask},
3    database::{Database, DatabaseTransaction},
4    task::{BufferedErrors, IntoTask, RecoverableTask, TaskVariant, WithDelay},
5};
6use eyre::{Context as _, OptionExt};
7use pgmq::{Message, PGMQueueExt};
8use std::{sync::Arc, time::Duration};
9use tokio::sync::Mutex;
10use tonic::async_trait;
11
12use crate::{
13    config::protocol::ProtocolParamset,
14    states::SystemEvent,
15    task::{Task, TaskExt},
16};
17use clementine_errors::BridgeError;
18
19use super::{context::Owner, StateManager};
20
21const POLL_DELAY: Duration = if cfg!(test) {
22    Duration::from_millis(250)
23} else {
24    Duration::from_secs(30)
25};
26
27/// Block handler that sends events to a PostgreSQL message queue
28#[derive(Debug, Clone)]
29pub struct QueueBlockHandler {
30    queue: PGMQueueExt,
31    queue_name: String,
32}
33
34#[async_trait]
35impl BlockHandler for QueueBlockHandler {
36    /// Handles a new block by sending a new block event to the queue.
37    /// State manager will process the block after reading the event from the queue.
38    async fn handle_new_block(
39        &mut self,
40        dbtx: DatabaseTransaction<'_>,
41        block_id: u32,
42        block: bitcoin::Block,
43        height: u32,
44    ) -> Result<(), BridgeError> {
45        let event = SystemEvent::NewFinalizedBlock {
46            block_id,
47            block,
48            height,
49        };
50
51        self.queue
52            .send_with_cxn(&self.queue_name, &event, &mut **dbtx)
53            .await
54            .wrap_err("Error sending new block event to queue")?;
55        Ok(())
56    }
57}
58
59/// A task that fetches new finalized blocks from Bitcoin and adds them to the state management queue
60#[derive(Debug)]
61pub struct BlockFetcherTask<T: Owner + std::fmt::Debug + 'static> {
62    /// Owner type marker
63    _phantom: std::marker::PhantomData<T>,
64}
65
66impl<T: Owner + std::fmt::Debug + 'static> BlockFetcherTask<T> {
67    /// Creates a new finalized block fetcher task that sends new finalized blocks to the message queue.
68    pub async fn new_finalized_block_fetcher_task(
69        db: Database,
70        paramset: &'static ProtocolParamset,
71    ) -> Result<FinalizedBlockFetcherTask<QueueBlockHandler>, BridgeError> {
72        let queue = PGMQueueExt::new_with_pool(db.get_pool()).await;
73        let queue_name = StateManager::<T>::queue_name();
74
75        let handler = QueueBlockHandler {
76            queue,
77            queue_name: queue_name.clone(),
78        };
79
80        // get the next finalized block height to start from
81        let next_height = db
82            .get_next_finalized_block_height_for_consumer(
83                None,
84                T::FINALIZED_BLOCK_CONSUMER_ID_AUTOMATION,
85                paramset,
86            )
87            .await?;
88
89        tracing::info!(
90            "Creating block fetcher task for owner type {} starting from height {}",
91            T::ENTITY_NAME,
92            next_height
93        );
94
95        Ok(crate::bitcoin_syncer::FinalizedBlockFetcherTask::new(
96            db,
97            T::FINALIZED_BLOCK_CONSUMER_ID_AUTOMATION.to_string(),
98            paramset,
99            next_height,
100            handler,
101        ))
102    }
103}
104
105/// A task that reads new events from the message queue and processes them.
106#[derive(Debug)]
107pub struct MessageConsumerTask<T: Owner + std::fmt::Debug + 'static> {
108    db: Database,
109    inner: StateManager<T>,
110    /// Queue name for this owner type (cached)
111    queue_name: String,
112}
113
114#[async_trait]
115impl<T: Owner + std::fmt::Debug + 'static> Task for MessageConsumerTask<T> {
116    type Output = bool;
117    const VARIANT: TaskVariant = TaskVariant::StateManager;
118
119    async fn run_once(&mut self) -> Result<Self::Output, BridgeError> {
120        let new_event_received = async {
121            let mut dbtx = self.db.begin_transaction().await?;
122
123            // Poll new event
124            let Some(Message {
125                msg_id, message, ..
126            }): Option<Message<SystemEvent>> = self
127                .inner
128                .queue
129                // 2nd param of read_with_cxn is the visibility timeout, set to 0 as we only have 1 consumer of the queue, which is the state machine
130                // visibility timeout is the time after which the message is visible again to other consumers
131                .read_with_cxn(&self.queue_name, 0, &mut *dbtx)
132                .await
133                .wrap_err("Reading event from queue")?
134            else {
135                dbtx.commit().await?;
136                return Ok::<_, BridgeError>(false);
137            };
138
139            let arc_dbtx = Arc::new(Mutex::new(dbtx));
140
141            self.inner.handle_event(message, arc_dbtx.clone()).await?;
142
143            let mut dbtx = Arc::into_inner(arc_dbtx)
144                .ok_or_eyre("Expected single reference to DB tx when committing")?
145                .into_inner();
146
147            // Delete event from queue
148            self.inner
149                .queue
150                .archive_with_cxn(&self.queue_name, msg_id, &mut *dbtx)
151                .await
152                .wrap_err("Deleting event from queue")?;
153
154            dbtx.commit().await?;
155            Ok(true)
156        }
157        .await?;
158
159        Ok(new_event_received)
160    }
161}
162
163#[async_trait]
164impl<T: Owner + std::fmt::Debug + 'static> RecoverableTask for MessageConsumerTask<T> {
165    async fn recover_from_error(&mut self, _error: &BridgeError) -> Result<(), BridgeError> {
166        // in case of any error, reload the state machines from the database
167        self.inner.reload_state_manager_from_db().await
168    }
169}
170
171impl<T: Owner + std::fmt::Debug + 'static> IntoTask for StateManager<T> {
172    type Task = WithDelay<BufferedErrors<MessageConsumerTask<T>>>;
173
174    /// Converts the StateManager into the consumer task with a polling delay.
175    fn into_task(self) -> Self::Task {
176        MessageConsumerTask {
177            db: self.db.clone(),
178            inner: self,
179            queue_name: StateManager::<T>::queue_name(),
180        }
181        .into_buffered_errors(10, 3, Duration::from_secs(10))
182        .with_delay(POLL_DELAY)
183    }
184}
185
186impl<T: Owner + std::fmt::Debug + 'static> StateManager<T> {
187    pub async fn block_fetcher_task(
188        &self,
189    ) -> Result<WithDelay<impl Task<Output = bool> + std::fmt::Debug>, BridgeError> {
190        Ok(BlockFetcherTask::<T>::new_finalized_block_fetcher_task(
191            self.db.clone(),
192            self.config.protocol_paramset,
193        )
194        .await?
195        .into_buffered_errors(20, 3, Duration::from_secs(10))
196        .with_delay(POLL_DELAY))
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use std::{collections::BTreeMap, sync::Arc};
203
204    use tokio::{sync::oneshot, task::JoinHandle, time::timeout};
205    use tonic::async_trait;
206
207    use crate::{
208        builder::transaction::{ContractContext, TxHandler},
209        config::BridgeConfig,
210        database::DatabaseTransaction,
211        extended_bitcoin_rpc::ExtendedBitcoinRpc,
212        states::{block_cache, context::DutyResult, Duty},
213        test::common::{create_regtest_rpc, create_test_config_with_thread_name},
214        utils::NamedEntity,
215    };
216    use clementine_primitives::TransactionType;
217
218    use super::*;
219
220    #[derive(Clone, Debug)]
221    struct MockHandler;
222
223    impl NamedEntity for MockHandler {
224        const ENTITY_NAME: &'static str = "MockHandler";
225        const TX_SENDER_CONSUMER_ID: &'static str = "mock_tx_sender";
226        const FINALIZED_BLOCK_CONSUMER_ID_NO_AUTOMATION: &'static str =
227            "mock_finalized_block_no_automation";
228        const FINALIZED_BLOCK_CONSUMER_ID_AUTOMATION: &'static str =
229            "mock_finalized_block_automation";
230    }
231
232    #[async_trait]
233    impl Owner for MockHandler {
234        async fn handle_duty(
235            &self,
236            _dbtx: DatabaseTransaction<'_>,
237            _: Duty,
238        ) -> Result<DutyResult, BridgeError> {
239            Ok(DutyResult::Handled)
240        }
241
242        async fn create_txhandlers(
243            &self,
244            _dbtx: DatabaseTransaction<'_>,
245            _: TransactionType,
246            _: ContractContext,
247        ) -> Result<BTreeMap<TransactionType, TxHandler>, BridgeError> {
248            Ok(BTreeMap::new())
249        }
250
251        async fn handle_finalized_block(
252            &self,
253            _dbtx: DatabaseTransaction<'_>,
254            _block_id: u32,
255            _block_height: u32,
256            _block_cache: Arc<block_cache::BlockCache>,
257            _light_client_proof_wait_interval_secs: Option<u32>,
258        ) -> Result<(), BridgeError> {
259            Ok(())
260        }
261    }
262
263    async fn create_state_manager(
264        config: &mut BridgeConfig,
265    ) -> (JoinHandle<Result<(), BridgeError>>, oneshot::Sender<()>) {
266        let db = Database::new(config).await.unwrap();
267
268        let rpc = ExtendedBitcoinRpc::connect(
269            config.bitcoin_rpc_url.clone(),
270            config.bitcoin_rpc_user.clone(),
271            config.bitcoin_rpc_password.clone(),
272            None,
273        )
274        .await
275        .expect("Failed to connect to Bitcoin RPC");
276
277        let state_manager = StateManager::new(db, MockHandler, rpc, config.clone())
278            .await
279            .unwrap();
280        let (t, shutdown) = state_manager.into_task().cancelable_loop();
281        (t.into_bg(), shutdown)
282    }
283
284    #[tokio::test]
285    async fn test_run_state_manager() {
286        let mut config = create_test_config_with_thread_name().await;
287        let cleanup = create_regtest_rpc(&mut config).await;
288        cleanup
289            .rpc()
290            .mine_blocks(config.protocol_paramset.start_height as u64)
291            .await
292            .unwrap();
293        let (handle, shutdown) = create_state_manager(&mut config).await;
294
295        drop(shutdown);
296
297        timeout(Duration::from_secs(1), handle)
298            .await
299            .expect("state manager should exit after shutdown signal (timed out after 1s)")
300            .expect("state manager should shutdown gracefully (thread panic should not happen)")
301            .expect("state manager should shutdown gracefully");
302    }
303
304    #[tokio::test]
305    async fn test_state_mgr_does_not_shutdown() {
306        let mut config = create_test_config_with_thread_name().await;
307        let cleanup = create_regtest_rpc(&mut config).await;
308        cleanup
309            .rpc()
310            .mine_blocks(config.protocol_paramset.start_height as u64)
311            .await
312            .unwrap();
313        let (handle, shutdown) = create_state_manager(&mut config).await;
314
315        timeout(Duration::from_secs(1), handle).await.expect_err(
316            "state manager should not shutdown while shutdown handle is alive (timed out after 1s)",
317        );
318
319        drop(shutdown);
320    }
321}