this repo has no description
at main 19 kB view raw
1use super::validation::validate_record_with_status; 2use super::write::has_verified_comms_channel; 3use crate::api::error::ApiError; 4use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_blob_cids}; 5use crate::auth::BearerAuth; 6use crate::delegation::{self, DelegationActionType}; 7use crate::repo::tracking::TrackingBlockStore; 8use crate::state::AppState; 9use crate::types::{AtIdentifier, AtUri, Did, Nsid, Rkey}; 10use axum::{ 11 Json, 12 extract::State, 13 http::StatusCode, 14 response::{IntoResponse, Response}, 15}; 16use cid::Cid; 17use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; 18use serde::{Deserialize, Serialize}; 19use serde_json::json; 20use std::str::FromStr; 21use std::sync::Arc; 22use tracing::{error, info}; 23 24const MAX_BATCH_WRITES: usize = 200; 25 26struct WriteAccumulator { 27 mst: Mst<TrackingBlockStore>, 28 results: Vec<WriteResult>, 29 ops: Vec<RecordOp>, 30 modified_keys: Vec<String>, 31 all_blob_cids: Vec<String>, 32} 33 34async fn process_single_write( 35 write: &WriteOp, 36 acc: WriteAccumulator, 37 did: &Did, 38 validate: Option<bool>, 39 tracking_store: &TrackingBlockStore, 40) -> Result<WriteAccumulator, Response> { 41 let WriteAccumulator { 42 mst, 43 mut results, 44 mut ops, 45 mut modified_keys, 46 mut all_blob_cids, 47 } = acc; 48 49 match write { 50 WriteOp::Create { 51 collection, 52 rkey, 53 value, 54 } => { 55 let validation_status = match validate { 56 Some(false) => None, 57 _ => { 58 let require_lexicon = validate == Some(true); 59 match validate_record_with_status( 60 value, 61 collection, 62 rkey.as_ref(), 63 require_lexicon, 64 ) { 65 Ok(status) => Some(status), 66 Err(err_response) => return Err(*err_response), 67 } 68 } 69 }; 70 all_blob_cids.extend(extract_blob_cids(value)); 71 let rkey = rkey.clone().unwrap_or_else(Rkey::generate); 72 let record_ipld = crate::util::json_to_ipld(value); 73 let record_bytes = serde_ipld_dagcbor::to_vec(&record_ipld).map_err(|_| { 74 ApiError::InvalidRecord("Failed to serialize record".into()).into_response() 75 })?; 76 let record_cid = tracking_store.put(&record_bytes).await.map_err(|_| { 77 ApiError::InternalError(Some("Failed to store record".into())).into_response() 78 })?; 79 let key = format!("{}/{}", collection, rkey); 80 modified_keys.push(key.clone()); 81 let new_mst = mst.add(&key, record_cid).await.map_err(|_| { 82 ApiError::InternalError(Some("Failed to add to MST".into())).into_response() 83 })?; 84 let uri = AtUri::from_parts(did, collection, &rkey); 85 results.push(WriteResult::CreateResult { 86 uri, 87 cid: record_cid.to_string(), 88 validation_status: validation_status.map(|s| s.to_string()), 89 }); 90 ops.push(RecordOp::Create { 91 collection: collection.clone(), 92 rkey: rkey.clone(), 93 cid: record_cid, 94 }); 95 Ok(WriteAccumulator { 96 mst: new_mst, 97 results, 98 ops, 99 modified_keys, 100 all_blob_cids, 101 }) 102 } 103 WriteOp::Update { 104 collection, 105 rkey, 106 value, 107 } => { 108 let validation_status = match validate { 109 Some(false) => None, 110 _ => { 111 let require_lexicon = validate == Some(true); 112 match validate_record_with_status( 113 value, 114 collection, 115 Some(rkey), 116 require_lexicon, 117 ) { 118 Ok(status) => Some(status), 119 Err(err_response) => return Err(*err_response), 120 } 121 } 122 }; 123 all_blob_cids.extend(extract_blob_cids(value)); 124 let record_ipld = crate::util::json_to_ipld(value); 125 let record_bytes = serde_ipld_dagcbor::to_vec(&record_ipld).map_err(|_| { 126 ApiError::InvalidRecord("Failed to serialize record".into()).into_response() 127 })?; 128 let record_cid = tracking_store.put(&record_bytes).await.map_err(|_| { 129 ApiError::InternalError(Some("Failed to store record".into())).into_response() 130 })?; 131 let key = format!("{}/{}", collection, rkey); 132 modified_keys.push(key.clone()); 133 let prev_record_cid = mst.get(&key).await.ok().flatten(); 134 let new_mst = mst.update(&key, record_cid).await.map_err(|_| { 135 ApiError::InternalError(Some("Failed to update MST".into())).into_response() 136 })?; 137 let uri = AtUri::from_parts(did, collection, rkey); 138 results.push(WriteResult::UpdateResult { 139 uri, 140 cid: record_cid.to_string(), 141 validation_status: validation_status.map(|s| s.to_string()), 142 }); 143 ops.push(RecordOp::Update { 144 collection: collection.clone(), 145 rkey: rkey.clone(), 146 cid: record_cid, 147 prev: prev_record_cid, 148 }); 149 Ok(WriteAccumulator { 150 mst: new_mst, 151 results, 152 ops, 153 modified_keys, 154 all_blob_cids, 155 }) 156 } 157 WriteOp::Delete { collection, rkey } => { 158 let key = format!("{}/{}", collection, rkey); 159 modified_keys.push(key.clone()); 160 let prev_record_cid = mst.get(&key).await.ok().flatten(); 161 let new_mst = mst.delete(&key).await.map_err(|_| { 162 ApiError::InternalError(Some("Failed to delete from MST".into())).into_response() 163 })?; 164 results.push(WriteResult::DeleteResult {}); 165 ops.push(RecordOp::Delete { 166 collection: collection.clone(), 167 rkey: rkey.clone(), 168 prev: prev_record_cid, 169 }); 170 Ok(WriteAccumulator { 171 mst: new_mst, 172 results, 173 ops, 174 modified_keys, 175 all_blob_cids, 176 }) 177 } 178 } 179} 180 181async fn process_writes( 182 writes: &[WriteOp], 183 initial_mst: Mst<TrackingBlockStore>, 184 did: &Did, 185 validate: Option<bool>, 186 tracking_store: &TrackingBlockStore, 187) -> Result<WriteAccumulator, Response> { 188 use futures::stream::{self, TryStreamExt}; 189 let initial_acc = WriteAccumulator { 190 mst: initial_mst, 191 results: Vec::new(), 192 ops: Vec::new(), 193 modified_keys: Vec::new(), 194 all_blob_cids: Vec::new(), 195 }; 196 stream::iter(writes.iter().map(Ok::<_, Response>)) 197 .try_fold(initial_acc, |acc, write| async move { 198 process_single_write(write, acc, did, validate, tracking_store).await 199 }) 200 .await 201} 202 203#[derive(Deserialize)] 204#[serde(tag = "$type")] 205pub enum WriteOp { 206 #[serde(rename = "com.atproto.repo.applyWrites#create")] 207 Create { 208 collection: Nsid, 209 rkey: Option<Rkey>, 210 value: serde_json::Value, 211 }, 212 #[serde(rename = "com.atproto.repo.applyWrites#update")] 213 Update { 214 collection: Nsid, 215 rkey: Rkey, 216 value: serde_json::Value, 217 }, 218 #[serde(rename = "com.atproto.repo.applyWrites#delete")] 219 Delete { collection: Nsid, rkey: Rkey }, 220} 221 222#[derive(Deserialize)] 223#[serde(rename_all = "camelCase")] 224pub struct ApplyWritesInput { 225 pub repo: AtIdentifier, 226 pub validate: Option<bool>, 227 pub writes: Vec<WriteOp>, 228 pub swap_commit: Option<String>, 229} 230 231#[derive(Serialize)] 232#[serde(tag = "$type")] 233pub enum WriteResult { 234 #[serde(rename = "com.atproto.repo.applyWrites#createResult")] 235 CreateResult { 236 uri: AtUri, 237 cid: String, 238 #[serde(rename = "validationStatus", skip_serializing_if = "Option::is_none")] 239 validation_status: Option<String>, 240 }, 241 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")] 242 UpdateResult { 243 uri: AtUri, 244 cid: String, 245 #[serde(rename = "validationStatus", skip_serializing_if = "Option::is_none")] 246 validation_status: Option<String>, 247 }, 248 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")] 249 DeleteResult {}, 250} 251 252#[derive(Serialize)] 253pub struct ApplyWritesOutput { 254 pub commit: CommitInfo, 255 pub results: Vec<WriteResult>, 256} 257 258#[derive(Serialize)] 259pub struct CommitInfo { 260 pub cid: String, 261 pub rev: String, 262} 263 264pub async fn apply_writes( 265 State(state): State<AppState>, 266 auth: BearerAuth, 267 Json(input): Json<ApplyWritesInput>, 268) -> Response { 269 info!( 270 "apply_writes called: repo={}, writes={}", 271 input.repo, 272 input.writes.len() 273 ); 274 let auth_user = auth.0; 275 let did = auth_user.did.clone(); 276 let is_oauth = auth_user.is_oauth; 277 let scope = auth_user.scope; 278 let controller_did = auth_user.controller_did.clone(); 279 if input.repo.as_str() != did { 280 return ApiError::InvalidRepo("Repo does not match authenticated user".into()) 281 .into_response(); 282 } 283 if crate::util::is_account_migrated(&state.db, &did) 284 .await 285 .unwrap_or(false) 286 { 287 return ApiError::AccountMigrated.into_response(); 288 } 289 let is_verified = has_verified_comms_channel(&state.db, &did) 290 .await 291 .unwrap_or(false); 292 let is_delegated = crate::delegation::is_delegated_account(&state.db, &did) 293 .await 294 .unwrap_or(false); 295 if !is_verified && !is_delegated { 296 return ApiError::AccountNotVerified.into_response(); 297 } 298 if input.writes.is_empty() { 299 return ApiError::InvalidRequest("writes array is empty".into()).into_response(); 300 } 301 if input.writes.len() > MAX_BATCH_WRITES { 302 return ApiError::InvalidRequest(format!("Too many writes (max {})", MAX_BATCH_WRITES)) 303 .into_response(); 304 } 305 306 let has_custom_scope = scope 307 .as_ref() 308 .map(|s| s != "com.atproto.access") 309 .unwrap_or(false); 310 if is_oauth || has_custom_scope { 311 use std::collections::HashSet; 312 let create_collections: HashSet<&Nsid> = input 313 .writes 314 .iter() 315 .filter_map(|w| { 316 if let WriteOp::Create { collection, .. } = w { 317 Some(collection) 318 } else { 319 None 320 } 321 }) 322 .collect(); 323 let update_collections: HashSet<&Nsid> = input 324 .writes 325 .iter() 326 .filter_map(|w| { 327 if let WriteOp::Update { collection, .. } = w { 328 Some(collection) 329 } else { 330 None 331 } 332 }) 333 .collect(); 334 let delete_collections: HashSet<&Nsid> = input 335 .writes 336 .iter() 337 .filter_map(|w| { 338 if let WriteOp::Delete { collection, .. } = w { 339 Some(collection) 340 } else { 341 None 342 } 343 }) 344 .collect(); 345 346 let scope_checks = create_collections 347 .iter() 348 .map(|c| (crate::oauth::RepoAction::Create, c)) 349 .chain( 350 update_collections 351 .iter() 352 .map(|c| (crate::oauth::RepoAction::Update, c)), 353 ) 354 .chain( 355 delete_collections 356 .iter() 357 .map(|c| (crate::oauth::RepoAction::Delete, c)), 358 ); 359 360 if let Some(err) = scope_checks 361 .filter_map(|(action, collection)| { 362 crate::auth::scope_check::check_repo_scope( 363 is_oauth, 364 scope.as_deref(), 365 action, 366 collection, 367 ) 368 .err() 369 }) 370 .next() 371 { 372 return err; 373 } 374 } 375 376 let user_id: uuid::Uuid = 377 match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did.as_str()) 378 .fetch_optional(&state.db) 379 .await 380 { 381 Ok(Some(id)) => id, 382 _ => return ApiError::InternalError(Some("User not found".into())).into_response(), 383 }; 384 let root_cid_str: String = match sqlx::query_scalar!( 385 "SELECT repo_root_cid FROM repos WHERE user_id = $1", 386 user_id 387 ) 388 .fetch_optional(&state.db) 389 .await 390 { 391 Ok(Some(cid_str)) => cid_str, 392 _ => return ApiError::InternalError(Some("Repo root not found".into())).into_response(), 393 }; 394 let current_root_cid = match Cid::from_str(&root_cid_str) { 395 Ok(c) => c, 396 Err(_) => { 397 return ApiError::InternalError(Some("Invalid repo root CID".into())).into_response(); 398 } 399 }; 400 if let Some(swap_commit) = &input.swap_commit 401 && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 402 { 403 return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 404 } 405 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 406 let commit_bytes = match tracking_store.get(&current_root_cid).await { 407 Ok(Some(b)) => b, 408 _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(), 409 }; 410 let commit = match Commit::from_cbor(&commit_bytes) { 411 Ok(c) => c, 412 _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 413 }; 414 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 415 let initial_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 416 let WriteAccumulator { 417 mst, 418 results, 419 ops, 420 modified_keys, 421 all_blob_cids, 422 } = match process_writes( 423 &input.writes, 424 initial_mst, 425 &did, 426 input.validate, 427 &tracking_store, 428 ) 429 .await 430 { 431 Ok(acc) => acc, 432 Err(response) => return response, 433 }; 434 let new_mst_root = match mst.persist().await { 435 Ok(c) => c, 436 Err(_) => { 437 return ApiError::InternalError(Some("Failed to persist MST".into())).into_response(); 438 } 439 }; 440 let (new_mst_blocks, old_mst_blocks) = { 441 let mut new_blocks = std::collections::BTreeMap::new(); 442 let mut old_blocks = std::collections::BTreeMap::new(); 443 for key in &modified_keys { 444 if mst.blocks_for_path(key, &mut new_blocks).await.is_err() { 445 return ApiError::InternalError(Some( 446 "Failed to get new MST blocks for path".into(), 447 )) 448 .into_response(); 449 } 450 if original_mst 451 .blocks_for_path(key, &mut old_blocks) 452 .await 453 .is_err() 454 { 455 return ApiError::InternalError(Some( 456 "Failed to get old MST blocks for path".into(), 457 )) 458 .into_response(); 459 } 460 } 461 (new_blocks, old_blocks) 462 }; 463 let mut relevant_blocks = new_mst_blocks.clone(); 464 relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone()))); 465 let written_cids: Vec<Cid> = tracking_store 466 .get_all_relevant_cids() 467 .into_iter() 468 .chain(relevant_blocks.keys().copied()) 469 .collect::<std::collections::HashSet<_>>() 470 .into_iter() 471 .collect(); 472 let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect(); 473 let prev_record_cids = ops.iter().filter_map(|op| match op { 474 RecordOp::Update { 475 prev: Some(cid), .. 476 } 477 | RecordOp::Delete { 478 prev: Some(cid), .. 479 } => Some(*cid), 480 _ => None, 481 }); 482 let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid) 483 .chain( 484 old_mst_blocks 485 .keys() 486 .filter(|cid| !new_mst_blocks.contains_key(*cid)) 487 .copied(), 488 ) 489 .chain(prev_record_cids) 490 .collect::<std::collections::HashSet<_>>() 491 .into_iter() 492 .collect(); 493 let commit_res = match commit_and_log( 494 &state, 495 CommitParams { 496 did: &did, 497 user_id, 498 current_root_cid: Some(current_root_cid), 499 prev_data_cid: Some(commit.data), 500 new_mst_root, 501 ops, 502 blocks_cids: &written_cids_str, 503 blobs: &all_blob_cids, 504 obsolete_cids, 505 }, 506 ) 507 .await 508 { 509 Ok(res) => res, 510 Err(e) if e.contains("ConcurrentModification") => { 511 return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 512 } 513 Err(e) => { 514 error!("Commit failed: {}", e); 515 return ApiError::InternalError(Some("Failed to commit changes".into())) 516 .into_response(); 517 } 518 }; 519 520 if let Some(ref controller) = controller_did { 521 let write_summary: Vec<serde_json::Value> = input 522 .writes 523 .iter() 524 .map(|w| match w { 525 WriteOp::Create { 526 collection, rkey, .. 527 } => json!({ 528 "action": "create", 529 "collection": collection, 530 "rkey": rkey 531 }), 532 WriteOp::Update { 533 collection, rkey, .. 534 } => json!({ 535 "action": "update", 536 "collection": collection, 537 "rkey": rkey 538 }), 539 WriteOp::Delete { collection, rkey } => json!({ 540 "action": "delete", 541 "collection": collection, 542 "rkey": rkey 543 }), 544 }) 545 .collect(); 546 547 let _ = delegation::log_delegation_action( 548 &state.db, 549 &did, 550 controller, 551 Some(controller), 552 DelegationActionType::RepoWrite, 553 Some(json!({ 554 "action": "apply_writes", 555 "count": input.writes.len(), 556 "writes": write_summary 557 })), 558 None, 559 None, 560 ) 561 .await; 562 } 563 564 ( 565 StatusCode::OK, 566 Json(ApplyWritesOutput { 567 commit: CommitInfo { 568 cid: commit_res.commit_cid.to_string(), 569 rev: commit_res.rev, 570 }, 571 results, 572 }), 573 ) 574 .into_response() 575}