this repo has no description
1use super::validation::validate_record;
2use super::write::has_verified_comms_channel;
3use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log};
4use crate::repo::tracking::TrackingBlockStore;
5use crate::state::AppState;
6use axum::{
7 Json,
8 extract::State,
9 http::StatusCode,
10 response::{IntoResponse, Response},
11};
12use cid::Cid;
13use jacquard::types::{
14 integer::LimitedU32,
15 string::{Nsid, Tid},
16};
17use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
18use serde::{Deserialize, Serialize};
19use serde_json::json;
20use std::str::FromStr;
21use std::sync::Arc;
22use tracing::error;
23
24const MAX_BATCH_WRITES: usize = 200;
25
26#[derive(Deserialize)]
27#[serde(tag = "$type")]
28pub enum WriteOp {
29 #[serde(rename = "com.atproto.repo.applyWrites#create")]
30 Create {
31 collection: String,
32 rkey: Option<String>,
33 value: serde_json::Value,
34 },
35 #[serde(rename = "com.atproto.repo.applyWrites#update")]
36 Update {
37 collection: String,
38 rkey: String,
39 value: serde_json::Value,
40 },
41 #[serde(rename = "com.atproto.repo.applyWrites#delete")]
42 Delete { collection: String, rkey: String },
43}
44
45#[derive(Deserialize)]
46#[serde(rename_all = "camelCase")]
47pub struct ApplyWritesInput {
48 pub repo: String,
49 pub validate: Option<bool>,
50 pub writes: Vec<WriteOp>,
51 pub swap_commit: Option<String>,
52}
53
54#[derive(Serialize)]
55#[serde(tag = "$type")]
56pub enum WriteResult {
57 #[serde(rename = "com.atproto.repo.applyWrites#createResult")]
58 CreateResult { uri: String, cid: String },
59 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")]
60 UpdateResult { uri: String, cid: String },
61 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")]
62 DeleteResult {},
63}
64
65#[derive(Serialize)]
66pub struct ApplyWritesOutput {
67 pub commit: CommitInfo,
68 pub results: Vec<WriteResult>,
69}
70
71#[derive(Serialize)]
72pub struct CommitInfo {
73 pub cid: String,
74 pub rev: String,
75}
76
77pub async fn apply_writes(
78 State(state): State<AppState>,
79 headers: axum::http::HeaderMap,
80 Json(input): Json<ApplyWritesInput>,
81) -> Response {
82 let token = match crate::auth::extract_bearer_token_from_header(
83 headers.get("Authorization").and_then(|h| h.to_str().ok()),
84 ) {
85 Some(t) => t,
86 None => {
87 return (
88 StatusCode::UNAUTHORIZED,
89 Json(json!({"error": "AuthenticationRequired"})),
90 )
91 .into_response();
92 }
93 };
94 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
95 Ok(user) => user,
96 Err(_) => {
97 return (
98 StatusCode::UNAUTHORIZED,
99 Json(json!({"error": "AuthenticationFailed"})),
100 )
101 .into_response();
102 }
103 };
104 let did = auth_user.did.clone();
105 let is_oauth = auth_user.is_oauth;
106 let scope = auth_user.scope;
107 if input.repo != did {
108 return (
109 StatusCode::FORBIDDEN,
110 Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"})),
111 )
112 .into_response();
113 }
114 match has_verified_comms_channel(&state.db, &did).await {
115 Ok(true) => {}
116 Ok(false) => {
117 return (
118 StatusCode::FORBIDDEN,
119 Json(json!({
120 "error": "AccountNotVerified",
121 "message": "You must verify at least one notification channel (email, Discord, Telegram, or Signal) before creating records"
122 })),
123 )
124 .into_response();
125 }
126 Err(e) => {
127 error!("DB error checking notification channels: {}", e);
128 return (
129 StatusCode::INTERNAL_SERVER_ERROR,
130 Json(json!({"error": "InternalError"})),
131 )
132 .into_response();
133 }
134 }
135 if input.writes.is_empty() {
136 return (
137 StatusCode::BAD_REQUEST,
138 Json(json!({"error": "InvalidRequest", "message": "writes array is empty"})),
139 )
140 .into_response();
141 }
142 if input.writes.len() > MAX_BATCH_WRITES {
143 return (
144 StatusCode::BAD_REQUEST,
145 Json(json!({"error": "InvalidRequest", "message": format!("Too many writes (max {})", MAX_BATCH_WRITES)})),
146 )
147 .into_response();
148 }
149
150 if is_oauth {
151 use std::collections::HashSet;
152 let create_collections: HashSet<&str> = input
153 .writes
154 .iter()
155 .filter_map(|w| {
156 if let WriteOp::Create { collection, .. } = w {
157 Some(collection.as_str())
158 } else {
159 None
160 }
161 })
162 .collect();
163 let update_collections: HashSet<&str> = input
164 .writes
165 .iter()
166 .filter_map(|w| {
167 if let WriteOp::Update { collection, .. } = w {
168 Some(collection.as_str())
169 } else {
170 None
171 }
172 })
173 .collect();
174 let delete_collections: HashSet<&str> = input
175 .writes
176 .iter()
177 .filter_map(|w| {
178 if let WriteOp::Delete { collection, .. } = w {
179 Some(collection.as_str())
180 } else {
181 None
182 }
183 })
184 .collect();
185
186 for collection in create_collections {
187 if let Err(e) = crate::auth::scope_check::check_repo_scope(
188 is_oauth,
189 scope.as_deref(),
190 crate::oauth::RepoAction::Create,
191 collection,
192 ) {
193 return e;
194 }
195 }
196 for collection in update_collections {
197 if let Err(e) = crate::auth::scope_check::check_repo_scope(
198 is_oauth,
199 scope.as_deref(),
200 crate::oauth::RepoAction::Update,
201 collection,
202 ) {
203 return e;
204 }
205 }
206 for collection in delete_collections {
207 if let Err(e) = crate::auth::scope_check::check_repo_scope(
208 is_oauth,
209 scope.as_deref(),
210 crate::oauth::RepoAction::Delete,
211 collection,
212 ) {
213 return e;
214 }
215 }
216 }
217
218 let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
219 .fetch_optional(&state.db)
220 .await
221 {
222 Ok(Some(id)) => id,
223 _ => {
224 return (
225 StatusCode::INTERNAL_SERVER_ERROR,
226 Json(json!({"error": "InternalError", "message": "User not found"})),
227 )
228 .into_response();
229 }
230 };
231 let root_cid_str: String = match sqlx::query_scalar!(
232 "SELECT repo_root_cid FROM repos WHERE user_id = $1",
233 user_id
234 )
235 .fetch_optional(&state.db)
236 .await
237 {
238 Ok(Some(cid_str)) => cid_str,
239 _ => {
240 return (
241 StatusCode::INTERNAL_SERVER_ERROR,
242 Json(json!({"error": "InternalError", "message": "Repo root not found"})),
243 )
244 .into_response();
245 }
246 };
247 let current_root_cid = match Cid::from_str(&root_cid_str) {
248 Ok(c) => c,
249 Err(_) => {
250 return (
251 StatusCode::INTERNAL_SERVER_ERROR,
252 Json(json!({"error": "InternalError", "message": "Invalid repo root CID"})),
253 )
254 .into_response();
255 }
256 };
257 if let Some(swap_commit) = &input.swap_commit
258 && Cid::from_str(swap_commit).ok() != Some(current_root_cid)
259 {
260 return (
261 StatusCode::CONFLICT,
262 Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
263 )
264 .into_response();
265 }
266 let tracking_store = TrackingBlockStore::new(state.block_store.clone());
267 let commit_bytes = match tracking_store.get(¤t_root_cid).await {
268 Ok(Some(b)) => b,
269 _ => {
270 return (
271 StatusCode::INTERNAL_SERVER_ERROR,
272 Json(json!({"error": "InternalError", "message": "Commit block not found"})),
273 )
274 .into_response();
275 }
276 };
277 let commit = match Commit::from_cbor(&commit_bytes) {
278 Ok(c) => c,
279 _ => {
280 return (
281 StatusCode::INTERNAL_SERVER_ERROR,
282 Json(json!({"error": "InternalError", "message": "Failed to parse commit"})),
283 )
284 .into_response();
285 }
286 };
287 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
288 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
289 let mut results: Vec<WriteResult> = Vec::new();
290 let mut ops: Vec<RecordOp> = Vec::new();
291 let mut modified_keys: Vec<String> = Vec::new();
292 for write in &input.writes {
293 match write {
294 WriteOp::Create {
295 collection,
296 rkey,
297 value,
298 } => {
299 if input.validate.unwrap_or(true)
300 && let Err(err_response) = validate_record(value, collection)
301 {
302 return *err_response;
303 }
304 let rkey = rkey
305 .clone()
306 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string());
307 let mut record_bytes = Vec::new();
308 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
309 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
310 }
311 let record_cid = match tracking_store.put(&record_bytes).await {
312 Ok(c) => c,
313 Err(_) => return (
314 StatusCode::INTERNAL_SERVER_ERROR,
315 Json(
316 json!({"error": "InternalError", "message": "Failed to store record"}),
317 ),
318 )
319 .into_response(),
320 };
321 let collection_nsid = match collection.parse::<Nsid>() {
322 Ok(n) => n,
323 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
324 };
325 let key = format!("{}/{}", collection_nsid, rkey);
326 modified_keys.push(key.clone());
327 mst = match mst.add(&key, record_cid).await {
328 Ok(m) => m,
329 Err(_) => return (
330 StatusCode::INTERNAL_SERVER_ERROR,
331 Json(json!({"error": "InternalError", "message": "Failed to add to MST"})),
332 )
333 .into_response(),
334 };
335 let uri = format!("at://{}/{}/{}", did, collection, rkey);
336 results.push(WriteResult::CreateResult {
337 uri,
338 cid: record_cid.to_string(),
339 });
340 ops.push(RecordOp::Create {
341 collection: collection.clone(),
342 rkey,
343 cid: record_cid,
344 });
345 }
346 WriteOp::Update {
347 collection,
348 rkey,
349 value,
350 } => {
351 if input.validate.unwrap_or(true)
352 && let Err(err_response) = validate_record(value, collection)
353 {
354 return *err_response;
355 }
356 let mut record_bytes = Vec::new();
357 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
358 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
359 }
360 let record_cid = match tracking_store.put(&record_bytes).await {
361 Ok(c) => c,
362 Err(_) => return (
363 StatusCode::INTERNAL_SERVER_ERROR,
364 Json(
365 json!({"error": "InternalError", "message": "Failed to store record"}),
366 ),
367 )
368 .into_response(),
369 };
370 let collection_nsid = match collection.parse::<Nsid>() {
371 Ok(n) => n,
372 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
373 };
374 let key = format!("{}/{}", collection_nsid, rkey);
375 modified_keys.push(key.clone());
376 let prev_record_cid = mst.get(&key).await.ok().flatten();
377 mst = match mst.update(&key, record_cid).await {
378 Ok(m) => m,
379 Err(_) => return (
380 StatusCode::INTERNAL_SERVER_ERROR,
381 Json(json!({"error": "InternalError", "message": "Failed to update MST"})),
382 )
383 .into_response(),
384 };
385 let uri = format!("at://{}/{}/{}", did, collection, rkey);
386 results.push(WriteResult::UpdateResult {
387 uri,
388 cid: record_cid.to_string(),
389 });
390 ops.push(RecordOp::Update {
391 collection: collection.clone(),
392 rkey: rkey.clone(),
393 cid: record_cid,
394 prev: prev_record_cid,
395 });
396 }
397 WriteOp::Delete { collection, rkey } => {
398 let collection_nsid = match collection.parse::<Nsid>() {
399 Ok(n) => n,
400 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
401 };
402 let key = format!("{}/{}", collection_nsid, rkey);
403 modified_keys.push(key.clone());
404 let prev_record_cid = mst.get(&key).await.ok().flatten();
405 mst = match mst.delete(&key).await {
406 Ok(m) => m,
407 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to delete from MST"}))).into_response(),
408 };
409 results.push(WriteResult::DeleteResult {});
410 ops.push(RecordOp::Delete {
411 collection: collection.clone(),
412 rkey: rkey.clone(),
413 prev: prev_record_cid,
414 });
415 }
416 }
417 }
418 let new_mst_root = match mst.persist().await {
419 Ok(c) => c,
420 Err(_) => {
421 return (
422 StatusCode::INTERNAL_SERVER_ERROR,
423 Json(json!({"error": "InternalError", "message": "Failed to persist MST"})),
424 )
425 .into_response();
426 }
427 };
428 let mut relevant_blocks = std::collections::BTreeMap::new();
429 for key in &modified_keys {
430 if mst
431 .blocks_for_path(key, &mut relevant_blocks)
432 .await
433 .is_err()
434 {
435 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
436 }
437 if original_mst
438 .blocks_for_path(key, &mut relevant_blocks)
439 .await
440 .is_err()
441 {
442 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response();
443 }
444 }
445 let mut written_cids = tracking_store.get_all_relevant_cids();
446 for cid in relevant_blocks.keys() {
447 if !written_cids.contains(cid) {
448 written_cids.push(*cid);
449 }
450 }
451 let written_cids_str = written_cids
452 .iter()
453 .map(|c| c.to_string())
454 .collect::<Vec<_>>();
455 let commit_res = match commit_and_log(
456 &state,
457 CommitParams {
458 did: &did,
459 user_id,
460 current_root_cid: Some(current_root_cid),
461 prev_data_cid: Some(commit.data),
462 new_mst_root,
463 ops,
464 blocks_cids: &written_cids_str,
465 },
466 )
467 .await
468 {
469 Ok(res) => res,
470 Err(e) => {
471 error!("Commit failed: {}", e);
472 return (
473 StatusCode::INTERNAL_SERVER_ERROR,
474 Json(json!({"error": "InternalError", "message": "Failed to commit changes"})),
475 )
476 .into_response();
477 }
478 };
479 (
480 StatusCode::OK,
481 Json(ApplyWritesOutput {
482 commit: CommitInfo {
483 cid: commit_res.commit_cid.to_string(),
484 rev: commit_res.rev,
485 },
486 results,
487 }),
488 )
489 .into_response()
490}