this repo has no description
1use crate::api::read_after_write::{
2 PostRecord, PostView, RecordDescript, extract_repo_rev, format_local_post,
3 format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview_via_registry,
4};
5use crate::state::AppState;
6use axum::{
7 Json,
8 extract::{Query, State},
9 http::StatusCode,
10 response::{IntoResponse, Response},
11};
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use std::collections::HashMap;
15use tracing::warn;
16
17#[derive(Deserialize)]
18pub struct GetPostThreadParams {
19 pub uri: String,
20 pub depth: Option<u32>,
21 #[serde(rename = "parentHeight")]
22 pub parent_height: Option<u32>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(rename_all = "camelCase")]
27pub struct ThreadViewPost {
28 #[serde(rename = "$type")]
29 pub thread_type: Option<String>,
30 pub post: PostView,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub parent: Option<Box<ThreadNode>>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub replies: Option<Vec<ThreadNode>>,
35 #[serde(flatten)]
36 pub extra: HashMap<String, Value>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40#[serde(untagged)]
41pub enum ThreadNode {
42 Post(Box<ThreadViewPost>),
43 NotFound(ThreadNotFound),
44 Blocked(ThreadBlocked),
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(rename_all = "camelCase")]
49pub struct ThreadNotFound {
50 #[serde(rename = "$type")]
51 pub thread_type: String,
52 pub uri: String,
53 pub not_found: bool,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(rename_all = "camelCase")]
58pub struct ThreadBlocked {
59 #[serde(rename = "$type")]
60 pub thread_type: String,
61 pub uri: String,
62 pub blocked: bool,
63 pub author: Value,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct PostThreadOutput {
68 pub thread: ThreadNode,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub threadgate: Option<Value>,
71}
72
73const MAX_THREAD_DEPTH: usize = 10;
74
75fn add_replies_to_thread(
76 thread: &mut ThreadViewPost,
77 local_posts: &[RecordDescript<PostRecord>],
78 author_did: &str,
79 author_handle: &str,
80 depth: usize,
81) {
82 if depth >= MAX_THREAD_DEPTH {
83 return;
84 }
85 let thread_uri = &thread.post.uri;
86 let replies: Vec<_> = local_posts
87 .iter()
88 .filter(|p| {
89 p.record
90 .reply
91 .as_ref()
92 .and_then(|r| r.get("parent"))
93 .and_then(|parent| parent.get("uri"))
94 .and_then(|u| u.as_str())
95 == Some(thread_uri)
96 })
97 .map(|p| {
98 let post_view = format_local_post(p, author_did, author_handle, None);
99 ThreadNode::Post(Box::new(ThreadViewPost {
100 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
101 post: post_view,
102 parent: None,
103 replies: None,
104 extra: HashMap::new(),
105 }))
106 })
107 .collect();
108 if !replies.is_empty() {
109 match &mut thread.replies {
110 Some(existing) => existing.extend(replies),
111 None => thread.replies = Some(replies),
112 }
113 }
114 if let Some(ref mut existing_replies) = thread.replies {
115 for reply in existing_replies.iter_mut() {
116 if let ThreadNode::Post(reply_thread) = reply {
117 add_replies_to_thread(
118 reply_thread,
119 local_posts,
120 author_did,
121 author_handle,
122 depth + 1,
123 );
124 }
125 }
126 }
127}
128
129pub async fn get_post_thread(
130 State(state): State<AppState>,
131 headers: axum::http::HeaderMap,
132 Query(params): Query<GetPostThreadParams>,
133) -> Response {
134 let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
135 let auth_user = if let Some(h) = auth_header {
136 if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
137 crate::auth::validate_bearer_token(&state.db, &token)
138 .await
139 .ok()
140 } else {
141 None
142 }
143 } else {
144 None
145 };
146 let auth_did = auth_user.as_ref().map(|u| u.did.clone());
147 let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
148 let mut query_params = HashMap::new();
149 query_params.insert("uri".to_string(), params.uri.clone());
150 if let Some(depth) = params.depth {
151 query_params.insert("depth".to_string(), depth.to_string());
152 }
153 if let Some(parent_height) = params.parent_height {
154 query_params.insert("parentHeight".to_string(), parent_height.to_string());
155 }
156 let proxy_result = match proxy_to_appview_via_registry(
157 &state,
158 "app.bsky.feed.getPostThread",
159 &query_params,
160 auth_did.as_deref().unwrap_or(""),
161 auth_key_bytes.as_deref(),
162 )
163 .await
164 {
165 Ok(r) => r,
166 Err(e) => return e,
167 };
168 if proxy_result.status == StatusCode::NOT_FOUND {
169 return handle_not_found(&state, ¶ms.uri, auth_did, &proxy_result.headers).await;
170 }
171 if !proxy_result.status.is_success() {
172 return proxy_result.into_response();
173 }
174 let rev = match extract_repo_rev(&proxy_result.headers) {
175 Some(r) => r,
176 None => return proxy_result.into_response(),
177 };
178 let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) {
179 Ok(t) => t,
180 Err(e) => {
181 warn!("Failed to parse post thread response: {:?}", e);
182 return proxy_result.into_response();
183 }
184 };
185 let requester_did = match auth_did {
186 Some(d) => d,
187 None => return (StatusCode::OK, Json(thread_output)).into_response(),
188 };
189 let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
190 Ok(r) => r,
191 Err(e) => {
192 warn!("Failed to get local records: {}", e);
193 return proxy_result.into_response();
194 }
195 };
196 if local_records.posts.is_empty() {
197 return (StatusCode::OK, Json(thread_output)).into_response();
198 }
199 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
200 .fetch_optional(&state.db)
201 .await
202 {
203 Ok(Some(h)) => h,
204 Ok(None) => requester_did.clone(),
205 Err(e) => {
206 warn!("Database error fetching handle: {:?}", e);
207 requester_did.clone()
208 }
209 };
210 if let ThreadNode::Post(ref mut thread_post) = thread_output.thread {
211 add_replies_to_thread(
212 thread_post,
213 &local_records.posts,
214 &requester_did,
215 &handle,
216 0,
217 );
218 }
219 let lag = get_local_lag(&local_records);
220 format_munged_response(thread_output, lag)
221}
222
223async fn handle_not_found(
224 state: &AppState,
225 uri: &str,
226 auth_did: Option<String>,
227 headers: &axum::http::HeaderMap,
228) -> Response {
229 let rev = match extract_repo_rev(headers) {
230 Some(r) => r,
231 None => {
232 return (
233 StatusCode::NOT_FOUND,
234 Json(json!({"error": "NotFound", "message": "Post not found"})),
235 )
236 .into_response();
237 }
238 };
239 let requester_did = match auth_did {
240 Some(d) => d,
241 None => {
242 return (
243 StatusCode::NOT_FOUND,
244 Json(json!({"error": "NotFound", "message": "Post not found"})),
245 )
246 .into_response();
247 }
248 };
249 let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect();
250 if uri_parts.len() != 3 {
251 return (
252 StatusCode::NOT_FOUND,
253 Json(json!({"error": "NotFound", "message": "Post not found"})),
254 )
255 .into_response();
256 }
257 let post_did = uri_parts[0];
258 if post_did != requester_did {
259 return (
260 StatusCode::NOT_FOUND,
261 Json(json!({"error": "NotFound", "message": "Post not found"})),
262 )
263 .into_response();
264 }
265 let local_records = match get_records_since_rev(state, &requester_did, &rev).await {
266 Ok(r) => r,
267 Err(_) => {
268 return (
269 StatusCode::NOT_FOUND,
270 Json(json!({"error": "NotFound", "message": "Post not found"})),
271 )
272 .into_response();
273 }
274 };
275 let local_post = local_records.posts.iter().find(|p| p.uri == uri);
276 let local_post = match local_post {
277 Some(p) => p,
278 None => {
279 return (
280 StatusCode::NOT_FOUND,
281 Json(json!({"error": "NotFound", "message": "Post not found"})),
282 )
283 .into_response();
284 }
285 };
286 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
287 .fetch_optional(&state.db)
288 .await
289 {
290 Ok(Some(h)) => h,
291 Ok(None) => requester_did.clone(),
292 Err(e) => {
293 warn!("Database error fetching handle: {:?}", e);
294 requester_did.clone()
295 }
296 };
297 let post_view = format_local_post(
298 local_post,
299 &requester_did,
300 &handle,
301 local_records.profile.as_ref(),
302 );
303 let thread = PostThreadOutput {
304 thread: ThreadNode::Post(Box::new(ThreadViewPost {
305 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
306 post: post_view,
307 parent: None,
308 replies: None,
309 extra: HashMap::new(),
310 })),
311 threadgate: None,
312 };
313 let lag = get_local_lag(&local_records);
314 format_munged_response(thread, lag)
315}