feat: replace async-trait with erased box type

Signed-off-by: kjuulh <contact@kjuulh.io>
This commit is contained in:
2026-01-07 11:05:33 +01:00
parent 5e60a272f7
commit f0c90edce9
13 changed files with 93 additions and 74 deletions

12
Cargo.lock generated
View File

@@ -17,17 +17,6 @@ version = "1.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
[[package]]
name = "async-trait"
version = "0.1.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.4.0" version = "1.4.0"
@@ -225,7 +214,6 @@ name = "notmad"
version = "0.10.0" version = "0.10.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait",
"futures", "futures",
"futures-util", "futures-util",
"rand", "rand",

View File

@@ -6,7 +6,6 @@ resolver = "2"
version = "0.10.0" version = "0.10.0"
[workspace.dependencies] [workspace.dependencies]
mad = { path = "crates/mad" }
anyhow = { version = "1.0.71" } anyhow = { version = "1.0.71" }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }

View File

@@ -10,7 +10,6 @@ readme = "../../README.md"
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
async-trait = "0.1.81"
futures = "0.3.30" futures = "0.3.30"
futures-util = "0.3.30" futures-util = "0.3.30"
rand = "0.9.0" rand = "0.9.0"

View File

@@ -1,16 +1,14 @@
use async_trait::async_trait;
use rand::Rng; use rand::Rng;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::Level; use tracing::Level;
struct WaitServer {} struct WaitServer {}
#[async_trait]
impl notmad::Component for WaitServer { impl notmad::Component for WaitServer {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some("WaitServer".into()) Some("WaitServer".into())
} }
async fn run(&self, cancellation: CancellationToken) -> Result<(), notmad::MadError> { async fn run(&self, _cancellation: CancellationToken) -> Result<(), notmad::MadError> {
let millis_wait = rand::thread_rng().gen_range(500..3000); let millis_wait = rand::thread_rng().gen_range(500..3000);
tracing::debug!("waiting: {}ms", millis_wait); tracing::debug!("waiting: {}ms", millis_wait);

View File

@@ -7,7 +7,6 @@
//! - Graceful shutdown with cancellation tokens //! - Graceful shutdown with cancellation tokens
//! - Concurrent component execution //! - Concurrent component execution
use async_trait::async_trait;
use notmad::{Component, Mad, MadError}; use notmad::{Component, Mad, MadError};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
@@ -21,7 +20,6 @@ struct WebServer {
request_count: Arc<AtomicUsize>, request_count: Arc<AtomicUsize>,
} }
#[async_trait]
impl Component for WebServer { impl Component for WebServer {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some(format!("web-server-{}", self.port)) Some(format!("web-server-{}", self.port))
@@ -81,7 +79,6 @@ struct JobProcessor {
processing_interval: Duration, processing_interval: Duration,
} }
#[async_trait]
impl Component for JobProcessor { impl Component for JobProcessor {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some(format!("job-processor-{}", self.queue_name)) Some(format!("job-processor-{}", self.queue_name))
@@ -139,7 +136,6 @@ struct HealthChecker {
check_interval: Duration, check_interval: Duration,
} }
#[async_trait]
impl Component for HealthChecker { impl Component for HealthChecker {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some("health-checker".to_string()) Some("health-checker".to_string())
@@ -181,7 +177,6 @@ struct FailingComponent {
fail_after: Duration, fail_after: Duration,
} }
#[async_trait]
impl Component for FailingComponent { impl Component for FailingComponent {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some("failing-component".to_string()) Some("failing-component".to_string())
@@ -209,7 +204,6 @@ impl Component for FailingComponent {
/// Debug component that logs system status periodically. /// Debug component that logs system status periodically.
struct DebugComponent; struct DebugComponent;
#[async_trait]
impl Component for DebugComponent { impl Component for DebugComponent {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some("debug-component".to_string()) Some("debug-component".to_string())

View File

@@ -1,16 +1,14 @@
use async_trait::async_trait;
use rand::Rng; use rand::Rng;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::Level; use tracing::Level;
struct ErrorServer {} struct ErrorServer {}
#[async_trait]
impl notmad::Component for ErrorServer { impl notmad::Component for ErrorServer {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some("ErrorServer".into()) Some("ErrorServer".into())
} }
async fn run(&self, cancellation: CancellationToken) -> Result<(), notmad::MadError> { async fn run(&self, _cancellation: CancellationToken) -> Result<(), notmad::MadError> {
let millis_wait = rand::thread_rng().gen_range(500..3000); let millis_wait = rand::thread_rng().gen_range(500..3000);
tracing::debug!("waiting: {}ms", millis_wait); tracing::debug!("waiting: {}ms", millis_wait);

View File

@@ -1,10 +1,8 @@
use async_trait::async_trait;
use rand::Rng; use rand::Rng;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::Level; use tracing::Level;
struct WaitServer {} struct WaitServer {}
#[async_trait]
impl notmad::Component for WaitServer { impl notmad::Component for WaitServer {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some("WaitServer".into()) Some("WaitServer".into())

View File

@@ -3,7 +3,6 @@
//! This example shows how to run a web server, queue processor, and //! This example shows how to run a web server, queue processor, and
//! scheduled task together, with graceful shutdown on Ctrl+C. //! scheduled task together, with graceful shutdown on Ctrl+C.
use async_trait::async_trait;
use notmad::{Component, Mad, MadError}; use notmad::{Component, Mad, MadError};
use tokio::time::{Duration, interval}; use tokio::time::{Duration, interval};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
@@ -16,7 +15,6 @@ struct WebServer {
port: u16, port: u16,
} }
#[async_trait]
impl Component for WebServer { impl Component for WebServer {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some(format!("WebServer:{}", self.port)) Some(format!("WebServer:{}", self.port))
@@ -70,7 +68,6 @@ struct QueueProcessor {
queue_name: String, queue_name: String,
} }
#[async_trait]
impl Component for QueueProcessor { impl Component for QueueProcessor {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some(format!("QueueProcessor:{}", self.queue_name)) Some(format!("QueueProcessor:{}", self.queue_name))
@@ -116,7 +113,6 @@ struct ScheduledTask {
interval_secs: u64, interval_secs: u64,
} }
#[async_trait]
impl Component for ScheduledTask { impl Component for ScheduledTask {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some(format!("ScheduledTask:{}", self.task_name)) Some(format!("ScheduledTask:{}", self.task_name))

View File

@@ -1,11 +1,9 @@
use async_trait::async_trait;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
struct NestedErrorComponent { struct NestedErrorComponent {
name: String, name: String,
} }
#[async_trait]
impl notmad::Component for NestedErrorComponent { impl notmad::Component for NestedErrorComponent {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some(self.name.clone()) Some(self.name.clone())
@@ -28,7 +26,6 @@ impl notmad::Component for NestedErrorComponent {
struct AnotherFailingComponent; struct AnotherFailingComponent;
#[async_trait]
impl notmad::Component for AnotherFailingComponent { impl notmad::Component for AnotherFailingComponent {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some("another-component".into()) Some("another-component".into())

View File

@@ -1,16 +1,14 @@
use async_trait::async_trait;
use rand::Rng; use rand::Rng;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::Level; use tracing::Level;
struct WaitServer {} struct WaitServer {}
#[async_trait]
impl notmad::Component for WaitServer { impl notmad::Component for WaitServer {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some("WaitServer".into()) Some("WaitServer".into())
} }
async fn run(&self, cancellation: CancellationToken) -> Result<(), notmad::MadError> { async fn run(&self, _cancellation: CancellationToken) -> Result<(), notmad::MadError> {
let millis_wait = rand::thread_rng().gen_range(500..3000); let millis_wait = rand::thread_rng().gen_range(500..3000);
tracing::debug!("waiting: {}ms", millis_wait); tracing::debug!("waiting: {}ms", millis_wait);
@@ -23,7 +21,6 @@ impl notmad::Component for WaitServer {
} }
struct RespectCancel {} struct RespectCancel {}
#[async_trait]
impl notmad::Component for RespectCancel { impl notmad::Component for RespectCancel {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some("RespectCancel".into()) Some("RespectCancel".into())
@@ -38,13 +35,12 @@ impl notmad::Component for RespectCancel {
} }
struct NeverStopServer {} struct NeverStopServer {}
#[async_trait]
impl notmad::Component for NeverStopServer { impl notmad::Component for NeverStopServer {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some("NeverStopServer".into()) Some("NeverStopServer".into())
} }
async fn run(&self, cancellation: CancellationToken) -> Result<(), notmad::MadError> { async fn run(&self, _cancellation: CancellationToken) -> Result<(), notmad::MadError> {
// Simulates a server running for some time. Is normally supposed to be futures blocking indefinitely // Simulates a server running for some time. Is normally supposed to be futures blocking indefinitely
tokio::time::sleep(std::time::Duration::from_millis(999999999)).await; tokio::time::sleep(std::time::Duration::from_millis(999999999)).await;

View File

@@ -77,11 +77,8 @@
use futures::stream::FuturesUnordered; use futures::stream::FuturesUnordered;
use futures_util::StreamExt; use futures_util::StreamExt;
use std::{error::Error, fmt::Display, sync::Arc}; use std::{error::Error, fmt::Display, pin::Pin, sync::Arc};
use tokio::{ use tokio::signal::unix::{SignalKind, signal};
signal::unix::{SignalKind, signal},
task::JoinError,
};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
@@ -229,7 +226,7 @@ impl Display for AggregateError {
/// # } /// # }
/// ``` /// ```
pub struct Mad { pub struct Mad {
components: Vec<Arc<dyn Component + Send + Sync + 'static>>, components: Vec<SharedComponent>,
should_cancel: Option<std::time::Duration>, should_cancel: Option<std::time::Duration>,
} }
@@ -397,7 +394,7 @@ impl Mad {
/// # Arguments /// # Arguments
/// ///
/// * `should_cancel` - Duration to wait after cancellation before forcing shutdown. /// * `should_cancel` - Duration to wait after cancellation before forcing shutdown.
/// Pass `None` to wait indefinitely. /// Pass `None` to wait indefinitely.
/// ///
/// # Example /// # Example
/// ///
@@ -669,8 +666,7 @@ async fn signal_unix_terminate() {
/// } /// }
/// } /// }
/// ``` /// ```
#[async_trait::async_trait] pub trait Component: Send + Sync + 'static {
pub trait Component {
/// Returns an optional name for the component. /// Returns an optional name for the component.
/// ///
/// This name is used in logging and error messages to identify /// This name is used in logging and error messages to identify
@@ -698,8 +694,8 @@ pub trait Component {
/// ///
/// If setup fails with an error other than `SetupNotDefined`, /// If setup fails with an error other than `SetupNotDefined`,
/// the entire application will stop before any components start running. /// the entire application will stop before any components start running.
async fn setup(&self) -> Result<(), MadError> { fn setup(&self) -> impl Future<Output = Result<(), MadError>> + Send + '_ {
Err(MadError::SetupNotDefined) async { Err(MadError::SetupNotDefined) }
} }
/// Main execution phase of the component. /// Main execution phase of the component.
@@ -721,7 +717,10 @@ pub trait Component {
/// # Errors /// # Errors
/// ///
/// Any error returned will trigger shutdown of all other components. /// Any error returned will trigger shutdown of all other components.
async fn run(&self, cancellation_token: CancellationToken) -> Result<(), MadError>; fn run(
&self,
cancellation_token: CancellationToken,
) -> impl Future<Output = Result<(), MadError>> + Send + '_;
/// Optional cleanup phase called after the component stops. /// Optional cleanup phase called after the component stops.
/// ///
@@ -738,8 +737,73 @@ pub trait Component {
/// ///
/// Errors during close are logged but don't prevent other components /// Errors during close are logged but don't prevent other components
/// from closing. /// from closing.
fn close(&self) -> impl Future<Output = Result<(), MadError>> + Send + '_ {
async { Err(MadError::CloseNotDefined) }
}
}
trait AsyncComponent: Send + Sync + 'static {
fn name_async(&self) -> Option<String>;
fn setup_async(&self) -> Pin<Box<dyn Future<Output = Result<(), MadError>> + Send + '_>>;
fn run_async(
&self,
cancellation_token: CancellationToken,
) -> Pin<Box<dyn Future<Output = Result<(), MadError>> + Send + '_>>;
fn close_async(&self) -> Pin<Box<dyn Future<Output = Result<(), MadError>> + Send + '_>>;
}
impl<E: Component> AsyncComponent for E {
#[inline(always)]
fn name_async(&self) -> Option<String> {
self.name()
}
#[inline(always)]
fn setup_async(&self) -> Pin<Box<dyn Future<Output = Result<(), MadError>> + Send + '_>> {
Box::pin(self.setup())
}
#[inline(always)]
fn run_async(
&self,
cancellation_token: CancellationToken,
) -> Pin<Box<dyn Future<Output = Result<(), MadError>> + Send + '_>> {
Box::pin(self.run(cancellation_token))
}
#[inline(always)]
fn close_async(&self) -> Pin<Box<dyn Future<Output = Result<(), MadError>> + Send + '_>> {
Box::pin(self.close())
}
}
#[derive(Clone)]
pub struct SharedComponent {
component: Arc<dyn AsyncComponent + Send + Sync + 'static>,
}
impl SharedComponent {
#[inline(always)]
pub fn name(&self) -> Option<String> {
self.component.name_async()
}
#[inline(always)]
async fn setup(&self) -> Result<(), MadError> {
self.component.setup_async().await
}
#[inline(always)]
async fn run(&self, cancellation_token: CancellationToken) -> Result<(), MadError> {
self.component.run_async(cancellation_token).await
}
#[inline(always)]
async fn close(&self) -> Result<(), MadError> { async fn close(&self) -> Result<(), MadError> {
Err(MadError::CloseNotDefined) self.component.close_async().await
} }
} }
@@ -769,12 +833,14 @@ pub trait Component {
/// ``` /// ```
pub trait IntoComponent { pub trait IntoComponent {
/// Converts self into an Arc-wrapped component. /// Converts self into an Arc-wrapped component.
fn into_component(self) -> Arc<dyn Component + Send + Sync + 'static>; fn into_component(self) -> SharedComponent;
} }
impl<T: Component + Send + Sync + 'static> IntoComponent for T { impl<T: Component> IntoComponent for T {
fn into_component(self) -> Arc<dyn Component + Send + Sync + 'static> { fn into_component(self) -> SharedComponent {
Arc::new(self) SharedComponent {
component: Arc::new(self),
}
} }
} }
@@ -798,7 +864,6 @@ where
} }
} }
#[async_trait::async_trait]
impl<F, Fut> Component for ClosureComponent<F, Fut> impl<F, Fut> Component for ClosureComponent<F, Fut>
where where
F: Fn(CancellationToken) -> Fut + Send + Sync + 'static, F: Fn(CancellationToken) -> Fut + Send + Sync + 'static,
@@ -812,7 +877,6 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use anyhow::Context;
#[test] #[test]
fn test_error_chaining_display() { fn test_error_chaining_display() {
@@ -909,7 +973,6 @@ mod tests {
async fn test_component_error_propagation() { async fn test_component_error_propagation() {
struct FailingComponent; struct FailingComponent;
#[async_trait::async_trait]
impl Component for FailingComponent { impl Component for FailingComponent {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some("test-component".to_string()) Some("test-component".to_string())

View File

@@ -4,19 +4,15 @@
//! without performing any work. Useful for keeping the application alive //! without performing any work. Useful for keeping the application alive
//! or as placeholders in conditional component loading. //! or as placeholders in conditional component loading.
use std::sync::Arc;
use async_trait::async_trait;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::{Component, MadError}; use crate::{Component, IntoComponent, MadError, SharedComponent};
/// A default waiter component that panics if run. /// A default waiter component that panics if run.
/// ///
/// This is used internally as a placeholder that should never /// This is used internally as a placeholder that should never
/// actually be executed. /// actually be executed.
pub struct DefaultWaiter {} pub struct DefaultWaiter;
#[async_trait]
impl Component for DefaultWaiter { impl Component for DefaultWaiter {
async fn run(&self, _cancellation_token: CancellationToken) -> Result<(), MadError> { async fn run(&self, _cancellation_token: CancellationToken) -> Result<(), MadError> {
panic!("should never be called"); panic!("should never be called");
@@ -38,13 +34,13 @@ impl Component for DefaultWaiter {
/// let waiter = Waiter::new(service.into_component()); /// let waiter = Waiter::new(service.into_component());
/// ``` /// ```
pub struct Waiter { pub struct Waiter {
comp: Arc<dyn Component + Send + Sync + 'static>, comp: SharedComponent,
} }
impl Default for Waiter { impl Default for Waiter {
fn default() -> Self { fn default() -> Self {
Self { Self {
comp: Arc::new(DefaultWaiter {}), comp: DefaultWaiter {}.into_component(),
} }
} }
} }
@@ -54,12 +50,11 @@ impl Waiter {
/// ///
/// The wrapped component's name will be used (prefixed with "waiter/"), /// The wrapped component's name will be used (prefixed with "waiter/"),
/// but its run method will not be called. /// but its run method will not be called.
pub fn new(c: Arc<dyn Component + Send + Sync + 'static>) -> Self { pub fn new(c: SharedComponent) -> Self {
Self { comp: c } Self { comp: c }
} }
} }
#[async_trait]
impl Component for Waiter { impl Component for Waiter {
/// Returns the name of the waiter, prefixed with "waiter/". /// Returns the name of the waiter, prefixed with "waiter/".
/// ///

View File

@@ -1,6 +1,5 @@
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait;
use notmad::{Component, Mad, MadError}; use notmad::{Component, Mad, MadError};
use rand::Rng; use rand::Rng;
use tokio::sync::Mutex; use tokio::sync::Mutex;
@@ -9,13 +8,12 @@ use tracing_test::traced_test;
struct NeverEndingRun {} struct NeverEndingRun {}
#[async_trait]
impl Component for NeverEndingRun { impl Component for NeverEndingRun {
fn name(&self) -> Option<String> { fn name(&self) -> Option<String> {
Some("NeverEndingRun".into()) Some("NeverEndingRun".into())
} }
async fn run(&self, cancellation: CancellationToken) -> Result<(), notmad::MadError> { async fn run(&self, _cancellation: CancellationToken) -> Result<(), notmad::MadError> {
let millis_wait = rand::thread_rng().gen_range(50..1000); let millis_wait = rand::thread_rng().gen_range(50..1000);
tokio::time::sleep(std::time::Duration::from_millis(millis_wait)).await; tokio::time::sleep(std::time::Duration::from_millis(millis_wait)).await;