Compare commits

...

8 Commits

Author SHA1 Message Date
e937501b35 chore(deps): update rust crate tracing-subscriber to v0.3.23 2026-03-14 01:58:10 +00:00
5e1cd2b1e7 feat: add weight 2026-03-06 12:55:10 +01:00
0fa906a8cf chore(deps): update rust crate uuid to v1.22.0 2026-03-06 01:48:18 +00:00
f2efd1ba59 feat: use default provider 2026-03-05 21:41:15 +01:00
cad8c0e307 feat: add providers 2026-03-05 21:29:06 +01:00
126776f389 feat: add tls 2026-03-05 21:11:09 +01:00
4977cb0485 feat: ship without sqlx 2026-03-05 18:56:49 +01:00
0f24a41435 feat: replace macros 2026-03-05 17:45:07 +01:00
13 changed files with 1497 additions and 1312 deletions

1572
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -323,6 +323,7 @@ async fn run_app<
) -> anyhow::Result<()>
where
TOperator::Specifications: Send + Sync,
<B as Backend>::Error: Send + Sync + 'static,
{
// Spawn refresh task
let app_clone = app.clone();

View File

@@ -24,13 +24,15 @@ serde_json = "1.0.148"
sha2 = "0.10.9"
tokio-util = "0.7.18"
sqlx = { version = "0.8.6", optional = true, features = [
"chrono",
"json",
"postgres",
"runtime-tokio",
"uuid",
tokio-postgres = { version = "0.7", optional = true, features = [
"with-uuid-1",
"with-serde_json-1",
"with-chrono-0_4",
] }
chrono = { version = "0.4", optional = true }
tokio-postgres-rustls = { version = "0.13", optional = true }
rustls = { version = "0.23", optional = true, default-features = false }
rustls-native-certs = { version = "0.8", optional = true }
[dev-dependencies]
@@ -39,4 +41,7 @@ tracing-test = { version = "0.2.5", features = ["no-env-filter"] }
[features]
default = []
postgres = ["dep:sqlx"]
postgres = ["dep:tokio-postgres", "dep:chrono"]
postgres-tls = ["postgres", "dep:tokio-postgres-rustls", "dep:rustls", "dep:rustls-native-certs"]
postgres-tls-ring = ["postgres-tls", "rustls/ring"]
postgres-tls-aws-lc-rs = ["postgres-tls", "rustls/aws_lc_rs"]

View File

@@ -1,21 +0,0 @@
-- Add migration script here
create table manifests (
id UUID PRIMARY KEY NOT NULL,
generation BIGINT NOT NULL,
name TEXT NOT NULL,
kind TEXT NOT NULL,
status JSONB NOT NULL,
manifest_content JSONB NOT NULL,
manifest_hash BYTEA NOT NULL,
created TIMESTAMPTZ NOT NULL,
updated TIMESTAMPTZ NOT NULL,
lease_owner_id UUID,
lease_last_updated TIMESTAMPTZ
);
CREATE UNIQUE INDEX idx_manifest_name ON manifests(name, kind);

View File

@@ -6,6 +6,7 @@ use crate::{
Specification,
control_plane::backing_store::in_process::BackingStoreInProcess,
manifests::{Manifest, ManifestState, WorkerId},
operator_state::ClusterStats,
};
pub mod in_process;
@@ -96,4 +97,14 @@ pub trait BackingStoreEdge<T: Specification>: Send + Sync + Clone {
&self,
manifest: &ManifestState<T>,
) -> impl std::future::Future<Output = anyhow::Result<()>> + Send;
/// Returns cluster-wide weight distribution statistics.
/// Returns None when cluster stats are not available (disables fair-share).
fn get_cluster_stats(
&self,
worker_id: &Uuid,
) -> impl std::future::Future<Output = anyhow::Result<Option<ClusterStats>>> + Send {
let _ = worker_id;
async { Ok(None) }
}
}

View File

@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{collections::HashSet, sync::Arc};
use jiff::ToSpan;
use sha2::{Digest, Sha256};
@@ -11,6 +11,7 @@ use crate::{
Manifest, ManifestChangeEvent, ManifestChangeEventType, ManifestLease, ManifestState,
ManifestStatus, ManifestStatusState, WorkerId,
},
operator_state::ClusterStats,
};
#[derive(Clone, Default)]
@@ -64,12 +65,13 @@ impl<T: Specification> BackingStoreEdge<T> for BackingStoreInProcess<T> {
.find(|m| m.manifest.name == manifest_state.manifest.name)
{
Some(manifest) => {
let mut manifest_state = manifest_state.clone();
if let Some(lease) = manifest_state.lease.as_mut() {
lease.last_seen = jiff::Timestamp::now();
if manifest.generation != manifest_state.generation {
anyhow::bail!("failed to update lease, generation mismatch");
}
manifest.lease = manifest_state.lease
if let Some(lease) = manifest.lease.as_mut() {
lease.last_seen = jiff::Timestamp::now();
}
}
None => anyhow::bail!("manifest is not found"),
}
@@ -90,12 +92,20 @@ impl<T: Specification> BackingStoreEdge<T> for BackingStoreInProcess<T> {
.find(|m| m.manifest.name == manifest_state.manifest.name)
{
Some(manifest) => {
let mut manifest_state = manifest_state.clone();
manifest_state.lease = Some(ManifestLease {
// CAS: only acquire if generation matches (prevents race conditions)
if manifest.generation != manifest_state.generation {
anyhow::bail!(
"failed to acquire lease: generation mismatch (expected {}, got {})",
manifest_state.generation,
manifest.generation
);
}
manifest.lease = Some(ManifestLease {
owner: *worker_id,
last_seen: jiff::Timestamp::now(),
});
manifest.lease = manifest_state.lease
manifest.generation += 1;
}
None => anyhow::bail!("manifest is not found"),
}
@@ -183,6 +193,37 @@ impl<T: Specification> BackingStoreEdge<T> for BackingStoreInProcess<T> {
Ok(())
}
async fn get_cluster_stats(
&self,
worker_id: &WorkerId,
) -> anyhow::Result<Option<ClusterStats>> {
let now = jiff::Timestamp::now().checked_sub(10.second())?;
let manifests = self.manifests.read().await;
let mut total_weight = 0u64;
let mut my_weight = 0u64;
let mut active_workers = HashSet::new();
for m in manifests.iter() {
if let Some(lease) = &m.lease {
if lease.last_seen > now {
let w = m.manifest.spec.weight();
total_weight += w;
active_workers.insert(lease.owner);
if &lease.owner == worker_id {
my_weight += w;
}
}
}
}
Ok(Some(ClusterStats {
total_weight,
active_workers: active_workers.len(),
my_weight,
}))
}
}
impl<T: Specification> BackingStoreInProcess<T> {

View File

@@ -1,227 +1,260 @@
use std::marker::PhantomData;
use std::{marker::PhantomData, sync::Arc};
use anyhow::Context;
use chrono::{DateTime, Utc};
use jiff::Timestamp;
use sha2::Digest;
use sqlx::PgPool;
use tokio_postgres::{Client, Row};
use crate::{
Specification,
manifests::{
Manifest, ManifestLease, ManifestState, ManifestStatus, ManifestStatusState, WorkerId,
},
operator_state::ClusterStats,
stores::BackingStoreEdge,
};
fn row_to_manifest_state<T: Specification>(row: &Row) -> anyhow::Result<ManifestState<T>> {
let manifest_content: serde_json::Value = row.get("manifest_content");
let status: serde_json::Value = row.get("status");
let manifest_hash: Vec<u8> = row.get("manifest_hash");
let generation: i64 = row.get("generation");
let created: DateTime<Utc> = row.get("created");
let updated: DateTime<Utc> = row.get("updated");
let lease_owner_id: Option<uuid::Uuid> = row.get("lease_owner_id");
let lease_last_updated: Option<DateTime<Utc>> = row.get("lease_last_updated");
let content: Manifest<T> = serde_json::from_value(manifest_content)?;
Ok(ManifestState {
manifest: content,
manifest_hash,
generation: generation as u64,
status: serde_json::from_value(status)?,
created: Timestamp::from_millisecond(created.timestamp_millis())?,
updated: Timestamp::from_millisecond(updated.timestamp_millis())?,
lease: match (lease_owner_id, lease_last_updated) {
(Some(owner_id), Some(last_updated)) => Some(ManifestLease {
owner: owner_id,
last_seen: Timestamp::from_millisecond(last_updated.timestamp_millis())?,
}),
_ => None,
},
})
}
#[derive(Clone)]
pub struct BackingStorePostgres<T: Specification> {
pool: PgPool,
client: Arc<Client>,
_marker: PhantomData<T>,
}
impl<T: Specification> BackingStorePostgres<T> {
pub(crate) async fn new(database_url: &str) -> anyhow::Result<Self> {
tracing::debug!("connecting to postgres database");
let pool = sqlx::PgPool::connect(database_url)
.await
.context("failed to connect to database")?;
let client = Self::connect(database_url).await?;
tracing::debug!("migrating database");
sqlx::migrate!("migrations/postgres/")
.run(&pool)
client
.batch_execute(
r#"
CREATE TABLE IF NOT EXISTS manifests (
id UUID PRIMARY KEY NOT NULL,
generation BIGINT NOT NULL,
name TEXT NOT NULL,
kind TEXT NOT NULL,
status JSONB NOT NULL,
manifest_content JSONB NOT NULL,
manifest_hash BYTEA NOT NULL,
created TIMESTAMPTZ NOT NULL,
updated TIMESTAMPTZ NOT NULL,
lease_owner_id UUID,
lease_last_updated TIMESTAMPTZ,
weight BIGINT NOT NULL DEFAULT 1
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_manifest_name ON manifests(name, kind);
-- Migration for existing tables
ALTER TABLE manifests ADD COLUMN IF NOT EXISTS weight BIGINT NOT NULL DEFAULT 1;
"#,
)
.await
.context("failed to migrate")?;
Ok(Self {
_marker: PhantomData,
pool,
client: Arc::new(client),
})
}
#[cfg(not(feature = "postgres-tls"))]
async fn connect(database_url: &str) -> anyhow::Result<Client> {
let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
.await
.context("failed to connect to database")?;
tokio::spawn(async move {
if let Err(e) = connection.await {
tracing::error!("postgres connection error: {e}");
}
});
Ok(client)
}
#[cfg(feature = "postgres-tls")]
async fn connect(database_url: &str) -> anyhow::Result<Client> {
let native_certs = rustls_native_certs::load_native_certs();
if !native_certs.errors.is_empty() {
tracing::warn!("errors loading some native certs: {:?}", native_certs.errors);
}
anyhow::ensure!(!native_certs.certs.is_empty(), "no native TLS certificates found");
let mut root_store = rustls::RootCertStore::empty();
for cert in native_certs.certs {
root_store
.add(cert)
.context("failed to add root certificate")?;
}
let provider = rustls::crypto::CryptoProvider::get_default()
.cloned()
.unwrap_or_else(|| {
#[cfg(feature = "postgres-tls-aws-lc-rs")]
{
std::sync::Arc::new(rustls::crypto::aws_lc_rs::default_provider())
}
#[cfg(all(feature = "postgres-tls-ring", not(feature = "postgres-tls-aws-lc-rs")))]
{
std::sync::Arc::new(rustls::crypto::ring::default_provider())
}
#[cfg(not(any(feature = "postgres-tls-ring", feature = "postgres-tls-aws-lc-rs")))]
{
compile_error!(
"enable either `postgres-tls-ring` or `postgres-tls-aws-lc-rs` feature"
);
}
});
let tls_config = rustls::ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.context("failed to configure TLS protocol versions")?
.with_root_certificates(root_store)
.with_no_client_auth();
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
let (client, connection) = tokio_postgres::connect(database_url, tls)
.await
.context("failed to connect to database")?;
tokio::spawn(async move {
if let Err(e) = connection.await {
tracing::error!("postgres connection error: {e}");
}
});
Ok(client)
}
}
impl<T: Specification> BackingStoreEdge<T> for BackingStorePostgres<T> {
async fn get_owned_and_potential_leases(
&self,
worker_id: &uuid::Uuid,
) -> anyhow::Result<Vec<crate::manifests::ManifestState<T>>> {
let recs = sqlx::query!(
r#"
) -> anyhow::Result<Vec<ManifestState<T>>> {
let rows = self
.client
.query(
r#"
SELECT
id,
generation,
name,
kind,
status,
manifest_content,
manifest_hash,
created,
updated,
lease_owner_id,
lease_last_updated
FROM
manifests
id, generation, name, kind, status,
manifest_content, manifest_hash,
created, updated,
lease_owner_id, lease_last_updated
FROM manifests
WHERE
lease_last_updated < now() - INTERVAL '30 seconds'
OR (lease_owner_id = $1 AND lease_last_updated > now() - INTERVAL '15 seconds')
OR lease_owner_id IS NULL
"#,
worker_id
)
.fetch_all(&self.pool)
.await?;
"#,
&[worker_id],
)
.await?;
recs.into_iter()
.map(|r| {
let content: Manifest<T> = serde_json::from_value(r.manifest_content)?;
Ok(ManifestState {
manifest: content,
manifest_hash: r.manifest_hash,
generation: r.generation as u64,
status: serde_json::from_value(r.status)?,
created: Timestamp::from_millisecond(r.created.timestamp_millis())?,
updated: Timestamp::from_millisecond(r.updated.timestamp_millis())?,
lease: {
match (r.lease_owner_id, r.lease_last_updated) {
(Some(owner_id), Some(last_updated)) => Some(ManifestLease {
owner: owner_id,
last_seen: Timestamp::from_millisecond(
last_updated.timestamp_millis(),
)?,
}),
(_, _) => None,
}
},
})
})
rows.iter()
.map(row_to_manifest_state)
.collect::<anyhow::Result<Vec<_>>>()
}
async fn get_manifests(&self) -> anyhow::Result<Vec<crate::manifests::ManifestState<T>>> {
let recs = sqlx::query!(
r#"
async fn get_manifests(&self) -> anyhow::Result<Vec<ManifestState<T>>> {
let rows = self
.client
.query(
r#"
SELECT
id,
generation,
name,
kind,
status,
manifest_content,
manifest_hash,
created,
updated,
lease_owner_id,
lease_last_updated
FROM
manifests
"#
)
.fetch_all(&self.pool)
.await
.context("failed to get manifests from database")?;
id, generation, name, kind, status,
manifest_content, manifest_hash,
created, updated,
lease_owner_id, lease_last_updated
FROM manifests
"#,
&[],
)
.await
.context("failed to get manifests from database")?;
recs.into_iter()
.map(|r| {
let content: Manifest<T> = serde_json::from_value(r.manifest_content)?;
Ok(ManifestState {
manifest: content,
manifest_hash: r.manifest_hash,
generation: r.generation as u64,
status: serde_json::from_value(r.status)?,
created: Timestamp::from_millisecond(r.created.timestamp_millis())?,
updated: Timestamp::from_millisecond(r.updated.timestamp_millis())?,
lease: {
match (r.lease_owner_id, r.lease_last_updated) {
(Some(owner_id), Some(last_updated)) => Some(ManifestLease {
owner: owner_id,
last_seen: Timestamp::from_millisecond(
last_updated.timestamp_millis(),
)?,
}),
(_, _) => None,
}
},
})
})
rows.iter()
.map(row_to_manifest_state)
.collect::<anyhow::Result<Vec<_>>>()
}
async fn get(&self, name: &str) -> anyhow::Result<Option<ManifestState<T>>> {
let rec = sqlx::query!(
r#"
let row = self
.client
.query_opt(
r#"
SELECT
id,
generation,
name,
kind,
status,
manifest_content,
manifest_hash,
created,
updated,
lease_owner_id,
lease_last_updated
FROM
manifests
WHERE
name = $1
"#,
name
)
.fetch_optional(&self.pool)
.await
.context("failed to get")?;
id, generation, name, kind, status,
manifest_content, manifest_hash,
created, updated,
lease_owner_id, lease_last_updated
FROM manifests
WHERE name = $1
"#,
&[&name],
)
.await
.context("failed to get")?;
let Some(rec) = rec else { return Ok(None) };
let Some(row) = row else { return Ok(None) };
let content: Manifest<T> = serde_json::from_value(rec.manifest_content)?;
Ok(Some(ManifestState {
manifest: content,
manifest_hash: rec.manifest_hash,
generation: rec.generation as u64,
status: serde_json::from_value(rec.status)?,
created: Timestamp::from_millisecond(rec.created.timestamp_millis())?,
updated: Timestamp::from_millisecond(rec.updated.timestamp_millis())?,
lease: {
match (rec.lease_owner_id, rec.lease_last_updated) {
(Some(owner_id), Some(last_updated)) => Some(ManifestLease {
owner: owner_id,
last_seen: Timestamp::from_millisecond(last_updated.timestamp_millis())?,
}),
(_, _) => None,
}
},
}))
Ok(Some(row_to_manifest_state(&row)?))
}
async fn update_lease(
&self,
manifest_state: &crate::manifests::ManifestState<T>,
) -> anyhow::Result<()> {
let resp = sqlx::query!(
r#"
async fn update_lease(&self, manifest_state: &ManifestState<T>) -> anyhow::Result<()> {
let rows = self
.client
.execute(
r#"
UPDATE manifests
SET
lease_last_updated = now()
SET lease_last_updated = now()
WHERE
name = $1
AND kind = $2
AND generation = $3
-- AND owner_id = $4
"#,
manifest_state.manifest.name,
manifest_state.manifest.spec.kind(),
manifest_state.generation as i64,
// worker_id,
)
.execute(&self.pool)
.await
.context("failed to update lease")?;
"#,
&[
&manifest_state.manifest.name,
&manifest_state.manifest.spec.kind(),
&(manifest_state.generation as i64),
],
)
.await
.context("failed to update lease")?;
if resp.rows_affected() == 0 {
if rows == 0 {
anyhow::bail!("failed to update lease, the host is no longer the owner")
}
@@ -235,10 +268,12 @@ impl<T: Specification> BackingStoreEdge<T> for BackingStorePostgres<T> {
) -> anyhow::Result<()> {
let name = &manifest_state.manifest.name;
let kind = manifest_state.manifest.spec.kind();
let generation = manifest_state.generation;
let generation = manifest_state.generation as i64;
let resp = sqlx::query!(
r#"
let rows = self
.client
.execute(
r#"
UPDATE manifests
SET
lease_owner_id = $4,
@@ -248,30 +283,24 @@ impl<T: Specification> BackingStoreEdge<T> for BackingStorePostgres<T> {
name = $1
AND kind = $2
AND generation = $3
"#,
name,
kind,
generation as i64,
worker_id
)
.execute(&self.pool)
.await
.context("failed to acquire lease")?;
"#,
&[&name, &kind, &generation, worker_id],
)
.await
.context("failed to acquire lease")?;
if resp.rows_affected() == 0 {
if rows == 0 {
anyhow::bail!("failed to acquire lease: {}/{}@{}", kind, name, generation);
}
// TODO: maybe we should update fence as well
// manifest_state.generation = generation + 1;
Ok(())
}
async fn upsert_manifest(&self, manifest: crate::manifests::Manifest<T>) -> anyhow::Result<()> {
async fn upsert_manifest(&self, manifest: Manifest<T>) -> anyhow::Result<()> {
let id = uuid::Uuid::now_v7();
let name = &manifest.name;
let kind = manifest.spec.kind();
let weight = manifest.spec.weight() as i64;
let content = serde_json::to_value(&manifest)?;
let hash = &sha2::Sha256::digest(serde_json::to_vec(&content)?)[..];
let status = serde_json::to_value(ManifestStatus {
@@ -280,61 +309,39 @@ impl<T: Specification> BackingStoreEdge<T> for BackingStorePostgres<T> {
changes: vec![],
})?;
sqlx::query!(
r#"
INSERT INTO manifests (
id,
generation,
name,
kind,
status,
manifest_content,
manifest_hash,
lease_owner_id,
lease_last_updated,
created,
updated
) VALUES (
$1,
0,
$2,
$3,
$4,
$5,
$6,
NULL,
NULL,
now(),
now()
)
ON CONFLICT (name, kind) DO UPDATE
SET
manifest_content = $5,
updated = now()
"#,
id,
name,
kind,
status,
content,
hash
)
.execute(&self.pool)
.await
.context("failed to upsert manifest")?;
self.client
.execute(
r#"
INSERT INTO manifests (
id, generation, name, kind, status,
manifest_content, manifest_hash,
lease_owner_id, lease_last_updated,
created, updated, weight
) VALUES (
$1, 0, $2, $3, $4, $5, $6, NULL, NULL, now(), now(), $7
)
ON CONFLICT (name, kind) DO UPDATE
SET
manifest_content = $5,
updated = now(),
weight = $7
"#,
&[&id, &name, &kind, &status, &content, &hash, &weight],
)
.await
.context("failed to upsert manifest")?;
Ok(())
}
async fn update_state(
&self,
manifest: &crate::manifests::ManifestState<T>,
) -> anyhow::Result<()> {
let generation = manifest.generation;
async fn update_state(&self, manifest: &ManifestState<T>) -> anyhow::Result<()> {
let generation = manifest.generation as i64;
let status = serde_json::to_value(&manifest.status)?;
let resp = sqlx::query!(
r#"
let rows = self
.client
.execute(
r#"
UPDATE manifests
SET
generation = $3 + 1,
@@ -344,17 +351,18 @@ impl<T: Specification> BackingStoreEdge<T> for BackingStorePostgres<T> {
name = $1
AND kind = $2
AND generation = $3
"#,
manifest.manifest.name,
manifest.manifest.spec.kind(),
generation as i32,
status
)
.execute(&self.pool)
.await
.context("failed to update state")?;
"#,
&[
&manifest.manifest.name,
&manifest.manifest.spec.kind(),
&generation,
&status,
],
)
.await
.context("failed to update state")?;
if resp.rows_affected() == 0 {
if rows == 0 {
anyhow::bail!("failed to update state")
}
@@ -368,10 +376,12 @@ impl<T: Specification> BackingStoreEdge<T> for BackingStorePostgres<T> {
) -> anyhow::Result<()> {
let name = &manifest.manifest.name;
let kind = manifest.manifest.spec.kind();
let generation = manifest.generation;
let generation = manifest.generation as i64;
let resp = sqlx::query!(
r#"
let rows = self
.client
.execute(
r#"
UPDATE manifests
SET
lease_owner_id = NULL,
@@ -381,20 +391,48 @@ impl<T: Specification> BackingStoreEdge<T> for BackingStorePostgres<T> {
AND kind = $2
AND generation = $3
AND lease_owner_id = $4
"#,
name,
kind,
generation as i64,
worker_id,
)
.execute(&self.pool)
.await
.context("failed to update lease")?;
"#,
&[&name, &kind, &generation, worker_id],
)
.await
.context("failed to update lease")?;
if resp.rows_affected() == 0 {
if rows == 0 {
anyhow::bail!("failed to delete lease, the host is no longer the owner")
}
Ok(())
}
async fn get_cluster_stats(
&self,
worker_id: &uuid::Uuid,
) -> anyhow::Result<Option<ClusterStats>> {
let row = self
.client
.query_one(
r#"
SELECT
COALESCE(SUM(weight), 0)::BIGINT AS total_weight,
COUNT(DISTINCT lease_owner_id)::BIGINT AS active_workers,
COALESCE(SUM(CASE WHEN lease_owner_id = $1 THEN weight ELSE 0 END), 0)::BIGINT AS my_weight
FROM manifests
WHERE lease_owner_id IS NOT NULL
AND lease_last_updated > now() - INTERVAL '30 seconds'
"#,
&[worker_id],
)
.await
.context("failed to get cluster stats")?;
let total_weight: i64 = row.get("total_weight");
let active_workers: i64 = row.get("active_workers");
let my_weight: i64 = row.get("my_weight");
Ok(Some(ClusterStats {
total_weight: total_weight as u64,
active_workers: active_workers as usize,
my_weight: my_weight as u64,
}))
}
}

View File

@@ -2,10 +2,13 @@ use anyhow::Context;
use jiff::Timestamp;
use tokio_util::sync::CancellationToken;
use rand::seq::SliceRandom;
use crate::{
Operator, OperatorState,
Operator, OperatorState, Specification,
control_plane::backing_store::{BackingStore, BackingStoreEdge},
manifests::{Action, ManifestName, ManifestState, ManifestStatusState, WorkerId},
operator_state::{ClusterStats, RebalancePolicy},
reconcile_queue::ReconcileQueue,
};
@@ -102,6 +105,10 @@ impl<T: Operator, TStore: BackingStoreEdge<T::Specifications>> Reconciler<T, TSt
}
/// Single sync iteration - check for manifests, acquire leases, enqueue work.
/// Implements 3-layer capacity management:
/// 1. Hard capacity limit (max_capacity)
/// 2. Cluster-aware fair share (get_cluster_stats)
/// 3. Voluntary shedding (RebalancePolicy::FairShare)
async fn sync_once(&self) -> anyhow::Result<()> {
let manifests = self
.store
@@ -110,37 +117,93 @@ impl<T: Operator, TStore: BackingStoreEdge<T::Specifications>> Reconciler<T, TSt
tracing::trace!(manifests = manifests.len(), "sync once manifests");
for manifest_state in manifests {
// Partition into owned and unowned
let (mut owned, mut unowned): (Vec<_>, Vec<_>) =
manifests.into_iter().partition(|m| {
matches!(&m.lease, Some(lease) if lease.owner == self.worker_id)
});
let mut owned_weight: u64 = owned.iter().map(|m| m.manifest.spec.weight()).sum();
// Layer 2: Fetch cluster stats for fair share calculation
let cluster_stats = self
.store
.get_cluster_stats(&self.worker_id)
.await
.unwrap_or_else(|e| {
tracing::warn!(error = %e, "failed to get cluster stats, proceeding without fair share");
None
});
// Layer 3: Voluntary shedding (before heartbeat, so shed manifests
// don't get their leases refreshed)
if let Err(e) = self
.maybe_shed_manifests(&mut owned, &mut owned_weight, &cluster_stats)
.await
{
tracing::warn!(error = %e, "failed during manifest shedding");
}
// Heartbeat and check changes for remaining owned manifests
for manifest_state in &owned {
let manifest_name = manifest_state.manifest.name.clone();
match &manifest_state.lease {
Some(lease) if lease.owner == self.worker_id => {
tracing::trace!("updating lease");
tracing::trace!(%manifest_name, "updating lease");
if let Err(e) = self.store.update_lease(manifest_state).await {
// Generation mismatch means someone else took the lease — expected contention
tracing::trace!(error = %e, %manifest_name, "failed to update lease, likely lost");
continue;
}
// We own the lease, update it
self.store
.update_lease(&manifest_state)
.await
.context("update lease")?;
if self.needs_reconciliation(manifest_state) {
self.reconcile_queue.enqueue(manifest_name).await;
}
}
// Check if there are unhandled changes
if self.needs_reconciliation(&manifest_state) {
self.reconcile_queue.enqueue(manifest_name).await;
}
}
_ => {
tracing::trace!("acquiring lease");
// Shuffle unowned to prevent thundering herd
unowned.shuffle(&mut rand::rng());
// No lease, try to acquire
self.store
.acquire_lease(&manifest_state, &self.worker_id)
.await
.context("acquire lease")?;
// Layers 1+2: Compute effective capacity limit
let effective_limit = self.effective_capacity_limit(owned_weight, &cluster_stats);
// Enqueue for reconciliation
self.reconcile_queue.enqueue(manifest_name).await;
tracing::trace!(
owned_weight,
?effective_limit,
unowned_count = unowned.len(),
"capacity check before acquisition"
);
// Acquire unowned manifests, respecting capacity
for manifest_state in unowned {
let manifest_name = manifest_state.manifest.name.clone();
let manifest_weight = manifest_state.manifest.spec.weight();
if let Some(limit) = effective_limit {
if owned_weight + manifest_weight > limit {
tracing::trace!(
owned_weight,
manifest_weight,
limit,
%manifest_name,
"skipping acquisition: would exceed capacity"
);
continue;
}
}
tracing::trace!(%manifest_name, "acquiring lease");
if let Err(e) = self
.store
.acquire_lease(&manifest_state, &self.worker_id)
.await
{
// CAS failure means another worker grabbed it first — expected contention
tracing::trace!(error = %e, %manifest_name, "failed to acquire lease, likely contention");
continue;
}
owned_weight += manifest_weight;
self.reconcile_queue.enqueue(manifest_name).await;
}
Ok(())
@@ -204,6 +267,152 @@ impl<T: Operator, TStore: BackingStoreEdge<T::Specifications>> Reconciler<T, TSt
false
}
/// Compute the effective capacity limit for this worker.
/// Returns None if there is no capacity limit.
fn effective_capacity_limit(
&self,
_owned_weight: u64,
cluster_stats: &Option<ClusterStats>,
) -> Option<u64> {
let max_capacity = self.operator.config().max_capacity;
let headroom = match &self.operator.config().rebalance_policy {
RebalancePolicy::FairShare { headroom } => *headroom,
RebalancePolicy::Disabled => 0,
};
match (max_capacity, cluster_stats) {
(Some(max_cap), Some(stats)) if stats.active_workers > 0 => {
// Account for cold start: if we have no weight, we're not yet
// counted in active_workers, so add ourselves
let effective_workers = if stats.my_weight == 0 {
stats.active_workers + 1
} else {
stats.active_workers
};
let fair_share = stats.total_weight / effective_workers as u64;
let fair_limit = fair_share.saturating_add(headroom);
Some(max_cap.min(fair_limit))
}
(Some(max_cap), _) => Some(max_cap),
(None, Some(stats)) if stats.active_workers > 0 => {
let effective_workers = if stats.my_weight == 0 {
stats.active_workers + 1
} else {
stats.active_workers
};
let fair_share = stats.total_weight / effective_workers as u64;
Some(fair_share.saturating_add(headroom))
}
_ => None,
}
}
/// Voluntarily shed manifests when owned weight exceeds fair share + headroom.
async fn maybe_shed_manifests(
&self,
owned: &mut Vec<ManifestState<T::Specifications>>,
owned_weight: &mut u64,
cluster_stats: &Option<ClusterStats>,
) -> anyhow::Result<()> {
let headroom = match &self.operator.config().rebalance_policy {
RebalancePolicy::FairShare { headroom } => *headroom,
RebalancePolicy::Disabled => return Ok(()),
};
let Some(stats) = cluster_stats else {
return Ok(());
};
if stats.active_workers == 0 {
return Ok(());
}
let fair_share = stats.total_weight / stats.active_workers as u64;
let target = fair_share.saturating_add(headroom);
if *owned_weight <= target {
return Ok(());
}
tracing::info!(
owned_weight = *owned_weight,
target,
fair_share,
headroom,
active_workers = stats.active_workers,
"owned weight exceeds target, shedding manifests"
);
// Sort by weight descending — shed fewer, larger manifests first
owned.sort_by(|a, b| b.manifest.spec.weight().cmp(&a.manifest.spec.weight()));
// Pick manifests to shed that bring us closest to target
let mut to_shed: Vec<usize> = Vec::new();
let mut projected_weight = *owned_weight;
for (i, manifest) in owned.iter().enumerate() {
if projected_weight <= target {
break;
}
let w = manifest.manifest.spec.weight();
let after_shed = projected_weight.saturating_sub(w);
let overshoot = projected_weight.saturating_sub(target);
let undershoot = target.saturating_sub(after_shed);
// Shed if we'd still be at or above target, or if overshoot is worse than undershoot
if after_shed >= target || overshoot > undershoot {
to_shed.push(i);
projected_weight = after_shed;
}
}
// Execute sheds in reverse order to maintain index validity
for &idx in to_shed.iter().rev() {
let manifest = &owned[idx];
let w = manifest.manifest.spec.weight();
tracing::info!(
manifest = %manifest.manifest.name,
weight = w,
"shedding manifest for rebalancing"
);
if let Err(e) = self.store.delete_lease(manifest, &self.worker_id).await {
tracing::warn!(
error = %e,
manifest = %manifest.manifest.name,
"failed to delete lease during shedding"
);
continue;
}
if let Err(_e) = self.operator.on_lease_lost(manifest).await {
tracing::warn!(
manifest = %manifest.manifest.name,
"on_lease_lost failed during shedding"
);
}
*owned_weight -= w;
}
// Remove shed manifests from the owned list
for &idx in to_shed.iter().rev() {
owned.remove(idx);
}
tracing::info!(
owned_weight = *owned_weight,
target,
shed_count = to_shed.len(),
"shedding complete"
);
Ok(())
}
/// Process the reconciliation queue.
/// Takes items from the queue and reconciles them, re-enqueuing with delay if needed.
async fn process_queue(&self, cancellation_token: &CancellationToken) -> anyhow::Result<()> {
@@ -243,7 +452,7 @@ impl<T: Operator, TStore: BackingStoreEdge<T::Specifications>> Reconciler<T, TSt
match &manifest.lease {
Some(lease) if lease.owner == self.worker_id => {}
_ => {
tracing::debug!(%manifest_name, "we don't own the lease, shutting down owned resources");
tracing::trace!(%manifest_name, "we don't own the lease, shutting down owned resources");
self.operator
.on_lease_lost(&manifest)

View File

@@ -4,6 +4,13 @@ use crate::manifests::{Action, ManifestState};
pub trait Specification: Clone + Serialize + DeserializeOwned + Send + Sync + 'static {
fn kind(&self) -> &'static str;
/// Returns the weight of this specification for capacity management.
/// Higher weight means this manifest consumes more of a worker's capacity budget.
/// Default is 1.
fn weight(&self) -> u64 {
1
}
}
#[allow(dead_code, unused_variables)]

View File

@@ -67,6 +67,10 @@ pub struct OperatorConfig {
/// Interval at which all manifests are re-enqueued for reconciliation.
/// Default is 5 minutes.
pub resync_interval: Duration,
/// Maximum total weight this worker will manage. None means unlimited.
pub max_capacity: Option<u64>,
/// Policy for active rebalancing of manifests across workers.
pub rebalance_policy: RebalancePolicy,
}
impl Default for OperatorConfig {
@@ -77,6 +81,8 @@ impl Default for OperatorConfig {
},
reconcile_on: Default::default(),
resync_interval: Duration::from_secs(5 * 60),
max_capacity: None,
rebalance_policy: RebalancePolicy::default(),
}
}
}
@@ -103,3 +109,34 @@ impl Default for BackoffPolicy {
}
}
}
/// Policy for active rebalancing of manifests across workers.
#[derive(Clone, Debug)]
pub enum RebalancePolicy {
/// No active rebalancing. Workers only limit acquisition via max_capacity.
Disabled,
/// Fair-share rebalancing. When a worker's owned weight exceeds
/// fair_share + headroom, it voluntarily releases excess manifests.
FairShare {
/// Extra weight budget above fair share before shedding begins.
/// Prevents thrashing when weights don't divide evenly.
headroom: u64,
},
}
impl Default for RebalancePolicy {
fn default() -> Self {
RebalancePolicy::Disabled
}
}
/// Statistics about the cluster's current weight distribution.
#[derive(Debug, Clone)]
pub struct ClusterStats {
/// Total weight of all manifests with active leases across the cluster.
pub total_weight: u64,
/// Number of workers with active leases.
pub active_workers: usize,
/// Total weight of manifests owned by the requesting worker.
pub my_weight: u64,
}

View File

@@ -29,7 +29,6 @@ async fn main() -> anyhow::Result<()> {
EnvFilter::from_default_env()
.add_directive("nocontrol=trace".parse().unwrap())
.add_directive("postgres_backend=trace".parse().unwrap())
.add_directive("sqlx=warn".parse().unwrap())
.add_directive("debug".parse().unwrap()),
)
.with_file(false)

View File

@@ -0,0 +1,16 @@
[package]
name = "rebalancing-stress"
version = "0.1.0"
edition = "2024"
publish = false
[dependencies]
nocontrol.workspace = true
anyhow.workspace = true
tokio.workspace = true
serde.workspace = true
tracing-subscriber.workspace = true
tracing.workspace = true
uuid.workspace = true
tokio-util = { version = "0.7", features = ["rt"] }

View File

@@ -0,0 +1,268 @@
//! Stress test for weight-based rebalancing.
//!
//! Simulates multiple workers sharing an in-process backing store.
//! Manifests have varying weights. Workers have capacity limits and
//! use FairShare rebalancing to redistribute work as nodes join/leave.
//!
//! Run with: RUST_LOG=info cargo run -p rebalancing-stress
use std::time::Duration;
use nocontrol::{
ControlPlane, Operator, OperatorConfig, OperatorState, RebalancePolicy, Specification,
manifests::{Action, Manifest, ManifestMetadata, ManifestState},
stores::{BackingStore, BackingStoreEdge},
};
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use tracing_subscriber::EnvFilter;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
)
.with_target(false)
.without_time()
.init();
let store = BackingStore::in_process();
// Create 20 manifests with varying weights (total weight = 110)
let manifests = vec![
("heavy-job-1", 10),
("heavy-job-2", 10),
("heavy-job-3", 10),
("medium-job-1", 5),
("medium-job-2", 5),
("medium-job-3", 5),
("medium-job-4", 5),
("medium-job-5", 5),
("medium-job-6", 5),
("light-job-1", 1),
("light-job-2", 1),
("light-job-3", 1),
("light-job-4", 1),
("light-job-5", 1),
("light-job-6", 1),
("light-job-7", 1),
("light-job-8", 1),
("light-job-9", 1),
("light-job-10", 1),
("tiny-job-1", 0),
];
let total_weight: u64 = manifests.iter().map(|(_, w)| *w).sum();
tracing::info!(
manifest_count = manifests.len(),
total_weight,
"creating manifests"
);
// Insert all manifests into the shared store using a temporary control plane
let seed_operator = OperatorState::new(StressOperator);
let seed_cp = ControlPlane::new(seed_operator, store.clone());
for (name, weight) in &manifests {
seed_cp
.add_manifest(Manifest {
name: name.to_string(),
metadata: ManifestMetadata {},
spec: WeightedJob {
weight: *weight,
name: name.to_string(),
},
})
.await?;
}
let cancellation = CancellationToken::new();
// --- Phase 1: Start 2 workers ---
tracing::info!("=== PHASE 1: Starting 2 workers (capacity=60 each, headroom=5) ===");
let worker1 = spawn_worker("worker-1", store.clone(), 60, 5, cancellation.child_token());
let worker2 = spawn_worker("worker-2", store.clone(), 60, 5, cancellation.child_token());
// Let them stabilize
tokio::time::sleep(Duration::from_secs(15)).await;
print_distribution(&seed_cp).await;
// --- Phase 2: Add a 3rd worker ---
tracing::info!("=== PHASE 2: Adding worker-3 ===");
let worker3 = spawn_worker("worker-3", store.clone(), 60, 5, cancellation.child_token());
// Let rebalancing happen
tokio::time::sleep(Duration::from_secs(25)).await;
print_distribution(&seed_cp).await;
// --- Phase 3: Add a 4th worker with low capacity ---
tracing::info!("=== PHASE 3: Adding worker-4 (capacity=15) ===");
let worker4 = spawn_worker("worker-4", store.clone(), 15, 2, cancellation.child_token());
tokio::time::sleep(Duration::from_secs(25)).await;
print_distribution(&seed_cp).await;
// --- Phase 4: Kill worker-1, observe redistribution ---
tracing::info!("=== PHASE 4: Killing worker-1, observing redistribution ===");
worker1.cancel();
// Wait for lease expiry (10s in-process) + sync cycles
tokio::time::sleep(Duration::from_secs(25)).await;
print_distribution(&seed_cp).await;
// Cleanup
tracing::info!("=== DONE: Shutting down all workers ===");
cancellation.cancel();
worker2.cancel();
worker3.cancel();
worker4.cancel();
// Give workers time to shut down gracefully
tokio::time::sleep(Duration::from_secs(2)).await;
Ok(())
}
fn spawn_worker<TStore: BackingStoreEdge<WeightedJob> + 'static>(
name: &'static str,
store: BackingStore<WeightedJob, TStore>,
max_capacity: u64,
headroom: u64,
cancellation: CancellationToken,
) -> CancellationToken {
let worker_cancel = CancellationToken::new();
tokio::spawn({
let cancel = worker_cancel.clone();
async move {
let config = OperatorConfig {
max_capacity: Some(max_capacity),
rebalance_policy: RebalancePolicy::FairShare { headroom },
resync_interval: Duration::from_secs(60),
..Default::default()
};
let operator = OperatorState::new_with_config(StressOperator, config);
let cp = ControlPlane::new(operator, store);
tracing::info!(%name, max_capacity, headroom, "worker started");
let combined = CancellationToken::new();
let combined_child = combined.child_token();
tokio::spawn({
let combined = combined.clone();
async move {
tokio::select! {
_ = cancel.cancelled() => {}
_ = cancellation.cancelled() => {}
}
combined.cancel();
}
});
if let Err(e) = cp.execute_with_cancellation(combined_child).await {
tracing::error!(%name, error = %e, "worker failed");
}
tracing::info!(%name, "worker stopped");
}
});
worker_cancel
}
async fn print_distribution<TOperator, TStore>(cp: &ControlPlane<TOperator, TStore>)
where
TOperator: Operator<Specifications = WeightedJob>,
TStore: BackingStoreEdge<WeightedJob>,
{
let manifests = cp.get_manifests().await.unwrap_or_default();
let mut by_worker: std::collections::HashMap<String, (usize, u64)> =
std::collections::HashMap::new();
let mut unowned = Vec::new();
for m in &manifests {
let w = m.manifest.spec.weight;
match &m.lease {
Some(lease) => {
let entry = by_worker
.entry(format!("{}", lease.owner))
.or_insert((0, 0));
entry.0 += 1;
entry.1 += w;
}
None => {
unowned.push(m.manifest.name.as_str());
}
}
}
tracing::info!("--- Distribution ---");
let mut workers: Vec<_> = by_worker.into_iter().collect();
workers.sort_by_key(|(id, _)| id.clone());
for (worker_id, (count, weight)) in &workers {
tracing::info!(worker = %worker_id, count, weight, "");
}
if !unowned.is_empty() {
tracing::info!(count = unowned.len(), "unowned manifests");
}
let total_owned_weight: u64 = workers.iter().map(|(_, (_, w))| w).sum();
tracing::info!(
total_owned_weight,
total_manifests = manifests.len(),
workers = workers.len(),
"summary"
);
tracing::info!("--------------------");
}
// --- Operator and Specification ---
#[derive(Clone)]
struct StressOperator;
impl Operator for StressOperator {
type Specifications = WeightedJob;
type Error = anyhow::Error;
async fn reconcile(
&self,
manifest: &mut ManifestState<WeightedJob>,
) -> Result<Action, Self::Error> {
// Simulate work proportional to weight
let work_ms = manifest.manifest.spec.weight * 10;
tokio::time::sleep(Duration::from_millis(work_ms)).await;
Ok(Action::Requeue(Duration::from_secs(5)))
}
async fn on_lease_lost(
&self,
manifest: &ManifestState<WeightedJob>,
) -> Result<(), Self::Error> {
tracing::debug!(
manifest = %manifest.manifest.name,
"lease lost, cleaning up"
);
Ok(())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct WeightedJob {
pub name: String,
pub weight: u64,
}
impl Specification for WeightedJob {
fn kind(&self) -> &'static str {
"weighted-job"
}
fn weight(&self) -> u64 {
self.weight
}
}