this repo has no description
1use super::validation::validate_record_with_status;
2use super::write::has_verified_comms_channel;
3use crate::api::error::ApiError;
4use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_blob_cids};
5use crate::auth::BearerAuth;
6use crate::delegation::{self, DelegationActionType};
7use crate::repo::tracking::TrackingBlockStore;
8use crate::state::AppState;
9use crate::types::{AtIdentifier, AtUri, Did, Nsid, Rkey};
10use axum::{
11 Json,
12 extract::State,
13 http::StatusCode,
14 response::{IntoResponse, Response},
15};
16use cid::Cid;
17use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
18use serde::{Deserialize, Serialize};
19use serde_json::json;
20use std::str::FromStr;
21use std::sync::Arc;
22use tracing::{error, info};
23
24const MAX_BATCH_WRITES: usize = 200;
25
26struct WriteAccumulator {
27 mst: Mst<TrackingBlockStore>,
28 results: Vec<WriteResult>,
29 ops: Vec<RecordOp>,
30 modified_keys: Vec<String>,
31 all_blob_cids: Vec<String>,
32}
33
34async fn process_single_write(
35 write: &WriteOp,
36 acc: WriteAccumulator,
37 did: &Did,
38 validate: Option<bool>,
39 tracking_store: &TrackingBlockStore,
40) -> Result<WriteAccumulator, Response> {
41 let WriteAccumulator {
42 mst,
43 mut results,
44 mut ops,
45 mut modified_keys,
46 mut all_blob_cids,
47 } = acc;
48
49 match write {
50 WriteOp::Create {
51 collection,
52 rkey,
53 value,
54 } => {
55 let validation_status = match validate {
56 Some(false) => None,
57 _ => {
58 let require_lexicon = validate == Some(true);
59 match validate_record_with_status(
60 value,
61 collection,
62 rkey.as_ref(),
63 require_lexicon,
64 ) {
65 Ok(status) => Some(status),
66 Err(err_response) => return Err(*err_response),
67 }
68 }
69 };
70 all_blob_cids.extend(extract_blob_cids(value));
71 let rkey = rkey.clone().unwrap_or_else(Rkey::generate);
72 let record_ipld = crate::util::json_to_ipld(value);
73 let record_bytes = serde_ipld_dagcbor::to_vec(&record_ipld).map_err(|_| {
74 ApiError::InvalidRecord("Failed to serialize record".into()).into_response()
75 })?;
76 let record_cid = tracking_store.put(&record_bytes).await.map_err(|_| {
77 ApiError::InternalError(Some("Failed to store record".into())).into_response()
78 })?;
79 let key = format!("{}/{}", collection, rkey);
80 modified_keys.push(key.clone());
81 let new_mst = mst.add(&key, record_cid).await.map_err(|_| {
82 ApiError::InternalError(Some("Failed to add to MST".into())).into_response()
83 })?;
84 let uri = AtUri::from_parts(did, collection, &rkey);
85 results.push(WriteResult::CreateResult {
86 uri,
87 cid: record_cid.to_string(),
88 validation_status: validation_status.map(|s| s.to_string()),
89 });
90 ops.push(RecordOp::Create {
91 collection: collection.clone(),
92 rkey: rkey.clone(),
93 cid: record_cid,
94 });
95 Ok(WriteAccumulator {
96 mst: new_mst,
97 results,
98 ops,
99 modified_keys,
100 all_blob_cids,
101 })
102 }
103 WriteOp::Update {
104 collection,
105 rkey,
106 value,
107 } => {
108 let validation_status = match validate {
109 Some(false) => None,
110 _ => {
111 let require_lexicon = validate == Some(true);
112 match validate_record_with_status(
113 value,
114 collection,
115 Some(rkey),
116 require_lexicon,
117 ) {
118 Ok(status) => Some(status),
119 Err(err_response) => return Err(*err_response),
120 }
121 }
122 };
123 all_blob_cids.extend(extract_blob_cids(value));
124 let record_ipld = crate::util::json_to_ipld(value);
125 let record_bytes = serde_ipld_dagcbor::to_vec(&record_ipld).map_err(|_| {
126 ApiError::InvalidRecord("Failed to serialize record".into()).into_response()
127 })?;
128 let record_cid = tracking_store.put(&record_bytes).await.map_err(|_| {
129 ApiError::InternalError(Some("Failed to store record".into())).into_response()
130 })?;
131 let key = format!("{}/{}", collection, rkey);
132 modified_keys.push(key.clone());
133 let prev_record_cid = mst.get(&key).await.ok().flatten();
134 let new_mst = mst.update(&key, record_cid).await.map_err(|_| {
135 ApiError::InternalError(Some("Failed to update MST".into())).into_response()
136 })?;
137 let uri = AtUri::from_parts(did, collection, rkey);
138 results.push(WriteResult::UpdateResult {
139 uri,
140 cid: record_cid.to_string(),
141 validation_status: validation_status.map(|s| s.to_string()),
142 });
143 ops.push(RecordOp::Update {
144 collection: collection.clone(),
145 rkey: rkey.clone(),
146 cid: record_cid,
147 prev: prev_record_cid,
148 });
149 Ok(WriteAccumulator {
150 mst: new_mst,
151 results,
152 ops,
153 modified_keys,
154 all_blob_cids,
155 })
156 }
157 WriteOp::Delete { collection, rkey } => {
158 let key = format!("{}/{}", collection, rkey);
159 modified_keys.push(key.clone());
160 let prev_record_cid = mst.get(&key).await.ok().flatten();
161 let new_mst = mst.delete(&key).await.map_err(|_| {
162 ApiError::InternalError(Some("Failed to delete from MST".into())).into_response()
163 })?;
164 results.push(WriteResult::DeleteResult {});
165 ops.push(RecordOp::Delete {
166 collection: collection.clone(),
167 rkey: rkey.clone(),
168 prev: prev_record_cid,
169 });
170 Ok(WriteAccumulator {
171 mst: new_mst,
172 results,
173 ops,
174 modified_keys,
175 all_blob_cids,
176 })
177 }
178 }
179}
180
181async fn process_writes(
182 writes: &[WriteOp],
183 initial_mst: Mst<TrackingBlockStore>,
184 did: &Did,
185 validate: Option<bool>,
186 tracking_store: &TrackingBlockStore,
187) -> Result<WriteAccumulator, Response> {
188 use futures::stream::{self, TryStreamExt};
189 let initial_acc = WriteAccumulator {
190 mst: initial_mst,
191 results: Vec::new(),
192 ops: Vec::new(),
193 modified_keys: Vec::new(),
194 all_blob_cids: Vec::new(),
195 };
196 stream::iter(writes.iter().map(Ok::<_, Response>))
197 .try_fold(initial_acc, |acc, write| async move {
198 process_single_write(write, acc, did, validate, tracking_store).await
199 })
200 .await
201}
202
203#[derive(Deserialize)]
204#[serde(tag = "$type")]
205pub enum WriteOp {
206 #[serde(rename = "com.atproto.repo.applyWrites#create")]
207 Create {
208 collection: Nsid,
209 rkey: Option<Rkey>,
210 value: serde_json::Value,
211 },
212 #[serde(rename = "com.atproto.repo.applyWrites#update")]
213 Update {
214 collection: Nsid,
215 rkey: Rkey,
216 value: serde_json::Value,
217 },
218 #[serde(rename = "com.atproto.repo.applyWrites#delete")]
219 Delete { collection: Nsid, rkey: Rkey },
220}
221
222#[derive(Deserialize)]
223#[serde(rename_all = "camelCase")]
224pub struct ApplyWritesInput {
225 pub repo: AtIdentifier,
226 pub validate: Option<bool>,
227 pub writes: Vec<WriteOp>,
228 pub swap_commit: Option<String>,
229}
230
231#[derive(Serialize)]
232#[serde(tag = "$type")]
233pub enum WriteResult {
234 #[serde(rename = "com.atproto.repo.applyWrites#createResult")]
235 CreateResult {
236 uri: AtUri,
237 cid: String,
238 #[serde(rename = "validationStatus", skip_serializing_if = "Option::is_none")]
239 validation_status: Option<String>,
240 },
241 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")]
242 UpdateResult {
243 uri: AtUri,
244 cid: String,
245 #[serde(rename = "validationStatus", skip_serializing_if = "Option::is_none")]
246 validation_status: Option<String>,
247 },
248 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")]
249 DeleteResult {},
250}
251
252#[derive(Serialize)]
253pub struct ApplyWritesOutput {
254 pub commit: CommitInfo,
255 pub results: Vec<WriteResult>,
256}
257
258#[derive(Serialize)]
259pub struct CommitInfo {
260 pub cid: String,
261 pub rev: String,
262}
263
264pub async fn apply_writes(
265 State(state): State<AppState>,
266 auth: BearerAuth,
267 Json(input): Json<ApplyWritesInput>,
268) -> Response {
269 info!(
270 "apply_writes called: repo={}, writes={}",
271 input.repo,
272 input.writes.len()
273 );
274 let auth_user = auth.0;
275 let did = auth_user.did.clone();
276 let is_oauth = auth_user.is_oauth;
277 let scope = auth_user.scope;
278 let controller_did = auth_user.controller_did.clone();
279 if input.repo.as_str() != did {
280 return ApiError::InvalidRepo("Repo does not match authenticated user".into())
281 .into_response();
282 }
283 if crate::util::is_account_migrated(&state.db, &did)
284 .await
285 .unwrap_or(false)
286 {
287 return ApiError::AccountMigrated.into_response();
288 }
289 let is_verified = has_verified_comms_channel(&state.db, &did)
290 .await
291 .unwrap_or(false);
292 let is_delegated = crate::delegation::is_delegated_account(&state.db, &did)
293 .await
294 .unwrap_or(false);
295 if !is_verified && !is_delegated {
296 return ApiError::AccountNotVerified.into_response();
297 }
298 if input.writes.is_empty() {
299 return ApiError::InvalidRequest("writes array is empty".into()).into_response();
300 }
301 if input.writes.len() > MAX_BATCH_WRITES {
302 return ApiError::InvalidRequest(format!("Too many writes (max {})", MAX_BATCH_WRITES))
303 .into_response();
304 }
305
306 let has_custom_scope = scope
307 .as_ref()
308 .map(|s| s != "com.atproto.access")
309 .unwrap_or(false);
310 if is_oauth || has_custom_scope {
311 use std::collections::HashSet;
312 let create_collections: HashSet<&Nsid> = input
313 .writes
314 .iter()
315 .filter_map(|w| {
316 if let WriteOp::Create { collection, .. } = w {
317 Some(collection)
318 } else {
319 None
320 }
321 })
322 .collect();
323 let update_collections: HashSet<&Nsid> = input
324 .writes
325 .iter()
326 .filter_map(|w| {
327 if let WriteOp::Update { collection, .. } = w {
328 Some(collection)
329 } else {
330 None
331 }
332 })
333 .collect();
334 let delete_collections: HashSet<&Nsid> = input
335 .writes
336 .iter()
337 .filter_map(|w| {
338 if let WriteOp::Delete { collection, .. } = w {
339 Some(collection)
340 } else {
341 None
342 }
343 })
344 .collect();
345
346 for collection in create_collections {
347 if let Err(e) = crate::auth::scope_check::check_repo_scope(
348 is_oauth,
349 scope.as_deref(),
350 crate::oauth::RepoAction::Create,
351 collection,
352 ) {
353 return e;
354 }
355 }
356 for collection in update_collections {
357 if let Err(e) = crate::auth::scope_check::check_repo_scope(
358 is_oauth,
359 scope.as_deref(),
360 crate::oauth::RepoAction::Update,
361 collection,
362 ) {
363 return e;
364 }
365 }
366 for collection in delete_collections {
367 if let Err(e) = crate::auth::scope_check::check_repo_scope(
368 is_oauth,
369 scope.as_deref(),
370 crate::oauth::RepoAction::Delete,
371 collection,
372 ) {
373 return e;
374 }
375 }
376 }
377
378 let user_id: uuid::Uuid =
379 match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did.as_str())
380 .fetch_optional(&state.db)
381 .await
382 {
383 Ok(Some(id)) => id,
384 _ => return ApiError::InternalError(Some("User not found".into())).into_response(),
385 };
386 let root_cid_str: String = match sqlx::query_scalar!(
387 "SELECT repo_root_cid FROM repos WHERE user_id = $1",
388 user_id
389 )
390 .fetch_optional(&state.db)
391 .await
392 {
393 Ok(Some(cid_str)) => cid_str,
394 _ => return ApiError::InternalError(Some("Repo root not found".into())).into_response(),
395 };
396 let current_root_cid = match Cid::from_str(&root_cid_str) {
397 Ok(c) => c,
398 Err(_) => {
399 return ApiError::InternalError(Some("Invalid repo root CID".into())).into_response();
400 }
401 };
402 if let Some(swap_commit) = &input.swap_commit
403 && Cid::from_str(swap_commit).ok() != Some(current_root_cid)
404 {
405 return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response();
406 }
407 let tracking_store = TrackingBlockStore::new(state.block_store.clone());
408 let commit_bytes = match tracking_store.get(¤t_root_cid).await {
409 Ok(Some(b)) => b,
410 _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(),
411 };
412 let commit = match Commit::from_cbor(&commit_bytes) {
413 Ok(c) => c,
414 _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(),
415 };
416 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
417 let initial_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
418 let WriteAccumulator {
419 mst,
420 results,
421 ops,
422 modified_keys,
423 all_blob_cids,
424 } = match process_writes(
425 &input.writes,
426 initial_mst,
427 &did,
428 input.validate,
429 &tracking_store,
430 )
431 .await
432 {
433 Ok(acc) => acc,
434 Err(response) => return response,
435 };
436 let new_mst_root = match mst.persist().await {
437 Ok(c) => c,
438 Err(_) => {
439 return ApiError::InternalError(Some("Failed to persist MST".into())).into_response();
440 }
441 };
442 let mut new_mst_blocks = std::collections::BTreeMap::new();
443 let mut old_mst_blocks = std::collections::BTreeMap::new();
444 for key in &modified_keys {
445 if mst.blocks_for_path(key, &mut new_mst_blocks).await.is_err() {
446 return ApiError::InternalError(Some("Failed to get new MST blocks for path".into()))
447 .into_response();
448 }
449 if original_mst
450 .blocks_for_path(key, &mut old_mst_blocks)
451 .await
452 .is_err()
453 {
454 return ApiError::InternalError(Some("Failed to get old MST blocks for path".into()))
455 .into_response();
456 }
457 }
458 let mut relevant_blocks = new_mst_blocks.clone();
459 relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone())));
460 let written_cids: Vec<Cid> = tracking_store
461 .get_all_relevant_cids()
462 .into_iter()
463 .chain(relevant_blocks.keys().copied())
464 .collect::<std::collections::HashSet<_>>()
465 .into_iter()
466 .collect();
467 let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect();
468 let prev_record_cids = ops.iter().filter_map(|op| match op {
469 RecordOp::Update {
470 prev: Some(cid), ..
471 }
472 | RecordOp::Delete {
473 prev: Some(cid), ..
474 } => Some(*cid),
475 _ => None,
476 });
477 let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid)
478 .chain(
479 old_mst_blocks
480 .keys()
481 .filter(|cid| !new_mst_blocks.contains_key(*cid))
482 .copied(),
483 )
484 .chain(prev_record_cids)
485 .collect::<std::collections::HashSet<_>>()
486 .into_iter()
487 .collect();
488 let commit_res = match commit_and_log(
489 &state,
490 CommitParams {
491 did: &did,
492 user_id,
493 current_root_cid: Some(current_root_cid),
494 prev_data_cid: Some(commit.data),
495 new_mst_root,
496 ops,
497 blocks_cids: &written_cids_str,
498 blobs: &all_blob_cids,
499 obsolete_cids,
500 },
501 )
502 .await
503 {
504 Ok(res) => res,
505 Err(e) if e.contains("ConcurrentModification") => {
506 return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response();
507 }
508 Err(e) => {
509 error!("Commit failed: {}", e);
510 return ApiError::InternalError(Some("Failed to commit changes".into()))
511 .into_response();
512 }
513 };
514
515 if let Some(ref controller) = controller_did {
516 let write_summary: Vec<serde_json::Value> = input
517 .writes
518 .iter()
519 .map(|w| match w {
520 WriteOp::Create {
521 collection, rkey, ..
522 } => json!({
523 "action": "create",
524 "collection": collection,
525 "rkey": rkey
526 }),
527 WriteOp::Update {
528 collection, rkey, ..
529 } => json!({
530 "action": "update",
531 "collection": collection,
532 "rkey": rkey
533 }),
534 WriteOp::Delete { collection, rkey } => json!({
535 "action": "delete",
536 "collection": collection,
537 "rkey": rkey
538 }),
539 })
540 .collect();
541
542 let _ = delegation::log_delegation_action(
543 &state.db,
544 &did,
545 controller,
546 Some(controller),
547 DelegationActionType::RepoWrite,
548 Some(json!({
549 "action": "apply_writes",
550 "count": input.writes.len(),
551 "writes": write_summary
552 })),
553 None,
554 None,
555 )
556 .await;
557 }
558
559 (
560 StatusCode::OK,
561 Json(ApplyWritesOutput {
562 commit: CommitInfo {
563 cid: commit_res.commit_cid.to_string(),
564 rev: commit_res.rev,
565 },
566 results,
567 }),
568 )
569 .into_response()
570}