112
crates/forage-core/src/auth/mod.rs
Normal file
112
crates/forage-core/src/auth/mod.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
mod validation;
|
||||
|
||||
pub use validation::{validate_email, validate_password, validate_username};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Tokens returned by forest-server after login/register.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AuthTokens {
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
pub expires_in_seconds: i64,
|
||||
}
|
||||
|
||||
/// Minimal user info from forest-server.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct User {
|
||||
pub user_id: String,
|
||||
pub username: String,
|
||||
pub emails: Vec<UserEmail>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UserEmail {
|
||||
pub email: String,
|
||||
pub verified: bool,
|
||||
}
|
||||
|
||||
/// A personal access token (metadata only, no raw key).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PersonalAccessToken {
|
||||
pub token_id: String,
|
||||
pub name: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub created_at: Option<String>,
|
||||
pub last_used: Option<String>,
|
||||
pub expires_at: Option<String>,
|
||||
}
|
||||
|
||||
/// Result of creating a PAT - includes the raw key shown once.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CreatedToken {
|
||||
pub token: PersonalAccessToken,
|
||||
pub raw_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, thiserror::Error)]
|
||||
pub enum AuthError {
|
||||
#[error("invalid credentials")]
|
||||
InvalidCredentials,
|
||||
|
||||
#[error("already exists: {0}")]
|
||||
AlreadyExists(String),
|
||||
|
||||
#[error("not authenticated")]
|
||||
NotAuthenticated,
|
||||
|
||||
#[error("token expired")]
|
||||
TokenExpired,
|
||||
|
||||
#[error("forest-server unavailable: {0}")]
|
||||
Unavailable(String),
|
||||
|
||||
#[error("{0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
/// Trait for communicating with forest-server's UsersService.
|
||||
/// Object-safe via async_trait so we can use `Arc<dyn ForestAuth>`.
|
||||
#[async_trait::async_trait]
|
||||
pub trait ForestAuth: Send + Sync {
|
||||
async fn register(
|
||||
&self,
|
||||
username: &str,
|
||||
email: &str,
|
||||
password: &str,
|
||||
) -> Result<AuthTokens, AuthError>;
|
||||
|
||||
async fn login(
|
||||
&self,
|
||||
identifier: &str,
|
||||
password: &str,
|
||||
) -> Result<AuthTokens, AuthError>;
|
||||
|
||||
async fn refresh_token(
|
||||
&self,
|
||||
refresh_token: &str,
|
||||
) -> Result<AuthTokens, AuthError>;
|
||||
|
||||
async fn logout(&self, refresh_token: &str) -> Result<(), AuthError>;
|
||||
|
||||
async fn get_user(&self, access_token: &str) -> Result<User, AuthError>;
|
||||
|
||||
async fn list_tokens(
|
||||
&self,
|
||||
access_token: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<PersonalAccessToken>, AuthError>;
|
||||
|
||||
async fn create_token(
|
||||
&self,
|
||||
access_token: &str,
|
||||
user_id: &str,
|
||||
name: &str,
|
||||
) -> Result<CreatedToken, AuthError>;
|
||||
|
||||
async fn delete_token(
|
||||
&self,
|
||||
access_token: &str,
|
||||
token_id: &str,
|
||||
) -> Result<(), AuthError>;
|
||||
}
|
||||
120
crates/forage-core/src/auth/validation.rs
Normal file
120
crates/forage-core/src/auth/validation.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct ValidationError(pub String);
|
||||
|
||||
pub fn validate_email(email: &str) -> Result<(), ValidationError> {
|
||||
if email.is_empty() {
|
||||
return Err(ValidationError("Email is required".into()));
|
||||
}
|
||||
if !email.contains('@') || !email.contains('.') {
|
||||
return Err(ValidationError("Invalid email format".into()));
|
||||
}
|
||||
if email.len() > 254 {
|
||||
return Err(ValidationError("Email too long".into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate_password(password: &str) -> Result<(), ValidationError> {
|
||||
if password.is_empty() {
|
||||
return Err(ValidationError("Password is required".into()));
|
||||
}
|
||||
if password.len() < 12 {
|
||||
return Err(ValidationError(
|
||||
"Password must be at least 12 characters".into(),
|
||||
));
|
||||
}
|
||||
if password.len() > 1024 {
|
||||
return Err(ValidationError("Password too long".into()));
|
||||
}
|
||||
if !password.chars().any(|c| c.is_uppercase()) {
|
||||
return Err(ValidationError(
|
||||
"Password must contain at least one uppercase letter".into(),
|
||||
));
|
||||
}
|
||||
if !password.chars().any(|c| c.is_lowercase()) {
|
||||
return Err(ValidationError(
|
||||
"Password must contain at least one lowercase letter".into(),
|
||||
));
|
||||
}
|
||||
if !password.chars().any(|c| c.is_ascii_digit()) {
|
||||
return Err(ValidationError(
|
||||
"Password must contain at least one digit".into(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate_username(username: &str) -> Result<(), ValidationError> {
|
||||
if username.is_empty() {
|
||||
return Err(ValidationError("Username is required".into()));
|
||||
}
|
||||
if username.len() < 3 {
|
||||
return Err(ValidationError(
|
||||
"Username must be at least 3 characters".into(),
|
||||
));
|
||||
}
|
||||
if username.len() > 64 {
|
||||
return Err(ValidationError("Username too long".into()));
|
||||
}
|
||||
if !username
|
||||
.chars()
|
||||
.all(|c| c.is_alphanumeric() || c == '-' || c == '_')
|
||||
{
|
||||
return Err(ValidationError(
|
||||
"Username can only contain letters, numbers, hyphens, and underscores".into(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn valid_email() {
|
||||
assert!(validate_email("user@example.com").is_ok());
|
||||
assert!(validate_email("a@b.c").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_email() {
|
||||
assert!(validate_email("").is_err());
|
||||
assert!(validate_email("noat").is_err());
|
||||
assert!(validate_email("no@dot").is_err());
|
||||
assert!(validate_email(&format!("{}@b.c", "a".repeat(251))).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_password() {
|
||||
assert!(validate_password("SecurePass123").is_ok());
|
||||
assert!(validate_password("MyLongPassphrase1").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_password() {
|
||||
assert!(validate_password("").is_err());
|
||||
assert!(validate_password("short").is_err());
|
||||
assert!(validate_password("12345678901").is_err()); // 11 chars
|
||||
assert!(validate_password(&"a".repeat(1025)).is_err());
|
||||
assert!(validate_password("alllowercase1").is_err()); // no uppercase
|
||||
assert!(validate_password("ALLUPPERCASE1").is_err()); // no lowercase
|
||||
assert!(validate_password("NoDigitsHere!").is_err()); // no digit
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_username() {
|
||||
assert!(validate_username("alice").is_ok());
|
||||
assert!(validate_username("bob-123").is_ok());
|
||||
assert!(validate_username("foo_bar").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_username() {
|
||||
assert!(validate_username("").is_err());
|
||||
assert!(validate_username("ab").is_err());
|
||||
assert!(validate_username("has spaces").is_err());
|
||||
assert!(validate_username("has@symbol").is_err());
|
||||
assert!(validate_username(&"a".repeat(65)).is_err());
|
||||
}
|
||||
}
|
||||
1
crates/forage-core/src/billing/mod.rs
Normal file
1
crates/forage-core/src/billing/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
// Billing and pricing logic - usage tracking, plan management.
|
||||
1
crates/forage-core/src/deployments/mod.rs
Normal file
1
crates/forage-core/src/deployments/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
// Deployment orchestration logic - managing deployment lifecycle.
|
||||
6
crates/forage-core/src/lib.rs
Normal file
6
crates/forage-core/src/lib.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod auth;
|
||||
pub mod session;
|
||||
pub mod platform;
|
||||
pub mod registry;
|
||||
pub mod deployments;
|
||||
pub mod billing;
|
||||
101
crates/forage-core/src/platform/mod.rs
Normal file
101
crates/forage-core/src/platform/mod.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Validate that a slug (org name, project name) is safe for use in URLs and templates.
|
||||
/// Allows lowercase alphanumeric, hyphens, max 64 chars. Must not be empty.
|
||||
pub fn validate_slug(s: &str) -> bool {
|
||||
!s.is_empty()
|
||||
&& s.len() <= 64
|
||||
&& s.chars()
|
||||
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-')
|
||||
&& !s.starts_with('-')
|
||||
&& !s.ends_with('-')
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Organisation {
|
||||
pub organisation_id: String,
|
||||
pub name: String,
|
||||
pub role: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Artifact {
|
||||
pub artifact_id: String,
|
||||
pub slug: String,
|
||||
pub context: ArtifactContext,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ArtifactContext {
|
||||
pub title: String,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, thiserror::Error)]
|
||||
pub enum PlatformError {
|
||||
#[error("not authenticated")]
|
||||
NotAuthenticated,
|
||||
|
||||
#[error("not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("service unavailable: {0}")]
|
||||
Unavailable(String),
|
||||
|
||||
#[error("{0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
/// Trait for platform data from forest-server (organisations, projects, artifacts).
|
||||
/// Separate from `ForestAuth` which handles identity.
|
||||
#[async_trait::async_trait]
|
||||
pub trait ForestPlatform: Send + Sync {
|
||||
async fn list_my_organisations(
|
||||
&self,
|
||||
access_token: &str,
|
||||
) -> Result<Vec<Organisation>, PlatformError>;
|
||||
|
||||
async fn list_projects(
|
||||
&self,
|
||||
access_token: &str,
|
||||
organisation: &str,
|
||||
) -> Result<Vec<String>, PlatformError>;
|
||||
|
||||
async fn list_artifacts(
|
||||
&self,
|
||||
access_token: &str,
|
||||
organisation: &str,
|
||||
project: &str,
|
||||
) -> Result<Vec<Artifact>, PlatformError>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn valid_slugs() {
|
||||
assert!(validate_slug("my-org"));
|
||||
assert!(validate_slug("a"));
|
||||
assert!(validate_slug("abc123"));
|
||||
assert!(validate_slug("my-cool-project-2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_slugs() {
|
||||
assert!(!validate_slug(""));
|
||||
assert!(!validate_slug("-starts-with-dash"));
|
||||
assert!(!validate_slug("ends-with-dash-"));
|
||||
assert!(!validate_slug("UPPERCASE"));
|
||||
assert!(!validate_slug("has spaces"));
|
||||
assert!(!validate_slug("has_underscores"));
|
||||
assert!(!validate_slug("has.dots"));
|
||||
assert!(!validate_slug(&"a".repeat(65)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_length_slug_is_valid() {
|
||||
assert!(validate_slug(&"a".repeat(64)));
|
||||
}
|
||||
}
|
||||
1
crates/forage-core/src/registry/mod.rs
Normal file
1
crates/forage-core/src/registry/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
// Component registry logic - discovering, resolving, and managing forest components.
|
||||
260
crates/forage-core/src/session/mod.rs
Normal file
260
crates/forage-core/src/session/mod.rs
Normal file
@@ -0,0 +1,260 @@
|
||||
mod store;
|
||||
|
||||
pub use store::InMemorySessionStore;
|
||||
|
||||
use crate::auth::UserEmail;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Opaque session identifier. 32 bytes of cryptographic randomness, base64url-encoded.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct SessionId(String);
|
||||
|
||||
impl SessionId {
|
||||
pub fn generate() -> Self {
|
||||
use rand::Rng;
|
||||
let mut bytes = [0u8; 32];
|
||||
rand::rng().fill(&mut bytes);
|
||||
Self(base64url_encode(&bytes))
|
||||
}
|
||||
|
||||
/// Construct from a raw cookie value. No validation - it's just a lookup key.
|
||||
pub fn from_raw(s: String) -> Self {
|
||||
Self(s)
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for SessionId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(&self.0)
|
||||
}
|
||||
}
|
||||
|
||||
fn base64url_encode(bytes: &[u8]) -> String {
|
||||
use std::fmt::Write;
|
||||
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
|
||||
let mut out = String::with_capacity((bytes.len() * 4).div_ceil(3));
|
||||
for chunk in bytes.chunks(3) {
|
||||
let n = match chunk.len() {
|
||||
3 => (chunk[0] as u32) << 16 | (chunk[1] as u32) << 8 | chunk[2] as u32,
|
||||
2 => (chunk[0] as u32) << 16 | (chunk[1] as u32) << 8,
|
||||
1 => (chunk[0] as u32) << 16,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let _ = out.write_char(CHARS[((n >> 18) & 0x3F) as usize] as char);
|
||||
let _ = out.write_char(CHARS[((n >> 12) & 0x3F) as usize] as char);
|
||||
if chunk.len() > 1 {
|
||||
let _ = out.write_char(CHARS[((n >> 6) & 0x3F) as usize] as char);
|
||||
}
|
||||
if chunk.len() > 2 {
|
||||
let _ = out.write_char(CHARS[(n & 0x3F) as usize] as char);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Cached user info stored in the session to avoid repeated gRPC calls.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CachedUser {
|
||||
pub user_id: String,
|
||||
pub username: String,
|
||||
pub emails: Vec<UserEmail>,
|
||||
#[serde(default)]
|
||||
pub orgs: Vec<CachedOrg>,
|
||||
}
|
||||
|
||||
/// Cached organisation membership.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CachedOrg {
|
||||
pub name: String,
|
||||
pub role: String,
|
||||
}
|
||||
|
||||
/// Generate a CSRF token (16 bytes of randomness, base64url-encoded).
|
||||
pub fn generate_csrf_token() -> String {
|
||||
use rand::Rng;
|
||||
let mut bytes = [0u8; 16];
|
||||
rand::rng().fill(&mut bytes);
|
||||
base64url_encode(&bytes)
|
||||
}
|
||||
|
||||
/// Server-side session data. Never exposed to the browser.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionData {
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
pub access_expires_at: DateTime<Utc>,
|
||||
pub user: Option<CachedUser>,
|
||||
pub csrf_token: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub last_seen_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl SessionData {
|
||||
/// Whether the access token is expired or will expire within the given margin.
|
||||
pub fn is_access_expired(&self, margin: chrono::Duration) -> bool {
|
||||
Utc::now() + margin >= self.access_expires_at
|
||||
}
|
||||
|
||||
/// Whether the access token needs refreshing (expired or within 60s of expiry).
|
||||
pub fn needs_refresh(&self) -> bool {
|
||||
self.is_access_expired(chrono::Duration::seconds(60))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum SessionError {
|
||||
#[error("session store error: {0}")]
|
||||
Store(String),
|
||||
}
|
||||
|
||||
/// Trait for session persistence. Swappable between in-memory, Redis, Postgres.
|
||||
#[async_trait::async_trait]
|
||||
pub trait SessionStore: Send + Sync {
|
||||
async fn create(&self, data: SessionData) -> Result<SessionId, SessionError>;
|
||||
async fn get(&self, id: &SessionId) -> Result<Option<SessionData>, SessionError>;
|
||||
async fn update(&self, id: &SessionId, data: SessionData) -> Result<(), SessionError>;
|
||||
async fn delete(&self, id: &SessionId) -> Result<(), SessionError>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[test]
|
||||
fn session_id_generates_unique_ids() {
|
||||
let ids: HashSet<String> = (0..1000).map(|_| SessionId::generate().0).collect();
|
||||
assert_eq!(ids.len(), 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_id_is_base64url_safe() {
|
||||
for _ in 0..100 {
|
||||
let id = SessionId::generate();
|
||||
let s = id.as_str();
|
||||
assert!(
|
||||
s.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
|
||||
"invalid chars in session id: {s}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_id_has_sufficient_length() {
|
||||
// 32 bytes -> ~43 base64url chars
|
||||
let id = SessionId::generate();
|
||||
assert!(id.as_str().len() >= 42, "session id too short: {}", id.as_str().len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_data_not_expired() {
|
||||
let data = SessionData {
|
||||
access_token: "tok".into(),
|
||||
refresh_token: "ref".into(),
|
||||
csrf_token: "test-csrf".into(),
|
||||
access_expires_at: Utc::now() + chrono::Duration::hours(1),
|
||||
user: None,
|
||||
created_at: Utc::now(),
|
||||
last_seen_at: Utc::now(),
|
||||
};
|
||||
assert!(!data.is_access_expired(chrono::Duration::zero()));
|
||||
assert!(!data.needs_refresh());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_data_expired() {
|
||||
let data = SessionData {
|
||||
access_token: "tok".into(),
|
||||
refresh_token: "ref".into(),
|
||||
csrf_token: "test-csrf".into(),
|
||||
access_expires_at: Utc::now() - chrono::Duration::seconds(1),
|
||||
user: None,
|
||||
created_at: Utc::now(),
|
||||
last_seen_at: Utc::now(),
|
||||
};
|
||||
assert!(data.is_access_expired(chrono::Duration::zero()));
|
||||
assert!(data.needs_refresh());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_data_needs_refresh_within_margin() {
|
||||
let data = SessionData {
|
||||
access_token: "tok".into(),
|
||||
refresh_token: "ref".into(),
|
||||
csrf_token: "test-csrf".into(),
|
||||
access_expires_at: Utc::now() + chrono::Duration::seconds(30),
|
||||
user: None,
|
||||
created_at: Utc::now(),
|
||||
last_seen_at: Utc::now(),
|
||||
};
|
||||
// Not expired yet, but within 60s margin
|
||||
assert!(!data.is_access_expired(chrono::Duration::zero()));
|
||||
assert!(data.needs_refresh());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn in_memory_store_create_and_get() {
|
||||
let store = InMemorySessionStore::new();
|
||||
let data = make_session_data();
|
||||
let id = store.create(data.clone()).await.unwrap();
|
||||
let retrieved = store.get(&id).await.unwrap().expect("session should exist");
|
||||
assert_eq!(retrieved.access_token, data.access_token);
|
||||
assert_eq!(retrieved.refresh_token, data.refresh_token);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn in_memory_store_get_nonexistent_returns_none() {
|
||||
let store = InMemorySessionStore::new();
|
||||
let id = SessionId::generate();
|
||||
assert!(store.get(&id).await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn in_memory_store_update() {
|
||||
let store = InMemorySessionStore::new();
|
||||
let data = make_session_data();
|
||||
let id = store.create(data).await.unwrap();
|
||||
|
||||
let mut updated = make_session_data();
|
||||
updated.access_token = "new-access".into();
|
||||
store.update(&id, updated).await.unwrap();
|
||||
|
||||
let retrieved = store.get(&id).await.unwrap().unwrap();
|
||||
assert_eq!(retrieved.access_token, "new-access");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn in_memory_store_delete() {
|
||||
let store = InMemorySessionStore::new();
|
||||
let data = make_session_data();
|
||||
let id = store.create(data).await.unwrap();
|
||||
store.delete(&id).await.unwrap();
|
||||
assert!(store.get(&id).await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn in_memory_store_delete_nonexistent_is_ok() {
|
||||
let store = InMemorySessionStore::new();
|
||||
let id = SessionId::generate();
|
||||
// Should not error
|
||||
store.delete(&id).await.unwrap();
|
||||
}
|
||||
|
||||
fn make_session_data() -> SessionData {
|
||||
SessionData {
|
||||
access_token: "test-access".into(),
|
||||
refresh_token: "test-refresh".into(),
|
||||
csrf_token: "test-csrf".into(),
|
||||
access_expires_at: Utc::now() + chrono::Duration::hours(1),
|
||||
user: None,
|
||||
created_at: Utc::now(),
|
||||
last_seen_at: Utc::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
66
crates/forage-core/src/session/store.rs
Normal file
66
crates/forage-core/src/session/store.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use chrono::{Duration, Utc};
|
||||
|
||||
use super::{SessionData, SessionError, SessionId, SessionStore};
|
||||
|
||||
/// In-memory session store. Suitable for single-instance deployments.
|
||||
/// Sessions are lost on server restart.
|
||||
pub struct InMemorySessionStore {
|
||||
sessions: RwLock<HashMap<SessionId, SessionData>>,
|
||||
max_inactive: Duration,
|
||||
}
|
||||
|
||||
impl Default for InMemorySessionStore {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl InMemorySessionStore {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sessions: RwLock::new(HashMap::new()),
|
||||
max_inactive: Duration::days(30),
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove sessions inactive for longer than `max_inactive`.
|
||||
pub fn reap_expired(&self) {
|
||||
let cutoff = Utc::now() - self.max_inactive;
|
||||
let mut sessions = self.sessions.write().unwrap();
|
||||
sessions.retain(|_, data| data.last_seen_at > cutoff);
|
||||
}
|
||||
|
||||
pub fn session_count(&self) -> usize {
|
||||
self.sessions.read().unwrap().len()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl SessionStore for InMemorySessionStore {
|
||||
async fn create(&self, data: SessionData) -> Result<SessionId, SessionError> {
|
||||
let id = SessionId::generate();
|
||||
let mut sessions = self.sessions.write().unwrap();
|
||||
sessions.insert(id.clone(), data);
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
async fn get(&self, id: &SessionId) -> Result<Option<SessionData>, SessionError> {
|
||||
let sessions = self.sessions.read().unwrap();
|
||||
Ok(sessions.get(id).cloned())
|
||||
}
|
||||
|
||||
async fn update(&self, id: &SessionId, data: SessionData) -> Result<(), SessionError> {
|
||||
let mut sessions = self.sessions.write().unwrap();
|
||||
sessions.insert(id.clone(), data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete(&self, id: &SessionId) -> Result<(), SessionError> {
|
||||
let mut sessions = self.sessions.write().unwrap();
|
||||
sessions.remove(id);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user