this repo has no description
1use super::validation::validate_record; 2use super::write::has_verified_comms_channel; 3use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log}; 4use crate::repo::tracking::TrackingBlockStore; 5use crate::state::AppState; 6use axum::{ 7 Json, 8 extract::State, 9 http::StatusCode, 10 response::{IntoResponse, Response}, 11}; 12use cid::Cid; 13use jacquard::types::{ 14 integer::LimitedU32, 15 string::{Nsid, Tid}, 16}; 17use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; 18use serde::{Deserialize, Serialize}; 19use serde_json::json; 20use std::str::FromStr; 21use std::sync::Arc; 22use tracing::error; 23 24const MAX_BATCH_WRITES: usize = 200; 25 26#[derive(Deserialize)] 27#[serde(tag = "$type")] 28pub enum WriteOp { 29 #[serde(rename = "com.atproto.repo.applyWrites#create")] 30 Create { 31 collection: String, 32 rkey: Option<String>, 33 value: serde_json::Value, 34 }, 35 #[serde(rename = "com.atproto.repo.applyWrites#update")] 36 Update { 37 collection: String, 38 rkey: String, 39 value: serde_json::Value, 40 }, 41 #[serde(rename = "com.atproto.repo.applyWrites#delete")] 42 Delete { collection: String, rkey: String }, 43} 44 45#[derive(Deserialize)] 46#[serde(rename_all = "camelCase")] 47pub struct ApplyWritesInput { 48 pub repo: String, 49 pub validate: Option<bool>, 50 pub writes: Vec<WriteOp>, 51 pub swap_commit: Option<String>, 52} 53 54#[derive(Serialize)] 55#[serde(tag = "$type")] 56pub enum WriteResult { 57 #[serde(rename = "com.atproto.repo.applyWrites#createResult")] 58 CreateResult { uri: String, cid: String }, 59 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")] 60 UpdateResult { uri: String, cid: String }, 61 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")] 62 DeleteResult {}, 63} 64 65#[derive(Serialize)] 66pub struct ApplyWritesOutput { 67 pub commit: CommitInfo, 68 pub results: Vec<WriteResult>, 69} 70 71#[derive(Serialize)] 72pub struct CommitInfo { 73 pub cid: String, 74 pub rev: String, 75} 76 77pub async fn apply_writes( 78 State(state): State<AppState>, 79 headers: axum::http::HeaderMap, 80 Json(input): Json<ApplyWritesInput>, 81) -> Response { 82 let token = match crate::auth::extract_bearer_token_from_header( 83 headers.get("Authorization").and_then(|h| h.to_str().ok()), 84 ) { 85 Some(t) => t, 86 None => { 87 return ( 88 StatusCode::UNAUTHORIZED, 89 Json(json!({"error": "AuthenticationRequired"})), 90 ) 91 .into_response(); 92 } 93 }; 94 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 95 Ok(user) => user, 96 Err(_) => { 97 return ( 98 StatusCode::UNAUTHORIZED, 99 Json(json!({"error": "AuthenticationFailed"})), 100 ) 101 .into_response(); 102 } 103 }; 104 let did = auth_user.did.clone(); 105 let is_oauth = auth_user.is_oauth; 106 let scope = auth_user.scope; 107 if input.repo != did { 108 return ( 109 StatusCode::FORBIDDEN, 110 Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"})), 111 ) 112 .into_response(); 113 } 114 match has_verified_comms_channel(&state.db, &did).await { 115 Ok(true) => {} 116 Ok(false) => { 117 return ( 118 StatusCode::FORBIDDEN, 119 Json(json!({ 120 "error": "AccountNotVerified", 121 "message": "You must verify at least one notification channel (email, Discord, Telegram, or Signal) before creating records" 122 })), 123 ) 124 .into_response(); 125 } 126 Err(e) => { 127 error!("DB error checking notification channels: {}", e); 128 return ( 129 StatusCode::INTERNAL_SERVER_ERROR, 130 Json(json!({"error": "InternalError"})), 131 ) 132 .into_response(); 133 } 134 } 135 if input.writes.is_empty() { 136 return ( 137 StatusCode::BAD_REQUEST, 138 Json(json!({"error": "InvalidRequest", "message": "writes array is empty"})), 139 ) 140 .into_response(); 141 } 142 if input.writes.len() > MAX_BATCH_WRITES { 143 return ( 144 StatusCode::BAD_REQUEST, 145 Json(json!({"error": "InvalidRequest", "message": format!("Too many writes (max {})", MAX_BATCH_WRITES)})), 146 ) 147 .into_response(); 148 } 149 150 if is_oauth { 151 use std::collections::HashSet; 152 let create_collections: HashSet<&str> = input 153 .writes 154 .iter() 155 .filter_map(|w| { 156 if let WriteOp::Create { collection, .. } = w { 157 Some(collection.as_str()) 158 } else { 159 None 160 } 161 }) 162 .collect(); 163 let update_collections: HashSet<&str> = input 164 .writes 165 .iter() 166 .filter_map(|w| { 167 if let WriteOp::Update { collection, .. } = w { 168 Some(collection.as_str()) 169 } else { 170 None 171 } 172 }) 173 .collect(); 174 let delete_collections: HashSet<&str> = input 175 .writes 176 .iter() 177 .filter_map(|w| { 178 if let WriteOp::Delete { collection, .. } = w { 179 Some(collection.as_str()) 180 } else { 181 None 182 } 183 }) 184 .collect(); 185 186 for collection in create_collections { 187 if let Err(e) = crate::auth::scope_check::check_repo_scope( 188 is_oauth, 189 scope.as_deref(), 190 crate::oauth::RepoAction::Create, 191 collection, 192 ) { 193 return e; 194 } 195 } 196 for collection in update_collections { 197 if let Err(e) = crate::auth::scope_check::check_repo_scope( 198 is_oauth, 199 scope.as_deref(), 200 crate::oauth::RepoAction::Update, 201 collection, 202 ) { 203 return e; 204 } 205 } 206 for collection in delete_collections { 207 if let Err(e) = crate::auth::scope_check::check_repo_scope( 208 is_oauth, 209 scope.as_deref(), 210 crate::oauth::RepoAction::Delete, 211 collection, 212 ) { 213 return e; 214 } 215 } 216 } 217 218 let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 219 .fetch_optional(&state.db) 220 .await 221 { 222 Ok(Some(id)) => id, 223 _ => { 224 return ( 225 StatusCode::INTERNAL_SERVER_ERROR, 226 Json(json!({"error": "InternalError", "message": "User not found"})), 227 ) 228 .into_response(); 229 } 230 }; 231 let root_cid_str: String = match sqlx::query_scalar!( 232 "SELECT repo_root_cid FROM repos WHERE user_id = $1", 233 user_id 234 ) 235 .fetch_optional(&state.db) 236 .await 237 { 238 Ok(Some(cid_str)) => cid_str, 239 _ => { 240 return ( 241 StatusCode::INTERNAL_SERVER_ERROR, 242 Json(json!({"error": "InternalError", "message": "Repo root not found"})), 243 ) 244 .into_response(); 245 } 246 }; 247 let current_root_cid = match Cid::from_str(&root_cid_str) { 248 Ok(c) => c, 249 Err(_) => { 250 return ( 251 StatusCode::INTERNAL_SERVER_ERROR, 252 Json(json!({"error": "InternalError", "message": "Invalid repo root CID"})), 253 ) 254 .into_response(); 255 } 256 }; 257 if let Some(swap_commit) = &input.swap_commit 258 && Cid::from_str(swap_commit).ok() != Some(current_root_cid) 259 { 260 return ( 261 StatusCode::CONFLICT, 262 Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 263 ) 264 .into_response(); 265 } 266 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 267 let commit_bytes = match tracking_store.get(&current_root_cid).await { 268 Ok(Some(b)) => b, 269 _ => { 270 return ( 271 StatusCode::INTERNAL_SERVER_ERROR, 272 Json(json!({"error": "InternalError", "message": "Commit block not found"})), 273 ) 274 .into_response(); 275 } 276 }; 277 let commit = match Commit::from_cbor(&commit_bytes) { 278 Ok(c) => c, 279 _ => { 280 return ( 281 StatusCode::INTERNAL_SERVER_ERROR, 282 Json(json!({"error": "InternalError", "message": "Failed to parse commit"})), 283 ) 284 .into_response(); 285 } 286 }; 287 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 288 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 289 let mut results: Vec<WriteResult> = Vec::new(); 290 let mut ops: Vec<RecordOp> = Vec::new(); 291 let mut modified_keys: Vec<String> = Vec::new(); 292 for write in &input.writes { 293 match write { 294 WriteOp::Create { 295 collection, 296 rkey, 297 value, 298 } => { 299 if input.validate.unwrap_or(true) 300 && let Err(err_response) = validate_record(value, collection) 301 { 302 return *err_response; 303 } 304 let rkey = rkey 305 .clone() 306 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string()); 307 let mut record_bytes = Vec::new(); 308 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 309 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); 310 } 311 let record_cid = match tracking_store.put(&record_bytes).await { 312 Ok(c) => c, 313 Err(_) => return ( 314 StatusCode::INTERNAL_SERVER_ERROR, 315 Json( 316 json!({"error": "InternalError", "message": "Failed to store record"}), 317 ), 318 ) 319 .into_response(), 320 }; 321 let collection_nsid = match collection.parse::<Nsid>() { 322 Ok(n) => n, 323 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(), 324 }; 325 let key = format!("{}/{}", collection_nsid, rkey); 326 modified_keys.push(key.clone()); 327 mst = match mst.add(&key, record_cid).await { 328 Ok(m) => m, 329 Err(_) => return ( 330 StatusCode::INTERNAL_SERVER_ERROR, 331 Json(json!({"error": "InternalError", "message": "Failed to add to MST"})), 332 ) 333 .into_response(), 334 }; 335 let uri = format!("at://{}/{}/{}", did, collection, rkey); 336 results.push(WriteResult::CreateResult { 337 uri, 338 cid: record_cid.to_string(), 339 }); 340 ops.push(RecordOp::Create { 341 collection: collection.clone(), 342 rkey, 343 cid: record_cid, 344 }); 345 } 346 WriteOp::Update { 347 collection, 348 rkey, 349 value, 350 } => { 351 if input.validate.unwrap_or(true) 352 && let Err(err_response) = validate_record(value, collection) 353 { 354 return *err_response; 355 } 356 let mut record_bytes = Vec::new(); 357 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 358 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); 359 } 360 let record_cid = match tracking_store.put(&record_bytes).await { 361 Ok(c) => c, 362 Err(_) => return ( 363 StatusCode::INTERNAL_SERVER_ERROR, 364 Json( 365 json!({"error": "InternalError", "message": "Failed to store record"}), 366 ), 367 ) 368 .into_response(), 369 }; 370 let collection_nsid = match collection.parse::<Nsid>() { 371 Ok(n) => n, 372 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(), 373 }; 374 let key = format!("{}/{}", collection_nsid, rkey); 375 modified_keys.push(key.clone()); 376 let prev_record_cid = mst.get(&key).await.ok().flatten(); 377 mst = match mst.update(&key, record_cid).await { 378 Ok(m) => m, 379 Err(_) => return ( 380 StatusCode::INTERNAL_SERVER_ERROR, 381 Json(json!({"error": "InternalError", "message": "Failed to update MST"})), 382 ) 383 .into_response(), 384 }; 385 let uri = format!("at://{}/{}/{}", did, collection, rkey); 386 results.push(WriteResult::UpdateResult { 387 uri, 388 cid: record_cid.to_string(), 389 }); 390 ops.push(RecordOp::Update { 391 collection: collection.clone(), 392 rkey: rkey.clone(), 393 cid: record_cid, 394 prev: prev_record_cid, 395 }); 396 } 397 WriteOp::Delete { collection, rkey } => { 398 let collection_nsid = match collection.parse::<Nsid>() { 399 Ok(n) => n, 400 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(), 401 }; 402 let key = format!("{}/{}", collection_nsid, rkey); 403 modified_keys.push(key.clone()); 404 let prev_record_cid = mst.get(&key).await.ok().flatten(); 405 mst = match mst.delete(&key).await { 406 Ok(m) => m, 407 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to delete from MST"}))).into_response(), 408 }; 409 results.push(WriteResult::DeleteResult {}); 410 ops.push(RecordOp::Delete { 411 collection: collection.clone(), 412 rkey: rkey.clone(), 413 prev: prev_record_cid, 414 }); 415 } 416 } 417 } 418 let new_mst_root = match mst.persist().await { 419 Ok(c) => c, 420 Err(_) => { 421 return ( 422 StatusCode::INTERNAL_SERVER_ERROR, 423 Json(json!({"error": "InternalError", "message": "Failed to persist MST"})), 424 ) 425 .into_response(); 426 } 427 }; 428 let mut relevant_blocks = std::collections::BTreeMap::new(); 429 for key in &modified_keys { 430 if mst 431 .blocks_for_path(key, &mut relevant_blocks) 432 .await 433 .is_err() 434 { 435 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 436 } 437 if original_mst 438 .blocks_for_path(key, &mut relevant_blocks) 439 .await 440 .is_err() 441 { 442 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 443 } 444 } 445 let mut written_cids = tracking_store.get_all_relevant_cids(); 446 for cid in relevant_blocks.keys() { 447 if !written_cids.contains(cid) { 448 written_cids.push(*cid); 449 } 450 } 451 let written_cids_str = written_cids 452 .iter() 453 .map(|c| c.to_string()) 454 .collect::<Vec<_>>(); 455 let commit_res = match commit_and_log( 456 &state, 457 CommitParams { 458 did: &did, 459 user_id, 460 current_root_cid: Some(current_root_cid), 461 prev_data_cid: Some(commit.data), 462 new_mst_root, 463 ops, 464 blocks_cids: &written_cids_str, 465 }, 466 ) 467 .await 468 { 469 Ok(res) => res, 470 Err(e) => { 471 error!("Commit failed: {}", e); 472 return ( 473 StatusCode::INTERNAL_SERVER_ERROR, 474 Json(json!({"error": "InternalError", "message": "Failed to commit changes"})), 475 ) 476 .into_response(); 477 } 478 }; 479 ( 480 StatusCode::OK, 481 Json(ApplyWritesOutput { 482 commit: CommitInfo { 483 cid: commit_res.commit_cid.to_string(), 484 rev: commit_res.rev, 485 }, 486 results, 487 }), 488 ) 489 .into_response() 490}