this repo has no description
1use super::validation::validate_record_with_status;
2use super::write::has_verified_comms_channel;
3use crate::api::error::ApiError;
4use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_blob_cids};
5use crate::auth::BearerAuth;
6use crate::delegation::{self, DelegationActionType};
7use crate::repo::tracking::TrackingBlockStore;
8use crate::state::AppState;
9use crate::types::{AtIdentifier, AtUri, Nsid, Rkey};
10use axum::{
11 Json,
12 extract::State,
13 http::StatusCode,
14 response::{IntoResponse, Response},
15};
16use cid::Cid;
17use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
18use serde::{Deserialize, Serialize};
19use serde_json::json;
20use std::str::FromStr;
21use std::sync::Arc;
22use tracing::{error, info};
23
24const MAX_BATCH_WRITES: usize = 200;
25
26#[derive(Deserialize)]
27#[serde(tag = "$type")]
28pub enum WriteOp {
29 #[serde(rename = "com.atproto.repo.applyWrites#create")]
30 Create {
31 collection: Nsid,
32 rkey: Option<Rkey>,
33 value: serde_json::Value,
34 },
35 #[serde(rename = "com.atproto.repo.applyWrites#update")]
36 Update {
37 collection: Nsid,
38 rkey: Rkey,
39 value: serde_json::Value,
40 },
41 #[serde(rename = "com.atproto.repo.applyWrites#delete")]
42 Delete { collection: Nsid, rkey: Rkey },
43}
44
45#[derive(Deserialize)]
46#[serde(rename_all = "camelCase")]
47pub struct ApplyWritesInput {
48 pub repo: AtIdentifier,
49 pub validate: Option<bool>,
50 pub writes: Vec<WriteOp>,
51 pub swap_commit: Option<String>,
52}
53
54#[derive(Serialize)]
55#[serde(tag = "$type")]
56pub enum WriteResult {
57 #[serde(rename = "com.atproto.repo.applyWrites#createResult")]
58 CreateResult {
59 uri: AtUri,
60 cid: String,
61 #[serde(rename = "validationStatus", skip_serializing_if = "Option::is_none")]
62 validation_status: Option<String>,
63 },
64 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")]
65 UpdateResult {
66 uri: AtUri,
67 cid: String,
68 #[serde(rename = "validationStatus", skip_serializing_if = "Option::is_none")]
69 validation_status: Option<String>,
70 },
71 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")]
72 DeleteResult {},
73}
74
75#[derive(Serialize)]
76pub struct ApplyWritesOutput {
77 pub commit: CommitInfo,
78 pub results: Vec<WriteResult>,
79}
80
81#[derive(Serialize)]
82pub struct CommitInfo {
83 pub cid: String,
84 pub rev: String,
85}
86
87pub async fn apply_writes(
88 State(state): State<AppState>,
89 auth: BearerAuth,
90 Json(input): Json<ApplyWritesInput>,
91) -> Response {
92 info!(
93 "apply_writes called: repo={}, writes={}",
94 input.repo,
95 input.writes.len()
96 );
97 let auth_user = auth.0;
98 let did = auth_user.did.clone();
99 let is_oauth = auth_user.is_oauth;
100 let scope = auth_user.scope;
101 let controller_did = auth_user.controller_did.clone();
102 if input.repo.as_str() != did {
103 return ApiError::InvalidRepo("Repo does not match authenticated user".into())
104 .into_response();
105 }
106 if crate::util::is_account_migrated(&state.db, &did)
107 .await
108 .unwrap_or(false)
109 {
110 return ApiError::AccountMigrated.into_response();
111 }
112 let is_verified = has_verified_comms_channel(&state.db, &did)
113 .await
114 .unwrap_or(false);
115 let is_delegated = crate::delegation::is_delegated_account(&state.db, &did)
116 .await
117 .unwrap_or(false);
118 if !is_verified && !is_delegated {
119 return ApiError::AccountNotVerified.into_response();
120 }
121 if input.writes.is_empty() {
122 return ApiError::InvalidRequest("writes array is empty".into()).into_response();
123 }
124 if input.writes.len() > MAX_BATCH_WRITES {
125 return ApiError::InvalidRequest(format!("Too many writes (max {})", MAX_BATCH_WRITES))
126 .into_response();
127 }
128
129 let has_custom_scope = scope
130 .as_ref()
131 .map(|s| s != "com.atproto.access")
132 .unwrap_or(false);
133 if is_oauth || has_custom_scope {
134 use std::collections::HashSet;
135 let create_collections: HashSet<&Nsid> = input
136 .writes
137 .iter()
138 .filter_map(|w| {
139 if let WriteOp::Create { collection, .. } = w {
140 Some(collection)
141 } else {
142 None
143 }
144 })
145 .collect();
146 let update_collections: HashSet<&Nsid> = input
147 .writes
148 .iter()
149 .filter_map(|w| {
150 if let WriteOp::Update { collection, .. } = w {
151 Some(collection)
152 } else {
153 None
154 }
155 })
156 .collect();
157 let delete_collections: HashSet<&Nsid> = input
158 .writes
159 .iter()
160 .filter_map(|w| {
161 if let WriteOp::Delete { collection, .. } = w {
162 Some(collection)
163 } else {
164 None
165 }
166 })
167 .collect();
168
169 for collection in create_collections {
170 if let Err(e) = crate::auth::scope_check::check_repo_scope(
171 is_oauth,
172 scope.as_deref(),
173 crate::oauth::RepoAction::Create,
174 collection,
175 ) {
176 return e;
177 }
178 }
179 for collection in update_collections {
180 if let Err(e) = crate::auth::scope_check::check_repo_scope(
181 is_oauth,
182 scope.as_deref(),
183 crate::oauth::RepoAction::Update,
184 collection,
185 ) {
186 return e;
187 }
188 }
189 for collection in delete_collections {
190 if let Err(e) = crate::auth::scope_check::check_repo_scope(
191 is_oauth,
192 scope.as_deref(),
193 crate::oauth::RepoAction::Delete,
194 collection,
195 ) {
196 return e;
197 }
198 }
199 }
200
201 let user_id: uuid::Uuid =
202 match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did.as_str())
203 .fetch_optional(&state.db)
204 .await
205 {
206 Ok(Some(id)) => id,
207 _ => return ApiError::InternalError(Some("User not found".into())).into_response(),
208 };
209 let root_cid_str: String = match sqlx::query_scalar!(
210 "SELECT repo_root_cid FROM repos WHERE user_id = $1",
211 user_id
212 )
213 .fetch_optional(&state.db)
214 .await
215 {
216 Ok(Some(cid_str)) => cid_str,
217 _ => return ApiError::InternalError(Some("Repo root not found".into())).into_response(),
218 };
219 let current_root_cid = match Cid::from_str(&root_cid_str) {
220 Ok(c) => c,
221 Err(_) => {
222 return ApiError::InternalError(Some("Invalid repo root CID".into())).into_response();
223 }
224 };
225 if let Some(swap_commit) = &input.swap_commit
226 && Cid::from_str(swap_commit).ok() != Some(current_root_cid)
227 {
228 return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response();
229 }
230 let tracking_store = TrackingBlockStore::new(state.block_store.clone());
231 let commit_bytes = match tracking_store.get(¤t_root_cid).await {
232 Ok(Some(b)) => b,
233 _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(),
234 };
235 let commit = match Commit::from_cbor(&commit_bytes) {
236 Ok(c) => c,
237 _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(),
238 };
239 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
240 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
241 let mut results: Vec<WriteResult> = Vec::new();
242 let mut ops: Vec<RecordOp> = Vec::new();
243 let mut modified_keys: Vec<String> = Vec::new();
244 let mut all_blob_cids: Vec<String> = Vec::new();
245 for write in &input.writes {
246 match write {
247 WriteOp::Create {
248 collection,
249 rkey,
250 value,
251 } => {
252 let validation_status = if input.validate == Some(false) {
253 None
254 } else {
255 let require_lexicon = input.validate == Some(true);
256 match validate_record_with_status(
257 value,
258 collection,
259 rkey.as_ref().map(|r| r.as_str()),
260 require_lexicon,
261 ) {
262 Ok(status) => Some(status),
263 Err(err_response) => return *err_response,
264 }
265 };
266 all_blob_cids.extend(extract_blob_cids(value));
267 let rkey = rkey.clone().unwrap_or_else(Rkey::generate);
268 let record_ipld = crate::util::json_to_ipld(value);
269 let mut record_bytes = Vec::new();
270 if serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld).is_err() {
271 return ApiError::InvalidRecord("Failed to serialize record".into())
272 .into_response();
273 }
274 let record_cid = match tracking_store.put(&record_bytes).await {
275 Ok(c) => c,
276 Err(_) => {
277 return ApiError::InternalError(Some("Failed to store record".into()))
278 .into_response();
279 }
280 };
281 let key = format!("{}/{}", collection, rkey);
282 modified_keys.push(key.clone());
283 mst = match mst.add(&key, record_cid).await {
284 Ok(m) => m,
285 Err(_) => {
286 return ApiError::InternalError(Some("Failed to add to MST".into()))
287 .into_response();
288 }
289 };
290 let uri = AtUri::from_parts(&did, collection, &rkey);
291 results.push(WriteResult::CreateResult {
292 uri,
293 cid: record_cid.to_string(),
294 validation_status: validation_status.map(|s| s.to_string()),
295 });
296 ops.push(RecordOp::Create {
297 collection: collection.to_string(),
298 rkey: rkey.to_string(),
299 cid: record_cid,
300 });
301 }
302 WriteOp::Update {
303 collection,
304 rkey,
305 value,
306 } => {
307 let validation_status = if input.validate == Some(false) {
308 None
309 } else {
310 let require_lexicon = input.validate == Some(true);
311 match validate_record_with_status(
312 value,
313 collection,
314 Some(rkey.as_str()),
315 require_lexicon,
316 ) {
317 Ok(status) => Some(status),
318 Err(err_response) => return *err_response,
319 }
320 };
321 all_blob_cids.extend(extract_blob_cids(value));
322 let record_ipld = crate::util::json_to_ipld(value);
323 let mut record_bytes = Vec::new();
324 if serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld).is_err() {
325 return ApiError::InvalidRecord("Failed to serialize record".into())
326 .into_response();
327 }
328 let record_cid = match tracking_store.put(&record_bytes).await {
329 Ok(c) => c,
330 Err(_) => {
331 return ApiError::InternalError(Some("Failed to store record".into()))
332 .into_response();
333 }
334 };
335 let key = format!("{}/{}", collection, rkey);
336 modified_keys.push(key.clone());
337 let prev_record_cid = mst.get(&key).await.ok().flatten();
338 mst = match mst.update(&key, record_cid).await {
339 Ok(m) => m,
340 Err(_) => {
341 return ApiError::InternalError(Some("Failed to update MST".into()))
342 .into_response();
343 }
344 };
345 let uri = AtUri::from_parts(&did, collection, rkey);
346 results.push(WriteResult::UpdateResult {
347 uri,
348 cid: record_cid.to_string(),
349 validation_status: validation_status.map(|s| s.to_string()),
350 });
351 ops.push(RecordOp::Update {
352 collection: collection.to_string(),
353 rkey: rkey.to_string(),
354 cid: record_cid,
355 prev: prev_record_cid,
356 });
357 }
358 WriteOp::Delete { collection, rkey } => {
359 let key = format!("{}/{}", collection, rkey);
360 modified_keys.push(key.clone());
361 let prev_record_cid = mst.get(&key).await.ok().flatten();
362 mst = match mst.delete(&key).await {
363 Ok(m) => m,
364 Err(_) => {
365 return ApiError::InternalError(Some("Failed to delete from MST".into()))
366 .into_response();
367 }
368 };
369 results.push(WriteResult::DeleteResult {});
370 ops.push(RecordOp::Delete {
371 collection: collection.to_string(),
372 rkey: rkey.to_string(),
373 prev: prev_record_cid,
374 });
375 }
376 }
377 }
378 let new_mst_root = match mst.persist().await {
379 Ok(c) => c,
380 Err(_) => {
381 return ApiError::InternalError(Some("Failed to persist MST".into())).into_response();
382 }
383 };
384 let mut new_mst_blocks = std::collections::BTreeMap::new();
385 let mut old_mst_blocks = std::collections::BTreeMap::new();
386 for key in &modified_keys {
387 if mst.blocks_for_path(key, &mut new_mst_blocks).await.is_err() {
388 return ApiError::InternalError(Some("Failed to get new MST blocks for path".into()))
389 .into_response();
390 }
391 if original_mst
392 .blocks_for_path(key, &mut old_mst_blocks)
393 .await
394 .is_err()
395 {
396 return ApiError::InternalError(Some("Failed to get old MST blocks for path".into()))
397 .into_response();
398 }
399 }
400 let mut relevant_blocks = new_mst_blocks.clone();
401 relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone())));
402 let written_cids: Vec<Cid> = tracking_store
403 .get_all_relevant_cids()
404 .into_iter()
405 .chain(relevant_blocks.keys().copied())
406 .collect::<std::collections::HashSet<_>>()
407 .into_iter()
408 .collect();
409 let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect();
410 let prev_record_cids = ops.iter().filter_map(|op| match op {
411 RecordOp::Update {
412 prev: Some(cid), ..
413 }
414 | RecordOp::Delete {
415 prev: Some(cid), ..
416 } => Some(*cid),
417 _ => None,
418 });
419 let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid)
420 .chain(
421 old_mst_blocks
422 .keys()
423 .filter(|cid| !new_mst_blocks.contains_key(*cid))
424 .copied(),
425 )
426 .chain(prev_record_cids)
427 .collect::<std::collections::HashSet<_>>()
428 .into_iter()
429 .collect();
430 let commit_res = match commit_and_log(
431 &state,
432 CommitParams {
433 did: &did,
434 user_id,
435 current_root_cid: Some(current_root_cid),
436 prev_data_cid: Some(commit.data),
437 new_mst_root,
438 ops,
439 blocks_cids: &written_cids_str,
440 blobs: &all_blob_cids,
441 obsolete_cids,
442 },
443 )
444 .await
445 {
446 Ok(res) => res,
447 Err(e) if e.contains("ConcurrentModification") => {
448 return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response();
449 }
450 Err(e) => {
451 error!("Commit failed: {}", e);
452 return ApiError::InternalError(Some("Failed to commit changes".into()))
453 .into_response();
454 }
455 };
456
457 if let Some(ref controller) = controller_did {
458 let write_summary: Vec<serde_json::Value> = input
459 .writes
460 .iter()
461 .map(|w| match w {
462 WriteOp::Create {
463 collection, rkey, ..
464 } => json!({
465 "action": "create",
466 "collection": collection,
467 "rkey": rkey
468 }),
469 WriteOp::Update {
470 collection, rkey, ..
471 } => json!({
472 "action": "update",
473 "collection": collection,
474 "rkey": rkey
475 }),
476 WriteOp::Delete { collection, rkey } => json!({
477 "action": "delete",
478 "collection": collection,
479 "rkey": rkey
480 }),
481 })
482 .collect();
483
484 let _ = delegation::log_delegation_action(
485 &state.db,
486 &did,
487 controller,
488 Some(controller),
489 DelegationActionType::RepoWrite,
490 Some(json!({
491 "action": "apply_writes",
492 "count": input.writes.len(),
493 "writes": write_summary
494 })),
495 None,
496 None,
497 )
498 .await;
499 }
500
501 (
502 StatusCode::OK,
503 Json(ApplyWritesOutput {
504 commit: CommitInfo {
505 cid: commit_res.commit_cid.to_string(),
506 rev: commit_res.rev,
507 },
508 results,
509 }),
510 )
511 .into_response()
512}