this repo has no description
1use super::validation::validate_record_with_status; 2use super::write::has_verified_comms_channel; 3use crate::api::error::ApiError; 4use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_blob_cids}; 5use crate::auth::BearerAuth; 6use crate::delegation::{self, DelegationActionType}; 7use crate::repo::tracking::TrackingBlockStore; 8use crate::state::AppState; 9use crate::types::{AtIdentifier, AtUri, Did, Nsid, Rkey}; 10use axum::{ 11 Json, 12 extract::State, 13 http::StatusCode, 14 response::{IntoResponse, Response}, 15}; 16use cid::Cid; 17use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; 18use serde::{Deserialize, Serialize}; 19use serde_json::json; 20use std::str::FromStr; 21use std::sync::Arc; 22use tracing::{error, info}; 23 24const MAX_BATCH_WRITES: usize = 200; 25 26struct WriteAccumulator { 27 mst: Mst<TrackingBlockStore>, 28 results: Vec<WriteResult>, 29 ops: Vec<RecordOp>, 30 modified_keys: Vec<String>, 31 all_blob_cids: Vec<String>, 32} 33 34async fn process_single_write( 35 write: &WriteOp, 36 acc: WriteAccumulator, 37 did: &Did, 38 validate: Option<bool>, 39 tracking_store: &TrackingBlockStore, 40) -> Result<WriteAccumulator, Response> { 41 let WriteAccumulator { 42 mst, 43 mut results, 44 mut ops, 45 mut modified_keys, 46 mut all_blob_cids, 47 } = acc; 48 49 match write { 50 WriteOp::Create { 51 collection, 52 rkey, 53 value, 54 } => { 55 let validation_status = match validate { 56 Some(false) => None, 57 _ => { 58 let require_lexicon = validate == Some(true); 59 match validate_record_with_status( 60 value, 61 collection, 62 rkey.as_ref(), 63 require_lexicon, 64 ) { 65 Ok(status) => Some(status), 66 Err(err_response) => return Err(*err_response), 67 } 68 } 69 }; 70 all_blob_cids.extend(extract_blob_cids(value)); 71 let rkey = rkey.clone().unwrap_or_else(Rkey::generate); 72 let record_ipld = crate::util::json_to_ipld(value); 73 let record_bytes = serde_ipld_dagcbor::to_vec(&record_ipld).map_err(|_| { 74 ApiError::InvalidRecord("Failed to serialize record".into()).into_response() 75 })?; 76 let record_cid = tracking_store.put(&record_bytes).await.map_err(|_| { 77 ApiError::InternalError(Some("Failed to store record".into())).into_response() 78 })?; 79 let key = format!("{}/{}", collection, rkey); 80 modified_keys.push(key.clone()); 81 let new_mst = mst.add(&key, record_cid).await.map_err(|_| { 82 ApiError::InternalError(Some("Failed to add to MST".into())).into_response() 83 })?; 84 let uri = AtUri::from_parts(did, collection, &rkey); 85 results.push(WriteResult::CreateResult { 86 uri, 87 cid: record_cid.to_string(), 88 validation_status: validation_status.map(|s| s.to_string()), 89 }); 90 ops.push(RecordOp::Create { 91 collection: collection.clone(), 92 rkey: rkey.clone(), 93 cid: record_cid, 94 }); 95 Ok(WriteAccumulator { 96 mst: new_mst, 97 results, 98 ops, 99 modified_keys, 100 all_blob_cids, 101 }) 102 } 103 WriteOp::Update { 104 collection, 105 rkey, 106 value, 107 } => { 108 let validation_status = match validate { 109 Some(false) => None, 110 _ => { 111 let require_lexicon = validate == Some(true); 112 match validate_record_with_status( 113 value, 114 collection, 115 Some(rkey), 116 require_lexicon, 117 ) { 118 Ok(status) => Some(status), 119 Err(err_response) => return Err(*err_response), 120 } 121 } 122 }; 123 all_blob_cids.extend(extract_blob_cids(value)); 124 let record_ipld = crate::util::json_to_ipld(value); 125 let record_bytes = serde_ipld_dagcbor::to_vec(&record_ipld).map_err(|_| { 126 ApiError::InvalidRecord("Failed to serialize record".into()).into_response() 127 })?; 128 let record_cid = tracking_store.put(&record_bytes).await.map_err(|_| { 129 ApiError::InternalError(Some("Failed to store record".into())).into_response() 130 })?; 131 let key = format!("{}/{}", collection, rkey); 132 modified_keys.push(key.clone()); 133 let prev_record_cid = mst.get(&key).await.ok().flatten(); 134 let new_mst = mst.update(&key, record_cid).await.map_err(|_| { 135 ApiError::InternalError(Some("Failed to update MST".into())).into_response() 136 })?; 137 let uri = AtUri::from_parts(did, collection, rkey); 138 results.push(WriteResult::UpdateResult { 139 uri, 140 cid: record_cid.to_string(), 141 validation_status: validation_status.map(|s| s.to_string()), 142 }); 143 ops.push(RecordOp::Update { 144 collection: collection.clone(), 145 rkey: rkey.clone(), 146 cid: record_cid, 147 prev: prev_record_cid, 148 }); 149 Ok(WriteAccumulator { 150 mst: new_mst, 151 results, 152 ops, 153 modified_keys, 154 all_blob_cids, 155 }) 156 } 157 WriteOp::Delete { collection, rkey } => { 158 let key = format!("{}/{}", collection, rkey); 159 modified_keys.push(key.clone()); 160 let prev_record_cid = mst.get(&key).await.ok().flatten(); 161 let new_mst = mst.delete(&key).await.map_err(|_| { 162 ApiError::InternalError(Some("Failed to delete from MST".into())).into_response() 163 })?; 164 results.push(WriteResult::DeleteResult {}); 165 ops.push(RecordOp::Delete { 166 collection: collection.clone(), 167 rkey: rkey.clone(), 168 prev: prev_record_cid, 169 }); 170 Ok(WriteAccumulator { 171 mst: new_mst, 172 results, 173 ops, 174 modified_keys, 175 all_blob_cids, 176 }) 177 } 178 } 179} 180 181async fn process_writes( 182 writes: &[WriteOp], 183 initial_mst: Mst<TrackingBlockStore>, 184 did: &Did, 185 validate: Option<bool>, 186 tracking_store: &TrackingBlockStore, 187) -> Result<WriteAccumulator, Response> { 188 use futures::stream::{self, TryStreamExt}; 189 let initial_acc = WriteAccumulator { 190 mst: initial_mst, 191 results: Vec::new(), 192 ops: Vec::new(), 193 modified_keys: Vec::new(), 194 all_blob_cids: Vec::new(), 195 }; 196 stream::iter(writes.iter().map(Ok::<_, Response>)) 197 .try_fold(initial_acc, |acc, write| async move { 198 process_single_write(write, acc, did, validate, tracking_store).await 199 }) 200 .await 201} 202 203#[derive(Deserialize)] 204#[serde(tag = "$type")] 205pub enum WriteOp { 206 #[serde(rename = "com.atproto.repo.applyWrites#create")] 207 Create { 208 collection: Nsid, 209 rkey: Option<Rkey>, 210 value: serde_json::Value, 211 }, 212 #[serde(rename = "com.atproto.repo.applyWrites#update")] 213 Update { 214 collection: Nsid, 215 rkey: Rkey, 216 value: serde_json::Value, 217 }, 218 #[serde(rename = "com.atproto.repo.applyWrites#delete")] 219 Delete { collection: Nsid, rkey: Rkey }, 220} 221 222#[derive(Deserialize)] 223#[serde(rename_all = "camelCase")] 224pub struct ApplyWritesInput { 225 pub repo: AtIdentifier, 226 pub validate: Option<bool>, 227 pub writes: Vec<WriteOp>, 228 pub swap_commit: Option<String>, 229} 230 231#[derive(Serialize)] 232#[serde(tag = "$type")] 233pub enum WriteResult { 234 #[serde(rename = "com.atproto.repo.applyWrites#createResult")] 235 CreateResult { 236 uri: AtUri, 237 cid: String, 238 #[serde(rename = "validationStatus", skip_serializing_if = "Option::is_none")] 239 validation_status: Option<String>, 240 }, 241 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")] 242 UpdateResult { 243 uri: AtUri, 244 cid: String, 245 #[serde(rename = "validationStatus", skip_serializing_if = "Option::is_none")] 246 validation_status: Option<String>, 247 }, 248 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")] 249 DeleteResult {}, 250} 251 252#[derive(Serialize)] 253pub struct ApplyWritesOutput { 254 pub commit: CommitInfo, 255 pub results: Vec<WriteResult>, 256} 257 258#[derive(Serialize)] 259pub struct CommitInfo { 260 pub cid: String, 261 pub rev: String, 262} 263 264pub async fn apply_writes( 265 State(state): State<AppState>, 266 auth: BearerAuth, 267 Json(input): Json<ApplyWritesInput>, 268) -> Response { 269 info!( 270 "apply_writes called: repo={}, writes={}", 271 input.repo, 272 input.writes.len() 273 ); 274 let auth_user = auth.0; 275 let did = auth_user.did.clone(); 276 let is_oauth = auth_user.is_oauth; 277 let scope = auth_user.scope; 278 let controller_did = auth_user.controller_did.clone(); 279 if input.repo.as_str() != did { 280 return ApiError::InvalidRepo("Repo does not match authenticated user".into()) 281 .into_response(); 282 } 283 if crate::util::is_account_migrated(&state.db, &did) 284 .await 285 .unwrap_or(false) 286 { 287 return ApiError::AccountMigrated.into_response(); 288 } 289 let is_verified = has_verified_comms_channel(&state.db, &did) 290 .await 291 .unwrap_or(false); 292 let is_delegated = crate::delegation::is_delegated_account(&state.db, &did) 293 .await 294 .unwrap_or(false); 295 if !is_verified && !is_delegated { 296 return ApiError::AccountNotVerified.into_response(); 297 } 298 if input.writes.is_empty() { 299 return ApiError::InvalidRequest("writes array is empty".into()).into_response(); 300 } 301 if input.writes.len() > MAX_BATCH_WRITES { 302 return ApiError::InvalidRequest(format!("Too many writes (max {})", MAX_BATCH_WRITES)) 303 .into_response(); 304 } 305 306 let has_custom_scope = scope 307 .as_ref() 308 .map(|s| s != "com.atproto.access") 309 .unwrap_or(false); 310 if is_oauth || has_custom_scope { 311 use std::collections::HashSet; 312 let create_collections: HashSet<&Nsid> = input 313 .writes 314 .iter() 315 .filter_map(|w| { 316 if let WriteOp::Create { collection, .. } = w { 317 Some(collection) 318 } else { 319 None 320 } 321 }) 322 .collect(); 323 let update_collections: HashSet<&Nsid> = input 324 .writes 325 .iter() 326 .filter_map(|w| { 327 if let WriteOp::Update { collection, .. } = w { 328 Some(collection) 329 } else { 330 None 331 } 332 }) 333 .collect(); 334 let delete_collections: HashSet<&Nsid> = input 335 .writes 336 .iter() 337 .filter_map(|w| { 338 if let WriteOp::Delete { collection, .. } = w { 339 Some(collection) 340 } else { 341 None 342 } 343 }) 344 .collect(); 345 346 for collection in create_collections { 347 if let Err(e) = crate::auth::scope_check::check_repo_scope( 348 is_oauth, 349 scope.as_deref(), 350 crate::oauth::RepoAction::Create, 351 collection, 352 ) { 353 return e; 354 } 355 } 356 for collection in update_collections { 357 if let Err(e) = crate::auth::scope_check::check_repo_scope( 358 is_oauth, 359 scope.as_deref(), 360 crate::oauth::RepoAction::Update, 361 collection, 362 ) { 363 return e; 364 } 365 } 366 for collection in delete_collections { 367 if let Err(e) = crate::auth::scope_check::check_repo_scope( 368 is_oauth, 369 scope.as_deref(), 370 crate::oauth::RepoAction::Delete, 371 collection, 372 ) { 373 return e; 374 } 375 } 376 } 377 378 let user_id: uuid::Uuid = 379 match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did.as_str()) 380 .fetch_optional(&state.db) 381 .await 382 { 383 Ok(Some(id)) => id, 384 _ => return ApiError::InternalError(Some("User not found".into())).into_response(), 385 }; 386 let root_cid_str: String = match sqlx::query_scalar!( 387 "SELECT repo_root_cid FROM repos WHERE user_id = $1", 388 user_id 389 ) 390 .fetch_optional(&state.db) 391 .await 392 { 393 Ok(Some(cid_str)) => cid_str, 394 _ => return ApiError::InternalError(Some("Repo root not found".into())).into_response(), 395 }; 396 let current_root_cid = match Cid::from_str(&root_cid_str) { 397 Ok(c) => c, 398 Err(_) => { 399 return ApiError::InternalError(Some("Invalid repo root CID".into())).into_response(); 400 } 401 }; 402 if let Some(swap_commit) = &input.swap_commit 403 && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 404 { 405 return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 406 } 407 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 408 let commit_bytes = match tracking_store.get(&current_root_cid).await { 409 Ok(Some(b)) => b, 410 _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(), 411 }; 412 let commit = match Commit::from_cbor(&commit_bytes) { 413 Ok(c) => c, 414 _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 415 }; 416 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 417 let initial_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 418 let WriteAccumulator { 419 mst, 420 results, 421 ops, 422 modified_keys, 423 all_blob_cids, 424 } = match process_writes(&input.writes, initial_mst, &did, input.validate, &tracking_store).await 425 { 426 Ok(acc) => acc, 427 Err(response) => return response, 428 }; 429 let new_mst_root = match mst.persist().await { 430 Ok(c) => c, 431 Err(_) => { 432 return ApiError::InternalError(Some("Failed to persist MST".into())).into_response(); 433 } 434 }; 435 let mut new_mst_blocks = std::collections::BTreeMap::new(); 436 let mut old_mst_blocks = std::collections::BTreeMap::new(); 437 for key in &modified_keys { 438 if mst.blocks_for_path(key, &mut new_mst_blocks).await.is_err() { 439 return ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 440 .into_response(); 441 } 442 if original_mst 443 .blocks_for_path(key, &mut old_mst_blocks) 444 .await 445 .is_err() 446 { 447 return ApiError::InternalError(Some("Failed to get old MST blocks for path".into())) 448 .into_response(); 449 } 450 } 451 let mut relevant_blocks = new_mst_blocks.clone(); 452 relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone()))); 453 let written_cids: Vec<Cid> = tracking_store 454 .get_all_relevant_cids() 455 .into_iter() 456 .chain(relevant_blocks.keys().copied()) 457 .collect::<std::collections::HashSet<_>>() 458 .into_iter() 459 .collect(); 460 let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect(); 461 let prev_record_cids = ops.iter().filter_map(|op| match op { 462 RecordOp::Update { 463 prev: Some(cid), .. 464 } 465 | RecordOp::Delete { 466 prev: Some(cid), .. 467 } => Some(*cid), 468 _ => None, 469 }); 470 let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid) 471 .chain( 472 old_mst_blocks 473 .keys() 474 .filter(|cid| !new_mst_blocks.contains_key(*cid)) 475 .copied(), 476 ) 477 .chain(prev_record_cids) 478 .collect::<std::collections::HashSet<_>>() 479 .into_iter() 480 .collect(); 481 let commit_res = match commit_and_log( 482 &state, 483 CommitParams { 484 did: &did, 485 user_id, 486 current_root_cid: Some(current_root_cid), 487 prev_data_cid: Some(commit.data), 488 new_mst_root, 489 ops, 490 blocks_cids: &written_cids_str, 491 blobs: &all_blob_cids, 492 obsolete_cids, 493 }, 494 ) 495 .await 496 { 497 Ok(res) => res, 498 Err(e) if e.contains("ConcurrentModification") => { 499 return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 500 } 501 Err(e) => { 502 error!("Commit failed: {}", e); 503 return ApiError::InternalError(Some("Failed to commit changes".into())) 504 .into_response(); 505 } 506 }; 507 508 if let Some(ref controller) = controller_did { 509 let write_summary: Vec<serde_json::Value> = input 510 .writes 511 .iter() 512 .map(|w| match w { 513 WriteOp::Create { 514 collection, rkey, .. 515 } => json!({ 516 "action": "create", 517 "collection": collection, 518 "rkey": rkey 519 }), 520 WriteOp::Update { 521 collection, rkey, .. 522 } => json!({ 523 "action": "update", 524 "collection": collection, 525 "rkey": rkey 526 }), 527 WriteOp::Delete { collection, rkey } => json!({ 528 "action": "delete", 529 "collection": collection, 530 "rkey": rkey 531 }), 532 }) 533 .collect(); 534 535 let _ = delegation::log_delegation_action( 536 &state.db, 537 &did, 538 controller, 539 Some(controller), 540 DelegationActionType::RepoWrite, 541 Some(json!({ 542 "action": "apply_writes", 543 "count": input.writes.len(), 544 "writes": write_summary 545 })), 546 None, 547 None, 548 ) 549 .await; 550 } 551 552 ( 553 StatusCode::OK, 554 Json(ApplyWritesOutput { 555 commit: CommitInfo { 556 cid: commit_res.commit_cid.to_string(), 557 rev: commit_res.rev, 558 }, 559 results, 560 }), 561 ) 562 .into_response() 563}