this repo has no description
1use super::validation::validate_record_with_status;
2use super::write::has_verified_comms_channel;
3use crate::api::error::ApiError;
4use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_blob_cids};
5use crate::auth::BearerAuth;
6use crate::delegation::{self, DelegationActionType};
7use crate::repo::tracking::TrackingBlockStore;
8use crate::state::AppState;
9use crate::types::{AtIdentifier, AtUri, Did, Nsid, Rkey};
10use axum::{
11 Json,
12 extract::State,
13 http::StatusCode,
14 response::{IntoResponse, Response},
15};
16use cid::Cid;
17use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
18use serde::{Deserialize, Serialize};
19use serde_json::json;
20use std::str::FromStr;
21use std::sync::Arc;
22use tracing::{error, info};
23
24const MAX_BATCH_WRITES: usize = 200;
25
26struct WriteAccumulator {
27 mst: Mst<TrackingBlockStore>,
28 results: Vec<WriteResult>,
29 ops: Vec<RecordOp>,
30 modified_keys: Vec<String>,
31 all_blob_cids: Vec<String>,
32}
33
34async fn process_single_write(
35 write: &WriteOp,
36 acc: WriteAccumulator,
37 did: &Did,
38 validate: Option<bool>,
39 tracking_store: &TrackingBlockStore,
40) -> Result<WriteAccumulator, Response> {
41 let WriteAccumulator {
42 mst,
43 mut results,
44 mut ops,
45 mut modified_keys,
46 mut all_blob_cids,
47 } = acc;
48
49 match write {
50 WriteOp::Create {
51 collection,
52 rkey,
53 value,
54 } => {
55 let validation_status = match validate {
56 Some(false) => None,
57 _ => {
58 let require_lexicon = validate == Some(true);
59 match validate_record_with_status(
60 value,
61 collection,
62 rkey.as_ref(),
63 require_lexicon,
64 ) {
65 Ok(status) => Some(status),
66 Err(err_response) => return Err(*err_response),
67 }
68 }
69 };
70 all_blob_cids.extend(extract_blob_cids(value));
71 let rkey = rkey.clone().unwrap_or_else(Rkey::generate);
72 let record_ipld = crate::util::json_to_ipld(value);
73 let record_bytes = serde_ipld_dagcbor::to_vec(&record_ipld).map_err(|_| {
74 ApiError::InvalidRecord("Failed to serialize record".into()).into_response()
75 })?;
76 let record_cid = tracking_store.put(&record_bytes).await.map_err(|_| {
77 ApiError::InternalError(Some("Failed to store record".into())).into_response()
78 })?;
79 let key = format!("{}/{}", collection, rkey);
80 modified_keys.push(key.clone());
81 let new_mst = mst.add(&key, record_cid).await.map_err(|_| {
82 ApiError::InternalError(Some("Failed to add to MST".into())).into_response()
83 })?;
84 let uri = AtUri::from_parts(did, collection, &rkey);
85 results.push(WriteResult::CreateResult {
86 uri,
87 cid: record_cid.to_string(),
88 validation_status: validation_status.map(|s| s.to_string()),
89 });
90 ops.push(RecordOp::Create {
91 collection: collection.clone(),
92 rkey: rkey.clone(),
93 cid: record_cid,
94 });
95 Ok(WriteAccumulator {
96 mst: new_mst,
97 results,
98 ops,
99 modified_keys,
100 all_blob_cids,
101 })
102 }
103 WriteOp::Update {
104 collection,
105 rkey,
106 value,
107 } => {
108 let validation_status = match validate {
109 Some(false) => None,
110 _ => {
111 let require_lexicon = validate == Some(true);
112 match validate_record_with_status(
113 value,
114 collection,
115 Some(rkey),
116 require_lexicon,
117 ) {
118 Ok(status) => Some(status),
119 Err(err_response) => return Err(*err_response),
120 }
121 }
122 };
123 all_blob_cids.extend(extract_blob_cids(value));
124 let record_ipld = crate::util::json_to_ipld(value);
125 let record_bytes = serde_ipld_dagcbor::to_vec(&record_ipld).map_err(|_| {
126 ApiError::InvalidRecord("Failed to serialize record".into()).into_response()
127 })?;
128 let record_cid = tracking_store.put(&record_bytes).await.map_err(|_| {
129 ApiError::InternalError(Some("Failed to store record".into())).into_response()
130 })?;
131 let key = format!("{}/{}", collection, rkey);
132 modified_keys.push(key.clone());
133 let prev_record_cid = mst.get(&key).await.ok().flatten();
134 let new_mst = mst.update(&key, record_cid).await.map_err(|_| {
135 ApiError::InternalError(Some("Failed to update MST".into())).into_response()
136 })?;
137 let uri = AtUri::from_parts(did, collection, rkey);
138 results.push(WriteResult::UpdateResult {
139 uri,
140 cid: record_cid.to_string(),
141 validation_status: validation_status.map(|s| s.to_string()),
142 });
143 ops.push(RecordOp::Update {
144 collection: collection.clone(),
145 rkey: rkey.clone(),
146 cid: record_cid,
147 prev: prev_record_cid,
148 });
149 Ok(WriteAccumulator {
150 mst: new_mst,
151 results,
152 ops,
153 modified_keys,
154 all_blob_cids,
155 })
156 }
157 WriteOp::Delete { collection, rkey } => {
158 let key = format!("{}/{}", collection, rkey);
159 modified_keys.push(key.clone());
160 let prev_record_cid = mst.get(&key).await.ok().flatten();
161 let new_mst = mst.delete(&key).await.map_err(|_| {
162 ApiError::InternalError(Some("Failed to delete from MST".into())).into_response()
163 })?;
164 results.push(WriteResult::DeleteResult {});
165 ops.push(RecordOp::Delete {
166 collection: collection.clone(),
167 rkey: rkey.clone(),
168 prev: prev_record_cid,
169 });
170 Ok(WriteAccumulator {
171 mst: new_mst,
172 results,
173 ops,
174 modified_keys,
175 all_blob_cids,
176 })
177 }
178 }
179}
180
181async fn process_writes(
182 writes: &[WriteOp],
183 initial_mst: Mst<TrackingBlockStore>,
184 did: &Did,
185 validate: Option<bool>,
186 tracking_store: &TrackingBlockStore,
187) -> Result<WriteAccumulator, Response> {
188 use futures::stream::{self, TryStreamExt};
189 let initial_acc = WriteAccumulator {
190 mst: initial_mst,
191 results: Vec::new(),
192 ops: Vec::new(),
193 modified_keys: Vec::new(),
194 all_blob_cids: Vec::new(),
195 };
196 stream::iter(writes.iter().map(Ok::<_, Response>))
197 .try_fold(initial_acc, |acc, write| async move {
198 process_single_write(write, acc, did, validate, tracking_store).await
199 })
200 .await
201}
202
203#[derive(Deserialize)]
204#[serde(tag = "$type")]
205pub enum WriteOp {
206 #[serde(rename = "com.atproto.repo.applyWrites#create")]
207 Create {
208 collection: Nsid,
209 rkey: Option<Rkey>,
210 value: serde_json::Value,
211 },
212 #[serde(rename = "com.atproto.repo.applyWrites#update")]
213 Update {
214 collection: Nsid,
215 rkey: Rkey,
216 value: serde_json::Value,
217 },
218 #[serde(rename = "com.atproto.repo.applyWrites#delete")]
219 Delete { collection: Nsid, rkey: Rkey },
220}
221
222#[derive(Deserialize)]
223#[serde(rename_all = "camelCase")]
224pub struct ApplyWritesInput {
225 pub repo: AtIdentifier,
226 pub validate: Option<bool>,
227 pub writes: Vec<WriteOp>,
228 pub swap_commit: Option<String>,
229}
230
231#[derive(Serialize)]
232#[serde(tag = "$type")]
233pub enum WriteResult {
234 #[serde(rename = "com.atproto.repo.applyWrites#createResult")]
235 CreateResult {
236 uri: AtUri,
237 cid: String,
238 #[serde(rename = "validationStatus", skip_serializing_if = "Option::is_none")]
239 validation_status: Option<String>,
240 },
241 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")]
242 UpdateResult {
243 uri: AtUri,
244 cid: String,
245 #[serde(rename = "validationStatus", skip_serializing_if = "Option::is_none")]
246 validation_status: Option<String>,
247 },
248 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")]
249 DeleteResult {},
250}
251
252#[derive(Serialize)]
253pub struct ApplyWritesOutput {
254 pub commit: CommitInfo,
255 pub results: Vec<WriteResult>,
256}
257
258#[derive(Serialize)]
259pub struct CommitInfo {
260 pub cid: String,
261 pub rev: String,
262}
263
264pub async fn apply_writes(
265 State(state): State<AppState>,
266 auth: BearerAuth,
267 Json(input): Json<ApplyWritesInput>,
268) -> Response {
269 info!(
270 "apply_writes called: repo={}, writes={}",
271 input.repo,
272 input.writes.len()
273 );
274 let auth_user = auth.0;
275 let did = auth_user.did.clone();
276 let is_oauth = auth_user.is_oauth;
277 let scope = auth_user.scope;
278 let controller_did = auth_user.controller_did.clone();
279 if input.repo.as_str() != did {
280 return ApiError::InvalidRepo("Repo does not match authenticated user".into())
281 .into_response();
282 }
283 if crate::util::is_account_migrated(&state.db, &did)
284 .await
285 .unwrap_or(false)
286 {
287 return ApiError::AccountMigrated.into_response();
288 }
289 let is_verified = has_verified_comms_channel(&state.db, &did)
290 .await
291 .unwrap_or(false);
292 let is_delegated = crate::delegation::is_delegated_account(&state.db, &did)
293 .await
294 .unwrap_or(false);
295 if !is_verified && !is_delegated {
296 return ApiError::AccountNotVerified.into_response();
297 }
298 if input.writes.is_empty() {
299 return ApiError::InvalidRequest("writes array is empty".into()).into_response();
300 }
301 if input.writes.len() > MAX_BATCH_WRITES {
302 return ApiError::InvalidRequest(format!("Too many writes (max {})", MAX_BATCH_WRITES))
303 .into_response();
304 }
305
306 let has_custom_scope = scope
307 .as_ref()
308 .map(|s| s != "com.atproto.access")
309 .unwrap_or(false);
310 if is_oauth || has_custom_scope {
311 use std::collections::HashSet;
312 let create_collections: HashSet<&Nsid> = input
313 .writes
314 .iter()
315 .filter_map(|w| {
316 if let WriteOp::Create { collection, .. } = w {
317 Some(collection)
318 } else {
319 None
320 }
321 })
322 .collect();
323 let update_collections: HashSet<&Nsid> = input
324 .writes
325 .iter()
326 .filter_map(|w| {
327 if let WriteOp::Update { collection, .. } = w {
328 Some(collection)
329 } else {
330 None
331 }
332 })
333 .collect();
334 let delete_collections: HashSet<&Nsid> = input
335 .writes
336 .iter()
337 .filter_map(|w| {
338 if let WriteOp::Delete { collection, .. } = w {
339 Some(collection)
340 } else {
341 None
342 }
343 })
344 .collect();
345
346 let scope_checks = create_collections
347 .iter()
348 .map(|c| (crate::oauth::RepoAction::Create, c))
349 .chain(
350 update_collections
351 .iter()
352 .map(|c| (crate::oauth::RepoAction::Update, c)),
353 )
354 .chain(
355 delete_collections
356 .iter()
357 .map(|c| (crate::oauth::RepoAction::Delete, c)),
358 );
359
360 if let Some(err) = scope_checks
361 .filter_map(|(action, collection)| {
362 crate::auth::scope_check::check_repo_scope(
363 is_oauth,
364 scope.as_deref(),
365 action,
366 collection,
367 )
368 .err()
369 })
370 .next()
371 {
372 return err;
373 }
374 }
375
376 let user_id: uuid::Uuid =
377 match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did.as_str())
378 .fetch_optional(&state.db)
379 .await
380 {
381 Ok(Some(id)) => id,
382 _ => return ApiError::InternalError(Some("User not found".into())).into_response(),
383 };
384 let root_cid_str: String = match sqlx::query_scalar!(
385 "SELECT repo_root_cid FROM repos WHERE user_id = $1",
386 user_id
387 )
388 .fetch_optional(&state.db)
389 .await
390 {
391 Ok(Some(cid_str)) => cid_str,
392 _ => return ApiError::InternalError(Some("Repo root not found".into())).into_response(),
393 };
394 let current_root_cid = match Cid::from_str(&root_cid_str) {
395 Ok(c) => c,
396 Err(_) => {
397 return ApiError::InternalError(Some("Invalid repo root CID".into())).into_response();
398 }
399 };
400 if let Some(swap_commit) = &input.swap_commit
401 && Cid::from_str(swap_commit).ok() != Some(current_root_cid)
402 {
403 return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response();
404 }
405 let tracking_store = TrackingBlockStore::new(state.block_store.clone());
406 let commit_bytes = match tracking_store.get(¤t_root_cid).await {
407 Ok(Some(b)) => b,
408 _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(),
409 };
410 let commit = match Commit::from_cbor(&commit_bytes) {
411 Ok(c) => c,
412 _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(),
413 };
414 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
415 let initial_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
416 let WriteAccumulator {
417 mst,
418 results,
419 ops,
420 modified_keys,
421 all_blob_cids,
422 } = match process_writes(
423 &input.writes,
424 initial_mst,
425 &did,
426 input.validate,
427 &tracking_store,
428 )
429 .await
430 {
431 Ok(acc) => acc,
432 Err(response) => return response,
433 };
434 let new_mst_root = match mst.persist().await {
435 Ok(c) => c,
436 Err(_) => {
437 return ApiError::InternalError(Some("Failed to persist MST".into())).into_response();
438 }
439 };
440 let (new_mst_blocks, old_mst_blocks) = {
441 let mut new_blocks = std::collections::BTreeMap::new();
442 let mut old_blocks = std::collections::BTreeMap::new();
443 for key in &modified_keys {
444 if mst.blocks_for_path(key, &mut new_blocks).await.is_err() {
445 return ApiError::InternalError(Some(
446 "Failed to get new MST blocks for path".into(),
447 ))
448 .into_response();
449 }
450 if original_mst
451 .blocks_for_path(key, &mut old_blocks)
452 .await
453 .is_err()
454 {
455 return ApiError::InternalError(Some(
456 "Failed to get old MST blocks for path".into(),
457 ))
458 .into_response();
459 }
460 }
461 (new_blocks, old_blocks)
462 };
463 let mut relevant_blocks = new_mst_blocks.clone();
464 relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone())));
465 let written_cids: Vec<Cid> = tracking_store
466 .get_all_relevant_cids()
467 .into_iter()
468 .chain(relevant_blocks.keys().copied())
469 .collect::<std::collections::HashSet<_>>()
470 .into_iter()
471 .collect();
472 let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect();
473 let prev_record_cids = ops.iter().filter_map(|op| match op {
474 RecordOp::Update {
475 prev: Some(cid), ..
476 }
477 | RecordOp::Delete {
478 prev: Some(cid), ..
479 } => Some(*cid),
480 _ => None,
481 });
482 let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid)
483 .chain(
484 old_mst_blocks
485 .keys()
486 .filter(|cid| !new_mst_blocks.contains_key(*cid))
487 .copied(),
488 )
489 .chain(prev_record_cids)
490 .collect::<std::collections::HashSet<_>>()
491 .into_iter()
492 .collect();
493 let commit_res = match commit_and_log(
494 &state,
495 CommitParams {
496 did: &did,
497 user_id,
498 current_root_cid: Some(current_root_cid),
499 prev_data_cid: Some(commit.data),
500 new_mst_root,
501 ops,
502 blocks_cids: &written_cids_str,
503 blobs: &all_blob_cids,
504 obsolete_cids,
505 },
506 )
507 .await
508 {
509 Ok(res) => res,
510 Err(e) if e.contains("ConcurrentModification") => {
511 return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response();
512 }
513 Err(e) => {
514 error!("Commit failed: {}", e);
515 return ApiError::InternalError(Some("Failed to commit changes".into()))
516 .into_response();
517 }
518 };
519
520 if let Some(ref controller) = controller_did {
521 let write_summary: Vec<serde_json::Value> = input
522 .writes
523 .iter()
524 .map(|w| match w {
525 WriteOp::Create {
526 collection, rkey, ..
527 } => json!({
528 "action": "create",
529 "collection": collection,
530 "rkey": rkey
531 }),
532 WriteOp::Update {
533 collection, rkey, ..
534 } => json!({
535 "action": "update",
536 "collection": collection,
537 "rkey": rkey
538 }),
539 WriteOp::Delete { collection, rkey } => json!({
540 "action": "delete",
541 "collection": collection,
542 "rkey": rkey
543 }),
544 })
545 .collect();
546
547 let _ = delegation::log_delegation_action(
548 &state.db,
549 &did,
550 controller,
551 Some(controller),
552 DelegationActionType::RepoWrite,
553 Some(json!({
554 "action": "apply_writes",
555 "count": input.writes.len(),
556 "writes": write_summary
557 })),
558 None,
559 None,
560 )
561 .await;
562 }
563
564 (
565 StatusCode::OK,
566 Json(ApplyWritesOutput {
567 commit: CommitInfo {
568 cid: commit_res.commit_cid.to_string(),
569 rev: commit_res.rev,
570 },
571 results,
572 }),
573 )
574 .into_response()
575}