this repo has no description
1use super::validation::validate_record;
2use super::write::has_verified_notification_channel;
3use crate::api::repo::record::utils::{commit_and_log, RecordOp};
4use crate::repo::tracking::TrackingBlockStore;
5use crate::state::AppState;
6use axum::{
7 extract::State,
8 http::StatusCode,
9 response::{IntoResponse, Response},
10 Json,
11};
12use chrono::Utc;
13use cid::Cid;
14use jacquard::types::string::Nsid;
15use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
16use serde::{Deserialize, Serialize};
17use serde_json::json;
18use std::str::FromStr;
19use std::sync::Arc;
20use tracing::error;
21
22const MAX_BATCH_WRITES: usize = 200;
23
24#[derive(Deserialize)]
25#[serde(tag = "$type")]
26pub enum WriteOp {
27 #[serde(rename = "com.atproto.repo.applyWrites#create")]
28 Create {
29 collection: String,
30 rkey: Option<String>,
31 value: serde_json::Value,
32 },
33 #[serde(rename = "com.atproto.repo.applyWrites#update")]
34 Update {
35 collection: String,
36 rkey: String,
37 value: serde_json::Value,
38 },
39 #[serde(rename = "com.atproto.repo.applyWrites#delete")]
40 Delete { collection: String, rkey: String },
41}
42
43#[derive(Deserialize)]
44#[serde(rename_all = "camelCase")]
45pub struct ApplyWritesInput {
46 pub repo: String,
47 pub validate: Option<bool>,
48 pub writes: Vec<WriteOp>,
49 pub swap_commit: Option<String>,
50}
51
52#[derive(Serialize)]
53#[serde(tag = "$type")]
54pub enum WriteResult {
55 #[serde(rename = "com.atproto.repo.applyWrites#createResult")]
56 CreateResult { uri: String, cid: String },
57 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")]
58 UpdateResult { uri: String, cid: String },
59 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")]
60 DeleteResult {},
61}
62
63#[derive(Serialize)]
64pub struct ApplyWritesOutput {
65 pub commit: CommitInfo,
66 pub results: Vec<WriteResult>,
67}
68
69#[derive(Serialize)]
70pub struct CommitInfo {
71 pub cid: String,
72 pub rev: String,
73}
74
75pub async fn apply_writes(
76 State(state): State<AppState>,
77 headers: axum::http::HeaderMap,
78 Json(input): Json<ApplyWritesInput>,
79) -> Response {
80 let token = match crate::auth::extract_bearer_token_from_header(
81 headers.get("Authorization").and_then(|h| h.to_str().ok())
82 ) {
83 Some(t) => t,
84 None => {
85 return (
86 StatusCode::UNAUTHORIZED,
87 Json(json!({"error": "AuthenticationRequired"})),
88 )
89 .into_response();
90 }
91 };
92
93 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
94 Ok(user) => user,
95 Err(_) => {
96 return (
97 StatusCode::UNAUTHORIZED,
98 Json(json!({"error": "AuthenticationFailed"})),
99 )
100 .into_response();
101 }
102 };
103
104 let did = auth_user.did;
105
106 if input.repo != did {
107 return (
108 StatusCode::FORBIDDEN,
109 Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"})),
110 )
111 .into_response();
112 }
113
114 match has_verified_notification_channel(&state.db, &did).await {
115 Ok(true) => {}
116 Ok(false) => {
117 return (
118 StatusCode::FORBIDDEN,
119 Json(json!({
120 "error": "AccountNotVerified",
121 "message": "You must verify at least one notification channel (email, Discord, Telegram, or Signal) before creating records"
122 })),
123 )
124 .into_response();
125 }
126 Err(e) => {
127 error!("DB error checking notification channels: {}", e);
128 return (
129 StatusCode::INTERNAL_SERVER_ERROR,
130 Json(json!({"error": "InternalError"})),
131 )
132 .into_response();
133 }
134 }
135
136 if input.writes.is_empty() {
137 return (
138 StatusCode::BAD_REQUEST,
139 Json(json!({"error": "InvalidRequest", "message": "writes array is empty"})),
140 )
141 .into_response();
142 }
143
144 if input.writes.len() > MAX_BATCH_WRITES {
145 return (
146 StatusCode::BAD_REQUEST,
147 Json(json!({"error": "InvalidRequest", "message": format!("Too many writes (max {})", MAX_BATCH_WRITES)})),
148 )
149 .into_response();
150 }
151
152 let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
153 .fetch_optional(&state.db)
154 .await
155 {
156 Ok(Some(id)) => id,
157 _ => {
158 return (
159 StatusCode::INTERNAL_SERVER_ERROR,
160 Json(json!({"error": "InternalError", "message": "User not found"})),
161 )
162 .into_response();
163 }
164 };
165
166 let root_cid_str: String =
167 match sqlx::query_scalar!("SELECT repo_root_cid FROM repos WHERE user_id = $1", user_id)
168 .fetch_optional(&state.db)
169 .await
170 {
171 Ok(Some(cid_str)) => cid_str,
172 _ => {
173 return (
174 StatusCode::INTERNAL_SERVER_ERROR,
175 Json(json!({"error": "InternalError", "message": "Repo root not found"})),
176 )
177 .into_response();
178 }
179 };
180
181 let current_root_cid = match Cid::from_str(&root_cid_str) {
182 Ok(c) => c,
183 Err(_) => {
184 return (
185 StatusCode::INTERNAL_SERVER_ERROR,
186 Json(json!({"error": "InternalError", "message": "Invalid repo root CID"})),
187 )
188 .into_response();
189 }
190 };
191
192 if let Some(swap_commit) = &input.swap_commit {
193 if Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
194 return (
195 StatusCode::CONFLICT,
196 Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
197 )
198 .into_response();
199 }
200 }
201
202 let tracking_store = TrackingBlockStore::new(state.block_store.clone());
203
204 let commit_bytes = match tracking_store.get(¤t_root_cid).await {
205 Ok(Some(b)) => b,
206 _ => {
207 return (
208 StatusCode::INTERNAL_SERVER_ERROR,
209 Json(json!({"error": "InternalError", "message": "Commit block not found"})),
210 )
211 .into_response()
212 }
213 };
214
215 let commit = match Commit::from_cbor(&commit_bytes) {
216 Ok(c) => c,
217 _ => {
218 return (
219 StatusCode::INTERNAL_SERVER_ERROR,
220 Json(json!({"error": "InternalError", "message": "Failed to parse commit"})),
221 )
222 .into_response()
223 }
224 };
225
226 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
227
228 let mut results: Vec<WriteResult> = Vec::new();
229 let mut ops: Vec<RecordOp> = Vec::new();
230
231 for write in &input.writes {
232 match write {
233 WriteOp::Create {
234 collection,
235 rkey,
236 value,
237 } => {
238 if input.validate.unwrap_or(true) {
239 if let Err(err_response) = validate_record(value, collection) {
240 return err_response;
241 }
242 }
243 let rkey = rkey
244 .clone()
245 .unwrap_or_else(|| Utc::now().format("%Y%m%d%H%M%S%f").to_string());
246 let mut record_bytes = Vec::new();
247 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
248 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
249 }
250 let record_cid = match tracking_store.put(&record_bytes).await {
251 Ok(c) => c,
252 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(),
253 };
254
255 let collection_nsid = match collection.parse::<Nsid>() {
256 Ok(n) => n,
257 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
258 };
259 let key = format!("{}/{}", collection_nsid, rkey);
260 mst = match mst.add(&key, record_cid).await {
261 Ok(m) => m,
262 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(),
263 };
264
265 let uri = format!("at://{}/{}/{}", did, collection, rkey);
266 results.push(WriteResult::CreateResult {
267 uri,
268 cid: record_cid.to_string(),
269 });
270 ops.push(RecordOp::Create {
271 collection: collection.clone(),
272 rkey,
273 cid: record_cid,
274 });
275 }
276 WriteOp::Update {
277 collection,
278 rkey,
279 value,
280 } => {
281 if input.validate.unwrap_or(true) {
282 if let Err(err_response) = validate_record(value, collection) {
283 return err_response;
284 }
285 }
286 let mut record_bytes = Vec::new();
287 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
288 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
289 }
290 let record_cid = match tracking_store.put(&record_bytes).await {
291 Ok(c) => c,
292 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(),
293 };
294
295 let collection_nsid = match collection.parse::<Nsid>() {
296 Ok(n) => n,
297 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
298 };
299 let key = format!("{}/{}", collection_nsid, rkey);
300 mst = match mst.update(&key, record_cid).await {
301 Ok(m) => m,
302 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update MST"}))).into_response(),
303 };
304
305 let uri = format!("at://{}/{}/{}", did, collection, rkey);
306 results.push(WriteResult::UpdateResult {
307 uri,
308 cid: record_cid.to_string(),
309 });
310 ops.push(RecordOp::Update {
311 collection: collection.clone(),
312 rkey: rkey.clone(),
313 cid: record_cid,
314 });
315 }
316 WriteOp::Delete { collection, rkey } => {
317 let collection_nsid = match collection.parse::<Nsid>() {
318 Ok(n) => n,
319 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
320 };
321 let key = format!("{}/{}", collection_nsid, rkey);
322 mst = match mst.delete(&key).await {
323 Ok(m) => m,
324 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to delete from MST"}))).into_response(),
325 };
326
327 results.push(WriteResult::DeleteResult {});
328 ops.push(RecordOp::Delete {
329 collection: collection.clone(),
330 rkey: rkey.clone(),
331 });
332 }
333 }
334 }
335
336 let new_mst_root = match mst.persist().await {
337 Ok(c) => c,
338 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(),
339 };
340 let written_cids = tracking_store.get_written_cids();
341 let written_cids_str = written_cids
342 .iter()
343 .map(|c| c.to_string())
344 .collect::<Vec<_>>();
345
346 let commit_res = match commit_and_log(
347 &state,
348 &did,
349 user_id,
350 Some(current_root_cid),
351 new_mst_root,
352 ops,
353 &written_cids_str,
354 )
355 .await
356 {
357 Ok(res) => res,
358 Err(e) => {
359 error!("Commit failed: {}", e);
360 return (
361 StatusCode::INTERNAL_SERVER_ERROR,
362 Json(json!({"error": "InternalError", "message": "Failed to commit changes"})),
363 )
364 .into_response();
365 }
366 };
367
368 (
369 StatusCode::OK,
370 Json(ApplyWritesOutput {
371 commit: CommitInfo {
372 cid: commit_res.commit_cid.to_string(),
373 rev: commit_res.rev,
374 },
375 results,
376 }),
377 )
378 .into_response()
379}