clementine_core/states/
task.rs

1use crate::{
2    bitcoin_syncer::{BlockHandler, FinalizedBlockFetcherTask},
3    database::{Database, DatabaseTransaction},
4    task::{BufferedErrors, IntoTask, RecoverableTask, TaskVariant, WithDelay},
5};
6use eyre::{Context as _, OptionExt};
7use pgmq::{Message, PGMQueueExt};
8use std::{sync::Arc, time::Duration};
9use tokio::sync::Mutex;
10use tonic::async_trait;
11
12use crate::{
13    config::protocol::ProtocolParamset,
14    errors::BridgeError,
15    states::SystemEvent,
16    task::{Task, TaskExt},
17};
18
19use super::{context::Owner, StateManager};
20
21const POLL_DELAY: Duration = if cfg!(test) {
22    Duration::from_millis(250)
23} else {
24    Duration::from_secs(30)
25};
26
27/// Block handler that sends events to a PostgreSQL message queue
28#[derive(Debug, Clone)]
29pub struct QueueBlockHandler {
30    queue: PGMQueueExt,
31    queue_name: String,
32}
33
34#[async_trait]
35impl BlockHandler for QueueBlockHandler {
36    /// Handles a new block by sending a new block event to the queue.
37    /// State manager will process the block after reading the event from the queue.
38    async fn handle_new_block(
39        &mut self,
40        dbtx: DatabaseTransaction<'_, '_>,
41        block_id: u32,
42        block: bitcoin::Block,
43        height: u32,
44    ) -> Result<(), BridgeError> {
45        let event = SystemEvent::NewFinalizedBlock {
46            block_id,
47            block,
48            height,
49        };
50
51        self.queue
52            .send_with_cxn(&self.queue_name, &event, &mut **dbtx)
53            .await
54            .wrap_err("Error sending new block event to queue")?;
55        Ok(())
56    }
57}
58
59/// A task that fetches new finalized blocks from Bitcoin and adds them to the state management queue
60#[derive(Debug)]
61pub struct BlockFetcherTask<T: Owner + std::fmt::Debug + 'static> {
62    /// Owner type marker
63    _phantom: std::marker::PhantomData<T>,
64}
65
66impl<T: Owner + std::fmt::Debug + 'static> BlockFetcherTask<T> {
67    /// Creates a new finalized block fetcher task that sends new finalized blocks to the message queue.
68    pub async fn new_finalized_block_fetcher_task(
69        db: Database,
70        paramset: &'static ProtocolParamset,
71    ) -> Result<FinalizedBlockFetcherTask<QueueBlockHandler>, BridgeError> {
72        let queue = PGMQueueExt::new_with_pool(db.get_pool()).await;
73        let queue_name = StateManager::<T>::queue_name();
74
75        let handler = QueueBlockHandler {
76            queue,
77            queue_name: queue_name.clone(),
78        };
79
80        // get the next finalized block height to start from
81        let next_height = db
82            .get_next_finalized_block_height_for_consumer(
83                None,
84                T::FINALIZED_BLOCK_CONSUMER_ID_AUTOMATION,
85                paramset,
86            )
87            .await?;
88
89        tracing::info!(
90            "Creating block fetcher task for owner type {} starting from height {}",
91            T::ENTITY_NAME,
92            next_height
93        );
94
95        Ok(crate::bitcoin_syncer::FinalizedBlockFetcherTask::new(
96            db,
97            T::FINALIZED_BLOCK_CONSUMER_ID_AUTOMATION.to_string(),
98            paramset,
99            next_height,
100            handler,
101        ))
102    }
103}
104
105/// A task that reads new events from the message queue and processes them.
106#[derive(Debug)]
107pub struct MessageConsumerTask<T: Owner + std::fmt::Debug + 'static> {
108    db: Database,
109    inner: StateManager<T>,
110    /// Queue name for this owner type (cached)
111    queue_name: String,
112}
113
114#[async_trait]
115impl<T: Owner + std::fmt::Debug + 'static> Task for MessageConsumerTask<T> {
116    type Output = bool;
117    const VARIANT: TaskVariant = TaskVariant::StateManager;
118
119    async fn run_once(&mut self) -> Result<Self::Output, BridgeError> {
120        let new_event_received = async {
121            let mut dbtx = self.db.begin_transaction().await?;
122
123            // Poll new event
124            let Some(Message {
125                msg_id, message, ..
126            }): Option<Message<SystemEvent>> = self
127                .inner
128                .queue
129                // 2nd param of read_with_cxn is the visibility timeout, set to 0 as we only have 1 consumer of the queue, which is the state machine
130                // visibility timeout is the time after which the message is visible again to other consumers
131                .read_with_cxn(&self.queue_name, 0, &mut *dbtx)
132                .await
133                .wrap_err("Reading event from queue")?
134            else {
135                dbtx.commit().await?;
136                return Ok::<_, BridgeError>(false);
137            };
138
139            let arc_dbtx = Arc::new(Mutex::new(dbtx));
140
141            self.inner.handle_event(message, arc_dbtx.clone()).await?;
142
143            let mut dbtx = Arc::into_inner(arc_dbtx)
144                .ok_or_eyre("Expected single reference to DB tx when committing")?
145                .into_inner();
146
147            // Delete event from queue
148            self.inner
149                .queue
150                .archive_with_cxn(&self.queue_name, msg_id, &mut *dbtx)
151                .await
152                .wrap_err("Deleting event from queue")?;
153
154            dbtx.commit().await?;
155            Ok(true)
156        }
157        .await?;
158
159        Ok(new_event_received)
160    }
161}
162
163#[async_trait]
164impl<T: Owner + std::fmt::Debug + 'static> RecoverableTask for MessageConsumerTask<T> {
165    async fn recover_from_error(&mut self, _error: &BridgeError) -> Result<(), BridgeError> {
166        // in case of any error, reload the state machines from the database
167        self.inner.reload_state_manager_from_db().await
168    }
169}
170
171impl<T: Owner + std::fmt::Debug + 'static> IntoTask for StateManager<T> {
172    type Task = WithDelay<BufferedErrors<MessageConsumerTask<T>>>;
173
174    /// Converts the StateManager into the consumer task with a polling delay.
175    fn into_task(self) -> Self::Task {
176        MessageConsumerTask {
177            db: self.db.clone(),
178            inner: self,
179            queue_name: StateManager::<T>::queue_name(),
180        }
181        .into_buffered_errors(10, 3, Duration::from_secs(10))
182        .with_delay(POLL_DELAY)
183    }
184}
185
186impl<T: Owner + std::fmt::Debug + 'static> StateManager<T> {
187    pub async fn block_fetcher_task(
188        &self,
189    ) -> Result<WithDelay<impl Task<Output = bool> + std::fmt::Debug>, BridgeError> {
190        Ok(BlockFetcherTask::<T>::new_finalized_block_fetcher_task(
191            self.db.clone(),
192            self.config.protocol_paramset,
193        )
194        .await?
195        .into_buffered_errors(20, 3, Duration::from_secs(10))
196        .with_delay(POLL_DELAY))
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use std::{collections::BTreeMap, sync::Arc};
203
204    use tokio::{sync::oneshot, task::JoinHandle, time::timeout};
205    use tonic::async_trait;
206
207    use crate::{
208        builder::transaction::{ContractContext, TransactionType, TxHandler},
209        config::BridgeConfig,
210        database::DatabaseTransaction,
211        extended_bitcoin_rpc::ExtendedBitcoinRpc,
212        states::{block_cache, context::DutyResult, Duty},
213        test::common::{create_regtest_rpc, create_test_config_with_thread_name},
214        utils::NamedEntity,
215    };
216
217    use super::*;
218
219    #[derive(Clone, Debug)]
220    struct MockHandler;
221
222    impl NamedEntity for MockHandler {
223        const ENTITY_NAME: &'static str = "MockHandler";
224        const TX_SENDER_CONSUMER_ID: &'static str = "mock_tx_sender";
225        const FINALIZED_BLOCK_CONSUMER_ID_NO_AUTOMATION: &'static str =
226            "mock_finalized_block_no_automation";
227        const FINALIZED_BLOCK_CONSUMER_ID_AUTOMATION: &'static str =
228            "mock_finalized_block_automation";
229    }
230
231    #[async_trait]
232    impl Owner for MockHandler {
233        async fn handle_duty(
234            &self,
235            _dbtx: DatabaseTransaction<'_, '_>,
236            _: Duty,
237        ) -> Result<DutyResult, BridgeError> {
238            Ok(DutyResult::Handled)
239        }
240
241        async fn create_txhandlers(
242            &self,
243            _dbtx: DatabaseTransaction<'_, '_>,
244            _: TransactionType,
245            _: ContractContext,
246        ) -> Result<BTreeMap<TransactionType, TxHandler>, BridgeError> {
247            Ok(BTreeMap::new())
248        }
249
250        async fn handle_finalized_block(
251            &self,
252            _dbtx: DatabaseTransaction<'_, '_>,
253            _block_id: u32,
254            _block_height: u32,
255            _block_cache: Arc<block_cache::BlockCache>,
256            _light_client_proof_wait_interval_secs: Option<u32>,
257        ) -> Result<(), BridgeError> {
258            Ok(())
259        }
260    }
261
262    async fn create_state_manager(
263        config: &mut BridgeConfig,
264    ) -> (JoinHandle<Result<(), BridgeError>>, oneshot::Sender<()>) {
265        let db = Database::new(config).await.unwrap();
266
267        let rpc = ExtendedBitcoinRpc::connect(
268            config.bitcoin_rpc_url.clone(),
269            config.bitcoin_rpc_user.clone(),
270            config.bitcoin_rpc_password.clone(),
271            None,
272        )
273        .await
274        .expect("Failed to connect to Bitcoin RPC");
275
276        let state_manager = StateManager::new(db, MockHandler, rpc, config.clone())
277            .await
278            .unwrap();
279        let (t, shutdown) = state_manager.into_task().cancelable_loop();
280        (t.into_bg(), shutdown)
281    }
282
283    #[tokio::test]
284    async fn test_run_state_manager() {
285        let mut config = create_test_config_with_thread_name().await;
286        let cleanup = create_regtest_rpc(&mut config).await;
287        cleanup
288            .rpc()
289            .mine_blocks(config.protocol_paramset.start_height as u64)
290            .await
291            .unwrap();
292        let (handle, shutdown) = create_state_manager(&mut config).await;
293
294        drop(shutdown);
295
296        timeout(Duration::from_secs(1), handle)
297            .await
298            .expect("state manager should exit after shutdown signal (timed out after 1s)")
299            .expect("state manager should shutdown gracefully (thread panic should not happen)")
300            .expect("state manager should shutdown gracefully");
301    }
302
303    #[tokio::test]
304    async fn test_state_mgr_does_not_shutdown() {
305        let mut config = create_test_config_with_thread_name().await;
306        let cleanup = create_regtest_rpc(&mut config).await;
307        cleanup
308            .rpc()
309            .mine_blocks(config.protocol_paramset.start_height as u64)
310            .await
311            .unwrap();
312        let (handle, shutdown) = create_state_manager(&mut config).await;
313
314        timeout(Duration::from_secs(1), handle).await.expect_err(
315            "state manager should not shutdown while shutdown handle is alive (timed out after 1s)",
316        );
317
318        drop(shutdown);
319    }
320}