Our Personal Data Server from scratch! tranquil.farm
oauth atproto pds rust postgresql objectstorage fun

fix: smaller docker img #19

merged opened by lewis.moe targeting main from fix/smaller-docker-img

Especially when using this new no-default-features thingy, the docker img (and regular binary for that matter) will be a decent amount smaller

Labels

None yet.

assignee

None yet.

Participants 1
AT URI
at://did:plc:3fwecdnvtcscjnrx2p4n7alz/sh.tangled.repo.pull/3mee377z4w422
+823 -735
Diff #0
+7 -76
Cargo.lock
··· 2744 2744 "tower-service", 2745 2745 ] 2746 2746 2747 - [[package]] 2748 - name = "hyper-tls" 2749 - version = "0.6.0" 2750 - source = "registry+https://github.com/rust-lang/crates.io-index" 2751 - checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" 2752 - dependencies = [ 2753 - "bytes", 2754 - "http-body-util", 2755 - "hyper 1.8.1", 2756 - "hyper-util", 2757 - "native-tls", 2758 - "tokio", 2759 - "tokio-native-tls", 2760 - "tower-service", 2761 - ] 2762 - 2763 2747 [[package]] 2764 2748 name = "hyper-util" 2765 2749 version = "0.1.19" ··· 3585 3569 "web-time", 3586 3570 ] 3587 3571 3588 - [[package]] 3589 - name = "native-tls" 3590 - version = "0.2.14" 3591 - source = "registry+https://github.com/rust-lang/crates.io-index" 3592 - checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" 3593 - dependencies = [ 3594 - "libc", 3595 - "log", 3596 - "openssl", 3597 - "openssl-probe", 3598 - "openssl-sys", 3599 - "schannel", 3600 - "security-framework 2.11.1", 3601 - "security-framework-sys", 3602 - "tempfile", 3603 - ] 3604 - 3605 3572 [[package]] 3606 3573 name = "nom" 3607 3574 version = "7.1.3" ··· 4498 4465 "http-body-util", 4499 4466 "hyper 1.8.1", 4500 4467 "hyper-rustls 0.27.7", 4501 - "hyper-tls", 4502 4468 "hyper-util", 4503 4469 "js-sys", 4504 4470 "log", 4505 4471 "mime", 4506 - "native-tls", 4507 4472 "percent-encoding", 4508 4473 "pin-project-lite", 4509 4474 "quinn", ··· 4514 4479 "serde_urlencoded", 4515 4480 "sync_wrapper", 4516 4481 "tokio", 4517 - "tokio-native-tls", 4518 4482 "tokio-rustls 0.26.4", 4519 4483 "tower", 4520 4484 "tower-http", ··· 4661 4625 "openssl-probe", 4662 4626 "rustls-pki-types", 4663 4627 "schannel", 4664 - "security-framework 3.5.1", 4628 + "security-framework", 4665 4629 ] 4666 4630 4667 4631 [[package]] ··· 4800 4764 "zeroize", 4801 4765 ] 4802 4766 4803 - [[package]] 4804 - name = "security-framework" 4805 - version = "2.11.1" 4806 - source = "registry+https://github.com/rust-lang/crates.io-index" 4807 - checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" 4808 - dependencies = [ 4809 - "bitflags", 4810 - "core-foundation 0.9.4", 4811 - "core-foundation-sys", 4812 - "libc", 4813 - "security-framework-sys", 4814 - ] 4815 - 4816 4767 [[package]] 4817 4768 name = "security-framework" 4818 4769 version = "3.5.1" ··· 5525 5476 "libc", 5526 5477 ] 5527 5478 5528 - [[package]] 5529 - name = "tempfile" 5530 - version = "3.23.0" 5531 - source = "registry+https://github.com/rust-lang/crates.io-index" 5532 - checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" 5533 - dependencies = [ 5534 - "fastrand", 5535 - "getrandom 0.3.4", 5536 - "once_cell", 5537 - "rustix", 5538 - "windows-sys 0.61.2", 5539 - ] 5540 - 5541 5479 [[package]] 5542 5480 name = "testcontainers" 5543 5481 version = "0.26.2" ··· 5709 5647 "syn 2.0.111", 5710 5648 ] 5711 5649 5712 - [[package]] 5713 - name = "tokio-native-tls" 5714 - version = "0.3.1" 5715 - source = "registry+https://github.com/rust-lang/crates.io-index" 5716 - checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" 5717 - dependencies = [ 5718 - "native-tls", 5719 - "tokio", 5720 - ] 5721 - 5722 5650 [[package]] 5723 5651 name = "tokio-rustls" 5724 5652 version = "0.24.1" ··· 5758 5686 dependencies = [ 5759 5687 "futures-util", 5760 5688 "log", 5761 - "native-tls", 5689 + "rustls 0.23.35", 5690 + "rustls-pki-types", 5762 5691 "tokio", 5763 - "tokio-native-tls", 5692 + "tokio-rustls 0.26.4", 5764 5693 "tungstenite", 5694 + "webpki-roots 0.26.11", 5765 5695 ] 5766 5696 5767 5697 [[package]] ··· 6274 6204 "http 1.4.0", 6275 6205 "httparse", 6276 6206 "log", 6277 - "native-tls", 6278 6207 "rand 0.9.2", 6208 + "rustls 0.23.35", 6209 + "rustls-pki-types", 6279 6210 "sha1", 6280 6211 "thiserror 2.0.17", 6281 6212 "utf-8",
+8 -3
Cargo.toml
··· 81 81 rand = "0.8" 82 82 redis = { version = "1.0", features = ["tokio-comp", "connection-manager"] } 83 83 regex = "1" 84 - reqwest = { version = "0.12", features = ["json"] } 84 + reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls-webpki-roots", "http2", "charset", "macos-system-configuration"] } 85 85 serde = { version = "1.0", features = ["derive"] } 86 86 serde_bytes = "0.11" 87 87 serde_ipld_dagcbor = "0.6" ··· 91 91 sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "uuid", "chrono", "json"] } 92 92 subtle = "2.5" 93 93 thiserror = "2.0" 94 - tokio = { version = "1.48", features = ["macros", "rt-multi-thread", "time", "signal", "process"] } 94 + tokio = { version = "1.48", features = ["macros", "rt-multi-thread", "time", "signal", "process", "io-util", "fs"] } 95 95 tokio-util = "0.7.18" 96 - tokio-tungstenite = { version = "0.28", features = ["native-tls"] } 96 + tokio-tungstenite = { version = "0.28", features = ["rustls-tls-webpki-roots"] } 97 97 totp-rs = { version = "5", features = ["qr"] } 98 98 tower = "0.5" 99 99 tower-http = { version = "0.6", features = ["cors"] } ··· 111 111 testcontainers = "0.26" 112 112 testcontainers-modules = { version = "0.14", features = ["postgres"] } 113 113 wiremock = "0.6" 114 + 115 + [profile.release] 116 + lto = "thin" 117 + strip = true 118 + codegen-units = 1
+7 -2
Dockerfile
··· 1 1 FROM rust:1.92-alpine AS builder 2 - RUN apk add --no-cache ca-certificates openssl openssl-dev openssl-libs-static pkgconfig musl-dev 2 + RUN apk add --no-cache ca-certificates musl-dev pkgconfig openssl-dev openssl-libs-static 3 3 WORKDIR /app 4 + ARG SLIM="false" 4 5 COPY Cargo.toml Cargo.lock ./ 5 6 COPY crates ./crates 6 7 COPY .sqlx ./.sqlx 7 8 COPY migrations ./crates/tranquil-pds/migrations 8 9 RUN --mount=type=cache,target=/usr/local/cargo/registry \ 9 10 --mount=type=cache,target=/app/target \ 10 - SQLX_OFFLINE=true cargo build --release -p tranquil-pds && \ 11 + if [ "$SLIM" = "true" ]; then \ 12 + SQLX_OFFLINE=true cargo build --release -p tranquil-pds --no-default-features; \ 13 + else \ 14 + SQLX_OFFLINE=true cargo build --release -p tranquil-pds; \ 15 + fi && \ 11 16 cp target/release/tranquil-pds /tmp/tranquil-pds 12 17 13 18 FROM alpine:3.23
+5 -1
crates/tranquil-cache/Cargo.toml
··· 4 4 edition.workspace = true 5 5 license.workspace = true 6 6 7 + [features] 8 + default = [] 9 + valkey = ["dep:redis"] 10 + 7 11 [dependencies] 8 12 tranquil-infra = { workspace = true } 9 13 tranquil-ripple = { workspace = true } 10 14 11 15 async-trait = { workspace = true } 12 16 base64 = { workspace = true } 13 - redis = { workspace = true } 17 + redis = { workspace = true, optional = true } 14 18 tokio-util = { workspace = true } 15 19 tracing = { workspace = true }
+127 -105
crates/tranquil-cache/src/lib.rs
··· 1 1 pub use tranquil_infra::{Cache, CacheError, DistributedRateLimiter}; 2 2 3 3 use async_trait::async_trait; 4 - use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; 5 4 use std::sync::Arc; 6 5 use std::time::Duration; 7 6 8 - #[derive(Clone)] 9 - pub struct ValkeyCache { 10 - conn: redis::aio::ConnectionManager, 11 - } 7 + #[cfg(feature = "valkey")] 8 + mod valkey { 9 + use super::*; 10 + use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; 11 + 12 + #[derive(Clone)] 13 + pub struct ValkeyCache { 14 + conn: redis::aio::ConnectionManager, 15 + } 16 + 17 + impl ValkeyCache { 18 + pub async fn new(url: &str) -> Result<Self, CacheError> { 19 + let client = 20 + redis::Client::open(url).map_err(|e| CacheError::Connection(e.to_string()))?; 21 + let manager = client 22 + .get_connection_manager() 23 + .await 24 + .map_err(|e| CacheError::Connection(e.to_string()))?; 25 + Ok(Self { conn: manager }) 26 + } 12 27 13 - impl ValkeyCache { 14 - pub async fn new(url: &str) -> Result<Self, CacheError> { 15 - let client = redis::Client::open(url).map_err(|e| CacheError::Connection(e.to_string()))?; 16 - let manager = client 17 - .get_connection_manager() 18 - .await 19 - .map_err(|e| CacheError::Connection(e.to_string()))?; 20 - Ok(Self { conn: manager }) 28 + pub fn connection(&self) -> redis::aio::ConnectionManager { 29 + self.conn.clone() 30 + } 21 31 } 22 32 23 - pub fn connection(&self) -> redis::aio::ConnectionManager { 24 - self.conn.clone() 33 + #[async_trait] 34 + impl Cache for ValkeyCache { 35 + async fn get(&self, key: &str) -> Option<String> { 36 + let mut conn = self.conn.clone(); 37 + redis::cmd("GET") 38 + .arg(key) 39 + .query_async::<Option<String>>(&mut conn) 40 + .await 41 + .ok() 42 + .flatten() 43 + } 44 + 45 + async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 46 + let mut conn = self.conn.clone(); 47 + redis::cmd("SET") 48 + .arg(key) 49 + .arg(value) 50 + .arg("PX") 51 + .arg(ttl.as_millis().min(i64::MAX as u128) as i64) 52 + .query_async::<()>(&mut conn) 53 + .await 54 + .map_err(|e| CacheError::Connection(e.to_string())) 55 + } 56 + 57 + async fn delete(&self, key: &str) -> Result<(), CacheError> { 58 + let mut conn = self.conn.clone(); 59 + redis::cmd("DEL") 60 + .arg(key) 61 + .query_async::<()>(&mut conn) 62 + .await 63 + .map_err(|e| CacheError::Connection(e.to_string())) 64 + } 65 + 66 + async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> { 67 + self.get(key).await.and_then(|s| BASE64.decode(&s).ok()) 68 + } 69 + 70 + async fn set_bytes( 71 + &self, 72 + key: &str, 73 + value: &[u8], 74 + ttl: Duration, 75 + ) -> Result<(), CacheError> { 76 + let encoded = BASE64.encode(value); 77 + self.set(key, &encoded, ttl).await 78 + } 25 79 } 26 - } 27 80 28 - #[async_trait] 29 - impl Cache for ValkeyCache { 30 - async fn get(&self, key: &str) -> Option<String> { 31 - let mut conn = self.conn.clone(); 32 - redis::cmd("GET") 33 - .arg(key) 34 - .query_async::<Option<String>>(&mut conn) 35 - .await 36 - .ok() 37 - .flatten() 38 - } 39 - 40 - async fn set(&self, key: &str, value: &str, ttl: Duration) -> Result<(), CacheError> { 41 - let mut conn = self.conn.clone(); 42 - redis::cmd("SET") 43 - .arg(key) 44 - .arg(value) 45 - .arg("PX") 46 - .arg(ttl.as_millis().min(i64::MAX as u128) as i64) 47 - .query_async::<()>(&mut conn) 48 - .await 49 - .map_err(|e| CacheError::Connection(e.to_string())) 50 - } 51 - 52 - async fn delete(&self, key: &str) -> Result<(), CacheError> { 53 - let mut conn = self.conn.clone(); 54 - redis::cmd("DEL") 55 - .arg(key) 56 - .query_async::<()>(&mut conn) 57 - .await 58 - .map_err(|e| CacheError::Connection(e.to_string())) 59 - } 60 - 61 - async fn get_bytes(&self, key: &str) -> Option<Vec<u8>> { 62 - self.get(key).await.and_then(|s| BASE64.decode(&s).ok()) 63 - } 64 - 65 - async fn set_bytes(&self, key: &str, value: &[u8], ttl: Duration) -> Result<(), CacheError> { 66 - let encoded = BASE64.encode(value); 67 - self.set(key, &encoded, ttl).await 81 + #[derive(Clone)] 82 + pub struct RedisRateLimiter { 83 + conn: redis::aio::ConnectionManager, 84 + } 85 + 86 + impl RedisRateLimiter { 87 + pub fn new(conn: redis::aio::ConnectionManager) -> Self { 88 + Self { conn } 89 + } 90 + } 91 + 92 + #[async_trait] 93 + impl DistributedRateLimiter for RedisRateLimiter { 94 + async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { 95 + let mut conn = self.conn.clone(); 96 + let full_key = format!("rl:{}", key); 97 + let window_secs = window_ms.div_ceil(1000).max(1) as i64; 98 + let result: Result<i64, _> = redis::Script::new( 99 + r"local c = redis.call('INCR', KEYS[1]) 100 + if c == 1 then redis.call('EXPIRE', KEYS[1], ARGV[1]) end 101 + if redis.call('TTL', KEYS[1]) == -1 then redis.call('EXPIRE', KEYS[1], ARGV[1]) end 102 + return c", 103 + ) 104 + .key(&full_key) 105 + .arg(window_secs) 106 + .invoke_async(&mut conn) 107 + .await; 108 + match result { 109 + Ok(count) => count <= limit as i64, 110 + Err(e) => { 111 + tracing::warn!(error = %e, "redis rate limit script failed, allowing request"); 112 + true 113 + } 114 + } 115 + } 116 + 117 + async fn peek_rate_limit_count(&self, key: &str, _window_ms: u64) -> u64 { 118 + let mut conn = self.conn.clone(); 119 + let full_key = format!("rl:{}", key); 120 + redis::cmd("GET") 121 + .arg(&full_key) 122 + .query_async::<Option<u64>>(&mut conn) 123 + .await 124 + .ok() 125 + .flatten() 126 + .unwrap_or(0) 127 + } 68 128 } 69 129 } 70 130 131 + #[cfg(feature = "valkey")] 132 + pub use valkey::{RedisRateLimiter, ValkeyCache}; 133 + 71 134 pub struct NoOpCache; 72 135 73 136 #[async_trait] ··· 97 160 } 98 161 } 99 162 100 - #[derive(Clone)] 101 - pub struct RedisRateLimiter { 102 - conn: redis::aio::ConnectionManager, 103 - } 104 - 105 - impl RedisRateLimiter { 106 - pub fn new(conn: redis::aio::ConnectionManager) -> Self { 107 - Self { conn } 108 - } 109 - } 110 - 111 - #[async_trait] 112 - impl DistributedRateLimiter for RedisRateLimiter { 113 - async fn check_rate_limit(&self, key: &str, limit: u32, window_ms: u64) -> bool { 114 - let mut conn = self.conn.clone(); 115 - let full_key = format!("rl:{}", key); 116 - let window_secs = window_ms.div_ceil(1000).max(1) as i64; 117 - let result: Result<i64, _> = redis::Script::new( 118 - r"local c = redis.call('INCR', KEYS[1]) 119 - if c == 1 then redis.call('EXPIRE', KEYS[1], ARGV[1]) end 120 - if redis.call('TTL', KEYS[1]) == -1 then redis.call('EXPIRE', KEYS[1], ARGV[1]) end 121 - return c" 122 - ) 123 - .key(&full_key) 124 - .arg(window_secs) 125 - .invoke_async(&mut conn) 126 - .await; 127 - match result { 128 - Ok(count) => count <= limit as i64, 129 - Err(e) => { 130 - tracing::warn!(error = %e, "redis rate limit script failed, allowing request"); 131 - true 132 - } 133 - } 134 - } 135 - 136 - async fn peek_rate_limit_count(&self, key: &str, _window_ms: u64) -> u64 { 137 - let mut conn = self.conn.clone(); 138 - let full_key = format!("rl:{}", key); 139 - redis::cmd("GET") 140 - .arg(&full_key) 141 - .query_async::<Option<u64>>(&mut conn) 142 - .await 143 - .ok() 144 - .flatten() 145 - .unwrap_or(0) 146 - } 147 - } 148 - 149 163 pub struct NoOpRateLimiter; 150 164 151 165 #[async_trait] ··· 158 172 pub async fn create_cache( 159 173 shutdown: tokio_util::sync::CancellationToken, 160 174 ) -> (Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>) { 175 + #[cfg(feature = "valkey")] 161 176 if let Ok(url) = std::env::var("VALKEY_URL") { 162 177 match ValkeyCache::new(&url).await { 163 178 Ok(cache) => { ··· 171 186 } 172 187 } 173 188 189 + #[cfg(not(feature = "valkey"))] 190 + if std::env::var("VALKEY_URL").is_ok() { 191 + tracing::warn!( 192 + "VALKEY_URL is set but binary was compiled without valkey feature. using ripple." 193 + ); 194 + } 195 + 174 196 match tranquil_ripple::RippleConfig::from_env() { 175 197 Ok(config) => { 176 198 let peer_count = config.seed_peers.len();
+7 -4
crates/tranquil-pds/Cargo.toml
··· 21 21 async-trait = { workspace = true } 22 22 backon = { workspace = true } 23 23 anyhow = { workspace = true } 24 - aws-config = { workspace = true } 25 - aws-sdk-s3 = { workspace = true } 26 24 axum = { workspace = true } 27 25 base32 = { workspace = true } 28 26 base64 = { workspace = true } ··· 55 53 multihash = { workspace = true } 56 54 p256 = { workspace = true } 57 55 rand = { workspace = true } 58 - redis = { workspace = true } 56 + redis = { workspace = true, optional = true } 59 57 regex = { workspace = true } 60 58 reqwest = { workspace = true } 61 59 serde = { workspace = true } ··· 79 77 uuid = { workspace = true } 80 78 webauthn-rs = { workspace = true } 81 79 zip = { workspace = true } 80 + aws-config = { workspace = true, optional = true } 81 + aws-sdk-s3 = { workspace = true, optional = true } 82 82 83 83 [features] 84 + default = ["s3", "valkey"] 84 85 external-infra = [] 85 - s3-storage = [] 86 + s3-storage = ["tranquil-storage/s3", "dep:aws-config", "dep:aws-sdk-s3"] 87 + s3 = ["s3-storage"] 88 + valkey = ["tranquil-cache/valkey", "dep:redis"] 86 89 87 90 [dev-dependencies] 88 91 ciborium = { workspace = true }
+8 -6
crates/tranquil-pds/src/api/repo/blob.rs
··· 162 162 if let Err(e) = state.blob_store.copy(&temp_key, &storage_key).await { 163 163 let _ = state.blob_store.delete(&temp_key).await; 164 164 if let Err(db_err) = state.blob_repo.delete_blob_by_cid(&cid_link).await { 165 - error!("Failed to clean up orphaned blob record after copy failure: {:?}", db_err); 165 + error!( 166 + "Failed to clean up orphaned blob record after copy failure: {:?}", 167 + db_err 168 + ); 166 169 } 167 170 error!("Failed to copy blob to final location: {:?}", e); 168 171 return Err(ApiError::InternalError(Some("Failed to store blob".into()))); ··· 170 173 171 174 let _ = state.blob_store.delete(&temp_key).await; 172 175 173 - if let Some(ref controller) = controller_did { 174 - if let Err(e) = state 176 + if let Some(ref controller) = controller_did 177 + && let Err(e) = state 175 178 .delegation_repo 176 179 .log_delegation_action( 177 180 &did, ··· 187 190 None, 188 191 ) 189 192 .await 190 - { 191 - warn!("Failed to log delegation action for blob upload: {:?}", e); 192 - } 193 + { 194 + warn!("Failed to log delegation action for blob upload: {:?}", e); 193 195 } 194 196 195 197 Ok(Json(json!({
+4 -2
crates/tranquil-pds/src/cache/mod.rs
··· 1 1 pub use tranquil_cache::{ 2 - Cache, CacheError, DistributedRateLimiter, NoOpCache, NoOpRateLimiter, RedisRateLimiter, 3 - ValkeyCache, create_cache, 2 + Cache, CacheError, DistributedRateLimiter, NoOpCache, NoOpRateLimiter, create_cache, 4 3 }; 4 + 5 + #[cfg(feature = "valkey")] 6 + pub use tranquil_cache::{RedisRateLimiter, ValkeyCache};
+6 -3
crates/tranquil-pds/src/storage/mod.rs
··· 1 1 pub use tranquil_storage::{ 2 - BackupStorage, BlobStorage, FilesystemBackupStorage, FilesystemBlobStorage, S3BackupStorage, 3 - S3BlobStorage, StorageError, StreamUploadResult, backup_interval_secs, backup_retention_count, 4 - create_backup_storage, create_blob_storage, 2 + BackupStorage, BlobStorage, FilesystemBackupStorage, FilesystemBlobStorage, StorageError, 3 + StreamUploadResult, backup_interval_secs, backup_retention_count, create_backup_storage, 4 + create_blob_storage, 5 5 }; 6 + 7 + #[cfg(feature = "s3-storage")] 8 + pub use tranquil_storage::{S3BackupStorage, S3BlobStorage};
+4 -5
crates/tranquil-pds/tests/common/mod.rs
··· 330 330 ); 331 331 std::env::set_var("S3_ENDPOINT", &s3_endpoint); 332 332 } else { 333 - let process_dir = std::env::temp_dir().join(format!( 334 - "tranquil-pds-test-{}", 335 - std::process::id() 336 - )); 333 + let process_dir = 334 + std::env::temp_dir().join(format!("tranquil-pds-test-{}", std::process::id())); 337 335 let blob_path = process_dir.join("blobs"); 338 336 let backup_path = process_dir.join("backups"); 339 337 std::fs::create_dir_all(&blob_path).expect("Failed to create blob directory"); ··· 715 713 716 714 #[cfg(not(feature = "external-infra"))] 717 715 async fn setup_cluster_testcontainers() -> String { 718 - let temp_dir = std::env::temp_dir().join(format!("tranquil-pds-cluster-{}", uuid::Uuid::new_v4())); 716 + let temp_dir = 717 + std::env::temp_dir().join(format!("tranquil-pds-cluster-{}", uuid::Uuid::new_v4())); 719 718 let blob_path = temp_dir.join("blobs"); 720 719 let backup_path = temp_dir.join("backups"); 721 720 std::fs::create_dir_all(&blob_path).expect("Failed to create blob temp directory");
+2
crates/tranquil-pds/tests/rate_limit.rs
··· 118 118 ); 119 119 } 120 120 121 + #[cfg(feature = "valkey")] 121 122 #[tokio::test] 122 123 async fn test_valkey_connection() { 123 124 if std::env::var("VALKEY_URL").is_err() { ··· 156 157 .expect("DEL failed"); 157 158 } 158 159 160 + #[cfg(feature = "valkey")] 159 161 #[tokio::test] 160 162 async fn test_distributed_rate_limiter_directly() { 161 163 if std::env::var("VALKEY_URL").is_err() {
+44 -34
crates/tranquil-pds/tests/ripple_cluster.rs
··· 45 45 assert!(nodes.len() >= 3, "expected at least 3 cluster nodes"); 46 46 47 47 let client = common::client(); 48 - let results: Vec<_> = futures::future::join_all( 49 - nodes.iter().map(|node| { 50 - let client = client.clone(); 51 - let url = node.url.clone(); 52 - async move { 53 - client 54 - .get(format!("{url}/xrpc/com.atproto.server.describeServer")) 55 - .send() 56 - .await 57 - } 58 - }) 59 - ).await; 48 + let results: Vec<_> = futures::future::join_all(nodes.iter().map(|node| { 49 + let client = client.clone(); 50 + let url = node.url.clone(); 51 + async move { 52 + client 53 + .get(format!("{url}/xrpc/com.atproto.server.describeServer")) 54 + .send() 55 + .await 56 + } 57 + })) 58 + .await; 60 59 61 60 results.iter().enumerate().for_each(|(i, result)| { 62 - let resp = result.as_ref().unwrap_or_else(|e| panic!("node {i} unreachable: {e}")); 61 + let resp = result 62 + .as_ref() 63 + .unwrap_or_else(|e| panic!("node {i} unreachable: {e}")); 63 64 assert_eq!( 64 65 resp.status(), 65 66 StatusCode::OK, ··· 91 92 assert_eq!(create_res.status(), StatusCode::OK); 92 93 let body: serde_json::Value = create_res.json().await.expect("invalid json"); 93 94 let did = body["did"].as_str().expect("no did").to_string(); 94 - let access_jwt = body["accessJwt"].as_str().expect("no accessJwt").to_string(); 95 + let access_jwt = body["accessJwt"] 96 + .as_str() 97 + .expect("no accessJwt") 98 + .to_string(); 95 99 96 100 let pool = common::get_test_db_pool().await; 97 101 let body_text: String = sqlx::query_scalar!( ··· 132 136 133 137 let token = match confirm_res.status() { 134 138 StatusCode::OK => { 135 - let confirm_body: serde_json::Value = 136 - confirm_res.json().await.expect("invalid json from confirmSignup"); 139 + let confirm_body: serde_json::Value = confirm_res 140 + .json() 141 + .await 142 + .expect("invalid json from confirmSignup"); 137 143 confirm_body["accessJwt"] 138 144 .as_str() 139 145 .unwrap_or(&access_jwt) ··· 164 170 async fn cache_convergence() { 165 171 let nodes = common::cluster().await; 166 172 167 - let cache_a = nodes[0] 168 - .cache 169 - .as_ref() 170 - .expect("node 0 should have a cache"); 171 - let cache_b = nodes[1] 172 - .cache 173 - .as_ref() 174 - .expect("node 1 should have a cache"); 173 + let cache_a = nodes[0].cache.as_ref().expect("node 0 should have a cache"); 174 + let cache_b = nodes[1].cache.as_ref().expect("node 1 should have a cache"); 175 175 176 176 let test_key = format!("ripple-test-{}", uuid::Uuid::new_v4()); 177 177 let test_value = "converged-value"; ··· 407 407 }) 408 408 .await; 409 409 410 - let spot_checks: Vec<Option<String>> = futures::future::join_all( 411 - [0, 99, 250, 499].iter().map(|&i| { 410 + let spot_checks: Vec<Option<String>> = 411 + futures::future::join_all([0, 99, 250, 499].iter().map(|&i| { 412 412 let c = cache_2.clone(); 413 413 let p = prefix.clone(); 414 414 async move { c.get(&format!("{p}-{i}")).await } 415 - }), 416 - ) 417 - .await; 415 + })) 416 + .await; 418 417 419 418 spot_checks.iter().enumerate().for_each(|(idx, val)| { 420 419 assert!( ··· 620 619 assert_eq!(create_res.status(), StatusCode::OK, "createAccount non-200"); 621 620 let body: serde_json::Value = create_res.json().await.expect("invalid json"); 622 621 let did = body["did"].as_str().expect("no did").to_string(); 623 - let access_jwt = body["accessJwt"].as_str().expect("no accessJwt").to_string(); 622 + let access_jwt = body["accessJwt"] 623 + .as_str() 624 + .expect("no accessJwt") 625 + .to_string(); 624 626 625 627 let pool = common::get_test_db_pool().await; 626 628 let body_text: String = sqlx::query_scalar!( ··· 654 656 655 657 let token = match confirm_res.status() { 656 658 StatusCode::OK => { 657 - let confirm_body: serde_json::Value = 658 - confirm_res.json().await.expect("invalid json from confirmSignup"); 659 + let confirm_body: serde_json::Value = confirm_res 660 + .json() 661 + .await 662 + .expect("invalid json from confirmSignup"); 659 663 confirm_body["accessJwt"] 660 664 .as_str() 661 665 .unwrap_or(&access_jwt) ··· 744 748 let cache_0 = cache_for(nodes, 0); 745 749 746 750 let fake_handle = format!("cached-{}.test", uuid::Uuid::new_v4().simple()); 747 - let fake_did = format!("did:plc:cached{}", &uuid::Uuid::new_v4().simple().to_string()[..16]); 751 + let fake_did = format!( 752 + "did:plc:cached{}", 753 + &uuid::Uuid::new_v4().simple().to_string()[..16] 754 + ); 748 755 749 756 cache_0 750 757 .set( ··· 795 802 let cache_1 = cache_for(nodes, 1); 796 803 797 804 let fake_handle = format!("deltest-{}.test", uuid::Uuid::new_v4().simple()); 798 - let fake_did = format!("did:plc:del{}", &uuid::Uuid::new_v4().simple().to_string()[..16]); 805 + let fake_did = format!( 806 + "did:plc:del{}", 807 + &uuid::Uuid::new_v4().simple().to_string()[..16] 808 + ); 799 809 let cache_key = format!("handle:{fake_handle}"); 800 810 801 811 cache_0
+5 -1
crates/tranquil-pds/tests/sync_conformance.rs
··· 190 190 assert!(takendown_repo.is_some(), "Takendown repo should be in list"); 191 191 let repo = takendown_repo.unwrap(); 192 192 assert_eq!(repo["active"], false, "repo should be inactive: {:?}", repo); 193 - assert_eq!(repo["status"], "takendown", "repo status should be takendown: {:?}", repo); 193 + assert_eq!( 194 + repo["status"], "takendown", 195 + "repo status should be takendown: {:?}", 196 + repo 197 + ); 194 198 } 195 199 196 200 #[tokio::test]
+8 -2
crates/tranquil-pds/tests/whole_story.rs
··· 1400 1400 .iter() 1401 1401 .enumerate() 1402 1402 .flat_map(|(i, (follower_did, follower_jwt))| { 1403 - users.iter().enumerate() 1403 + users 1404 + .iter() 1405 + .enumerate() 1404 1406 .filter(move |(j, _)| *j != i) 1405 1407 .map(|(_, (followee_did, _))| { 1406 - (follower_did.clone(), follower_jwt.clone(), followee_did.clone()) 1408 + ( 1409 + follower_did.clone(), 1410 + follower_jwt.clone(), 1411 + followee_did.clone(), 1412 + ) 1407 1413 }) 1408 1414 .collect::<Vec<_>>() 1409 1415 })
+5 -1
crates/tranquil-ripple/src/config.rs
··· 19 19 match raw.parse::<T>() { 20 20 Ok(v) => Some(v), 21 21 Err(_) => { 22 - tracing::warn!(var = var_name, value = raw, "invalid env var value, using default"); 22 + tracing::warn!( 23 + var = var_name, 24 + value = raw, 25 + "invalid env var value, using default" 26 + ); 23 27 None 24 28 } 25 29 }
+2 -2
crates/tranquil-ripple/src/crdt/delta.rs
··· 1 - use super::lww_map::LwwDelta; 2 1 use super::g_counter::GCounterDelta; 2 + use super::lww_map::LwwDelta; 3 3 use serde::{Deserialize, Serialize}; 4 4 5 5 const SCHEMA_VERSION: u8 = 1; ··· 21 21 pub fn is_empty(&self) -> bool { 22 22 self.cache_delta 23 23 .as_ref() 24 - .map_or(true, |d| d.entries.is_empty()) 24 + .is_none_or(|d| d.entries.is_empty()) 25 25 && self.rate_limit_deltas.is_empty() 26 26 } 27 27
+8 -7
crates/tranquil-ripple/src/crdt/g_counter.rs
··· 142 142 self.dirty 143 143 .iter() 144 144 .filter_map(|key| { 145 - self.counters 146 - .get(key) 147 - .map(|counter| GCounterDelta { 148 - key: key.clone(), 149 - counter: counter.clone(), 150 - }) 145 + self.counters.get(key).map(|counter| GCounterDelta { 146 + key: key.clone(), 147 + counter: counter.clone(), 148 + }) 151 149 }) 152 150 .collect() 153 151 } ··· 167 165 return 0; 168 166 } 169 167 match self.counters.get(key) { 170 - Some(counter) if counter.window_start_ms == Self::aligned_window_start(now_wall_ms, window_ms) => { 168 + Some(counter) 169 + if counter.window_start_ms 170 + == Self::aligned_window_start(now_wall_ms, window_ms) => 171 + { 171 172 counter.total() 172 173 } 173 174 _ => 0,
+18 -1
crates/tranquil-ripple/src/crdt/lww_map.rs
··· 124 124 estimated_bytes: usize, 125 125 } 126 126 127 + impl Default for LwwMap { 128 + fn default() -> Self { 129 + Self::new() 130 + } 131 + } 132 + 127 133 impl LwwMap { 128 134 pub fn new() -> Self { 129 135 Self { ··· 141 147 entry.value.clone() 142 148 } 143 149 144 - pub fn set(&mut self, key: String, value: Vec<u8>, timestamp: HlcTimestamp, ttl_ms: u64, wall_ms_now: u64) { 150 + pub fn set( 151 + &mut self, 152 + key: String, 153 + value: Vec<u8>, 154 + timestamp: HlcTimestamp, 155 + ttl_ms: u64, 156 + wall_ms_now: u64, 157 + ) { 145 158 let entry = LwwEntry { 146 159 created_at_wall_ms: wall_ms_now, 147 160 value: Some(value), ··· 248 261 self.entries.len() 249 262 } 250 263 264 + pub fn is_empty(&self) -> bool { 265 + self.entries.is_empty() 266 + } 267 + 251 268 fn remove_estimated_bytes(&mut self, key: &str) { 252 269 if let Some(existing) = self.entries.get(key) { 253 270 let size = existing.entry_byte_size(key);
+83 -67
crates/tranquil-ripple/src/crdt/mod.rs
··· 1 1 pub mod delta; 2 + pub mod g_counter; 2 3 pub mod hlc; 3 4 pub mod lww_map; 4 - pub mod g_counter; 5 5 6 6 use crate::config::fnv1a; 7 7 use delta::CrdtDelta; 8 + use g_counter::RateLimitStore; 8 9 use hlc::{Hlc, HlcTimestamp}; 9 10 use lww_map::{LwwDelta, LwwMap}; 10 - use g_counter::RateLimitStore; 11 11 use parking_lot::{Mutex, RwLock}; 12 12 use std::time::{SystemTime, UNIX_EPOCH}; 13 13 ··· 45 45 let shards: Vec<RwLock<CrdtShard>> = (0..SHARD_COUNT) 46 46 .map(|_| RwLock::new(CrdtShard::new(node_id))) 47 47 .collect(); 48 - let promotions: Vec<Mutex<Vec<String>>> = (0..SHARD_COUNT) 49 - .map(|_| Mutex::new(Vec::new())) 50 - .collect(); 48 + let promotions: Vec<Mutex<Vec<String>>> = 49 + (0..SHARD_COUNT).map(|_| Mutex::new(Vec::new())).collect(); 51 50 Self { 52 51 hlc: Mutex::new(Hlc::new(node_id)), 53 52 shards: shards.into_boxed_slice(), ··· 136 135 137 136 let cache_delta = match cache_entries.is_empty() { 138 137 true => None, 139 - false => Some(LwwDelta { entries: cache_entries }), 138 + false => Some(LwwDelta { 139 + entries: cache_entries, 140 + }), 140 141 }; 141 142 142 143 CrdtDelta { ··· 159 160 }) 160 161 .unwrap_or_default(); 161 162 162 - let mut max_ts_per_shard: Vec<Option<HlcTimestamp>> = (0..self.shards.len()) 163 - .map(|_| None) 164 - .collect(); 163 + let mut max_ts_per_shard: Vec<Option<HlcTimestamp>> = 164 + (0..self.shards.len()).map(|_| None).collect(); 165 165 166 166 cache_entries_by_shard.iter().for_each(|&(shard_idx, ts)| { 167 167 let slot = &mut max_ts_per_shard[shard_idx]; ··· 177 177 .map(|d| (d.key.as_str(), &d.counter)) 178 178 .collect(); 179 179 180 - let mut shard_rl_keys: Vec<Vec<&str>> = (0..self.shards.len()) 181 - .map(|_| Vec::new()) 182 - .collect(); 180 + let mut shard_rl_keys: Vec<Vec<&str>> = 181 + (0..self.shards.len()).map(|_| Vec::new()).collect(); 183 182 rl_index.keys().for_each(|&key| { 184 183 shard_rl_keys[self.shard_for(key)].push(key); 185 184 }); 186 185 187 - self.shards.iter().enumerate().for_each(|(idx, shard_lock)| { 188 - let has_cache_update = max_ts_per_shard[idx].is_some(); 189 - let has_rl_keys = !shard_rl_keys[idx].is_empty(); 190 - if !has_cache_update && !has_rl_keys { 191 - return; 192 - } 193 - let mut shard = shard_lock.write(); 194 - if let Some(max_ts) = max_ts_per_shard[idx] { 195 - shard.last_broadcast_ts = max_ts; 196 - } 197 - shard_rl_keys[idx].iter().for_each(|&key| { 198 - let still_matches = shard 199 - .rate_limits 200 - .peek_dirty_counter(key) 201 - .zip(rl_index.get(key)) 202 - .is_some_and(|(current, committed)| { 203 - current.window_start_ms == committed.window_start_ms 204 - && current.total() == committed.total() 205 - }); 206 - if still_matches { 207 - shard.rate_limits.clear_single_dirty(key); 186 + self.shards 187 + .iter() 188 + .enumerate() 189 + .for_each(|(idx, shard_lock)| { 190 + let has_cache_update = max_ts_per_shard[idx].is_some(); 191 + let has_rl_keys = !shard_rl_keys[idx].is_empty(); 192 + if !has_cache_update && !has_rl_keys { 193 + return; 208 194 } 195 + let mut shard = shard_lock.write(); 196 + if let Some(max_ts) = max_ts_per_shard[idx] { 197 + shard.last_broadcast_ts = max_ts; 198 + } 199 + shard_rl_keys[idx].iter().for_each(|&key| { 200 + let still_matches = shard 201 + .rate_limits 202 + .peek_dirty_counter(key) 203 + .zip(rl_index.get(key)) 204 + .is_some_and(|(current, committed)| { 205 + current.window_start_ms == committed.window_start_ms 206 + && current.total() == committed.total() 207 + }); 208 + if still_matches { 209 + shard.rate_limits.clear_single_dirty(key); 210 + } 211 + }); 209 212 }); 210 - }); 211 213 } 212 214 213 215 pub fn merge_delta(&self, delta: &CrdtDelta) -> bool { ··· 219 221 return false; 220 222 } 221 223 222 - if let Some(ref cache_delta) = delta.cache_delta { 223 - if let Some(max_ts) = cache_delta.entries.iter().map(|(_, e)| e.timestamp).max() { 224 - let _ = self.hlc.lock().receive(max_ts); 225 - } 224 + if let Some(ref cache_delta) = delta.cache_delta 225 + && let Some(max_ts) = cache_delta.entries.iter().map(|(_, e)| e.timestamp).max() 226 + { 227 + let _ = self.hlc.lock().receive(max_ts); 226 228 } 227 229 228 230 let mut changed = false; ··· 235 237 entries_by_shard[self.shard_for(key)].push((key.clone(), entry.clone())); 236 238 }); 237 239 238 - entries_by_shard.into_iter().enumerate().for_each(|(idx, entries)| { 239 - if entries.is_empty() { 240 - return; 241 - } 242 - let mut shard = self.shards[idx].write(); 243 - entries.into_iter().for_each(|(key, entry)| { 244 - if shard.cache.merge_entry(key, entry) { 245 - changed = true; 240 + entries_by_shard 241 + .into_iter() 242 + .enumerate() 243 + .for_each(|(idx, entries)| { 244 + if entries.is_empty() { 245 + return; 246 246 } 247 + let mut shard = self.shards[idx].write(); 248 + entries.into_iter().for_each(|(key, entry)| { 249 + if shard.cache.merge_entry(key, entry) { 250 + changed = true; 251 + } 252 + }); 247 253 }); 248 - }); 249 254 } 250 255 251 256 if !delta.rate_limit_deltas.is_empty() { ··· 256 261 rl_by_shard[self.shard_for(&rd.key)].push((rd.key.clone(), &rd.counter)); 257 262 }); 258 263 259 - rl_by_shard.into_iter().enumerate().for_each(|(idx, entries)| { 260 - if entries.is_empty() { 261 - return; 262 - } 263 - let mut shard = self.shards[idx].write(); 264 - entries.into_iter().for_each(|(key, counter)| { 265 - if shard.rate_limits.merge_counter(key, counter) { 266 - changed = true; 264 + rl_by_shard 265 + .into_iter() 266 + .enumerate() 267 + .for_each(|(idx, entries)| { 268 + if entries.is_empty() { 269 + return; 267 270 } 271 + let mut shard = self.shards[idx].write(); 272 + entries.into_iter().for_each(|(key, counter)| { 273 + if shard.rate_limits.merge_counter(key, counter) { 274 + changed = true; 275 + } 276 + }); 268 277 }); 269 - }); 270 278 } 271 279 272 280 changed ··· 274 282 275 283 pub fn run_maintenance(&self) { 276 284 let now = Self::wall_ms_now(); 277 - self.shards.iter().enumerate().for_each(|(idx, shard_lock)| { 278 - let pending: Vec<String> = self.promotions[idx].lock().drain(..).collect(); 279 - let mut shard = shard_lock.write(); 280 - pending.iter().for_each(|key| shard.cache.touch(key)); 281 - shard.cache.gc_tombstones(now); 282 - shard.cache.gc_expired(now); 283 - shard.rate_limits.gc_expired(now); 284 - }); 285 + self.shards 286 + .iter() 287 + .enumerate() 288 + .for_each(|(idx, shard_lock)| { 289 + let pending: Vec<String> = self.promotions[idx].lock().drain(..).collect(); 290 + let mut shard = shard_lock.write(); 291 + pending.iter().for_each(|key| shard.cache.touch(key)); 292 + shard.cache.gc_tombstones(now); 293 + shard.cache.gc_expired(now); 294 + shard.rate_limits.gc_expired(now); 295 + }); 285 296 } 286 297 287 298 pub fn peek_full_state(&self) -> CrdtDelta { ··· 297 308 298 309 let cache_delta = match cache_entries.is_empty() { 299 310 true => None, 300 - false => Some(LwwDelta { entries: cache_entries }), 311 + false => Some(LwwDelta { 312 + entries: cache_entries, 313 + }), 301 314 }; 302 315 303 316 CrdtDelta { ··· 327 340 .iter() 328 341 .map(|s| { 329 342 let shard = s.read(); 330 - shard.cache.estimated_bytes().saturating_add(shard.rate_limits.estimated_bytes()) 343 + shard 344 + .cache 345 + .estimated_bytes() 346 + .saturating_add(shard.rate_limits.estimated_bytes()) 331 347 }) 332 348 .fold(0usize, usize::saturating_add) 333 349 } ··· 335 351 pub fn evict_lru_round_robin(&self, start_shard: usize) -> Option<(usize, usize)> { 336 352 (0..self.shards.len()).find_map(|offset| { 337 353 let idx = (start_shard + offset) & self.shard_mask; 338 - let has_entries = self.shards[idx].read().cache.len() > 0; 354 + let has_entries = !self.shards[idx].read().cache.is_empty(); 339 355 match has_entries { 340 356 true => { 341 357 let mut shard = self.shards[idx].write();
+7 -6
crates/tranquil-ripple/src/engine.rs
··· 17 17 pub async fn start( 18 18 config: RippleConfig, 19 19 shutdown: CancellationToken, 20 - ) -> Result<(Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>, SocketAddr), RippleStartError> { 20 + ) -> Result<(Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>, SocketAddr), RippleStartError> 21 + { 21 22 let store = Arc::new(ShardedCrdtStore::new(config.machine_id)); 22 23 23 - let (transport, incoming_rx) = Transport::bind(config.bind_addr, config.machine_id, shutdown.clone()) 24 - .await 25 - .map_err(|e| RippleStartError::Bind(e.to_string()))?; 24 + let (transport, incoming_rx) = 25 + Transport::bind(config.bind_addr, config.machine_id, shutdown.clone()) 26 + .await 27 + .map_err(|e| RippleStartError::Bind(e.to_string()))?; 26 28 27 29 let transport = Arc::new(transport); 28 30 ··· 79 81 }); 80 82 81 83 let cache: Arc<dyn Cache> = Arc::new(RippleCache::new(store.clone())); 82 - let rate_limiter: Arc<dyn DistributedRateLimiter> = 83 - Arc::new(RippleRateLimiter::new(store)); 84 + let rate_limiter: Arc<dyn DistributedRateLimiter> = Arc::new(RippleRateLimiter::new(store)); 84 85 85 86 metrics::describe_metrics(); 86 87
+15 -21
crates/tranquil-ripple/src/eviction.rs
··· 39 39 let mut remaining = total_bytes; 40 40 let mut next_shard: usize = self.next_shard.load(std::sync::atomic::Ordering::Relaxed); 41 41 let mut evicted: usize = 0; 42 - (0..batch_size).try_for_each(|_| { 43 - match remaining > max_bytes { 44 - true => { 45 - match store.evict_lru_round_robin(next_shard) { 46 - Some((ns, freed)) => { 47 - next_shard = ns; 48 - remaining = remaining.saturating_sub(freed); 49 - evicted += 1; 50 - Ok(()) 51 - } 52 - None => Err(()), 42 + (0..batch_size) 43 + .try_for_each(|_| match remaining > max_bytes { 44 + true => match store.evict_lru_round_robin(next_shard) { 45 + Some((ns, freed)) => { 46 + next_shard = ns; 47 + remaining = remaining.saturating_sub(freed); 48 + evicted += 1; 49 + Ok(()) 53 50 } 54 - } 51 + None => Err(()), 52 + }, 55 53 false => Err(()), 56 - } 57 - }) 58 - .ok(); 59 - self.next_shard.store(next_shard, std::sync::atomic::Ordering::Relaxed); 54 + }) 55 + .ok(); 56 + self.next_shard 57 + .store(next_shard, std::sync::atomic::Ordering::Relaxed); 60 58 if evicted > 0 { 61 59 metrics::record_evictions(evicted); 62 60 let cache_bytes_after = store.cache_estimated_bytes(); ··· 92 90 let store = ShardedCrdtStore::new(1); 93 91 let budget = MemoryBudget::new(100); 94 92 (0..50).for_each(|i| { 95 - store.cache_set( 96 - format!("key-{i}"), 97 - vec![0u8; 64], 98 - 60_000, 99 - ); 93 + store.cache_set(format!("key-{i}"), vec![0u8; 64], 60_000); 100 94 }); 101 95 budget.enforce(&store); 102 96 assert!(store.total_estimated_bytes() <= 100);
+29 -23
crates/tranquil-ripple/src/gossip.rs
··· 1 + use crate::crdt::ShardedCrdtStore; 1 2 use crate::crdt::delta::CrdtDelta; 2 3 use crate::crdt::lww_map::LwwDelta; 3 - use crate::crdt::ShardedCrdtStore; 4 4 use crate::metrics; 5 5 use crate::transport::{ChannelTag, IncomingFrame, Transport}; 6 6 use foca::{Config, Foca, Notification, Runtime, Timer}; 7 - use rand::rngs::StdRng; 8 7 use rand::SeedableRng; 8 + use rand::rngs::StdRng; 9 9 use std::collections::HashSet; 10 10 use std::fmt; 11 11 use std::net::SocketAddr; ··· 135 135 } 136 136 137 137 fn send_to(&mut self, to: PeerId, data: &[u8]) { 138 - self.actions 139 - .push(RuntimeAction::SendTo(to, data.to_vec())); 138 + self.actions.push(RuntimeAction::SendTo(to, data.to_vec())); 140 139 } 141 140 142 141 fn submit_after(&mut self, event: Timer<PeerId>, after: Duration) { 143 - self.actions.push(RuntimeAction::ScheduleTimer(event, after)); 142 + self.actions 143 + .push(RuntimeAction::ScheduleTimer(event, after)); 144 144 } 145 145 } 146 146 ··· 151 151 } 152 152 153 153 impl GossipEngine { 154 - pub fn new( 155 - transport: Arc<Transport>, 156 - store: Arc<ShardedCrdtStore>, 157 - local_id: PeerId, 158 - ) -> Self { 154 + pub fn new(transport: Arc<Transport>, store: Arc<ShardedCrdtStore>, local_id: PeerId) -> Self { 159 155 Self { 160 156 transport, 161 157 store, ··· 208 204 } 209 205 }); 210 206 211 - drain_runtime_actions(&mut runtime, &transport, &timer_tx, &mut members, &store, &shutdown); 207 + drain_runtime_actions( 208 + &mut runtime, 209 + &transport, 210 + &timer_tx, 211 + &mut members, 212 + &store, 213 + &shutdown, 214 + ); 212 215 213 - let mut gossip_tick = 214 - tokio::time::interval(Duration::from_millis(gossip_interval_ms)); 216 + let mut gossip_tick = tokio::time::interval(Duration::from_millis(gossip_interval_ms)); 215 217 gossip_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); 216 218 217 219 loop { ··· 401 403 metrics::set_gossip_peers(members.peer_count()); 402 404 let snapshot = store.peek_full_state(); 403 405 if !snapshot.is_empty() { 404 - chunk_and_serialize(&snapshot).into_iter().for_each(|chunk| { 405 - let t = transport.clone(); 406 - let c = shutdown.clone(); 407 - tokio::spawn(async move { 408 - tokio::select! { 409 - _ = c.cancelled() => {} 410 - _ = t.send(addr, ChannelTag::CrdtSync, &chunk) => {} 411 - } 406 + chunk_and_serialize(&snapshot) 407 + .into_iter() 408 + .for_each(|chunk| { 409 + let t = transport.clone(); 410 + let c = shutdown.clone(); 411 + tokio::spawn(async move { 412 + tokio::select! { 413 + _ = c.cancelled() => {} 414 + _ = t.send(addr, ChannelTag::CrdtSync, &chunk) => {} 415 + } 416 + }); 412 417 }); 413 - }); 414 418 } 415 419 } 416 420 RuntimeAction::MemberDown(addr) => { ··· 449 453 source_node, 450 454 cache_delta: match cache_entries.is_empty() { 451 455 true => None, 452 - false => Some(LwwDelta { entries: cache_entries }), 456 + false => Some(LwwDelta { 457 + entries: cache_entries, 458 + }), 453 459 }, 454 460 rate_limit_deltas: rl_deltas, 455 461 };
+1 -4
crates/tranquil-ripple/src/metrics.rs
··· 13 13 "tranquil_ripple_gossip_peers", 14 14 "Number of active gossip peers" 15 15 ); 16 - metrics::describe_counter!( 17 - "tranquil_ripple_cache_hits_total", 18 - "Total cache read hits" 19 - ); 16 + metrics::describe_counter!("tranquil_ripple_cache_hits_total", "Total cache read hits"); 20 17 metrics::describe_counter!( 21 18 "tranquil_ripple_cache_misses_total", 22 19 "Total cache read misses"
+9 -2
crates/tranquil-ripple/src/transport.rs
··· 222 222 let conn_gen = self.conn_generation.fetch_add(1, Ordering::Relaxed); 223 223 self.connections.lock().insert( 224 224 target, 225 - ConnectionWriter { tx: write_tx.clone(), generation: conn_gen }, 225 + ConnectionWriter { 226 + tx: write_tx.clone(), 227 + generation: conn_gen, 228 + }, 226 229 ); 227 230 if let Some(frame) = encode_frame(tag, data) { 228 231 let _ = write_tx.try_send(frame); ··· 412 415 } 413 416 let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize; 414 417 if len > MAX_FRAME_SIZE { 415 - tracing::warn!(frame_len = len, max = MAX_FRAME_SIZE, "oversized frame, closing connection"); 418 + tracing::warn!( 419 + frame_len = len, 420 + max = MAX_FRAME_SIZE, 421 + "oversized frame, closing connection" 422 + ); 416 423 buf.clear(); 417 424 return DecodeResult::Corrupt; 418 425 }
+52 -44
crates/tranquil-ripple/tests/two_node_convergence.rs
··· 168 168 let val_a = cache_a.get(&key).await.expect("A should have the key"); 169 169 let val_b = cache_b.get(&key).await.expect("B should have the key"); 170 170 171 - assert_eq!(val_a, val_b, "both nodes must agree on the same value after LWW resolution"); 171 + assert_eq!( 172 + val_a, val_b, 173 + "both nodes must agree on the same value after LWW resolution" 174 + ); 172 175 173 176 shutdown.cancel(); 174 177 } ··· 240 243 241 244 tokio::time::sleep(Duration::from_secs(3)).await; 242 245 243 - assert!(cache_a.get(&key).await.is_none(), "A should have expired the key"); 244 - assert!(cache_b.get(&key).await.is_none(), "B should have expired the key"); 246 + assert!( 247 + cache_a.get(&key).await.is_none(), 248 + "A should have expired the key" 249 + ); 250 + assert!( 251 + cache_b.get(&key).await.is_none(), 252 + "B should have expired the key" 253 + ); 245 254 246 255 shutdown.cancel(); 247 256 } ··· 715 724 let shutdown = CancellationToken::new(); 716 725 let ((cache_a, rl_a), (cache_b, rl_b)) = spawn_pair(shutdown.clone()).await; 717 726 718 - let tasks: Vec<tokio::task::JoinHandle<()>> = (0u32..8).map(|task_id| { 719 - let cache = match task_id < 4 { 720 - true => cache_a.clone(), 721 - false => cache_b.clone(), 722 - }; 723 - let rl = match task_id < 4 { 724 - true => rl_a.clone(), 725 - false => rl_b.clone(), 726 - }; 727 - tokio::spawn(async move { 728 - let value = vec![0xABu8; 1024]; 729 - futures::future::join_all((0u32..500).map(|op| { 730 - let cache = cache.clone(); 731 - let rl = rl.clone(); 732 - let value = value.clone(); 733 - async move { 734 - let key_idx = op % 100; 735 - let key = format!("stress-{task_id}-{key_idx}"); 736 - match op % 4 { 737 - 0 | 1 => { 738 - cache 739 - .set_bytes(&key, &value, Duration::from_secs(120)) 740 - .await 741 - .expect("set_bytes failed"); 742 - } 743 - 2 => { 744 - let _ = cache.get(&key).await; 745 - } 746 - _ => { 747 - let _ = rl.check_rate_limit(&key, 1000, 60_000).await; 727 + let tasks: Vec<tokio::task::JoinHandle<()>> = (0u32..8) 728 + .map(|task_id| { 729 + let cache = match task_id < 4 { 730 + true => cache_a.clone(), 731 + false => cache_b.clone(), 732 + }; 733 + let rl = match task_id < 4 { 734 + true => rl_a.clone(), 735 + false => rl_b.clone(), 736 + }; 737 + tokio::spawn(async move { 738 + let value = vec![0xABu8; 1024]; 739 + futures::future::join_all((0u32..500).map(|op| { 740 + let cache = cache.clone(); 741 + let rl = rl.clone(); 742 + let value = value.clone(); 743 + async move { 744 + let key_idx = op % 100; 745 + let key = format!("stress-{task_id}-{key_idx}"); 746 + match op % 4 { 747 + 0 | 1 => { 748 + cache 749 + .set_bytes(&key, &value, Duration::from_secs(120)) 750 + .await 751 + .expect("set_bytes failed"); 752 + } 753 + 2 => { 754 + let _ = cache.get(&key).await; 755 + } 756 + _ => { 757 + let _ = rl.check_rate_limit(&key, 1000, 60_000).await; 758 + } 748 759 } 749 760 } 750 - } 751 - })) 752 - .await; 761 + })) 762 + .await; 763 + }) 753 764 }) 754 - }).collect(); 755 - 756 - let results = tokio::time::timeout( 757 - Duration::from_secs(30), 758 - futures::future::join_all(tasks), 759 - ) 760 - .await 761 - .expect("stress test timed out after 30s"); 765 + .collect(); 766 + 767 + let results = tokio::time::timeout(Duration::from_secs(30), futures::future::join_all(tasks)) 768 + .await 769 + .expect("stress test timed out after 30s"); 762 770 763 771 results.into_iter().enumerate().for_each(|(i, r)| { 764 772 r.unwrap_or_else(|e| panic!("task {i} panicked: {e}"));
+6 -2
crates/tranquil-storage/Cargo.toml
··· 4 4 edition.workspace = true 5 5 license.workspace = true 6 6 7 + [features] 8 + default = [] 9 + s3 = ["dep:aws-config", "dep:aws-sdk-s3"] 10 + 7 11 [dependencies] 8 12 tranquil-infra = { workspace = true } 9 13 10 14 async-trait = { workspace = true } 11 - aws-config = { workspace = true } 12 - aws-sdk-s3 = { workspace = true } 15 + aws-config = { workspace = true, optional = true } 16 + aws-sdk-s3 = { workspace = true, optional = true } 13 17 bytes = { workspace = true } 14 18 futures = { workspace = true } 15 19 sha2 = { workspace = true }
+346 -311
crates/tranquil-storage/src/lib.rs
··· 4 4 }; 5 5 6 6 use async_trait::async_trait; 7 - use aws_config::BehaviorVersion; 8 - use aws_config::meta::region::RegionProviderChain; 9 - use aws_sdk_s3::Client; 10 - use aws_sdk_s3::primitives::ByteStream; 11 - use aws_sdk_s3::types::CompletedMultipartUpload; 12 - use aws_sdk_s3::types::CompletedPart; 13 7 use bytes::Bytes; 14 8 use futures::Stream; 15 9 use sha2::{Digest, Sha256}; ··· 17 11 use std::pin::Pin; 18 12 use std::sync::Arc; 19 13 20 - const MIN_PART_SIZE: usize = 5 * 1024 * 1024; 21 14 const EXDEV: i32 = 18; 22 15 const CID_SHARD_PREFIX_LEN: usize = 9; 23 16 ··· 109 102 } 110 103 } 111 104 112 - pub struct S3BlobStorage { 113 - client: Client, 114 - bucket: String, 115 - } 105 + #[cfg(feature = "s3")] 106 + mod s3 { 107 + use super::*; 108 + use aws_config::BehaviorVersion; 109 + use aws_config::meta::region::RegionProviderChain; 110 + use aws_sdk_s3::Client; 111 + use aws_sdk_s3::primitives::ByteStream; 112 + use aws_sdk_s3::types::CompletedMultipartUpload; 113 + use aws_sdk_s3::types::CompletedPart; 116 114 117 - impl S3BlobStorage { 118 - pub async fn new() -> Self { 119 - let bucket = std::env::var("S3_BUCKET").expect("S3_BUCKET must be set"); 120 - let client = create_s3_client().await; 121 - Self { client, bucket } 122 - } 115 + const MIN_PART_SIZE: usize = 5 * 1024 * 1024; 123 116 124 - pub async fn with_bucket(bucket: String) -> Self { 125 - let client = create_s3_client().await; 126 - Self { client, bucket } 117 + pub struct S3BlobStorage { 118 + client: Client, 119 + bucket: String, 127 120 } 128 - } 129 121 130 - async fn create_s3_client() -> Client { 131 - let region_provider = RegionProviderChain::default_provider().or_else("us-east-1"); 122 + impl S3BlobStorage { 123 + pub async fn new() -> Self { 124 + let bucket = std::env::var("S3_BUCKET").expect("S3_BUCKET must be set"); 125 + let client = create_s3_client().await; 126 + Self { client, bucket } 127 + } 132 128 133 - let config = aws_config::defaults(BehaviorVersion::latest()) 134 - .region(region_provider) 135 - .load() 136 - .await; 129 + pub async fn with_bucket(bucket: String) -> Self { 130 + let client = create_s3_client().await; 131 + Self { client, bucket } 132 + } 133 + } 137 134 138 - std::env::var("S3_ENDPOINT").ok().map_or_else( 139 - || Client::new(&config), 140 - |endpoint| { 141 - let s3_config = aws_sdk_s3::config::Builder::from(&config) 142 - .endpoint_url(endpoint) 143 - .force_path_style(true) 144 - .build(); 145 - Client::from_conf(s3_config) 146 - }, 147 - ) 148 - } 135 + async fn create_s3_client() -> Client { 136 + let region_provider = RegionProviderChain::default_provider().or_else("us-east-1"); 149 137 150 - pub struct S3BackupStorage { 151 - client: Client, 152 - bucket: String, 153 - } 138 + let config = aws_config::defaults(BehaviorVersion::latest()) 139 + .region(region_provider) 140 + .load() 141 + .await; 154 142 155 - impl S3BackupStorage { 156 - pub async fn new() -> Option<Self> { 157 - let bucket = std::env::var("BACKUP_S3_BUCKET").ok()?; 158 - let client = create_s3_client().await; 159 - Some(Self { client, bucket }) 143 + std::env::var("S3_ENDPOINT").ok().map_or_else( 144 + || Client::new(&config), 145 + |endpoint| { 146 + let s3_config = aws_sdk_s3::config::Builder::from(&config) 147 + .endpoint_url(endpoint) 148 + .force_path_style(true) 149 + .build(); 150 + Client::from_conf(s3_config) 151 + }, 152 + ) 160 153 } 161 - } 162 154 163 - #[async_trait] 164 - impl BackupStorage for S3BackupStorage { 165 - async fn put_backup(&self, did: &str, rev: &str, data: &[u8]) -> Result<String, StorageError> { 166 - let key = format!("{}/{}.car", did, rev); 167 - self.client 168 - .put_object() 169 - .bucket(&self.bucket) 170 - .key(&key) 171 - .body(ByteStream::from(Bytes::copy_from_slice(data))) 172 - .send() 173 - .await 174 - .map_err(|e| StorageError::Backend(e.to_string()))?; 175 - 176 - Ok(key) 155 + pub struct S3BackupStorage { 156 + client: Client, 157 + bucket: String, 177 158 } 178 159 179 - async fn get_backup(&self, storage_key: &str) -> Result<Bytes, StorageError> { 180 - let resp = self 181 - .client 182 - .get_object() 183 - .bucket(&self.bucket) 184 - .key(storage_key) 185 - .send() 186 - .await 187 - .map_err(|e| StorageError::Backend(e.to_string()))?; 188 - 189 - resp.body 190 - .collect() 191 - .await 192 - .map(|agg| agg.into_bytes()) 193 - .map_err(|e| StorageError::Backend(e.to_string())) 160 + impl S3BackupStorage { 161 + pub async fn new() -> Option<Self> { 162 + let bucket = std::env::var("BACKUP_S3_BUCKET").ok()?; 163 + let client = create_s3_client().await; 164 + Some(Self { client, bucket }) 165 + } 194 166 } 195 167 196 - async fn delete_backup(&self, storage_key: &str) -> Result<(), StorageError> { 197 - self.client 198 - .delete_object() 199 - .bucket(&self.bucket) 200 - .key(storage_key) 201 - .send() 202 - .await 203 - .map_err(|e| StorageError::Backend(e.to_string()))?; 168 + #[async_trait] 169 + impl BackupStorage for S3BackupStorage { 170 + async fn put_backup( 171 + &self, 172 + did: &str, 173 + rev: &str, 174 + data: &[u8], 175 + ) -> Result<String, StorageError> { 176 + let key = format!("{}/{}.car", did, rev); 177 + self.client 178 + .put_object() 179 + .bucket(&self.bucket) 180 + .key(&key) 181 + .body(ByteStream::from(Bytes::copy_from_slice(data))) 182 + .send() 183 + .await 184 + .map_err(|e| StorageError::Backend(e.to_string()))?; 204 185 205 - Ok(()) 206 - } 207 - } 186 + Ok(key) 187 + } 208 188 209 - #[async_trait] 210 - impl BlobStorage for S3BlobStorage { 211 - async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> { 212 - self.put_bytes(key, Bytes::copy_from_slice(data)).await 213 - } 189 + async fn get_backup(&self, storage_key: &str) -> Result<Bytes, StorageError> { 190 + let resp = self 191 + .client 192 + .get_object() 193 + .bucket(&self.bucket) 194 + .key(storage_key) 195 + .send() 196 + .await 197 + .map_err(|e| StorageError::Backend(e.to_string()))?; 214 198 215 - async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError> { 216 - self.client 217 - .put_object() 218 - .bucket(&self.bucket) 219 - .key(key) 220 - .body(ByteStream::from(data)) 221 - .send() 222 - .await 223 - .map_err(|e| StorageError::Backend(e.to_string()))?; 199 + resp.body 200 + .collect() 201 + .await 202 + .map(|agg| agg.into_bytes()) 203 + .map_err(|e| StorageError::Backend(e.to_string())) 204 + } 224 205 225 - Ok(()) 226 - } 206 + async fn delete_backup(&self, storage_key: &str) -> Result<(), StorageError> { 207 + self.client 208 + .delete_object() 209 + .bucket(&self.bucket) 210 + .key(storage_key) 211 + .send() 212 + .await 213 + .map_err(|e| StorageError::Backend(e.to_string()))?; 227 214 228 - async fn get(&self, key: &str) -> Result<Vec<u8>, StorageError> { 229 - self.get_bytes(key).await.map(|b| b.to_vec()) 215 + Ok(()) 216 + } 230 217 } 231 218 232 - async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError> { 233 - let resp = self 234 - .client 235 - .get_object() 236 - .bucket(&self.bucket) 237 - .key(key) 238 - .send() 239 - .await 240 - .map_err(|e| StorageError::Backend(e.to_string()))?; 219 + #[async_trait] 220 + impl BlobStorage for S3BlobStorage { 221 + async fn put(&self, key: &str, data: &[u8]) -> Result<(), StorageError> { 222 + self.put_bytes(key, Bytes::copy_from_slice(data)).await 223 + } 241 224 242 - resp.body 243 - .collect() 244 - .await 245 - .map(|agg| agg.into_bytes()) 246 - .map_err(|e| StorageError::Backend(e.to_string())) 247 - } 225 + async fn put_bytes(&self, key: &str, data: Bytes) -> Result<(), StorageError> { 226 + self.client 227 + .put_object() 228 + .bucket(&self.bucket) 229 + .key(key) 230 + .body(ByteStream::from(data)) 231 + .send() 232 + .await 233 + .map_err(|e| StorageError::Backend(e.to_string()))?; 248 234 249 - async fn get_head(&self, key: &str, size: usize) -> Result<Bytes, StorageError> { 250 - let range = format!("bytes=0-{}", size.saturating_sub(1)); 251 - let resp = self 252 - .client 253 - .get_object() 254 - .bucket(&self.bucket) 255 - .key(key) 256 - .range(range) 257 - .send() 258 - .await 259 - .map_err(|e| StorageError::Backend(e.to_string()))?; 235 + Ok(()) 236 + } 260 237 261 - resp.body 262 - .collect() 263 - .await 264 - .map(|agg| agg.into_bytes()) 265 - .map_err(|e| StorageError::Backend(e.to_string())) 266 - } 238 + async fn get(&self, key: &str) -> Result<Vec<u8>, StorageError> { 239 + self.get_bytes(key).await.map(|b| b.to_vec()) 240 + } 267 241 268 - async fn delete(&self, key: &str) -> Result<(), StorageError> { 269 - self.client 270 - .delete_object() 271 - .bucket(&self.bucket) 272 - .key(key) 273 - .send() 274 - .await 275 - .map_err(|e| StorageError::Backend(e.to_string()))?; 242 + async fn get_bytes(&self, key: &str) -> Result<Bytes, StorageError> { 243 + let resp = self 244 + .client 245 + .get_object() 246 + .bucket(&self.bucket) 247 + .key(key) 248 + .send() 249 + .await 250 + .map_err(|e| StorageError::Backend(e.to_string()))?; 276 251 277 - Ok(()) 278 - } 252 + resp.body 253 + .collect() 254 + .await 255 + .map(|agg| agg.into_bytes()) 256 + .map_err(|e| StorageError::Backend(e.to_string())) 257 + } 279 258 280 - async fn put_stream( 281 - &self, 282 - key: &str, 283 - stream: Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>, 284 - ) -> Result<StreamUploadResult, StorageError> { 285 - use futures::StreamExt; 286 - 287 - let create_resp = self 288 - .client 289 - .create_multipart_upload() 290 - .bucket(&self.bucket) 291 - .key(key) 292 - .send() 293 - .await 294 - .map_err(|e| { 295 - StorageError::Backend(format!("Failed to create multipart upload: {}", e)) 296 - })?; 297 - 298 - let upload_id = create_resp 299 - .upload_id() 300 - .ok_or_else(|| StorageError::Backend("No upload ID returned".to_string()))? 301 - .to_string(); 302 - 303 - let upload_part = |client: &Client, 304 - bucket: &str, 305 - key: &str, 306 - upload_id: &str, 307 - part_num: i32, 308 - data: Vec<u8>| 309 - -> std::pin::Pin< 310 - Box<dyn std::future::Future<Output = Result<CompletedPart, StorageError>> + Send>, 311 - > { 312 - let client = client.clone(); 313 - let bucket = bucket.to_string(); 314 - let key = key.to_string(); 315 - let upload_id = upload_id.to_string(); 316 - Box::pin(async move { 317 - let resp = client 318 - .upload_part() 319 - .bucket(&bucket) 320 - .key(&key) 321 - .upload_id(&upload_id) 322 - .part_number(part_num) 323 - .body(ByteStream::from(data)) 324 - .send() 325 - .await 326 - .map_err(|e| StorageError::Backend(format!("Failed to upload part: {}", e)))?; 327 - 328 - let etag = resp 329 - .e_tag() 330 - .ok_or_else(|| StorageError::Backend("No ETag returned for part".to_string()))? 331 - .to_string(); 332 - 333 - Ok(CompletedPart::builder() 334 - .part_number(part_num) 335 - .e_tag(etag) 336 - .build()) 337 - }) 338 - }; 259 + async fn get_head(&self, key: &str, size: usize) -> Result<Bytes, StorageError> { 260 + let range = format!("bytes=0-{}", size.saturating_sub(1)); 261 + let resp = self 262 + .client 263 + .get_object() 264 + .bucket(&self.bucket) 265 + .key(key) 266 + .range(range) 267 + .send() 268 + .await 269 + .map_err(|e| StorageError::Backend(e.to_string()))?; 339 270 340 - struct UploadState { 341 - hasher: Sha256, 342 - total_size: u64, 343 - part_number: i32, 344 - completed_parts: Vec<CompletedPart>, 345 - buffer: Vec<u8>, 271 + resp.body 272 + .collect() 273 + .await 274 + .map(|agg| agg.into_bytes()) 275 + .map_err(|e| StorageError::Backend(e.to_string())) 346 276 } 347 277 348 - let initial_state = UploadState { 349 - hasher: Sha256::new(), 350 - total_size: 0, 351 - part_number: 1, 352 - completed_parts: Vec::new(), 353 - buffer: Vec::with_capacity(MIN_PART_SIZE), 354 - }; 278 + async fn delete(&self, key: &str) -> Result<(), StorageError> { 279 + self.client 280 + .delete_object() 281 + .bucket(&self.bucket) 282 + .key(key) 283 + .send() 284 + .await 285 + .map_err(|e| StorageError::Backend(e.to_string()))?; 286 + 287 + Ok(()) 288 + } 289 + 290 + async fn put_stream( 291 + &self, 292 + key: &str, 293 + stream: Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>, 294 + ) -> Result<StreamUploadResult, StorageError> { 295 + use futures::StreamExt; 355 296 356 - let abort_upload = || async { 357 - let _ = self 297 + let create_resp = self 358 298 .client 359 - .abort_multipart_upload() 299 + .create_multipart_upload() 360 300 .bucket(&self.bucket) 361 301 .key(key) 362 - .upload_id(&upload_id) 363 302 .send() 364 - .await; 365 - }; 303 + .await 304 + .map_err(|e| { 305 + StorageError::Backend(format!("Failed to create multipart upload: {}", e)) 306 + })?; 307 + 308 + let upload_id = create_resp 309 + .upload_id() 310 + .ok_or_else(|| StorageError::Backend("No upload ID returned".to_string()))? 311 + .to_string(); 312 + 313 + let upload_part = |client: &Client, 314 + bucket: &str, 315 + key: &str, 316 + upload_id: &str, 317 + part_num: i32, 318 + data: Vec<u8>| 319 + -> std::pin::Pin< 320 + Box<dyn std::future::Future<Output = Result<CompletedPart, StorageError>> + Send>, 321 + > { 322 + let client = client.clone(); 323 + let bucket = bucket.to_string(); 324 + let key = key.to_string(); 325 + let upload_id = upload_id.to_string(); 326 + Box::pin(async move { 327 + let resp = client 328 + .upload_part() 329 + .bucket(&bucket) 330 + .key(&key) 331 + .upload_id(&upload_id) 332 + .part_number(part_num) 333 + .body(ByteStream::from(data)) 334 + .send() 335 + .await 336 + .map_err(|e| { 337 + StorageError::Backend(format!("Failed to upload part: {}", e)) 338 + })?; 339 + 340 + let etag = resp 341 + .e_tag() 342 + .ok_or_else(|| { 343 + StorageError::Backend("No ETag returned for part".to_string()) 344 + })? 345 + .to_string(); 346 + 347 + Ok(CompletedPart::builder() 348 + .part_number(part_num) 349 + .e_tag(etag) 350 + .build()) 351 + }) 352 + }; 353 + 354 + struct UploadState { 355 + hasher: Sha256, 356 + total_size: u64, 357 + part_number: i32, 358 + completed_parts: Vec<CompletedPart>, 359 + buffer: Vec<u8>, 360 + } 366 361 367 - let result: Result<UploadState, StorageError> = { 368 - let mut state = initial_state; 369 - 370 - let chunk_results: Vec<Result<Bytes, std::io::Error>> = stream.collect().await; 371 - 372 - for chunk_result in chunk_results { 373 - match chunk_result { 374 - Ok(chunk) => { 375 - state.hasher.update(&chunk); 376 - state.total_size += chunk.len() as u64; 377 - state.buffer.extend_from_slice(&chunk); 378 - 379 - if state.buffer.len() >= MIN_PART_SIZE { 380 - let part_data = std::mem::replace( 381 - &mut state.buffer, 382 - Vec::with_capacity(MIN_PART_SIZE), 383 - ); 384 - let part = upload_part( 385 - &self.client, 386 - &self.bucket, 387 - key, 388 - &upload_id, 389 - state.part_number, 390 - part_data, 391 - ) 392 - .await?; 393 - state.completed_parts.push(part); 394 - state.part_number += 1; 362 + let initial_state = UploadState { 363 + hasher: Sha256::new(), 364 + total_size: 0, 365 + part_number: 1, 366 + completed_parts: Vec::new(), 367 + buffer: Vec::with_capacity(MIN_PART_SIZE), 368 + }; 369 + 370 + let abort_upload = || async { 371 + let _ = self 372 + .client 373 + .abort_multipart_upload() 374 + .bucket(&self.bucket) 375 + .key(key) 376 + .upload_id(&upload_id) 377 + .send() 378 + .await; 379 + }; 380 + 381 + let result: Result<UploadState, StorageError> = { 382 + let mut state = initial_state; 383 + 384 + let chunk_results: Vec<Result<Bytes, std::io::Error>> = stream.collect().await; 385 + 386 + for chunk_result in chunk_results { 387 + match chunk_result { 388 + Ok(chunk) => { 389 + state.hasher.update(&chunk); 390 + state.total_size += chunk.len() as u64; 391 + state.buffer.extend_from_slice(&chunk); 392 + 393 + if state.buffer.len() >= MIN_PART_SIZE { 394 + let part_data = std::mem::replace( 395 + &mut state.buffer, 396 + Vec::with_capacity(MIN_PART_SIZE), 397 + ); 398 + let part = upload_part( 399 + &self.client, 400 + &self.bucket, 401 + key, 402 + &upload_id, 403 + state.part_number, 404 + part_data, 405 + ) 406 + .await?; 407 + state.completed_parts.push(part); 408 + state.part_number += 1; 409 + } 410 + } 411 + Err(e) => { 412 + abort_upload().await; 413 + return Err(StorageError::Io(e)); 395 414 } 396 - } 397 - Err(e) => { 398 - abort_upload().await; 399 - return Err(StorageError::Io(e)); 400 415 } 401 416 } 417 + 418 + Ok(state) 419 + }; 420 + 421 + let mut state = result?; 422 + 423 + if !state.buffer.is_empty() { 424 + let part = upload_part( 425 + &self.client, 426 + &self.bucket, 427 + key, 428 + &upload_id, 429 + state.part_number, 430 + std::mem::take(&mut state.buffer), 431 + ) 432 + .await?; 433 + state.completed_parts.push(part); 402 434 } 403 435 404 - Ok(state) 405 - }; 436 + if state.completed_parts.is_empty() { 437 + abort_upload().await; 438 + return Err(StorageError::Other("Empty upload".to_string())); 439 + } 406 440 407 - let mut state = result?; 408 - 409 - if !state.buffer.is_empty() { 410 - let part = upload_part( 411 - &self.client, 412 - &self.bucket, 413 - key, 414 - &upload_id, 415 - state.part_number, 416 - std::mem::take(&mut state.buffer), 417 - ) 418 - .await?; 419 - state.completed_parts.push(part); 420 - } 441 + let completed_upload = CompletedMultipartUpload::builder() 442 + .set_parts(Some(state.completed_parts)) 443 + .build(); 421 444 422 - if state.completed_parts.is_empty() { 423 - abort_upload().await; 424 - return Err(StorageError::Other("Empty upload".to_string())); 445 + self.client 446 + .complete_multipart_upload() 447 + .bucket(&self.bucket) 448 + .key(key) 449 + .upload_id(&upload_id) 450 + .multipart_upload(completed_upload) 451 + .send() 452 + .await 453 + .map_err(|e| { 454 + StorageError::Backend(format!("Failed to complete multipart upload: {}", e)) 455 + })?; 456 + 457 + let hash: [u8; 32] = state.hasher.finalize().into(); 458 + Ok(StreamUploadResult { 459 + sha256_hash: hash, 460 + size: state.total_size, 461 + }) 425 462 } 426 463 427 - let completed_upload = CompletedMultipartUpload::builder() 428 - .set_parts(Some(state.completed_parts)) 429 - .build(); 430 - 431 - self.client 432 - .complete_multipart_upload() 433 - .bucket(&self.bucket) 434 - .key(key) 435 - .upload_id(&upload_id) 436 - .multipart_upload(completed_upload) 437 - .send() 438 - .await 439 - .map_err(|e| { 440 - StorageError::Backend(format!("Failed to complete multipart upload: {}", e)) 441 - })?; 464 + async fn copy(&self, src_key: &str, dst_key: &str) -> Result<(), StorageError> { 465 + let copy_source = format!("{}/{}", self.bucket, src_key); 442 466 443 - let hash: [u8; 32] = state.hasher.finalize().into(); 444 - Ok(StreamUploadResult { 445 - sha256_hash: hash, 446 - size: state.total_size, 447 - }) 448 - } 449 - 450 - async fn copy(&self, src_key: &str, dst_key: &str) -> Result<(), StorageError> { 451 - let copy_source = format!("{}/{}", self.bucket, src_key); 452 - 453 - self.client 454 - .copy_object() 455 - .bucket(&self.bucket) 456 - .copy_source(&copy_source) 457 - .key(dst_key) 458 - .send() 459 - .await 460 - .map_err(|e| StorageError::Backend(format!("Failed to copy object: {}", e)))?; 467 + self.client 468 + .copy_object() 469 + .bucket(&self.bucket) 470 + .copy_source(&copy_source) 471 + .key(dst_key) 472 + .send() 473 + .await 474 + .map_err(|e| StorageError::Backend(format!("Failed to copy object: {}", e)))?; 461 475 462 - Ok(()) 476 + Ok(()) 477 + } 463 478 } 464 479 } 465 480 481 + #[cfg(feature = "s3")] 482 + pub use s3::{S3BackupStorage, S3BlobStorage}; 483 + 466 484 pub struct FilesystemBlobStorage { 467 485 base_path: PathBuf, 468 486 tmp_path: PathBuf, ··· 686 704 let backend = std::env::var("BLOB_STORAGE_BACKEND").unwrap_or_else(|_| "filesystem".into()); 687 705 688 706 match backend.as_str() { 707 + #[cfg(feature = "s3")] 689 708 "s3" => { 690 709 tracing::info!("Initializing S3 blob storage"); 691 710 Arc::new(S3BlobStorage::new().await) 692 711 } 712 + #[cfg(not(feature = "s3"))] 713 + "s3" => { 714 + panic!( 715 + "BLOB_STORAGE_BACKEND=s3 but binary was compiled without s3 feature. \ 716 + Rebuild with --features s3 to enable S3 storage." 717 + ); 718 + } 693 719 _ => { 694 720 tracing::info!("Initializing filesystem blob storage"); 695 721 FilesystemBlobStorage::from_env() ··· 719 745 let backend = std::env::var("BACKUP_STORAGE_BACKEND").unwrap_or_else(|_| "filesystem".into()); 720 746 721 747 match backend.as_str() { 748 + #[cfg(feature = "s3")] 722 749 "s3" => S3BackupStorage::new().await.map_or_else( 723 750 || { 724 751 tracing::error!( ··· 732 759 Some(Arc::new(storage) as Arc<dyn BackupStorage>) 733 760 }, 734 761 ), 762 + #[cfg(not(feature = "s3"))] 763 + "s3" => { 764 + tracing::error!( 765 + "BACKUP_STORAGE_BACKEND=s3 but binary was compiled without s3 feature. \ 766 + Backups will be disabled." 767 + ); 768 + None 769 + } 735 770 _ => FilesystemBackupStorage::from_env().await.map_or_else( 736 771 |e| { 737 772 tracing::error!(

History

1 round 0 comments
sign up or login to add to the discussion
lewis.moe submitted #0
1 commit
expand
fix: smaller docker img
expand 0 comments
pull request successfully merged