this repo has no description
1use crate::api::read_after_write::{
2 extract_repo_rev, format_local_post, format_munged_response, get_local_lag,
3 get_records_since_rev, proxy_to_appview, PostRecord, PostView, RecordDescript,
4};
5use crate::state::AppState;
6use axum::{
7 extract::{Query, State},
8 http::StatusCode,
9 response::{IntoResponse, Response},
10 Json,
11};
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Value};
14use std::collections::HashMap;
15use tracing::warn;
16#[derive(Deserialize)]
17pub struct GetPostThreadParams {
18 pub uri: String,
19 pub depth: Option<u32>,
20 #[serde(rename = "parentHeight")]
21 pub parent_height: Option<u32>,
22}
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(rename_all = "camelCase")]
25pub struct ThreadViewPost {
26 #[serde(rename = "$type")]
27 pub thread_type: Option<String>,
28 pub post: PostView,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub parent: Option<Box<ThreadNode>>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub replies: Option<Vec<ThreadNode>>,
33 #[serde(flatten)]
34 pub extra: HashMap<String, Value>,
35}
36#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(untagged)]
38pub enum ThreadNode {
39 Post(ThreadViewPost),
40 NotFound(ThreadNotFound),
41 Blocked(ThreadBlocked),
42}
43#[derive(Debug, Clone, Serialize, Deserialize)]
44#[serde(rename_all = "camelCase")]
45pub struct ThreadNotFound {
46 #[serde(rename = "$type")]
47 pub thread_type: String,
48 pub uri: String,
49 pub not_found: bool,
50}
51#[derive(Debug, Clone, Serialize, Deserialize)]
52#[serde(rename_all = "camelCase")]
53pub struct ThreadBlocked {
54 #[serde(rename = "$type")]
55 pub thread_type: String,
56 pub uri: String,
57 pub blocked: bool,
58 pub author: Value,
59}
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct PostThreadOutput {
62 pub thread: ThreadNode,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub threadgate: Option<Value>,
65}
66const MAX_THREAD_DEPTH: usize = 10;
67fn add_replies_to_thread(
68 thread: &mut ThreadViewPost,
69 local_posts: &[RecordDescript<PostRecord>],
70 author_did: &str,
71 author_handle: &str,
72 depth: usize,
73) {
74 if depth >= MAX_THREAD_DEPTH {
75 return;
76 }
77 let thread_uri = &thread.post.uri;
78 let replies: Vec<_> = local_posts
79 .iter()
80 .filter(|p| {
81 p.record
82 .reply
83 .as_ref()
84 .and_then(|r| r.get("parent"))
85 .and_then(|parent| parent.get("uri"))
86 .and_then(|u| u.as_str())
87 == Some(thread_uri)
88 })
89 .map(|p| {
90 let post_view = format_local_post(p, author_did, author_handle, None);
91 ThreadNode::Post(ThreadViewPost {
92 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
93 post: post_view,
94 parent: None,
95 replies: None,
96 extra: HashMap::new(),
97 })
98 })
99 .collect();
100 if !replies.is_empty() {
101 match &mut thread.replies {
102 Some(existing) => existing.extend(replies),
103 None => thread.replies = Some(replies),
104 }
105 }
106 if let Some(ref mut existing_replies) = thread.replies {
107 for reply in existing_replies.iter_mut() {
108 if let ThreadNode::Post(reply_thread) = reply {
109 add_replies_to_thread(reply_thread, local_posts, author_did, author_handle, depth + 1);
110 }
111 }
112 }
113}
114pub async fn get_post_thread(
115 State(state): State<AppState>,
116 headers: axum::http::HeaderMap,
117 Query(params): Query<GetPostThreadParams>,
118) -> Response {
119 let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
120 let auth_did = if let Some(h) = auth_header {
121 if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
122 match crate::auth::validate_bearer_token(&state.db, &token).await {
123 Ok(user) => Some(user.did),
124 Err(_) => None,
125 }
126 } else {
127 None
128 }
129 } else {
130 None
131 };
132 let mut query_params = HashMap::new();
133 query_params.insert("uri".to_string(), params.uri.clone());
134 if let Some(depth) = params.depth {
135 query_params.insert("depth".to_string(), depth.to_string());
136 }
137 if let Some(parent_height) = params.parent_height {
138 query_params.insert("parentHeight".to_string(), parent_height.to_string());
139 }
140 let proxy_result =
141 match proxy_to_appview("app.bsky.feed.getPostThread", &query_params, auth_header).await {
142 Ok(r) => r,
143 Err(e) => return e,
144 };
145 if proxy_result.status == StatusCode::NOT_FOUND {
146 return handle_not_found(&state, ¶ms.uri, auth_did, &proxy_result.headers).await;
147 }
148 if !proxy_result.status.is_success() {
149 return (proxy_result.status, proxy_result.body).into_response();
150 }
151 let rev = match extract_repo_rev(&proxy_result.headers) {
152 Some(r) => r,
153 None => return (proxy_result.status, proxy_result.body).into_response(),
154 };
155 let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) {
156 Ok(t) => t,
157 Err(e) => {
158 warn!("Failed to parse post thread response: {:?}", e);
159 return (proxy_result.status, proxy_result.body).into_response();
160 }
161 };
162 let requester_did = match auth_did {
163 Some(d) => d,
164 None => return (StatusCode::OK, Json(thread_output)).into_response(),
165 };
166 let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
167 Ok(r) => r,
168 Err(e) => {
169 warn!("Failed to get local records: {}", e);
170 return (proxy_result.status, proxy_result.body).into_response();
171 }
172 };
173 if local_records.posts.is_empty() {
174 return (StatusCode::OK, Json(thread_output)).into_response();
175 }
176 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
177 .fetch_optional(&state.db)
178 .await
179 {
180 Ok(Some(h)) => h,
181 Ok(None) => requester_did.clone(),
182 Err(e) => {
183 warn!("Database error fetching handle: {:?}", e);
184 requester_did.clone()
185 }
186 };
187 if let ThreadNode::Post(ref mut thread_post) = thread_output.thread {
188 add_replies_to_thread(thread_post, &local_records.posts, &requester_did, &handle, 0);
189 }
190 let lag = get_local_lag(&local_records);
191 format_munged_response(thread_output, lag)
192}
193async fn handle_not_found(
194 state: &AppState,
195 uri: &str,
196 auth_did: Option<String>,
197 headers: &axum::http::HeaderMap,
198) -> Response {
199 let rev = match extract_repo_rev(headers) {
200 Some(r) => r,
201 None => {
202 return (
203 StatusCode::NOT_FOUND,
204 Json(json!({"error": "NotFound", "message": "Post not found"})),
205 )
206 .into_response()
207 }
208 };
209 let requester_did = match auth_did {
210 Some(d) => d,
211 None => {
212 return (
213 StatusCode::NOT_FOUND,
214 Json(json!({"error": "NotFound", "message": "Post not found"})),
215 )
216 .into_response()
217 }
218 };
219 let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect();
220 if uri_parts.len() != 3 {
221 return (
222 StatusCode::NOT_FOUND,
223 Json(json!({"error": "NotFound", "message": "Post not found"})),
224 )
225 .into_response();
226 }
227 let post_did = uri_parts[0];
228 if post_did != requester_did {
229 return (
230 StatusCode::NOT_FOUND,
231 Json(json!({"error": "NotFound", "message": "Post not found"})),
232 )
233 .into_response();
234 }
235 let local_records = match get_records_since_rev(state, &requester_did, &rev).await {
236 Ok(r) => r,
237 Err(_) => {
238 return (
239 StatusCode::NOT_FOUND,
240 Json(json!({"error": "NotFound", "message": "Post not found"})),
241 )
242 .into_response()
243 }
244 };
245 let local_post = local_records.posts.iter().find(|p| p.uri == uri);
246 let local_post = match local_post {
247 Some(p) => p,
248 None => {
249 return (
250 StatusCode::NOT_FOUND,
251 Json(json!({"error": "NotFound", "message": "Post not found"})),
252 )
253 .into_response()
254 }
255 };
256 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
257 .fetch_optional(&state.db)
258 .await
259 {
260 Ok(Some(h)) => h,
261 Ok(None) => requester_did.clone(),
262 Err(e) => {
263 warn!("Database error fetching handle: {:?}", e);
264 requester_did.clone()
265 }
266 };
267 let post_view = format_local_post(
268 local_post,
269 &requester_did,
270 &handle,
271 local_records.profile.as_ref(),
272 );
273 let thread = PostThreadOutput {
274 thread: ThreadNode::Post(ThreadViewPost {
275 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
276 post: post_view,
277 parent: None,
278 replies: None,
279 extra: HashMap::new(),
280 }),
281 threadgate: None,
282 };
283 let lag = get_local_lag(&local_records);
284 format_munged_response(thread, lag)
285}