this repo has no description
1use super::validation::validate_record_with_status; 2use super::write::has_verified_comms_channel; 3use crate::api::error::ApiError; 4use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_blob_cids}; 5use crate::auth::BearerAuth; 6use crate::delegation::{self, DelegationActionType}; 7use crate::repo::tracking::TrackingBlockStore; 8use crate::state::AppState; 9use crate::types::{AtIdentifier, AtUri, Nsid, Rkey}; 10use axum::{ 11 Json, 12 extract::State, 13 http::StatusCode, 14 response::{IntoResponse, Response}, 15}; 16use cid::Cid; 17use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; 18use serde::{Deserialize, Serialize}; 19use serde_json::json; 20use std::str::FromStr; 21use std::sync::Arc; 22use tracing::{error, info}; 23 24const MAX_BATCH_WRITES: usize = 200; 25 26#[derive(Deserialize)] 27#[serde(tag = "$type")] 28pub enum WriteOp { 29 #[serde(rename = "com.atproto.repo.applyWrites#create")] 30 Create { 31 collection: Nsid, 32 rkey: Option<Rkey>, 33 value: serde_json::Value, 34 }, 35 #[serde(rename = "com.atproto.repo.applyWrites#update")] 36 Update { 37 collection: Nsid, 38 rkey: Rkey, 39 value: serde_json::Value, 40 }, 41 #[serde(rename = "com.atproto.repo.applyWrites#delete")] 42 Delete { collection: Nsid, rkey: Rkey }, 43} 44 45#[derive(Deserialize)] 46#[serde(rename_all = "camelCase")] 47pub struct ApplyWritesInput { 48 pub repo: AtIdentifier, 49 pub validate: Option<bool>, 50 pub writes: Vec<WriteOp>, 51 pub swap_commit: Option<String>, 52} 53 54#[derive(Serialize)] 55#[serde(tag = "$type")] 56pub enum WriteResult { 57 #[serde(rename = "com.atproto.repo.applyWrites#createResult")] 58 CreateResult { 59 uri: AtUri, 60 cid: String, 61 #[serde(rename = "validationStatus", skip_serializing_if = "Option::is_none")] 62 validation_status: Option<String>, 63 }, 64 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")] 65 UpdateResult { 66 uri: AtUri, 67 cid: String, 68 #[serde(rename = "validationStatus", skip_serializing_if = "Option::is_none")] 69 validation_status: Option<String>, 70 }, 71 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")] 72 DeleteResult {}, 73} 74 75#[derive(Serialize)] 76pub struct ApplyWritesOutput { 77 pub commit: CommitInfo, 78 pub results: Vec<WriteResult>, 79} 80 81#[derive(Serialize)] 82pub struct CommitInfo { 83 pub cid: String, 84 pub rev: String, 85} 86 87pub async fn apply_writes( 88 State(state): State<AppState>, 89 auth: BearerAuth, 90 Json(input): Json<ApplyWritesInput>, 91) -> Response { 92 info!( 93 "apply_writes called: repo={}, writes={}", 94 input.repo, 95 input.writes.len() 96 ); 97 let auth_user = auth.0; 98 let did = auth_user.did.clone(); 99 let is_oauth = auth_user.is_oauth; 100 let scope = auth_user.scope; 101 let controller_did = auth_user.controller_did.clone(); 102 if input.repo.as_str() != did { 103 return ApiError::InvalidRepo("Repo does not match authenticated user".into()) 104 .into_response(); 105 } 106 if crate::util::is_account_migrated(&state.db, &did) 107 .await 108 .unwrap_or(false) 109 { 110 return ApiError::AccountMigrated.into_response(); 111 } 112 let is_verified = has_verified_comms_channel(&state.db, &did) 113 .await 114 .unwrap_or(false); 115 let is_delegated = crate::delegation::is_delegated_account(&state.db, &did) 116 .await 117 .unwrap_or(false); 118 if !is_verified && !is_delegated { 119 return ApiError::AccountNotVerified.into_response(); 120 } 121 if input.writes.is_empty() { 122 return ApiError::InvalidRequest("writes array is empty".into()).into_response(); 123 } 124 if input.writes.len() > MAX_BATCH_WRITES { 125 return ApiError::InvalidRequest(format!("Too many writes (max {})", MAX_BATCH_WRITES)) 126 .into_response(); 127 } 128 129 let has_custom_scope = scope 130 .as_ref() 131 .map(|s| s != "com.atproto.access") 132 .unwrap_or(false); 133 if is_oauth || has_custom_scope { 134 use std::collections::HashSet; 135 let create_collections: HashSet<&Nsid> = input 136 .writes 137 .iter() 138 .filter_map(|w| { 139 if let WriteOp::Create { collection, .. } = w { 140 Some(collection) 141 } else { 142 None 143 } 144 }) 145 .collect(); 146 let update_collections: HashSet<&Nsid> = input 147 .writes 148 .iter() 149 .filter_map(|w| { 150 if let WriteOp::Update { collection, .. } = w { 151 Some(collection) 152 } else { 153 None 154 } 155 }) 156 .collect(); 157 let delete_collections: HashSet<&Nsid> = input 158 .writes 159 .iter() 160 .filter_map(|w| { 161 if let WriteOp::Delete { collection, .. } = w { 162 Some(collection) 163 } else { 164 None 165 } 166 }) 167 .collect(); 168 169 for collection in create_collections { 170 if let Err(e) = crate::auth::scope_check::check_repo_scope( 171 is_oauth, 172 scope.as_deref(), 173 crate::oauth::RepoAction::Create, 174 collection, 175 ) { 176 return e; 177 } 178 } 179 for collection in update_collections { 180 if let Err(e) = crate::auth::scope_check::check_repo_scope( 181 is_oauth, 182 scope.as_deref(), 183 crate::oauth::RepoAction::Update, 184 collection, 185 ) { 186 return e; 187 } 188 } 189 for collection in delete_collections { 190 if let Err(e) = crate::auth::scope_check::check_repo_scope( 191 is_oauth, 192 scope.as_deref(), 193 crate::oauth::RepoAction::Delete, 194 collection, 195 ) { 196 return e; 197 } 198 } 199 } 200 201 let user_id: uuid::Uuid = 202 match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did.as_str()) 203 .fetch_optional(&state.db) 204 .await 205 { 206 Ok(Some(id)) => id, 207 _ => return ApiError::InternalError(Some("User not found".into())).into_response(), 208 }; 209 let root_cid_str: String = match sqlx::query_scalar!( 210 "SELECT repo_root_cid FROM repos WHERE user_id = $1", 211 user_id 212 ) 213 .fetch_optional(&state.db) 214 .await 215 { 216 Ok(Some(cid_str)) => cid_str, 217 _ => return ApiError::InternalError(Some("Repo root not found".into())).into_response(), 218 }; 219 let current_root_cid = match Cid::from_str(&root_cid_str) { 220 Ok(c) => c, 221 Err(_) => { 222 return ApiError::InternalError(Some("Invalid repo root CID".into())).into_response(); 223 } 224 }; 225 if let Some(swap_commit) = &input.swap_commit 226 && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 227 { 228 return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 229 } 230 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 231 let commit_bytes = match tracking_store.get(&current_root_cid).await { 232 Ok(Some(b)) => b, 233 _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(), 234 }; 235 let commit = match Commit::from_cbor(&commit_bytes) { 236 Ok(c) => c, 237 _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 238 }; 239 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 240 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 241 let mut results: Vec<WriteResult> = Vec::new(); 242 let mut ops: Vec<RecordOp> = Vec::new(); 243 let mut modified_keys: Vec<String> = Vec::new(); 244 let mut all_blob_cids: Vec<String> = Vec::new(); 245 for write in &input.writes { 246 match write { 247 WriteOp::Create { 248 collection, 249 rkey, 250 value, 251 } => { 252 let validation_status = if input.validate == Some(false) { 253 None 254 } else { 255 let require_lexicon = input.validate == Some(true); 256 match validate_record_with_status( 257 value, 258 collection, 259 rkey.as_ref().map(|r| r.as_str()), 260 require_lexicon, 261 ) { 262 Ok(status) => Some(status), 263 Err(err_response) => return *err_response, 264 } 265 }; 266 all_blob_cids.extend(extract_blob_cids(value)); 267 let rkey = rkey.clone().unwrap_or_else(Rkey::generate); 268 let record_ipld = crate::util::json_to_ipld(value); 269 let mut record_bytes = Vec::new(); 270 if serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld).is_err() { 271 return ApiError::InvalidRecord("Failed to serialize record".into()) 272 .into_response(); 273 } 274 let record_cid = match tracking_store.put(&record_bytes).await { 275 Ok(c) => c, 276 Err(_) => { 277 return ApiError::InternalError(Some("Failed to store record".into())) 278 .into_response(); 279 } 280 }; 281 let key = format!("{}/{}", collection, rkey); 282 modified_keys.push(key.clone()); 283 mst = match mst.add(&key, record_cid).await { 284 Ok(m) => m, 285 Err(_) => { 286 return ApiError::InternalError(Some("Failed to add to MST".into())) 287 .into_response(); 288 } 289 }; 290 let uri = AtUri::from_parts(&did, collection, &rkey); 291 results.push(WriteResult::CreateResult { 292 uri, 293 cid: record_cid.to_string(), 294 validation_status: validation_status.map(|s| s.to_string()), 295 }); 296 ops.push(RecordOp::Create { 297 collection: collection.to_string(), 298 rkey: rkey.to_string(), 299 cid: record_cid, 300 }); 301 } 302 WriteOp::Update { 303 collection, 304 rkey, 305 value, 306 } => { 307 let validation_status = if input.validate == Some(false) { 308 None 309 } else { 310 let require_lexicon = input.validate == Some(true); 311 match validate_record_with_status( 312 value, 313 collection, 314 Some(rkey.as_str()), 315 require_lexicon, 316 ) { 317 Ok(status) => Some(status), 318 Err(err_response) => return *err_response, 319 } 320 }; 321 all_blob_cids.extend(extract_blob_cids(value)); 322 let record_ipld = crate::util::json_to_ipld(value); 323 let mut record_bytes = Vec::new(); 324 if serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld).is_err() { 325 return ApiError::InvalidRecord("Failed to serialize record".into()) 326 .into_response(); 327 } 328 let record_cid = match tracking_store.put(&record_bytes).await { 329 Ok(c) => c, 330 Err(_) => { 331 return ApiError::InternalError(Some("Failed to store record".into())) 332 .into_response(); 333 } 334 }; 335 let key = format!("{}/{}", collection, rkey); 336 modified_keys.push(key.clone()); 337 let prev_record_cid = mst.get(&key).await.ok().flatten(); 338 mst = match mst.update(&key, record_cid).await { 339 Ok(m) => m, 340 Err(_) => { 341 return ApiError::InternalError(Some("Failed to update MST".into())) 342 .into_response(); 343 } 344 }; 345 let uri = AtUri::from_parts(&did, collection, rkey); 346 results.push(WriteResult::UpdateResult { 347 uri, 348 cid: record_cid.to_string(), 349 validation_status: validation_status.map(|s| s.to_string()), 350 }); 351 ops.push(RecordOp::Update { 352 collection: collection.to_string(), 353 rkey: rkey.to_string(), 354 cid: record_cid, 355 prev: prev_record_cid, 356 }); 357 } 358 WriteOp::Delete { collection, rkey } => { 359 let key = format!("{}/{}", collection, rkey); 360 modified_keys.push(key.clone()); 361 let prev_record_cid = mst.get(&key).await.ok().flatten(); 362 mst = match mst.delete(&key).await { 363 Ok(m) => m, 364 Err(_) => { 365 return ApiError::InternalError(Some("Failed to delete from MST".into())) 366 .into_response(); 367 } 368 }; 369 results.push(WriteResult::DeleteResult {}); 370 ops.push(RecordOp::Delete { 371 collection: collection.to_string(), 372 rkey: rkey.to_string(), 373 prev: prev_record_cid, 374 }); 375 } 376 } 377 } 378 let new_mst_root = match mst.persist().await { 379 Ok(c) => c, 380 Err(_) => { 381 return ApiError::InternalError(Some("Failed to persist MST".into())).into_response(); 382 } 383 }; 384 let mut new_mst_blocks = std::collections::BTreeMap::new(); 385 let mut old_mst_blocks = std::collections::BTreeMap::new(); 386 for key in &modified_keys { 387 if mst.blocks_for_path(key, &mut new_mst_blocks).await.is_err() { 388 return ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 389 .into_response(); 390 } 391 if original_mst 392 .blocks_for_path(key, &mut old_mst_blocks) 393 .await 394 .is_err() 395 { 396 return ApiError::InternalError(Some("Failed to get old MST blocks for path".into())) 397 .into_response(); 398 } 399 } 400 let mut relevant_blocks = new_mst_blocks.clone(); 401 relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone()))); 402 let written_cids: Vec<Cid> = tracking_store 403 .get_all_relevant_cids() 404 .into_iter() 405 .chain(relevant_blocks.keys().copied()) 406 .collect::<std::collections::HashSet<_>>() 407 .into_iter() 408 .collect(); 409 let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect(); 410 let prev_record_cids = ops.iter().filter_map(|op| match op { 411 RecordOp::Update { 412 prev: Some(cid), .. 413 } 414 | RecordOp::Delete { 415 prev: Some(cid), .. 416 } => Some(*cid), 417 _ => None, 418 }); 419 let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid) 420 .chain( 421 old_mst_blocks 422 .keys() 423 .filter(|cid| !new_mst_blocks.contains_key(*cid)) 424 .copied(), 425 ) 426 .chain(prev_record_cids) 427 .collect::<std::collections::HashSet<_>>() 428 .into_iter() 429 .collect(); 430 let commit_res = match commit_and_log( 431 &state, 432 CommitParams { 433 did: &did, 434 user_id, 435 current_root_cid: Some(current_root_cid), 436 prev_data_cid: Some(commit.data), 437 new_mst_root, 438 ops, 439 blocks_cids: &written_cids_str, 440 blobs: &all_blob_cids, 441 obsolete_cids, 442 }, 443 ) 444 .await 445 { 446 Ok(res) => res, 447 Err(e) => { 448 error!("Commit failed: {}", e); 449 return ApiError::InternalError(Some("Failed to commit changes".into())) 450 .into_response(); 451 } 452 }; 453 454 if let Some(ref controller) = controller_did { 455 let write_summary: Vec<serde_json::Value> = input 456 .writes 457 .iter() 458 .map(|w| match w { 459 WriteOp::Create { 460 collection, rkey, .. 461 } => json!({ 462 "action": "create", 463 "collection": collection, 464 "rkey": rkey 465 }), 466 WriteOp::Update { 467 collection, rkey, .. 468 } => json!({ 469 "action": "update", 470 "collection": collection, 471 "rkey": rkey 472 }), 473 WriteOp::Delete { collection, rkey } => json!({ 474 "action": "delete", 475 "collection": collection, 476 "rkey": rkey 477 }), 478 }) 479 .collect(); 480 481 let _ = delegation::log_delegation_action( 482 &state.db, 483 &did, 484 controller, 485 Some(controller), 486 DelegationActionType::RepoWrite, 487 Some(json!({ 488 "action": "apply_writes", 489 "count": input.writes.len(), 490 "writes": write_summary 491 })), 492 None, 493 None, 494 ) 495 .await; 496 } 497 498 ( 499 StatusCode::OK, 500 Json(ApplyWritesOutput { 501 commit: CommitInfo { 502 cid: commit_res.commit_cid.to_string(), 503 rev: commit_res.rev, 504 }, 505 results, 506 }), 507 ) 508 .into_response() 509}