diff --git a/crates/mad/src/lib.rs b/crates/mad/src/lib.rs index 76f9002..935e7df 100644 --- a/crates/mad/src/lib.rs +++ b/crates/mad/src/lib.rs @@ -77,8 +77,11 @@ use futures::stream::FuturesUnordered; use futures_util::StreamExt; -use std::{fmt::Display, sync::Arc, error::Error}; -use tokio::signal::unix::{SignalKind, signal}; +use std::{error::Error, fmt::Display, sync::Arc}; +use tokio::{ + signal::unix::{SignalKind, signal}, + task::JoinError, +}; use tokio_util::sync::CancellationToken; @@ -101,15 +104,13 @@ pub enum MadError { /// Error that occurred during the run phase of a component. #[error(transparent)] - RunError { - run: anyhow::Error - }, + RunError { run: anyhow::Error }, /// Error that occurred during the close phase of a component. #[error("component(s) failed during close")] - CloseError { + CloseError { #[source] - close: anyhow::Error + close: anyhow::Error, }, /// Multiple errors from different components. @@ -180,7 +181,7 @@ impl Display for AggregateError { writeln!(f, "{} component errors occurred:", self.errors.len())?; for (i, error) in self.errors.iter().enumerate() { write!(f, "\n[Component {}] {}", i + 1, error)?; - + // Print the error chain for each component error let mut source = error.source(); let mut level = 1; @@ -501,11 +502,32 @@ impl Mad { tracing::debug!(component = name, "mad running"); + let handle = tokio::spawn(async move { comp.run(job_cancellation).await }); + tokio::select! { _ = cancellation_token.cancelled() => { error_tx.send(CompletionResult { res: Ok(()) , name }).await } - res = comp.run(job_cancellation) => { + res = handle => { + let res = match res { + Ok(res) => res, + Err(join) => { + match join.source() { + Some(error) => { + Err(MadError::RunError{run: anyhow::anyhow!("component aborted: {:?}", error)}) + }, + None => { + if join.is_panic(){ + Err(MadError::RunError { run: anyhow::anyhow!("component panicked: {}", join) }) + } else { + Err(MadError::RunError { run: anyhow::anyhow!("component faced unknown error: {}", join) }) + } + }, + } + }, + }; + + error_tx.send(CompletionResult { res , name }).await } } @@ -796,13 +818,13 @@ mod tests { .context("failed to read configuration") .context("unable to initialize database") .context("service startup failed"); - + let mad_error = MadError::Inner(error); let display = format!("{}", mad_error); - + // Should display the top-level error message assert!(display.contains("service startup failed")); - + // Test error chain iteration if let MadError::Inner(ref e) = mad_error { let chain: Vec = e.chain().map(|c| c.to_string()).collect(); @@ -818,26 +840,26 @@ mod tests { fn test_aggregate_error_display() { let error1 = MadError::Inner( anyhow::anyhow!("database connection failed") - .context("failed to connect to PostgreSQL") + .context("failed to connect to PostgreSQL"), ); - + let error2 = MadError::Inner( anyhow::anyhow!("port already in use") .context("failed to bind to port 8080") - .context("web server initialization failed") + .context("web server initialization failed"), ); - + let aggregate = MadError::AggregateError(AggregateError { errors: vec![error1, error2], }); - + let display = format!("{}", aggregate); - + // Check that it shows multiple errors assert!(display.contains("2 component errors occurred")); assert!(display.contains("[Component 1]")); assert!(display.contains("[Component 2]")); - + // Check that context chains are displayed assert!(display.contains("failed to connect to PostgreSQL")); assert!(display.contains("database connection failed")); @@ -852,7 +874,7 @@ mod tests { let aggregate = AggregateError { errors: vec![error], }; - + let display = format!("{}", aggregate); // Single error should be displayed directly assert!(display.contains("single error")); @@ -864,9 +886,9 @@ mod tests { let error = MadError::Inner( anyhow::anyhow!("root cause") .context("middle layer") - .context("top layer") + .context("top layer"), ); - + // Test that we can access the error chain if let MadError::Inner(ref e) = error { let chain: Vec = e.chain().map(|c| c.to_string()).collect(); @@ -882,13 +904,13 @@ mod tests { #[tokio::test] async fn test_component_error_propagation() { struct FailingComponent; - + #[async_trait::async_trait] impl Component for FailingComponent { fn name(&self) -> Option { Some("test-component".to_string()) } - + async fn run(&self, _cancel: CancellationToken) -> Result<(), MadError> { Err(anyhow::anyhow!("IO error") .context("failed to open file") @@ -896,16 +918,16 @@ mod tests { .into()) } } - + let result = Mad::builder() .add(FailingComponent) .cancellation(Some(std::time::Duration::from_millis(100))) .run() .await; - + assert!(result.is_err()); let error = result.unwrap_err(); - + // Check error display let display = format!("{}", error); assert!(display.contains("component initialization failed")); diff --git a/crates/mad/tests/mod.rs b/crates/mad/tests/mod.rs index f88e1bd..602deea 100644 --- a/crates/mad/tests/mod.rs +++ b/crates/mad/tests/mod.rs @@ -138,6 +138,30 @@ async fn test_can_shutdown_gracefully() -> anyhow::Result<()> { Ok(()) } +#[tokio::test] +#[traced_test] +async fn test_component_panics_shutdowns_cleanly() -> anyhow::Result<()> { + let res = Mad::builder() + .add_fn({ + move |_cancel| async move { + panic!("my inner panic"); + } + }) + .add_fn(|cancel| async move { + cancel.cancelled().await; + + Ok(()) + }) + .run() + .await; + + let err_content = res.unwrap_err().to_string(); + assert!(err_content.contains("component panicked")); + assert!(err_content.contains("my inner panic")); + + Ok(()) +} + #[test] fn test_can_easily_transform_error() -> anyhow::Result<()> { fn fallible() -> anyhow::Result<()> {