this repo has no description
1use crate::api::read_after_write::{
2 extract_repo_rev, format_local_post, format_munged_response, get_local_lag,
3 get_records_since_rev, proxy_to_appview, PostRecord, PostView, RecordDescript,
4};
5use crate::state::AppState;
6use axum::{
7 extract::{Query, State},
8 http::StatusCode,
9 response::{IntoResponse, Response},
10 Json,
11};
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Value};
14use std::collections::HashMap;
15use tracing::warn;
16
17#[derive(Deserialize)]
18pub struct GetPostThreadParams {
19 pub uri: String,
20 pub depth: Option<u32>,
21 #[serde(rename = "parentHeight")]
22 pub parent_height: Option<u32>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(rename_all = "camelCase")]
27pub struct ThreadViewPost {
28 #[serde(rename = "$type")]
29 pub thread_type: Option<String>,
30 pub post: PostView,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub parent: Option<Box<ThreadNode>>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub replies: Option<Vec<ThreadNode>>,
35 #[serde(flatten)]
36 pub extra: HashMap<String, Value>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40#[serde(untagged)]
41pub enum ThreadNode {
42 Post(ThreadViewPost),
43 NotFound(ThreadNotFound),
44 Blocked(ThreadBlocked),
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(rename_all = "camelCase")]
49pub struct ThreadNotFound {
50 #[serde(rename = "$type")]
51 pub thread_type: String,
52 pub uri: String,
53 pub not_found: bool,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(rename_all = "camelCase")]
58pub struct ThreadBlocked {
59 #[serde(rename = "$type")]
60 pub thread_type: String,
61 pub uri: String,
62 pub blocked: bool,
63 pub author: Value,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct PostThreadOutput {
68 pub thread: ThreadNode,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub threadgate: Option<Value>,
71}
72
73const MAX_THREAD_DEPTH: usize = 10;
74
75fn add_replies_to_thread(
76 thread: &mut ThreadViewPost,
77 local_posts: &[RecordDescript<PostRecord>],
78 author_did: &str,
79 author_handle: &str,
80 depth: usize,
81) {
82 if depth >= MAX_THREAD_DEPTH {
83 return;
84 }
85 let thread_uri = &thread.post.uri;
86 let replies: Vec<_> = local_posts
87 .iter()
88 .filter(|p| {
89 p.record
90 .reply
91 .as_ref()
92 .and_then(|r| r.get("parent"))
93 .and_then(|parent| parent.get("uri"))
94 .and_then(|u| u.as_str())
95 == Some(thread_uri)
96 })
97 .map(|p| {
98 let post_view = format_local_post(p, author_did, author_handle, None);
99 ThreadNode::Post(ThreadViewPost {
100 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
101 post: post_view,
102 parent: None,
103 replies: None,
104 extra: HashMap::new(),
105 })
106 })
107 .collect();
108 if !replies.is_empty() {
109 match &mut thread.replies {
110 Some(existing) => existing.extend(replies),
111 None => thread.replies = Some(replies),
112 }
113 }
114 if let Some(ref mut existing_replies) = thread.replies {
115 for reply in existing_replies.iter_mut() {
116 if let ThreadNode::Post(reply_thread) = reply {
117 add_replies_to_thread(reply_thread, local_posts, author_did, author_handle, depth + 1);
118 }
119 }
120 }
121}
122
123pub async fn get_post_thread(
124 State(state): State<AppState>,
125 headers: axum::http::HeaderMap,
126 Query(params): Query<GetPostThreadParams>,
127) -> Response {
128 let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
129 let auth_user = if let Some(h) = auth_header {
130 if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
131 crate::auth::validate_bearer_token(&state.db, &token).await.ok()
132 } else {
133 None
134 }
135 } else {
136 None
137 };
138 let auth_did = auth_user.as_ref().map(|u| u.did.clone());
139 let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
140 let mut query_params = HashMap::new();
141 query_params.insert("uri".to_string(), params.uri.clone());
142 if let Some(depth) = params.depth {
143 query_params.insert("depth".to_string(), depth.to_string());
144 }
145 if let Some(parent_height) = params.parent_height {
146 query_params.insert("parentHeight".to_string(), parent_height.to_string());
147 }
148 let proxy_result =
149 match proxy_to_appview("app.bsky.feed.getPostThread", &query_params, auth_did.as_deref().unwrap_or(""), auth_key_bytes.as_deref()).await {
150 Ok(r) => r,
151 Err(e) => return e,
152 };
153 if proxy_result.status == StatusCode::NOT_FOUND {
154 return handle_not_found(&state, ¶ms.uri, auth_did, &proxy_result.headers).await;
155 }
156 if !proxy_result.status.is_success() {
157 return proxy_result.into_response();
158 }
159 let rev = match extract_repo_rev(&proxy_result.headers) {
160 Some(r) => r,
161 None => return proxy_result.into_response(),
162 };
163 let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) {
164 Ok(t) => t,
165 Err(e) => {
166 warn!("Failed to parse post thread response: {:?}", e);
167 return proxy_result.into_response();
168 }
169 };
170 let requester_did = match auth_did {
171 Some(d) => d,
172 None => return (StatusCode::OK, Json(thread_output)).into_response(),
173 };
174 let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
175 Ok(r) => r,
176 Err(e) => {
177 warn!("Failed to get local records: {}", e);
178 return proxy_result.into_response();
179 }
180 };
181 if local_records.posts.is_empty() {
182 return (StatusCode::OK, Json(thread_output)).into_response();
183 }
184 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
185 .fetch_optional(&state.db)
186 .await
187 {
188 Ok(Some(h)) => h,
189 Ok(None) => requester_did.clone(),
190 Err(e) => {
191 warn!("Database error fetching handle: {:?}", e);
192 requester_did.clone()
193 }
194 };
195 if let ThreadNode::Post(ref mut thread_post) = thread_output.thread {
196 add_replies_to_thread(thread_post, &local_records.posts, &requester_did, &handle, 0);
197 }
198 let lag = get_local_lag(&local_records);
199 format_munged_response(thread_output, lag)
200}
201
202async fn handle_not_found(
203 state: &AppState,
204 uri: &str,
205 auth_did: Option<String>,
206 headers: &axum::http::HeaderMap,
207) -> Response {
208 let rev = match extract_repo_rev(headers) {
209 Some(r) => r,
210 None => {
211 return (
212 StatusCode::NOT_FOUND,
213 Json(json!({"error": "NotFound", "message": "Post not found"})),
214 )
215 .into_response()
216 }
217 };
218 let requester_did = match auth_did {
219 Some(d) => d,
220 None => {
221 return (
222 StatusCode::NOT_FOUND,
223 Json(json!({"error": "NotFound", "message": "Post not found"})),
224 )
225 .into_response()
226 }
227 };
228 let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect();
229 if uri_parts.len() != 3 {
230 return (
231 StatusCode::NOT_FOUND,
232 Json(json!({"error": "NotFound", "message": "Post not found"})),
233 )
234 .into_response();
235 }
236 let post_did = uri_parts[0];
237 if post_did != requester_did {
238 return (
239 StatusCode::NOT_FOUND,
240 Json(json!({"error": "NotFound", "message": "Post not found"})),
241 )
242 .into_response();
243 }
244 let local_records = match get_records_since_rev(state, &requester_did, &rev).await {
245 Ok(r) => r,
246 Err(_) => {
247 return (
248 StatusCode::NOT_FOUND,
249 Json(json!({"error": "NotFound", "message": "Post not found"})),
250 )
251 .into_response()
252 }
253 };
254 let local_post = local_records.posts.iter().find(|p| p.uri == uri);
255 let local_post = match local_post {
256 Some(p) => p,
257 None => {
258 return (
259 StatusCode::NOT_FOUND,
260 Json(json!({"error": "NotFound", "message": "Post not found"})),
261 )
262 .into_response()
263 }
264 };
265 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
266 .fetch_optional(&state.db)
267 .await
268 {
269 Ok(Some(h)) => h,
270 Ok(None) => requester_did.clone(),
271 Err(e) => {
272 warn!("Database error fetching handle: {:?}", e);
273 requester_did.clone()
274 }
275 };
276 let post_view = format_local_post(
277 local_post,
278 &requester_did,
279 &handle,
280 local_records.profile.as_ref(),
281 );
282 let thread = PostThreadOutput {
283 thread: ThreadNode::Post(ThreadViewPost {
284 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
285 post: post_view,
286 parent: None,
287 replies: None,
288 extra: HashMap::new(),
289 }),
290 threadgate: None,
291 };
292 let lag = get_local_lag(&local_records);
293 format_munged_response(thread, lag)
294}