this repo has no description
1use super::validation::validate_record;
2use super::write::has_verified_comms_channel;
3use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log};
4use crate::repo::tracking::TrackingBlockStore;
5use crate::state::AppState;
6use axum::{
7 Json,
8 extract::State,
9 http::StatusCode,
10 response::{IntoResponse, Response},
11};
12use cid::Cid;
13use jacquard::types::{
14 integer::LimitedU32,
15 string::{Nsid, Tid},
16};
17use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
18use serde::{Deserialize, Serialize};
19use serde_json::json;
20use std::str::FromStr;
21use std::sync::Arc;
22use tracing::{error, info};
23
24const MAX_BATCH_WRITES: usize = 200;
25
26#[derive(Deserialize)]
27#[serde(tag = "$type")]
28pub enum WriteOp {
29 #[serde(rename = "com.atproto.repo.applyWrites#create")]
30 Create {
31 collection: String,
32 rkey: Option<String>,
33 value: serde_json::Value,
34 },
35 #[serde(rename = "com.atproto.repo.applyWrites#update")]
36 Update {
37 collection: String,
38 rkey: String,
39 value: serde_json::Value,
40 },
41 #[serde(rename = "com.atproto.repo.applyWrites#delete")]
42 Delete { collection: String, rkey: String },
43}
44
45#[derive(Deserialize)]
46#[serde(rename_all = "camelCase")]
47pub struct ApplyWritesInput {
48 pub repo: String,
49 pub validate: Option<bool>,
50 pub writes: Vec<WriteOp>,
51 pub swap_commit: Option<String>,
52}
53
54#[derive(Serialize)]
55#[serde(tag = "$type")]
56pub enum WriteResult {
57 #[serde(rename = "com.atproto.repo.applyWrites#createResult")]
58 CreateResult { uri: String, cid: String },
59 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")]
60 UpdateResult { uri: String, cid: String },
61 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")]
62 DeleteResult {},
63}
64
65#[derive(Serialize)]
66pub struct ApplyWritesOutput {
67 pub commit: CommitInfo,
68 pub results: Vec<WriteResult>,
69}
70
71#[derive(Serialize)]
72pub struct CommitInfo {
73 pub cid: String,
74 pub rev: String,
75}
76
77pub async fn apply_writes(
78 State(state): State<AppState>,
79 headers: axum::http::HeaderMap,
80 Json(input): Json<ApplyWritesInput>,
81) -> Response {
82 info!(
83 "apply_writes called: repo={}, writes={}",
84 input.repo,
85 input.writes.len()
86 );
87 let token = match crate::auth::extract_bearer_token_from_header(
88 headers.get("Authorization").and_then(|h| h.to_str().ok()),
89 ) {
90 Some(t) => t,
91 None => {
92 return (
93 StatusCode::UNAUTHORIZED,
94 Json(json!({"error": "AuthenticationRequired"})),
95 )
96 .into_response();
97 }
98 };
99 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
100 Ok(user) => user,
101 Err(_) => {
102 return (
103 StatusCode::UNAUTHORIZED,
104 Json(json!({"error": "AuthenticationFailed"})),
105 )
106 .into_response();
107 }
108 };
109 let did = auth_user.did.clone();
110 let is_oauth = auth_user.is_oauth;
111 let scope = auth_user.scope;
112 if input.repo != did {
113 return (
114 StatusCode::FORBIDDEN,
115 Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"})),
116 )
117 .into_response();
118 }
119 match has_verified_comms_channel(&state.db, &did).await {
120 Ok(true) => {}
121 Ok(false) => {
122 return (
123 StatusCode::FORBIDDEN,
124 Json(json!({
125 "error": "AccountNotVerified",
126 "message": "You must verify at least one notification channel (email, Discord, Telegram, or Signal) before creating records"
127 })),
128 )
129 .into_response();
130 }
131 Err(e) => {
132 error!("DB error checking notification channels: {}", e);
133 return (
134 StatusCode::INTERNAL_SERVER_ERROR,
135 Json(json!({"error": "InternalError"})),
136 )
137 .into_response();
138 }
139 }
140 if input.writes.is_empty() {
141 return (
142 StatusCode::BAD_REQUEST,
143 Json(json!({"error": "InvalidRequest", "message": "writes array is empty"})),
144 )
145 .into_response();
146 }
147 if input.writes.len() > MAX_BATCH_WRITES {
148 return (
149 StatusCode::BAD_REQUEST,
150 Json(json!({"error": "InvalidRequest", "message": format!("Too many writes (max {})", MAX_BATCH_WRITES)})),
151 )
152 .into_response();
153 }
154
155 let has_custom_scope = scope
156 .as_ref()
157 .map(|s| s != "com.atproto.access")
158 .unwrap_or(false);
159 if is_oauth || has_custom_scope {
160 use std::collections::HashSet;
161 let create_collections: HashSet<&str> = input
162 .writes
163 .iter()
164 .filter_map(|w| {
165 if let WriteOp::Create { collection, .. } = w {
166 Some(collection.as_str())
167 } else {
168 None
169 }
170 })
171 .collect();
172 let update_collections: HashSet<&str> = input
173 .writes
174 .iter()
175 .filter_map(|w| {
176 if let WriteOp::Update { collection, .. } = w {
177 Some(collection.as_str())
178 } else {
179 None
180 }
181 })
182 .collect();
183 let delete_collections: HashSet<&str> = input
184 .writes
185 .iter()
186 .filter_map(|w| {
187 if let WriteOp::Delete { collection, .. } = w {
188 Some(collection.as_str())
189 } else {
190 None
191 }
192 })
193 .collect();
194
195 for collection in create_collections {
196 if let Err(e) = crate::auth::scope_check::check_repo_scope(
197 is_oauth,
198 scope.as_deref(),
199 crate::oauth::RepoAction::Create,
200 collection,
201 ) {
202 return e;
203 }
204 }
205 for collection in update_collections {
206 if let Err(e) = crate::auth::scope_check::check_repo_scope(
207 is_oauth,
208 scope.as_deref(),
209 crate::oauth::RepoAction::Update,
210 collection,
211 ) {
212 return e;
213 }
214 }
215 for collection in delete_collections {
216 if let Err(e) = crate::auth::scope_check::check_repo_scope(
217 is_oauth,
218 scope.as_deref(),
219 crate::oauth::RepoAction::Delete,
220 collection,
221 ) {
222 return e;
223 }
224 }
225 }
226
227 let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
228 .fetch_optional(&state.db)
229 .await
230 {
231 Ok(Some(id)) => id,
232 _ => {
233 return (
234 StatusCode::INTERNAL_SERVER_ERROR,
235 Json(json!({"error": "InternalError", "message": "User not found"})),
236 )
237 .into_response();
238 }
239 };
240 let root_cid_str: String = match sqlx::query_scalar!(
241 "SELECT repo_root_cid FROM repos WHERE user_id = $1",
242 user_id
243 )
244 .fetch_optional(&state.db)
245 .await
246 {
247 Ok(Some(cid_str)) => cid_str,
248 _ => {
249 return (
250 StatusCode::INTERNAL_SERVER_ERROR,
251 Json(json!({"error": "InternalError", "message": "Repo root not found"})),
252 )
253 .into_response();
254 }
255 };
256 let current_root_cid = match Cid::from_str(&root_cid_str) {
257 Ok(c) => c,
258 Err(_) => {
259 return (
260 StatusCode::INTERNAL_SERVER_ERROR,
261 Json(json!({"error": "InternalError", "message": "Invalid repo root CID"})),
262 )
263 .into_response();
264 }
265 };
266 if let Some(swap_commit) = &input.swap_commit
267 && Cid::from_str(swap_commit).ok() != Some(current_root_cid)
268 {
269 return (
270 StatusCode::CONFLICT,
271 Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
272 )
273 .into_response();
274 }
275 let tracking_store = TrackingBlockStore::new(state.block_store.clone());
276 let commit_bytes = match tracking_store.get(¤t_root_cid).await {
277 Ok(Some(b)) => b,
278 _ => {
279 return (
280 StatusCode::INTERNAL_SERVER_ERROR,
281 Json(json!({"error": "InternalError", "message": "Commit block not found"})),
282 )
283 .into_response();
284 }
285 };
286 let commit = match Commit::from_cbor(&commit_bytes) {
287 Ok(c) => c,
288 _ => {
289 return (
290 StatusCode::INTERNAL_SERVER_ERROR,
291 Json(json!({"error": "InternalError", "message": "Failed to parse commit"})),
292 )
293 .into_response();
294 }
295 };
296 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
297 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
298 let mut results: Vec<WriteResult> = Vec::new();
299 let mut ops: Vec<RecordOp> = Vec::new();
300 let mut modified_keys: Vec<String> = Vec::new();
301 for write in &input.writes {
302 match write {
303 WriteOp::Create {
304 collection,
305 rkey,
306 value,
307 } => {
308 if input.validate.unwrap_or(true)
309 && let Err(err_response) = validate_record(value, collection)
310 {
311 return *err_response;
312 }
313 let rkey = rkey
314 .clone()
315 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string());
316 let mut record_bytes = Vec::new();
317 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
318 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
319 }
320 let record_cid = match tracking_store.put(&record_bytes).await {
321 Ok(c) => c,
322 Err(_) => return (
323 StatusCode::INTERNAL_SERVER_ERROR,
324 Json(
325 json!({"error": "InternalError", "message": "Failed to store record"}),
326 ),
327 )
328 .into_response(),
329 };
330 let collection_nsid = match collection.parse::<Nsid>() {
331 Ok(n) => n,
332 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
333 };
334 let key = format!("{}/{}", collection_nsid, rkey);
335 modified_keys.push(key.clone());
336 mst = match mst.add(&key, record_cid).await {
337 Ok(m) => m,
338 Err(_) => return (
339 StatusCode::INTERNAL_SERVER_ERROR,
340 Json(json!({"error": "InternalError", "message": "Failed to add to MST"})),
341 )
342 .into_response(),
343 };
344 let uri = format!("at://{}/{}/{}", did, collection, rkey);
345 results.push(WriteResult::CreateResult {
346 uri,
347 cid: record_cid.to_string(),
348 });
349 ops.push(RecordOp::Create {
350 collection: collection.clone(),
351 rkey,
352 cid: record_cid,
353 });
354 }
355 WriteOp::Update {
356 collection,
357 rkey,
358 value,
359 } => {
360 if input.validate.unwrap_or(true)
361 && let Err(err_response) = validate_record(value, collection)
362 {
363 return *err_response;
364 }
365 let mut record_bytes = Vec::new();
366 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
367 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
368 }
369 let record_cid = match tracking_store.put(&record_bytes).await {
370 Ok(c) => c,
371 Err(_) => return (
372 StatusCode::INTERNAL_SERVER_ERROR,
373 Json(
374 json!({"error": "InternalError", "message": "Failed to store record"}),
375 ),
376 )
377 .into_response(),
378 };
379 let collection_nsid = match collection.parse::<Nsid>() {
380 Ok(n) => n,
381 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
382 };
383 let key = format!("{}/{}", collection_nsid, rkey);
384 modified_keys.push(key.clone());
385 let prev_record_cid = mst.get(&key).await.ok().flatten();
386 mst = match mst.update(&key, record_cid).await {
387 Ok(m) => m,
388 Err(_) => return (
389 StatusCode::INTERNAL_SERVER_ERROR,
390 Json(json!({"error": "InternalError", "message": "Failed to update MST"})),
391 )
392 .into_response(),
393 };
394 let uri = format!("at://{}/{}/{}", did, collection, rkey);
395 results.push(WriteResult::UpdateResult {
396 uri,
397 cid: record_cid.to_string(),
398 });
399 ops.push(RecordOp::Update {
400 collection: collection.clone(),
401 rkey: rkey.clone(),
402 cid: record_cid,
403 prev: prev_record_cid,
404 });
405 }
406 WriteOp::Delete { collection, rkey } => {
407 let collection_nsid = match collection.parse::<Nsid>() {
408 Ok(n) => n,
409 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
410 };
411 let key = format!("{}/{}", collection_nsid, rkey);
412 modified_keys.push(key.clone());
413 let prev_record_cid = mst.get(&key).await.ok().flatten();
414 mst = match mst.delete(&key).await {
415 Ok(m) => m,
416 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to delete from MST"}))).into_response(),
417 };
418 results.push(WriteResult::DeleteResult {});
419 ops.push(RecordOp::Delete {
420 collection: collection.clone(),
421 rkey: rkey.clone(),
422 prev: prev_record_cid,
423 });
424 }
425 }
426 }
427 let new_mst_root = match mst.persist().await {
428 Ok(c) => c,
429 Err(_) => {
430 return (
431 StatusCode::INTERNAL_SERVER_ERROR,
432 Json(json!({"error": "InternalError", "message": "Failed to persist MST"})),
433 )
434 .into_response();
435 }
436 };
437 let mut relevant_blocks = std::collections::BTreeMap::new();
438 for key in &modified_keys {
439 if mst
440 .blocks_for_path(key, &mut relevant_blocks)
441 .await
442 .is_err()
443 {
444 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
445 }
446 if original_mst
447 .blocks_for_path(key, &mut relevant_blocks)
448 .await
449 .is_err()
450 {
451 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response();
452 }
453 }
454 let mut written_cids = tracking_store.get_all_relevant_cids();
455 for cid in relevant_blocks.keys() {
456 if !written_cids.contains(cid) {
457 written_cids.push(*cid);
458 }
459 }
460 let written_cids_str = written_cids
461 .iter()
462 .map(|c| c.to_string())
463 .collect::<Vec<_>>();
464 let commit_res = match commit_and_log(
465 &state,
466 CommitParams {
467 did: &did,
468 user_id,
469 current_root_cid: Some(current_root_cid),
470 prev_data_cid: Some(commit.data),
471 new_mst_root,
472 ops,
473 blocks_cids: &written_cids_str,
474 },
475 )
476 .await
477 {
478 Ok(res) => res,
479 Err(e) => {
480 error!("Commit failed: {}", e);
481 return (
482 StatusCode::INTERNAL_SERVER_ERROR,
483 Json(json!({"error": "InternalError", "message": "Failed to commit changes"})),
484 )
485 .into_response();
486 }
487 };
488 (
489 StatusCode::OK,
490 Json(ApplyWritesOutput {
491 commit: CommitInfo {
492 cid: commit_res.commit_cid.to_string(),
493 rev: commit_res.rev,
494 },
495 results,
496 }),
497 )
498 .into_response()
499}