this repo has no description
1use super::validation::validate_record;
2use super::write::has_verified_notification_channel;
3use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log};
4use crate::repo::tracking::TrackingBlockStore;
5use crate::state::AppState;
6use axum::{
7 Json,
8 extract::State,
9 http::StatusCode,
10 response::{IntoResponse, Response},
11};
12use cid::Cid;
13use jacquard::types::{
14 integer::LimitedU32,
15 string::{Nsid, Tid},
16};
17use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
18use serde::{Deserialize, Serialize};
19use serde_json::json;
20use std::str::FromStr;
21use std::sync::Arc;
22use tracing::error;
23
24const MAX_BATCH_WRITES: usize = 200;
25
26#[derive(Deserialize)]
27#[serde(tag = "$type")]
28pub enum WriteOp {
29 #[serde(rename = "com.atproto.repo.applyWrites#create")]
30 Create {
31 collection: String,
32 rkey: Option<String>,
33 value: serde_json::Value,
34 },
35 #[serde(rename = "com.atproto.repo.applyWrites#update")]
36 Update {
37 collection: String,
38 rkey: String,
39 value: serde_json::Value,
40 },
41 #[serde(rename = "com.atproto.repo.applyWrites#delete")]
42 Delete { collection: String, rkey: String },
43}
44
45#[derive(Deserialize)]
46#[serde(rename_all = "camelCase")]
47pub struct ApplyWritesInput {
48 pub repo: String,
49 pub validate: Option<bool>,
50 pub writes: Vec<WriteOp>,
51 pub swap_commit: Option<String>,
52}
53
54#[derive(Serialize)]
55#[serde(tag = "$type")]
56pub enum WriteResult {
57 #[serde(rename = "com.atproto.repo.applyWrites#createResult")]
58 CreateResult { uri: String, cid: String },
59 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")]
60 UpdateResult { uri: String, cid: String },
61 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")]
62 DeleteResult {},
63}
64
65#[derive(Serialize)]
66pub struct ApplyWritesOutput {
67 pub commit: CommitInfo,
68 pub results: Vec<WriteResult>,
69}
70
71#[derive(Serialize)]
72pub struct CommitInfo {
73 pub cid: String,
74 pub rev: String,
75}
76
77pub async fn apply_writes(
78 State(state): State<AppState>,
79 headers: axum::http::HeaderMap,
80 Json(input): Json<ApplyWritesInput>,
81) -> Response {
82 let token = match crate::auth::extract_bearer_token_from_header(
83 headers.get("Authorization").and_then(|h| h.to_str().ok()),
84 ) {
85 Some(t) => t,
86 None => {
87 return (
88 StatusCode::UNAUTHORIZED,
89 Json(json!({"error": "AuthenticationRequired"})),
90 )
91 .into_response();
92 }
93 };
94 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
95 Ok(user) => user,
96 Err(_) => {
97 return (
98 StatusCode::UNAUTHORIZED,
99 Json(json!({"error": "AuthenticationFailed"})),
100 )
101 .into_response();
102 }
103 };
104 let did = auth_user.did;
105 if input.repo != did {
106 return (
107 StatusCode::FORBIDDEN,
108 Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"})),
109 )
110 .into_response();
111 }
112 match has_verified_notification_channel(&state.db, &did).await {
113 Ok(true) => {}
114 Ok(false) => {
115 return (
116 StatusCode::FORBIDDEN,
117 Json(json!({
118 "error": "AccountNotVerified",
119 "message": "You must verify at least one notification channel (email, Discord, Telegram, or Signal) before creating records"
120 })),
121 )
122 .into_response();
123 }
124 Err(e) => {
125 error!("DB error checking notification channels: {}", e);
126 return (
127 StatusCode::INTERNAL_SERVER_ERROR,
128 Json(json!({"error": "InternalError"})),
129 )
130 .into_response();
131 }
132 }
133 if input.writes.is_empty() {
134 return (
135 StatusCode::BAD_REQUEST,
136 Json(json!({"error": "InvalidRequest", "message": "writes array is empty"})),
137 )
138 .into_response();
139 }
140 if input.writes.len() > MAX_BATCH_WRITES {
141 return (
142 StatusCode::BAD_REQUEST,
143 Json(json!({"error": "InvalidRequest", "message": format!("Too many writes (max {})", MAX_BATCH_WRITES)})),
144 )
145 .into_response();
146 }
147 let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
148 .fetch_optional(&state.db)
149 .await
150 {
151 Ok(Some(id)) => id,
152 _ => {
153 return (
154 StatusCode::INTERNAL_SERVER_ERROR,
155 Json(json!({"error": "InternalError", "message": "User not found"})),
156 )
157 .into_response();
158 }
159 };
160 let root_cid_str: String = match sqlx::query_scalar!(
161 "SELECT repo_root_cid FROM repos WHERE user_id = $1",
162 user_id
163 )
164 .fetch_optional(&state.db)
165 .await
166 {
167 Ok(Some(cid_str)) => cid_str,
168 _ => {
169 return (
170 StatusCode::INTERNAL_SERVER_ERROR,
171 Json(json!({"error": "InternalError", "message": "Repo root not found"})),
172 )
173 .into_response();
174 }
175 };
176 let current_root_cid = match Cid::from_str(&root_cid_str) {
177 Ok(c) => c,
178 Err(_) => {
179 return (
180 StatusCode::INTERNAL_SERVER_ERROR,
181 Json(json!({"error": "InternalError", "message": "Invalid repo root CID"})),
182 )
183 .into_response();
184 }
185 };
186 if let Some(swap_commit) = &input.swap_commit
187 && Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
188 return (
189 StatusCode::CONFLICT,
190 Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
191 )
192 .into_response();
193 }
194 let tracking_store = TrackingBlockStore::new(state.block_store.clone());
195 let commit_bytes = match tracking_store.get(¤t_root_cid).await {
196 Ok(Some(b)) => b,
197 _ => {
198 return (
199 StatusCode::INTERNAL_SERVER_ERROR,
200 Json(json!({"error": "InternalError", "message": "Commit block not found"})),
201 )
202 .into_response();
203 }
204 };
205 let commit = match Commit::from_cbor(&commit_bytes) {
206 Ok(c) => c,
207 _ => {
208 return (
209 StatusCode::INTERNAL_SERVER_ERROR,
210 Json(json!({"error": "InternalError", "message": "Failed to parse commit"})),
211 )
212 .into_response();
213 }
214 };
215 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
216 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
217 let mut results: Vec<WriteResult> = Vec::new();
218 let mut ops: Vec<RecordOp> = Vec::new();
219 let mut modified_keys: Vec<String> = Vec::new();
220 for write in &input.writes {
221 match write {
222 WriteOp::Create {
223 collection,
224 rkey,
225 value,
226 } => {
227 if input.validate.unwrap_or(true)
228 && let Err(err_response) = validate_record(value, collection) {
229 return *err_response;
230 }
231 let rkey = rkey
232 .clone()
233 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string());
234 let mut record_bytes = Vec::new();
235 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
236 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
237 }
238 let record_cid = match tracking_store.put(&record_bytes).await {
239 Ok(c) => c,
240 Err(_) => return (
241 StatusCode::INTERNAL_SERVER_ERROR,
242 Json(
243 json!({"error": "InternalError", "message": "Failed to store record"}),
244 ),
245 )
246 .into_response(),
247 };
248 let collection_nsid = match collection.parse::<Nsid>() {
249 Ok(n) => n,
250 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
251 };
252 let key = format!("{}/{}", collection_nsid, rkey);
253 modified_keys.push(key.clone());
254 mst = match mst.add(&key, record_cid).await {
255 Ok(m) => m,
256 Err(_) => return (
257 StatusCode::INTERNAL_SERVER_ERROR,
258 Json(json!({"error": "InternalError", "message": "Failed to add to MST"})),
259 )
260 .into_response(),
261 };
262 let uri = format!("at://{}/{}/{}", did, collection, rkey);
263 results.push(WriteResult::CreateResult {
264 uri,
265 cid: record_cid.to_string(),
266 });
267 ops.push(RecordOp::Create {
268 collection: collection.clone(),
269 rkey,
270 cid: record_cid,
271 });
272 }
273 WriteOp::Update {
274 collection,
275 rkey,
276 value,
277 } => {
278 if input.validate.unwrap_or(true)
279 && let Err(err_response) = validate_record(value, collection) {
280 return *err_response;
281 }
282 let mut record_bytes = Vec::new();
283 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
284 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
285 }
286 let record_cid = match tracking_store.put(&record_bytes).await {
287 Ok(c) => c,
288 Err(_) => return (
289 StatusCode::INTERNAL_SERVER_ERROR,
290 Json(
291 json!({"error": "InternalError", "message": "Failed to store record"}),
292 ),
293 )
294 .into_response(),
295 };
296 let collection_nsid = match collection.parse::<Nsid>() {
297 Ok(n) => n,
298 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
299 };
300 let key = format!("{}/{}", collection_nsid, rkey);
301 modified_keys.push(key.clone());
302 let prev_record_cid = mst.get(&key).await.ok().flatten();
303 mst = match mst.update(&key, record_cid).await {
304 Ok(m) => m,
305 Err(_) => return (
306 StatusCode::INTERNAL_SERVER_ERROR,
307 Json(json!({"error": "InternalError", "message": "Failed to update MST"})),
308 )
309 .into_response(),
310 };
311 let uri = format!("at://{}/{}/{}", did, collection, rkey);
312 results.push(WriteResult::UpdateResult {
313 uri,
314 cid: record_cid.to_string(),
315 });
316 ops.push(RecordOp::Update {
317 collection: collection.clone(),
318 rkey: rkey.clone(),
319 cid: record_cid,
320 prev: prev_record_cid,
321 });
322 }
323 WriteOp::Delete { collection, rkey } => {
324 let collection_nsid = match collection.parse::<Nsid>() {
325 Ok(n) => n,
326 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
327 };
328 let key = format!("{}/{}", collection_nsid, rkey);
329 modified_keys.push(key.clone());
330 let prev_record_cid = mst.get(&key).await.ok().flatten();
331 mst = match mst.delete(&key).await {
332 Ok(m) => m,
333 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to delete from MST"}))).into_response(),
334 };
335 results.push(WriteResult::DeleteResult {});
336 ops.push(RecordOp::Delete {
337 collection: collection.clone(),
338 rkey: rkey.clone(),
339 prev: prev_record_cid,
340 });
341 }
342 }
343 }
344 let new_mst_root = match mst.persist().await {
345 Ok(c) => c,
346 Err(_) => {
347 return (
348 StatusCode::INTERNAL_SERVER_ERROR,
349 Json(json!({"error": "InternalError", "message": "Failed to persist MST"})),
350 )
351 .into_response();
352 }
353 };
354 let mut relevant_blocks = std::collections::BTreeMap::new();
355 for key in &modified_keys {
356 if mst.blocks_for_path(key, &mut relevant_blocks).await.is_err() {
357 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
358 }
359 if original_mst
360 .blocks_for_path(key, &mut relevant_blocks)
361 .await
362 .is_err()
363 {
364 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response();
365 }
366 }
367 let mut written_cids = tracking_store.get_all_relevant_cids();
368 for cid in relevant_blocks.keys() {
369 if !written_cids.contains(cid) {
370 written_cids.push(*cid);
371 }
372 }
373 let written_cids_str = written_cids
374 .iter()
375 .map(|c| c.to_string())
376 .collect::<Vec<_>>();
377 let commit_res = match commit_and_log(
378 &state,
379 CommitParams {
380 did: &did,
381 user_id,
382 current_root_cid: Some(current_root_cid),
383 prev_data_cid: Some(commit.data),
384 new_mst_root,
385 ops,
386 blocks_cids: &written_cids_str,
387 },
388 )
389 .await
390 {
391 Ok(res) => res,
392 Err(e) => {
393 error!("Commit failed: {}", e);
394 return (
395 StatusCode::INTERNAL_SERVER_ERROR,
396 Json(json!({"error": "InternalError", "message": "Failed to commit changes"})),
397 )
398 .into_response();
399 }
400 };
401 (
402 StatusCode::OK,
403 Json(ApplyWritesOutput {
404 commit: CommitInfo {
405 cid: commit_res.commit_cid.to_string(),
406 rev: commit_res.rev,
407 },
408 results,
409 }),
410 )
411 .into_response()
412}