this repo has no description
1use super::validation::validate_record; 2use super::write::has_verified_notification_channel; 3use crate::api::repo::record::utils::{commit_and_log, RecordOp}; 4use crate::repo::tracking::TrackingBlockStore; 5use crate::state::AppState; 6use axum::{ 7 extract::State, 8 http::StatusCode, 9 response::{IntoResponse, Response}, 10 Json, 11}; 12use cid::Cid; 13use jacquard::types::{integer::LimitedU32, string::{Nsid, Tid}}; 14use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; 15use serde::{Deserialize, Serialize}; 16use serde_json::json; 17use std::str::FromStr; 18use std::sync::Arc; 19use tracing::error; 20const MAX_BATCH_WRITES: usize = 200; 21#[derive(Deserialize)] 22#[serde(tag = "$type")] 23pub enum WriteOp { 24 #[serde(rename = "com.atproto.repo.applyWrites#create")] 25 Create { 26 collection: String, 27 rkey: Option<String>, 28 value: serde_json::Value, 29 }, 30 #[serde(rename = "com.atproto.repo.applyWrites#update")] 31 Update { 32 collection: String, 33 rkey: String, 34 value: serde_json::Value, 35 }, 36 #[serde(rename = "com.atproto.repo.applyWrites#delete")] 37 Delete { collection: String, rkey: String }, 38} 39#[derive(Deserialize)] 40#[serde(rename_all = "camelCase")] 41pub struct ApplyWritesInput { 42 pub repo: String, 43 pub validate: Option<bool>, 44 pub writes: Vec<WriteOp>, 45 pub swap_commit: Option<String>, 46} 47#[derive(Serialize)] 48#[serde(tag = "$type")] 49pub enum WriteResult { 50 #[serde(rename = "com.atproto.repo.applyWrites#createResult")] 51 CreateResult { uri: String, cid: String }, 52 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")] 53 UpdateResult { uri: String, cid: String }, 54 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")] 55 DeleteResult {}, 56} 57#[derive(Serialize)] 58pub struct ApplyWritesOutput { 59 pub commit: CommitInfo, 60 pub results: Vec<WriteResult>, 61} 62#[derive(Serialize)] 63pub struct CommitInfo { 64 pub cid: String, 65 pub rev: String, 66} 67pub async fn apply_writes( 68 State(state): State<AppState>, 69 headers: axum::http::HeaderMap, 70 Json(input): Json<ApplyWritesInput>, 71) -> Response { 72 let token = match crate::auth::extract_bearer_token_from_header( 73 headers.get("Authorization").and_then(|h| h.to_str().ok()) 74 ) { 75 Some(t) => t, 76 None => { 77 return ( 78 StatusCode::UNAUTHORIZED, 79 Json(json!({"error": "AuthenticationRequired"})), 80 ) 81 .into_response(); 82 } 83 }; 84 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await { 85 Ok(user) => user, 86 Err(_) => { 87 return ( 88 StatusCode::UNAUTHORIZED, 89 Json(json!({"error": "AuthenticationFailed"})), 90 ) 91 .into_response(); 92 } 93 }; 94 let did = auth_user.did; 95 if input.repo != did { 96 return ( 97 StatusCode::FORBIDDEN, 98 Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"})), 99 ) 100 .into_response(); 101 } 102 match has_verified_notification_channel(&state.db, &did).await { 103 Ok(true) => {} 104 Ok(false) => { 105 return ( 106 StatusCode::FORBIDDEN, 107 Json(json!({ 108 "error": "AccountNotVerified", 109 "message": "You must verify at least one notification channel (email, Discord, Telegram, or Signal) before creating records" 110 })), 111 ) 112 .into_response(); 113 } 114 Err(e) => { 115 error!("DB error checking notification channels: {}", e); 116 return ( 117 StatusCode::INTERNAL_SERVER_ERROR, 118 Json(json!({"error": "InternalError"})), 119 ) 120 .into_response(); 121 } 122 } 123 if input.writes.is_empty() { 124 return ( 125 StatusCode::BAD_REQUEST, 126 Json(json!({"error": "InvalidRequest", "message": "writes array is empty"})), 127 ) 128 .into_response(); 129 } 130 if input.writes.len() > MAX_BATCH_WRITES { 131 return ( 132 StatusCode::BAD_REQUEST, 133 Json(json!({"error": "InvalidRequest", "message": format!("Too many writes (max {})", MAX_BATCH_WRITES)})), 134 ) 135 .into_response(); 136 } 137 let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 138 .fetch_optional(&state.db) 139 .await 140 { 141 Ok(Some(id)) => id, 142 _ => { 143 return ( 144 StatusCode::INTERNAL_SERVER_ERROR, 145 Json(json!({"error": "InternalError", "message": "User not found"})), 146 ) 147 .into_response(); 148 } 149 }; 150 let root_cid_str: String = 151 match sqlx::query_scalar!("SELECT repo_root_cid FROM repos WHERE user_id = $1", user_id) 152 .fetch_optional(&state.db) 153 .await 154 { 155 Ok(Some(cid_str)) => cid_str, 156 _ => { 157 return ( 158 StatusCode::INTERNAL_SERVER_ERROR, 159 Json(json!({"error": "InternalError", "message": "Repo root not found"})), 160 ) 161 .into_response(); 162 } 163 }; 164 let current_root_cid = match Cid::from_str(&root_cid_str) { 165 Ok(c) => c, 166 Err(_) => { 167 return ( 168 StatusCode::INTERNAL_SERVER_ERROR, 169 Json(json!({"error": "InternalError", "message": "Invalid repo root CID"})), 170 ) 171 .into_response(); 172 } 173 }; 174 if let Some(swap_commit) = &input.swap_commit { 175 if Cid::from_str(swap_commit).ok() != Some(current_root_cid) { 176 return ( 177 StatusCode::CONFLICT, 178 Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})), 179 ) 180 .into_response(); 181 } 182 } 183 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 184 let commit_bytes = match tracking_store.get(&current_root_cid).await { 185 Ok(Some(b)) => b, 186 _ => { 187 return ( 188 StatusCode::INTERNAL_SERVER_ERROR, 189 Json(json!({"error": "InternalError", "message": "Commit block not found"})), 190 ) 191 .into_response() 192 } 193 }; 194 let commit = match Commit::from_cbor(&commit_bytes) { 195 Ok(c) => c, 196 _ => { 197 return ( 198 StatusCode::INTERNAL_SERVER_ERROR, 199 Json(json!({"error": "InternalError", "message": "Failed to parse commit"})), 200 ) 201 .into_response() 202 } 203 }; 204 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 205 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 206 let mut results: Vec<WriteResult> = Vec::new(); 207 let mut ops: Vec<RecordOp> = Vec::new(); 208 let mut modified_keys: Vec<String> = Vec::new(); 209 for write in &input.writes { 210 match write { 211 WriteOp::Create { 212 collection, 213 rkey, 214 value, 215 } => { 216 if input.validate.unwrap_or(true) { 217 if let Err(err_response) = validate_record(value, collection) { 218 return err_response; 219 } 220 } 221 let rkey = rkey 222 .clone() 223 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string()); 224 let mut record_bytes = Vec::new(); 225 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 226 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); 227 } 228 let record_cid = match tracking_store.put(&record_bytes).await { 229 Ok(c) => c, 230 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(), 231 }; 232 let collection_nsid = match collection.parse::<Nsid>() { 233 Ok(n) => n, 234 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(), 235 }; 236 let key = format!("{}/{}", collection_nsid, rkey); 237 modified_keys.push(key.clone()); 238 mst = match mst.add(&key, record_cid).await { 239 Ok(m) => m, 240 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(), 241 }; 242 let uri = format!("at://{}/{}/{}", did, collection, rkey); 243 results.push(WriteResult::CreateResult { 244 uri, 245 cid: record_cid.to_string(), 246 }); 247 ops.push(RecordOp::Create { 248 collection: collection.clone(), 249 rkey, 250 cid: record_cid, 251 }); 252 } 253 WriteOp::Update { 254 collection, 255 rkey, 256 value, 257 } => { 258 if input.validate.unwrap_or(true) { 259 if let Err(err_response) = validate_record(value, collection) { 260 return err_response; 261 } 262 } 263 let mut record_bytes = Vec::new(); 264 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() { 265 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response(); 266 } 267 let record_cid = match tracking_store.put(&record_bytes).await { 268 Ok(c) => c, 269 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(), 270 }; 271 let collection_nsid = match collection.parse::<Nsid>() { 272 Ok(n) => n, 273 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(), 274 }; 275 let key = format!("{}/{}", collection_nsid, rkey); 276 modified_keys.push(key.clone()); 277 let prev_record_cid = mst.get(&key).await.ok().flatten(); 278 mst = match mst.update(&key, record_cid).await { 279 Ok(m) => m, 280 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update MST"}))).into_response(), 281 }; 282 let uri = format!("at://{}/{}/{}", did, collection, rkey); 283 results.push(WriteResult::UpdateResult { 284 uri, 285 cid: record_cid.to_string(), 286 }); 287 ops.push(RecordOp::Update { 288 collection: collection.clone(), 289 rkey: rkey.clone(), 290 cid: record_cid, 291 prev: prev_record_cid, 292 }); 293 } 294 WriteOp::Delete { collection, rkey } => { 295 let collection_nsid = match collection.parse::<Nsid>() { 296 Ok(n) => n, 297 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(), 298 }; 299 let key = format!("{}/{}", collection_nsid, rkey); 300 modified_keys.push(key.clone()); 301 let prev_record_cid = mst.get(&key).await.ok().flatten(); 302 mst = match mst.delete(&key).await { 303 Ok(m) => m, 304 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to delete from MST"}))).into_response(), 305 }; 306 results.push(WriteResult::DeleteResult {}); 307 ops.push(RecordOp::Delete { 308 collection: collection.clone(), 309 rkey: rkey.clone(), 310 prev: prev_record_cid, 311 }); 312 } 313 } 314 } 315 let new_mst_root = match mst.persist().await { 316 Ok(c) => c, 317 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(), 318 }; 319 let mut relevant_blocks = std::collections::BTreeMap::new(); 320 for key in &modified_keys { 321 if let Err(_) = mst.blocks_for_path(key, &mut relevant_blocks).await { 322 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response(); 323 } 324 if let Err(_) = original_mst.blocks_for_path(key, &mut relevant_blocks).await { 325 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response(); 326 } 327 } 328 let mut written_cids = tracking_store.get_all_relevant_cids(); 329 for cid in relevant_blocks.keys() { 330 if !written_cids.contains(cid) { 331 written_cids.push(*cid); 332 } 333 } 334 let written_cids_str = written_cids 335 .iter() 336 .map(|c| c.to_string()) 337 .collect::<Vec<_>>(); 338 let commit_res = match commit_and_log( 339 &state, 340 &did, 341 user_id, 342 Some(current_root_cid), 343 Some(commit.data), 344 new_mst_root, 345 ops, 346 &written_cids_str, 347 ) 348 .await 349 { 350 Ok(res) => res, 351 Err(e) => { 352 error!("Commit failed: {}", e); 353 return ( 354 StatusCode::INTERNAL_SERVER_ERROR, 355 Json(json!({"error": "InternalError", "message": "Failed to commit changes"})), 356 ) 357 .into_response(); 358 } 359 }; 360 ( 361 StatusCode::OK, 362 Json(ApplyWritesOutput { 363 commit: CommitInfo { 364 cid: commit_res.commit_cid.to_string(), 365 rev: commit_res.rev, 366 }, 367 results, 368 }), 369 ) 370 .into_response() 371}