this repo has no description
1use super::validation::validate_record;
2use super::write::has_verified_notification_channel;
3use crate::api::repo::record::utils::{commit_and_log, RecordOp};
4use crate::repo::tracking::TrackingBlockStore;
5use crate::state::AppState;
6use axum::{
7 extract::State,
8 http::StatusCode,
9 response::{IntoResponse, Response},
10 Json,
11};
12use cid::Cid;
13use jacquard::types::{integer::LimitedU32, string::{Nsid, Tid}};
14use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17use std::str::FromStr;
18use std::sync::Arc;
19use tracing::error;
20const MAX_BATCH_WRITES: usize = 200;
21#[derive(Deserialize)]
22#[serde(tag = "$type")]
23pub enum WriteOp {
24 #[serde(rename = "com.atproto.repo.applyWrites#create")]
25 Create {
26 collection: String,
27 rkey: Option<String>,
28 value: serde_json::Value,
29 },
30 #[serde(rename = "com.atproto.repo.applyWrites#update")]
31 Update {
32 collection: String,
33 rkey: String,
34 value: serde_json::Value,
35 },
36 #[serde(rename = "com.atproto.repo.applyWrites#delete")]
37 Delete { collection: String, rkey: String },
38}
39#[derive(Deserialize)]
40#[serde(rename_all = "camelCase")]
41pub struct ApplyWritesInput {
42 pub repo: String,
43 pub validate: Option<bool>,
44 pub writes: Vec<WriteOp>,
45 pub swap_commit: Option<String>,
46}
47#[derive(Serialize)]
48#[serde(tag = "$type")]
49pub enum WriteResult {
50 #[serde(rename = "com.atproto.repo.applyWrites#createResult")]
51 CreateResult { uri: String, cid: String },
52 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")]
53 UpdateResult { uri: String, cid: String },
54 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")]
55 DeleteResult {},
56}
57#[derive(Serialize)]
58pub struct ApplyWritesOutput {
59 pub commit: CommitInfo,
60 pub results: Vec<WriteResult>,
61}
62#[derive(Serialize)]
63pub struct CommitInfo {
64 pub cid: String,
65 pub rev: String,
66}
67pub async fn apply_writes(
68 State(state): State<AppState>,
69 headers: axum::http::HeaderMap,
70 Json(input): Json<ApplyWritesInput>,
71) -> Response {
72 let token = match crate::auth::extract_bearer_token_from_header(
73 headers.get("Authorization").and_then(|h| h.to_str().ok())
74 ) {
75 Some(t) => t,
76 None => {
77 return (
78 StatusCode::UNAUTHORIZED,
79 Json(json!({"error": "AuthenticationRequired"})),
80 )
81 .into_response();
82 }
83 };
84 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
85 Ok(user) => user,
86 Err(_) => {
87 return (
88 StatusCode::UNAUTHORIZED,
89 Json(json!({"error": "AuthenticationFailed"})),
90 )
91 .into_response();
92 }
93 };
94 let did = auth_user.did;
95 if input.repo != did {
96 return (
97 StatusCode::FORBIDDEN,
98 Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"})),
99 )
100 .into_response();
101 }
102 match has_verified_notification_channel(&state.db, &did).await {
103 Ok(true) => {}
104 Ok(false) => {
105 return (
106 StatusCode::FORBIDDEN,
107 Json(json!({
108 "error": "AccountNotVerified",
109 "message": "You must verify at least one notification channel (email, Discord, Telegram, or Signal) before creating records"
110 })),
111 )
112 .into_response();
113 }
114 Err(e) => {
115 error!("DB error checking notification channels: {}", e);
116 return (
117 StatusCode::INTERNAL_SERVER_ERROR,
118 Json(json!({"error": "InternalError"})),
119 )
120 .into_response();
121 }
122 }
123 if input.writes.is_empty() {
124 return (
125 StatusCode::BAD_REQUEST,
126 Json(json!({"error": "InvalidRequest", "message": "writes array is empty"})),
127 )
128 .into_response();
129 }
130 if input.writes.len() > MAX_BATCH_WRITES {
131 return (
132 StatusCode::BAD_REQUEST,
133 Json(json!({"error": "InvalidRequest", "message": format!("Too many writes (max {})", MAX_BATCH_WRITES)})),
134 )
135 .into_response();
136 }
137 let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
138 .fetch_optional(&state.db)
139 .await
140 {
141 Ok(Some(id)) => id,
142 _ => {
143 return (
144 StatusCode::INTERNAL_SERVER_ERROR,
145 Json(json!({"error": "InternalError", "message": "User not found"})),
146 )
147 .into_response();
148 }
149 };
150 let root_cid_str: String =
151 match sqlx::query_scalar!("SELECT repo_root_cid FROM repos WHERE user_id = $1", user_id)
152 .fetch_optional(&state.db)
153 .await
154 {
155 Ok(Some(cid_str)) => cid_str,
156 _ => {
157 return (
158 StatusCode::INTERNAL_SERVER_ERROR,
159 Json(json!({"error": "InternalError", "message": "Repo root not found"})),
160 )
161 .into_response();
162 }
163 };
164 let current_root_cid = match Cid::from_str(&root_cid_str) {
165 Ok(c) => c,
166 Err(_) => {
167 return (
168 StatusCode::INTERNAL_SERVER_ERROR,
169 Json(json!({"error": "InternalError", "message": "Invalid repo root CID"})),
170 )
171 .into_response();
172 }
173 };
174 if let Some(swap_commit) = &input.swap_commit {
175 if Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
176 return (
177 StatusCode::CONFLICT,
178 Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
179 )
180 .into_response();
181 }
182 }
183 let tracking_store = TrackingBlockStore::new(state.block_store.clone());
184 let commit_bytes = match tracking_store.get(¤t_root_cid).await {
185 Ok(Some(b)) => b,
186 _ => {
187 return (
188 StatusCode::INTERNAL_SERVER_ERROR,
189 Json(json!({"error": "InternalError", "message": "Commit block not found"})),
190 )
191 .into_response()
192 }
193 };
194 let commit = match Commit::from_cbor(&commit_bytes) {
195 Ok(c) => c,
196 _ => {
197 return (
198 StatusCode::INTERNAL_SERVER_ERROR,
199 Json(json!({"error": "InternalError", "message": "Failed to parse commit"})),
200 )
201 .into_response()
202 }
203 };
204 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
205 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
206 let mut results: Vec<WriteResult> = Vec::new();
207 let mut ops: Vec<RecordOp> = Vec::new();
208 let mut modified_keys: Vec<String> = Vec::new();
209 for write in &input.writes {
210 match write {
211 WriteOp::Create {
212 collection,
213 rkey,
214 value,
215 } => {
216 if input.validate.unwrap_or(true) {
217 if let Err(err_response) = validate_record(value, collection) {
218 return err_response;
219 }
220 }
221 let rkey = rkey
222 .clone()
223 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string());
224 let mut record_bytes = Vec::new();
225 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
226 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
227 }
228 let record_cid = match tracking_store.put(&record_bytes).await {
229 Ok(c) => c,
230 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(),
231 };
232 let collection_nsid = match collection.parse::<Nsid>() {
233 Ok(n) => n,
234 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
235 };
236 let key = format!("{}/{}", collection_nsid, rkey);
237 modified_keys.push(key.clone());
238 mst = match mst.add(&key, record_cid).await {
239 Ok(m) => m,
240 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(),
241 };
242 let uri = format!("at://{}/{}/{}", did, collection, rkey);
243 results.push(WriteResult::CreateResult {
244 uri,
245 cid: record_cid.to_string(),
246 });
247 ops.push(RecordOp::Create {
248 collection: collection.clone(),
249 rkey,
250 cid: record_cid,
251 });
252 }
253 WriteOp::Update {
254 collection,
255 rkey,
256 value,
257 } => {
258 if input.validate.unwrap_or(true) {
259 if let Err(err_response) = validate_record(value, collection) {
260 return err_response;
261 }
262 }
263 let mut record_bytes = Vec::new();
264 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
265 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
266 }
267 let record_cid = match tracking_store.put(&record_bytes).await {
268 Ok(c) => c,
269 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(),
270 };
271 let collection_nsid = match collection.parse::<Nsid>() {
272 Ok(n) => n,
273 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
274 };
275 let key = format!("{}/{}", collection_nsid, rkey);
276 modified_keys.push(key.clone());
277 let prev_record_cid = mst.get(&key).await.ok().flatten();
278 mst = match mst.update(&key, record_cid).await {
279 Ok(m) => m,
280 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update MST"}))).into_response(),
281 };
282 let uri = format!("at://{}/{}/{}", did, collection, rkey);
283 results.push(WriteResult::UpdateResult {
284 uri,
285 cid: record_cid.to_string(),
286 });
287 ops.push(RecordOp::Update {
288 collection: collection.clone(),
289 rkey: rkey.clone(),
290 cid: record_cid,
291 prev: prev_record_cid,
292 });
293 }
294 WriteOp::Delete { collection, rkey } => {
295 let collection_nsid = match collection.parse::<Nsid>() {
296 Ok(n) => n,
297 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
298 };
299 let key = format!("{}/{}", collection_nsid, rkey);
300 modified_keys.push(key.clone());
301 let prev_record_cid = mst.get(&key).await.ok().flatten();
302 mst = match mst.delete(&key).await {
303 Ok(m) => m,
304 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to delete from MST"}))).into_response(),
305 };
306 results.push(WriteResult::DeleteResult {});
307 ops.push(RecordOp::Delete {
308 collection: collection.clone(),
309 rkey: rkey.clone(),
310 prev: prev_record_cid,
311 });
312 }
313 }
314 }
315 let new_mst_root = match mst.persist().await {
316 Ok(c) => c,
317 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(),
318 };
319 let mut relevant_blocks = std::collections::BTreeMap::new();
320 for key in &modified_keys {
321 if let Err(_) = mst.blocks_for_path(key, &mut relevant_blocks).await {
322 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
323 }
324 if let Err(_) = original_mst.blocks_for_path(key, &mut relevant_blocks).await {
325 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response();
326 }
327 }
328 let mut written_cids = tracking_store.get_all_relevant_cids();
329 for cid in relevant_blocks.keys() {
330 if !written_cids.contains(cid) {
331 written_cids.push(*cid);
332 }
333 }
334 let written_cids_str = written_cids
335 .iter()
336 .map(|c| c.to_string())
337 .collect::<Vec<_>>();
338 let commit_res = match commit_and_log(
339 &state,
340 &did,
341 user_id,
342 Some(current_root_cid),
343 Some(commit.data),
344 new_mst_root,
345 ops,
346 &written_cids_str,
347 )
348 .await
349 {
350 Ok(res) => res,
351 Err(e) => {
352 error!("Commit failed: {}", e);
353 return (
354 StatusCode::INTERNAL_SERVER_ERROR,
355 Json(json!({"error": "InternalError", "message": "Failed to commit changes"})),
356 )
357 .into_response();
358 }
359 };
360 (
361 StatusCode::OK,
362 Json(ApplyWritesOutput {
363 commit: CommitInfo {
364 cid: commit_res.commit_cid.to_string(),
365 rev: commit_res.rev,
366 },
367 results,
368 }),
369 )
370 .into_response()
371}