this repo has no description
1use super::validation::validate_record; 2use super::write::has_verified_comms_channel; 3use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log}; 4use crate::repo::tracking::TrackingBlockStore; 5use crate::state::AppState; 6use axum::{ 7 Json, 8 extract::State, 9 http::StatusCode, 10 response::{IntoResponse, Response}, 11}; 12use cid::Cid; 13use jacquard::types::{ 14 integer::LimitedU32, 15 string::{Nsid, Tid}, 16}; 17use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; 18use serde::{Deserialize, Serialize}; 19use serde_json::json; 20use std::str::FromStr; 21use std::sync::Arc; 22use tracing::{error, info}; 23 24const MAX_BATCH_WRITES: usize = 200; 25 26#[derive(Deserialize)] 27#[serde(tag = "$type")] 28pub enum WriteOp { 29 #[serde(rename = "com.atproto.repo.applyWrites#create")] 30 Create { 31 collection: String, 32 rkey: Option<String>, 33 value: serde_json::Value, 34 }, 35 #[serde(rename = "com.atproto.repo.applyWrites#update")] 36 Update { 37 collection: String, 38 rkey: String, 39 value: serde_json::Value, 40 }, 41 #[serde(rename = "com.atproto.repo.applyWrites#delete")] 42 Delete { collection: String, rkey: String }, 43} 44 45#[derive(Deserialize)] 46#[serde(rename_all = "camelCase")] 47pub struct ApplyWritesInput { 48 pub repo: String, 49 pub validate: Option<bool>, 50 pub writes: Vec<WriteOp>, 51 pub swap_commit: Option<String>, 52} 53 54#[derive(Serialize)] 55#[serde(tag = "$type")] 56pub enum WriteResult { 57 #[serde(rename = "com.atproto.repo.applyWrites#createResult")] 58 CreateResult { uri: String, cid: String }, 59 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")] 60 UpdateResult { uri: String, cid: String }, 61 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")] 62 DeleteResult {}, 63} 64 65#[derive(Serialize)] 66pub struct ApplyWritesOutput { 67 pub commit: CommitInfo, 68 pub results: Vec<WriteResult>, 69} 70 71#[derive(Serialize)] 72pub struct CommitInfo { 73 pub cid: String, 74 pub rev: String, 75} 76 77pub async fn apply_writes( 78 State(state): State<AppState>, 79 headers: axum::http::HeaderMap, 80 Json(input): Json<ApplyWritesInput>, 81) -> Response { 82 info!( 83 "apply_writes called: repo={}, writes={}", 84 input.repo, 85 input.writes.len() 86 ); 87 let token = match crate::auth::extract_bearer_token_from_header( 88 headers.get("Authorization").and_then(|h| h.to_str().ok()), 89 ) { 90 Some(t) => t, 91 None => { 92 return ( 93 StatusCode::UNAUTHORIZED, 94 Json(json!({"error": "AuthenticationRequired"})), 95 ) 96 .into_response(); 97 } 98 }; 99 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 100 Ok(user) => user, 101 Err(_) => { 102 return ( 103 StatusCode::UNAUTHORIZED, 104 Json(json!({"error": "AuthenticationFailed"})), 105 ) 106 .into_response(); 107 } 108 }; 109 let did = auth_user.did.clone(); 110 let is_oauth = auth_user.is_oauth; 111 let scope = auth_user.scope; 112 if input.repo != did { 113 return ( 114 StatusCode::FORBIDDEN, 115 Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"})), 116 ) 117 .into_response(); 118 } 119 match has_verified_comms_channel(&state.db, &did).await { 120 Ok(true) => {} 121 Ok(false) => { 122 return ( 123 StatusCode::FORBIDDEN, 124 Json(json!({ 125 "error": "AccountNotVerified", 126 "message": "You must verify at least one notification channel (email, Discord, Telegram, or Signal) before creating records" 127 })), 128 ) 129 .into_response(); 130 } 131 Err(e) => { 132 error!("DB error checking notification channels: {}", e); 133 return ( 134 StatusCode::INTERNAL_SERVER_ERROR, 135 Json(json!({"error": "InternalError"})), 136 ) 137 .into_response(); 138 } 139 } 140 if input.writes.is_empty() { 141 return ( 142 StatusCode::BAD_REQUEST, 143 Json(json!({"error": "InvalidRequest", "message": "writes array is empty"})), 144 ) 145 .into_response(); 146 } 147 if input.writes.len() > MAX_BATCH_WRITES { 148 return ( 149 StatusCode::BAD_REQUEST, 150 Json(json!({"error": "InvalidRequest", "message": format!("Too many writes (max {})", MAX_BATCH_WRITES)})), 151 ) 152 .into_response(); 153 } 154 155 let has_custom_scope = scope 156 .as_ref() 157 .map(|s| s != "com.atproto.access") 158 .unwrap_or(false); 159 if is_oauth || has_custom_scope { 160 use std::collections::HashSet; 161 let create_collections: HashSet<&str> = input 162 .writes 163 .iter() 164 .filter_map(|w| { 165 if let WriteOp::Create { collection, .. } = w { 166 Some(collection.as_str()) 167 } else { 168 None 169 } 170 }) 171 .collect(); 172 let update_collections: HashSet<&str> = input 173 .writes 174 .iter() 175 .filter_map(|w| { 176 if let WriteOp::Update { collection, .. } = w { 177 Some(collection.as_str()) 178 } else { 179 None 180 } 181 }) 182 .collect(); 183 let delete_collections: HashSet<&str> = input 184 .writes 185 .iter() 186 .filter_map(|w| { 187 if let WriteOp::Delete { collection, .. } = w { 188 Some(collection.as_str()) 189 } else { 190 None 191 } 192 }) 193 .collect(); 194 195 for collection in create_collections { 196 if let Err(e) = crate::auth::scope_check::check_repo_scope( 197 is_oauth, 198 scope.as_deref(), 199 crate::oauth::RepoAction::Create, 200 collection, 201 ) { 202 return e; 203 } 204 } 205 for collection in update_collections { 206 if let Err(e) = crate::auth::scope_check::check_repo_scope( 207 is_oauth, 208 scope.as_deref(), 209 crate::oauth::RepoAction::Update, 210 collection, 211 ) { 212 return e; 213 } 214 } 215 for collection in delete_collections { 216 if let Err(e) = crate::auth::scope_check::check_repo_scope( 217 is_oauth, 218 scope.as_deref(), 219 crate::oauth::RepoAction::Delete, 220 collection, 221 ) { 222 return e; 223 } 224 } 225 } 226 227 let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 228 .fetch_optional(&state.db) 229 .await 230 { 231 Ok(Some(id)) => id, 232 _ => { 233 return ( 234 StatusCode::INTERNAL_SERVER_ERROR, 235 Json(json!({"error": "InternalError", "message": "User not found"})), 236 ) 237 .into_response(); 238 } 239 }; 240 let root_cid_str: String = match sqlx::query_scalar!( 241 "SELECT repo_root_cid FROM repos WHERE user_id = $1", 242 user_id 243 ) 244 .fetch_optional(&state.db) 245 .await 246 { 247 Ok(Some(cid_str)) => cid_str, 248 _ => { 249 return ( 250 StatusCode::INTERNAL_SERVER_ERROR, 251 Json(json!({"error": "InternalError", "message": "Repo root not found"})), 252 ) 253 .into_response(); 254 } 255 }; 256 let current_root_cid = match Cid::from_str(&root_cid_str) { 257 Ok(c) => c, 258 Err(_) => { 259 return ( 260 StatusCode::INTERNAL_SERVER_ERROR, 261 Json(json!({"error": "InternalError", "message": "Invalid repo root CID"})), 262 ) 263 .into_response(); 264 } 265 }; 266 if let Some(swap_commit) = &input.swap_commit 267 && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 268 { 269 return ( 270 StatusCode::CONFLICT, 271 Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 272 ) 273 .into_response(); 274 } 275 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 276 let commit_bytes = match tracking_store.get(&current_root_cid).await { 277 Ok(Some(b)) => b, 278 _ => { 279 return ( 280 StatusCode::INTERNAL_SERVER_ERROR, 281 Json(json!({"error": "InternalError", "message": "Commit block not found"})), 282 ) 283 .into_response(); 284 } 285 }; 286 let commit = match Commit::from_cbor(&commit_bytes) { 287 Ok(c) => c, 288 _ => { 289 return ( 290 StatusCode::INTERNAL_SERVER_ERROR, 291 Json(json!({"error": "InternalError", "message": "Failed to parse commit"})), 292 ) 293 .into_response(); 294 } 295 }; 296 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 297 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 298 let mut results: Vec<WriteResult> = Vec::new(); 299 let mut ops: Vec<RecordOp> = Vec::new(); 300 let mut modified_keys: Vec<String> = Vec::new(); 301 for write in &input.writes { 302 match write { 303 WriteOp::Create { 304 collection, 305 rkey, 306 value, 307 } => { 308 if input.validate.unwrap_or(true) 309 && let Err(err_response) = validate_record(value, collection) 310 { 311 return *err_response; 312 } 313 let rkey = rkey 314 .clone() 315 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string()); 316 let mut record_bytes = Vec::new(); 317 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 318 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); 319 } 320 let record_cid = match tracking_store.put(&record_bytes).await { 321 Ok(c) => c, 322 Err(_) => return ( 323 StatusCode::INTERNAL_SERVER_ERROR, 324 Json( 325 json!({"error": "InternalError", "message": "Failed to store record"}), 326 ), 327 ) 328 .into_response(), 329 }; 330 let collection_nsid = match collection.parse::<Nsid>() { 331 Ok(n) => n, 332 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(), 333 }; 334 let key = format!("{}/{}", collection_nsid, rkey); 335 modified_keys.push(key.clone()); 336 mst = match mst.add(&key, record_cid).await { 337 Ok(m) => m, 338 Err(_) => return ( 339 StatusCode::INTERNAL_SERVER_ERROR, 340 Json(json!({"error": "InternalError", "message": "Failed to add to MST"})), 341 ) 342 .into_response(), 343 }; 344 let uri = format!("at://{}/{}/{}", did, collection, rkey); 345 results.push(WriteResult::CreateResult { 346 uri, 347 cid: record_cid.to_string(), 348 }); 349 ops.push(RecordOp::Create { 350 collection: collection.clone(), 351 rkey, 352 cid: record_cid, 353 }); 354 } 355 WriteOp::Update { 356 collection, 357 rkey, 358 value, 359 } => { 360 if input.validate.unwrap_or(true) 361 && let Err(err_response) = validate_record(value, collection) 362 { 363 return *err_response; 364 } 365 let mut record_bytes = Vec::new(); 366 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 367 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); 368 } 369 let record_cid = match tracking_store.put(&record_bytes).await { 370 Ok(c) => c, 371 Err(_) => return ( 372 StatusCode::INTERNAL_SERVER_ERROR, 373 Json( 374 json!({"error": "InternalError", "message": "Failed to store record"}), 375 ), 376 ) 377 .into_response(), 378 }; 379 let collection_nsid = match collection.parse::<Nsid>() { 380 Ok(n) => n, 381 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(), 382 }; 383 let key = format!("{}/{}", collection_nsid, rkey); 384 modified_keys.push(key.clone()); 385 let prev_record_cid = mst.get(&key).await.ok().flatten(); 386 mst = match mst.update(&key, record_cid).await { 387 Ok(m) => m, 388 Err(_) => return ( 389 StatusCode::INTERNAL_SERVER_ERROR, 390 Json(json!({"error": "InternalError", "message": "Failed to update MST"})), 391 ) 392 .into_response(), 393 }; 394 let uri = format!("at://{}/{}/{}", did, collection, rkey); 395 results.push(WriteResult::UpdateResult { 396 uri, 397 cid: record_cid.to_string(), 398 }); 399 ops.push(RecordOp::Update { 400 collection: collection.clone(), 401 rkey: rkey.clone(), 402 cid: record_cid, 403 prev: prev_record_cid, 404 }); 405 } 406 WriteOp::Delete { collection, rkey } => { 407 let collection_nsid = match collection.parse::<Nsid>() { 408 Ok(n) => n, 409 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(), 410 }; 411 let key = format!("{}/{}", collection_nsid, rkey); 412 modified_keys.push(key.clone()); 413 let prev_record_cid = mst.get(&key).await.ok().flatten(); 414 mst = match mst.delete(&key).await { 415 Ok(m) => m, 416 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to delete from MST"}))).into_response(), 417 }; 418 results.push(WriteResult::DeleteResult {}); 419 ops.push(RecordOp::Delete { 420 collection: collection.clone(), 421 rkey: rkey.clone(), 422 prev: prev_record_cid, 423 }); 424 } 425 } 426 } 427 let new_mst_root = match mst.persist().await { 428 Ok(c) => c, 429 Err(_) => { 430 return ( 431 StatusCode::INTERNAL_SERVER_ERROR, 432 Json(json!({"error": "InternalError", "message": "Failed to persist MST"})), 433 ) 434 .into_response(); 435 } 436 }; 437 let mut relevant_blocks = std::collections::BTreeMap::new(); 438 for key in &modified_keys { 439 if mst 440 .blocks_for_path(key, &mut relevant_blocks) 441 .await 442 .is_err() 443 { 444 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 445 } 446 if original_mst 447 .blocks_for_path(key, &mut relevant_blocks) 448 .await 449 .is_err() 450 { 451 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 452 } 453 } 454 let mut written_cids = tracking_store.get_all_relevant_cids(); 455 for cid in relevant_blocks.keys() { 456 if !written_cids.contains(cid) { 457 written_cids.push(*cid); 458 } 459 } 460 let written_cids_str = written_cids 461 .iter() 462 .map(|c| c.to_string()) 463 .collect::<Vec<_>>(); 464 let commit_res = match commit_and_log( 465 &state, 466 CommitParams { 467 did: &did, 468 user_id, 469 current_root_cid: Some(current_root_cid), 470 prev_data_cid: Some(commit.data), 471 new_mst_root, 472 ops, 473 blocks_cids: &written_cids_str, 474 }, 475 ) 476 .await 477 { 478 Ok(res) => res, 479 Err(e) => { 480 error!("Commit failed: {}", e); 481 return ( 482 StatusCode::INTERNAL_SERVER_ERROR, 483 Json(json!({"error": "InternalError", "message": "Failed to commit changes"})), 484 ) 485 .into_response(); 486 } 487 }; 488 ( 489 StatusCode::OK, 490 Json(ApplyWritesOutput { 491 commit: CommitInfo { 492 cid: commit_res.commit_cid.to_string(), 493 rev: commit_res.rev, 494 }, 495 results, 496 }), 497 ) 498 .into_response() 499}