diff --git a/Cargo.lock b/Cargo.lock index 4947587..ddd6a46 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,17 +17,6 @@ version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" -[[package]] -name = "async-trait" -version = "0.1.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "autocfg" version = "1.4.0" @@ -225,7 +214,6 @@ name = "notmad" version = "0.10.0" dependencies = [ "anyhow", - "async-trait", "futures", "futures-util", "rand", diff --git a/Cargo.toml b/Cargo.toml index 922d8c7..0def370 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ resolver = "2" version = "0.10.0" [workspace.dependencies] -mad = { path = "crates/mad" } anyhow = { version = "1.0.71" } tokio = { version = "1", features = ["full"] } diff --git a/crates/mad/Cargo.toml b/crates/mad/Cargo.toml index 757e9f6..ec2f9d7 100644 --- a/crates/mad/Cargo.toml +++ b/crates/mad/Cargo.toml @@ -10,7 +10,6 @@ readme = "../../README.md" [dependencies] anyhow.workspace = true -async-trait = "0.1.81" futures = "0.3.30" futures-util = "0.3.30" rand = "0.9.0" diff --git a/crates/mad/examples/basic/main.rs b/crates/mad/examples/basic/main.rs index f75e5ba..667c74b 100644 --- a/crates/mad/examples/basic/main.rs +++ b/crates/mad/examples/basic/main.rs @@ -1,16 +1,14 @@ -use async_trait::async_trait; use rand::Rng; use tokio_util::sync::CancellationToken; use tracing::Level; struct WaitServer {} -#[async_trait] impl notmad::Component for WaitServer { fn name(&self) -> Option { Some("WaitServer".into()) } - async fn run(&self, cancellation: CancellationToken) -> Result<(), notmad::MadError> { + async fn run(&self, _cancellation: CancellationToken) -> Result<(), notmad::MadError> { let millis_wait = rand::thread_rng().gen_range(500..3000); tracing::debug!("waiting: {}ms", millis_wait); diff --git a/crates/mad/examples/comprehensive/main.rs b/crates/mad/examples/comprehensive/main.rs index 0fa40cb..1befb5f 100644 --- a/crates/mad/examples/comprehensive/main.rs +++ b/crates/mad/examples/comprehensive/main.rs @@ -7,7 +7,6 @@ //! - Graceful shutdown with cancellation tokens //! - Concurrent component execution -use async_trait::async_trait; use notmad::{Component, Mad, MadError}; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -21,7 +20,6 @@ struct WebServer { request_count: Arc, } -#[async_trait] impl Component for WebServer { fn name(&self) -> Option { Some(format!("web-server-{}", self.port)) @@ -81,7 +79,6 @@ struct JobProcessor { processing_interval: Duration, } -#[async_trait] impl Component for JobProcessor { fn name(&self) -> Option { Some(format!("job-processor-{}", self.queue_name)) @@ -139,7 +136,6 @@ struct HealthChecker { check_interval: Duration, } -#[async_trait] impl Component for HealthChecker { fn name(&self) -> Option { Some("health-checker".to_string()) @@ -181,7 +177,6 @@ struct FailingComponent { fail_after: Duration, } -#[async_trait] impl Component for FailingComponent { fn name(&self) -> Option { Some("failing-component".to_string()) @@ -209,7 +204,6 @@ impl Component for FailingComponent { /// Debug component that logs system status periodically. struct DebugComponent; -#[async_trait] impl Component for DebugComponent { fn name(&self) -> Option { Some("debug-component".to_string()) diff --git a/crates/mad/examples/error_log/main.rs b/crates/mad/examples/error_log/main.rs index f60496d..2213694 100644 --- a/crates/mad/examples/error_log/main.rs +++ b/crates/mad/examples/error_log/main.rs @@ -1,16 +1,14 @@ -use async_trait::async_trait; use rand::Rng; use tokio_util::sync::CancellationToken; use tracing::Level; struct ErrorServer {} -#[async_trait] impl notmad::Component for ErrorServer { fn name(&self) -> Option { Some("ErrorServer".into()) } - async fn run(&self, cancellation: CancellationToken) -> Result<(), notmad::MadError> { + async fn run(&self, _cancellation: CancellationToken) -> Result<(), notmad::MadError> { let millis_wait = rand::thread_rng().gen_range(500..3000); tracing::debug!("waiting: {}ms", millis_wait); diff --git a/crates/mad/examples/fn/main.rs b/crates/mad/examples/fn/main.rs index a59c176..fa1c0ed 100644 --- a/crates/mad/examples/fn/main.rs +++ b/crates/mad/examples/fn/main.rs @@ -1,10 +1,8 @@ -use async_trait::async_trait; use rand::Rng; use tokio_util::sync::CancellationToken; use tracing::Level; struct WaitServer {} -#[async_trait] impl notmad::Component for WaitServer { fn name(&self) -> Option { Some("WaitServer".into()) diff --git a/crates/mad/examples/multi_service/main.rs b/crates/mad/examples/multi_service/main.rs index b9b1222..2cf4161 100644 --- a/crates/mad/examples/multi_service/main.rs +++ b/crates/mad/examples/multi_service/main.rs @@ -3,7 +3,6 @@ //! This example shows how to run a web server, queue processor, and //! scheduled task together, with graceful shutdown on Ctrl+C. -use async_trait::async_trait; use notmad::{Component, Mad, MadError}; use tokio::time::{Duration, interval}; use tokio_util::sync::CancellationToken; @@ -16,7 +15,6 @@ struct WebServer { port: u16, } -#[async_trait] impl Component for WebServer { fn name(&self) -> Option { Some(format!("WebServer:{}", self.port)) @@ -70,7 +68,6 @@ struct QueueProcessor { queue_name: String, } -#[async_trait] impl Component for QueueProcessor { fn name(&self) -> Option { Some(format!("QueueProcessor:{}", self.queue_name)) @@ -116,7 +113,6 @@ struct ScheduledTask { interval_secs: u64, } -#[async_trait] impl Component for ScheduledTask { fn name(&self) -> Option { Some(format!("ScheduledTask:{}", self.task_name)) diff --git a/crates/mad/examples/nested_errors/main.rs b/crates/mad/examples/nested_errors/main.rs index ba4c166..6ab5f21 100644 --- a/crates/mad/examples/nested_errors/main.rs +++ b/crates/mad/examples/nested_errors/main.rs @@ -1,11 +1,9 @@ -use async_trait::async_trait; use tokio_util::sync::CancellationToken; struct NestedErrorComponent { name: String, } -#[async_trait] impl notmad::Component for NestedErrorComponent { fn name(&self) -> Option { Some(self.name.clone()) @@ -28,7 +26,6 @@ impl notmad::Component for NestedErrorComponent { struct AnotherFailingComponent; -#[async_trait] impl notmad::Component for AnotherFailingComponent { fn name(&self) -> Option { Some("another-component".into()) diff --git a/crates/mad/examples/signals/main.rs b/crates/mad/examples/signals/main.rs index ed9c086..1fdd0e8 100644 --- a/crates/mad/examples/signals/main.rs +++ b/crates/mad/examples/signals/main.rs @@ -1,16 +1,14 @@ -use async_trait::async_trait; use rand::Rng; use tokio_util::sync::CancellationToken; use tracing::Level; struct WaitServer {} -#[async_trait] impl notmad::Component for WaitServer { fn name(&self) -> Option { Some("WaitServer".into()) } - async fn run(&self, cancellation: CancellationToken) -> Result<(), notmad::MadError> { + async fn run(&self, _cancellation: CancellationToken) -> Result<(), notmad::MadError> { let millis_wait = rand::thread_rng().gen_range(500..3000); tracing::debug!("waiting: {}ms", millis_wait); @@ -23,7 +21,6 @@ impl notmad::Component for WaitServer { } struct RespectCancel {} -#[async_trait] impl notmad::Component for RespectCancel { fn name(&self) -> Option { Some("RespectCancel".into()) @@ -38,13 +35,12 @@ impl notmad::Component for RespectCancel { } struct NeverStopServer {} -#[async_trait] impl notmad::Component for NeverStopServer { fn name(&self) -> Option { Some("NeverStopServer".into()) } - async fn run(&self, cancellation: CancellationToken) -> Result<(), notmad::MadError> { + async fn run(&self, _cancellation: CancellationToken) -> Result<(), notmad::MadError> { // Simulates a server running for some time. Is normally supposed to be futures blocking indefinitely tokio::time::sleep(std::time::Duration::from_millis(999999999)).await; diff --git a/crates/mad/src/lib.rs b/crates/mad/src/lib.rs index 2dbef9c..f8e4640 100644 --- a/crates/mad/src/lib.rs +++ b/crates/mad/src/lib.rs @@ -77,11 +77,8 @@ use futures::stream::FuturesUnordered; use futures_util::StreamExt; -use std::{error::Error, fmt::Display, sync::Arc}; -use tokio::{ - signal::unix::{SignalKind, signal}, - task::JoinError, -}; +use std::{error::Error, fmt::Display, pin::Pin, sync::Arc}; +use tokio::signal::unix::{SignalKind, signal}; use tokio_util::sync::CancellationToken; @@ -229,7 +226,7 @@ impl Display for AggregateError { /// # } /// ``` pub struct Mad { - components: Vec>, + components: Vec, should_cancel: Option, } @@ -397,7 +394,7 @@ impl Mad { /// # Arguments /// /// * `should_cancel` - Duration to wait after cancellation before forcing shutdown. - /// Pass `None` to wait indefinitely. + /// Pass `None` to wait indefinitely. /// /// # Example /// @@ -669,8 +666,7 @@ async fn signal_unix_terminate() { /// } /// } /// ``` -#[async_trait::async_trait] -pub trait Component { +pub trait Component: Send + Sync + 'static { /// Returns an optional name for the component. /// /// This name is used in logging and error messages to identify @@ -698,8 +694,8 @@ pub trait Component { /// /// If setup fails with an error other than `SetupNotDefined`, /// the entire application will stop before any components start running. - async fn setup(&self) -> Result<(), MadError> { - Err(MadError::SetupNotDefined) + fn setup(&self) -> impl Future> + Send + '_ { + async { Err(MadError::SetupNotDefined) } } /// Main execution phase of the component. @@ -721,7 +717,10 @@ pub trait Component { /// # Errors /// /// Any error returned will trigger shutdown of all other components. - async fn run(&self, cancellation_token: CancellationToken) -> Result<(), MadError>; + fn run( + &self, + cancellation_token: CancellationToken, + ) -> impl Future> + Send + '_; /// Optional cleanup phase called after the component stops. /// @@ -738,8 +737,73 @@ pub trait Component { /// /// Errors during close are logged but don't prevent other components /// from closing. + fn close(&self) -> impl Future> + Send + '_ { + async { Err(MadError::CloseNotDefined) } + } +} + +trait AsyncComponent: Send + Sync + 'static { + fn name_async(&self) -> Option; + + fn setup_async(&self) -> Pin> + Send + '_>>; + + fn run_async( + &self, + cancellation_token: CancellationToken, + ) -> Pin> + Send + '_>>; + + fn close_async(&self) -> Pin> + Send + '_>>; +} + +impl AsyncComponent for E { + #[inline(always)] + fn name_async(&self) -> Option { + self.name() + } + + #[inline(always)] + fn setup_async(&self) -> Pin> + Send + '_>> { + Box::pin(self.setup()) + } + + #[inline(always)] + fn run_async( + &self, + cancellation_token: CancellationToken, + ) -> Pin> + Send + '_>> { + Box::pin(self.run(cancellation_token)) + } + + #[inline(always)] + fn close_async(&self) -> Pin> + Send + '_>> { + Box::pin(self.close()) + } +} + +#[derive(Clone)] +pub struct SharedComponent { + component: Arc, +} + +impl SharedComponent { + #[inline(always)] + pub fn name(&self) -> Option { + self.component.name_async() + } + + #[inline(always)] + async fn setup(&self) -> Result<(), MadError> { + self.component.setup_async().await + } + + #[inline(always)] + async fn run(&self, cancellation_token: CancellationToken) -> Result<(), MadError> { + self.component.run_async(cancellation_token).await + } + + #[inline(always)] async fn close(&self) -> Result<(), MadError> { - Err(MadError::CloseNotDefined) + self.component.close_async().await } } @@ -769,12 +833,14 @@ pub trait Component { /// ``` pub trait IntoComponent { /// Converts self into an Arc-wrapped component. - fn into_component(self) -> Arc; + fn into_component(self) -> SharedComponent; } -impl IntoComponent for T { - fn into_component(self) -> Arc { - Arc::new(self) +impl IntoComponent for T { + fn into_component(self) -> SharedComponent { + SharedComponent { + component: Arc::new(self), + } } } @@ -798,7 +864,6 @@ where } } -#[async_trait::async_trait] impl Component for ClosureComponent where F: Fn(CancellationToken) -> Fut + Send + Sync + 'static, @@ -812,7 +877,6 @@ where #[cfg(test)] mod tests { use super::*; - use anyhow::Context; #[test] fn test_error_chaining_display() { @@ -909,7 +973,6 @@ mod tests { async fn test_component_error_propagation() { struct FailingComponent; - #[async_trait::async_trait] impl Component for FailingComponent { fn name(&self) -> Option { Some("test-component".to_string()) diff --git a/crates/mad/src/waiter.rs b/crates/mad/src/waiter.rs index 5426326..d6d4df1 100644 --- a/crates/mad/src/waiter.rs +++ b/crates/mad/src/waiter.rs @@ -4,19 +4,15 @@ //! without performing any work. Useful for keeping the application alive //! or as placeholders in conditional component loading. -use std::sync::Arc; - -use async_trait::async_trait; use tokio_util::sync::CancellationToken; -use crate::{Component, MadError}; +use crate::{Component, IntoComponent, MadError, SharedComponent}; /// A default waiter component that panics if run. /// /// This is used internally as a placeholder that should never /// actually be executed. -pub struct DefaultWaiter {} -#[async_trait] +pub struct DefaultWaiter; impl Component for DefaultWaiter { async fn run(&self, _cancellation_token: CancellationToken) -> Result<(), MadError> { panic!("should never be called"); @@ -38,13 +34,13 @@ impl Component for DefaultWaiter { /// let waiter = Waiter::new(service.into_component()); /// ``` pub struct Waiter { - comp: Arc, + comp: SharedComponent, } impl Default for Waiter { fn default() -> Self { Self { - comp: Arc::new(DefaultWaiter {}), + comp: DefaultWaiter {}.into_component(), } } } @@ -54,12 +50,11 @@ impl Waiter { /// /// The wrapped component's name will be used (prefixed with "waiter/"), /// but its run method will not be called. - pub fn new(c: Arc) -> Self { + pub fn new(c: SharedComponent) -> Self { Self { comp: c } } } -#[async_trait] impl Component for Waiter { /// Returns the name of the waiter, prefixed with "waiter/". /// diff --git a/crates/mad/tests/mod.rs b/crates/mad/tests/mod.rs index 602deea..525c342 100644 --- a/crates/mad/tests/mod.rs +++ b/crates/mad/tests/mod.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use async_trait::async_trait; use notmad::{Component, Mad, MadError}; use rand::Rng; use tokio::sync::Mutex; @@ -9,13 +8,12 @@ use tracing_test::traced_test; struct NeverEndingRun {} -#[async_trait] impl Component for NeverEndingRun { fn name(&self) -> Option { Some("NeverEndingRun".into()) } - async fn run(&self, cancellation: CancellationToken) -> Result<(), notmad::MadError> { + async fn run(&self, _cancellation: CancellationToken) -> Result<(), notmad::MadError> { let millis_wait = rand::thread_rng().gen_range(50..1000); tokio::time::sleep(std::time::Duration::from_millis(millis_wait)).await;