this repo has no description
1use super::validation::validate_record; 2use super::write::has_verified_comms_channel; 3use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log}; 4use crate::repo::tracking::TrackingBlockStore; 5use crate::state::AppState; 6use axum::{ 7 Json, 8 extract::State, 9 http::StatusCode, 10 response::{IntoResponse, Response}, 11}; 12use cid::Cid; 13use jacquard::types::{ 14 integer::LimitedU32, 15 string::{Nsid, Tid}, 16}; 17use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; 18use serde::{Deserialize, Serialize}; 19use serde_json::json; 20use std::str::FromStr; 21use std::sync::Arc; 22use tracing::error; 23 24const MAX_BATCH_WRITES: usize = 200; 25 26#[derive(Deserialize)] 27#[serde(tag = "$type")] 28pub enum WriteOp { 29 #[serde(rename = "com.atproto.repo.applyWrites#create")] 30 Create { 31 collection: String, 32 rkey: Option<String>, 33 value: serde_json::Value, 34 }, 35 #[serde(rename = "com.atproto.repo.applyWrites#update")] 36 Update { 37 collection: String, 38 rkey: String, 39 value: serde_json::Value, 40 }, 41 #[serde(rename = "com.atproto.repo.applyWrites#delete")] 42 Delete { collection: String, rkey: String }, 43} 44 45#[derive(Deserialize)] 46#[serde(rename_all = "camelCase")] 47pub struct ApplyWritesInput { 48 pub repo: String, 49 pub validate: Option<bool>, 50 pub writes: Vec<WriteOp>, 51 pub swap_commit: Option<String>, 52} 53 54#[derive(Serialize)] 55#[serde(tag = "$type")] 56pub enum WriteResult { 57 #[serde(rename = "com.atproto.repo.applyWrites#createResult")] 58 CreateResult { uri: String, cid: String }, 59 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")] 60 UpdateResult { uri: String, cid: String }, 61 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")] 62 DeleteResult {}, 63} 64 65#[derive(Serialize)] 66pub struct ApplyWritesOutput { 67 pub commit: CommitInfo, 68 pub results: Vec<WriteResult>, 69} 70 71#[derive(Serialize)] 72pub struct CommitInfo { 73 pub cid: String, 74 pub rev: String, 75} 76 77pub async fn apply_writes( 78 State(state): State<AppState>, 79 headers: axum::http::HeaderMap, 80 Json(input): Json<ApplyWritesInput>, 81) -> Response { 82 let token = match crate::auth::extract_bearer_token_from_header( 83 headers.get("Authorization").and_then(|h| h.to_str().ok()), 84 ) { 85 Some(t) => t, 86 None => { 87 return ( 88 StatusCode::UNAUTHORIZED, 89 Json(json!({"error": "AuthenticationRequired"})), 90 ) 91 .into_response(); 92 } 93 }; 94 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 95 Ok(user) => user, 96 Err(_) => { 97 return ( 98 StatusCode::UNAUTHORIZED, 99 Json(json!({"error": "AuthenticationFailed"})), 100 ) 101 .into_response(); 102 } 103 }; 104 let did = auth_user.did; 105 if input.repo != did { 106 return ( 107 StatusCode::FORBIDDEN, 108 Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"})), 109 ) 110 .into_response(); 111 } 112 match has_verified_comms_channel(&state.db, &did).await { 113 Ok(true) => {} 114 Ok(false) => { 115 return ( 116 StatusCode::FORBIDDEN, 117 Json(json!({ 118 "error": "AccountNotVerified", 119 "message": "You must verify at least one notification channel (email, Discord, Telegram, or Signal) before creating records" 120 })), 121 ) 122 .into_response(); 123 } 124 Err(e) => { 125 error!("DB error checking notification channels: {}", e); 126 return ( 127 StatusCode::INTERNAL_SERVER_ERROR, 128 Json(json!({"error": "InternalError"})), 129 ) 130 .into_response(); 131 } 132 } 133 if input.writes.is_empty() { 134 return ( 135 StatusCode::BAD_REQUEST, 136 Json(json!({"error": "InvalidRequest", "message": "writes array is empty"})), 137 ) 138 .into_response(); 139 } 140 if input.writes.len() > MAX_BATCH_WRITES { 141 return ( 142 StatusCode::BAD_REQUEST, 143 Json(json!({"error": "InvalidRequest", "message": format!("Too many writes (max {})", MAX_BATCH_WRITES)})), 144 ) 145 .into_response(); 146 } 147 let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 148 .fetch_optional(&state.db) 149 .await 150 { 151 Ok(Some(id)) => id, 152 _ => { 153 return ( 154 StatusCode::INTERNAL_SERVER_ERROR, 155 Json(json!({"error": "InternalError", "message": "User not found"})), 156 ) 157 .into_response(); 158 } 159 }; 160 let root_cid_str: String = match sqlx::query_scalar!( 161 "SELECT repo_root_cid FROM repos WHERE user_id = $1", 162 user_id 163 ) 164 .fetch_optional(&state.db) 165 .await 166 { 167 Ok(Some(cid_str)) => cid_str, 168 _ => { 169 return ( 170 StatusCode::INTERNAL_SERVER_ERROR, 171 Json(json!({"error": "InternalError", "message": "Repo root not found"})), 172 ) 173 .into_response(); 174 } 175 }; 176 let current_root_cid = match Cid::from_str(&root_cid_str) { 177 Ok(c) => c, 178 Err(_) => { 179 return ( 180 StatusCode::INTERNAL_SERVER_ERROR, 181 Json(json!({"error": "InternalError", "message": "Invalid repo root CID"})), 182 ) 183 .into_response(); 184 } 185 }; 186 if let Some(swap_commit) = &input.swap_commit 187 && Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 188 return ( 189 StatusCode::CONFLICT, 190 Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 191 ) 192 .into_response(); 193 } 194 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 195 let commit_bytes = match tracking_store.get(&current_root_cid).await { 196 Ok(Some(b)) => b, 197 _ => { 198 return ( 199 StatusCode::INTERNAL_SERVER_ERROR, 200 Json(json!({"error": "InternalError", "message": "Commit block not found"})), 201 ) 202 .into_response(); 203 } 204 }; 205 let commit = match Commit::from_cbor(&commit_bytes) { 206 Ok(c) => c, 207 _ => { 208 return ( 209 StatusCode::INTERNAL_SERVER_ERROR, 210 Json(json!({"error": "InternalError", "message": "Failed to parse commit"})), 211 ) 212 .into_response(); 213 } 214 }; 215 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 216 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 217 let mut results: Vec<WriteResult> = Vec::new(); 218 let mut ops: Vec<RecordOp> = Vec::new(); 219 let mut modified_keys: Vec<String> = Vec::new(); 220 for write in &input.writes { 221 match write { 222 WriteOp::Create { 223 collection, 224 rkey, 225 value, 226 } => { 227 if input.validate.unwrap_or(true) 228 && let Err(err_response) = validate_record(value, collection) { 229 return *err_response; 230 } 231 let rkey = rkey 232 .clone() 233 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string()); 234 let mut record_bytes = Vec::new(); 235 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 236 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); 237 } 238 let record_cid = match tracking_store.put(&record_bytes).await { 239 Ok(c) => c, 240 Err(_) => return ( 241 StatusCode::INTERNAL_SERVER_ERROR, 242 Json( 243 json!({"error": "InternalError", "message": "Failed to store record"}), 244 ), 245 ) 246 .into_response(), 247 }; 248 let collection_nsid = match collection.parse::<Nsid>() { 249 Ok(n) => n, 250 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(), 251 }; 252 let key = format!("{}/{}", collection_nsid, rkey); 253 modified_keys.push(key.clone()); 254 mst = match mst.add(&key, record_cid).await { 255 Ok(m) => m, 256 Err(_) => return ( 257 StatusCode::INTERNAL_SERVER_ERROR, 258 Json(json!({"error": "InternalError", "message": "Failed to add to MST"})), 259 ) 260 .into_response(), 261 }; 262 let uri = format!("at://{}/{}/{}", did, collection, rkey); 263 results.push(WriteResult::CreateResult { 264 uri, 265 cid: record_cid.to_string(), 266 }); 267 ops.push(RecordOp::Create { 268 collection: collection.clone(), 269 rkey, 270 cid: record_cid, 271 }); 272 } 273 WriteOp::Update { 274 collection, 275 rkey, 276 value, 277 } => { 278 if input.validate.unwrap_or(true) 279 && let Err(err_response) = validate_record(value, collection) { 280 return *err_response; 281 } 282 let mut record_bytes = Vec::new(); 283 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 284 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); 285 } 286 let record_cid = match tracking_store.put(&record_bytes).await { 287 Ok(c) => c, 288 Err(_) => return ( 289 StatusCode::INTERNAL_SERVER_ERROR, 290 Json( 291 json!({"error": "InternalError", "message": "Failed to store record"}), 292 ), 293 ) 294 .into_response(), 295 }; 296 let collection_nsid = match collection.parse::<Nsid>() { 297 Ok(n) => n, 298 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(), 299 }; 300 let key = format!("{}/{}", collection_nsid, rkey); 301 modified_keys.push(key.clone()); 302 let prev_record_cid = mst.get(&key).await.ok().flatten(); 303 mst = match mst.update(&key, record_cid).await { 304 Ok(m) => m, 305 Err(_) => return ( 306 StatusCode::INTERNAL_SERVER_ERROR, 307 Json(json!({"error": "InternalError", "message": "Failed to update MST"})), 308 ) 309 .into_response(), 310 }; 311 let uri = format!("at://{}/{}/{}", did, collection, rkey); 312 results.push(WriteResult::UpdateResult { 313 uri, 314 cid: record_cid.to_string(), 315 }); 316 ops.push(RecordOp::Update { 317 collection: collection.clone(), 318 rkey: rkey.clone(), 319 cid: record_cid, 320 prev: prev_record_cid, 321 }); 322 } 323 WriteOp::Delete { collection, rkey } => { 324 let collection_nsid = match collection.parse::<Nsid>() { 325 Ok(n) => n, 326 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(), 327 }; 328 let key = format!("{}/{}", collection_nsid, rkey); 329 modified_keys.push(key.clone()); 330 let prev_record_cid = mst.get(&key).await.ok().flatten(); 331 mst = match mst.delete(&key).await { 332 Ok(m) => m, 333 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to delete from MST"}))).into_response(), 334 }; 335 results.push(WriteResult::DeleteResult {}); 336 ops.push(RecordOp::Delete { 337 collection: collection.clone(), 338 rkey: rkey.clone(), 339 prev: prev_record_cid, 340 }); 341 } 342 } 343 } 344 let new_mst_root = match mst.persist().await { 345 Ok(c) => c, 346 Err(_) => { 347 return ( 348 StatusCode::INTERNAL_SERVER_ERROR, 349 Json(json!({"error": "InternalError", "message": "Failed to persist MST"})), 350 ) 351 .into_response(); 352 } 353 }; 354 let mut relevant_blocks = std::collections::BTreeMap::new(); 355 for key in &modified_keys { 356 if mst.blocks_for_path(key, &mut relevant_blocks).await.is_err() { 357 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 358 } 359 if original_mst 360 .blocks_for_path(key, &mut relevant_blocks) 361 .await 362 .is_err() 363 { 364 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 365 } 366 } 367 let mut written_cids = tracking_store.get_all_relevant_cids(); 368 for cid in relevant_blocks.keys() { 369 if !written_cids.contains(cid) { 370 written_cids.push(*cid); 371 } 372 } 373 let written_cids_str = written_cids 374 .iter() 375 .map(|c| c.to_string()) 376 .collect::<Vec<_>>(); 377 let commit_res = match commit_and_log( 378 &state, 379 CommitParams { 380 did: &did, 381 user_id, 382 current_root_cid: Some(current_root_cid), 383 prev_data_cid: Some(commit.data), 384 new_mst_root, 385 ops, 386 blocks_cids: &written_cids_str, 387 }, 388 ) 389 .await 390 { 391 Ok(res) => res, 392 Err(e) => { 393 error!("Commit failed: {}", e); 394 return ( 395 StatusCode::INTERNAL_SERVER_ERROR, 396 Json(json!({"error": "InternalError", "message": "Failed to commit changes"})), 397 ) 398 .into_response(); 399 } 400 }; 401 ( 402 StatusCode::OK, 403 Json(ApplyWritesOutput { 404 commit: CommitInfo { 405 cid: commit_res.commit_cid.to_string(), 406 rev: commit_res.rev, 407 }, 408 results, 409 }), 410 ) 411 .into_response() 412}