this repo has no description
1use crate::api::read_after_write::{ 2 extract_repo_rev, format_local_post, format_munged_response, get_local_lag, 3 get_records_since_rev, proxy_to_appview, PostRecord, PostView, RecordDescript, 4}; 5use crate::state::AppState; 6use axum::{ 7 extract::{Query, State}, 8 http::StatusCode, 9 response::{IntoResponse, Response}, 10 Json, 11}; 12use serde::{Deserialize, Serialize}; 13use serde_json::{json, Value}; 14use std::collections::HashMap; 15use tracing::warn; 16 17#[derive(Deserialize)] 18pub struct GetPostThreadParams { 19 pub uri: String, 20 pub depth: Option<u32>, 21 #[serde(rename = "parentHeight")] 22 pub parent_height: Option<u32>, 23} 24 25#[derive(Debug, Clone, Serialize, Deserialize)] 26#[serde(rename_all = "camelCase")] 27pub struct ThreadViewPost { 28 #[serde(rename = "$type")] 29 pub thread_type: Option<String>, 30 pub post: PostView, 31 #[serde(skip_serializing_if = "Option::is_none")] 32 pub parent: Option<Box<ThreadNode>>, 33 #[serde(skip_serializing_if = "Option::is_none")] 34 pub replies: Option<Vec<ThreadNode>>, 35 #[serde(flatten)] 36 pub extra: HashMap<String, Value>, 37} 38 39#[derive(Debug, Clone, Serialize, Deserialize)] 40#[serde(untagged)] 41pub enum ThreadNode { 42 Post(ThreadViewPost), 43 NotFound(ThreadNotFound), 44 Blocked(ThreadBlocked), 45} 46 47#[derive(Debug, Clone, Serialize, Deserialize)] 48#[serde(rename_all = "camelCase")] 49pub struct ThreadNotFound { 50 #[serde(rename = "$type")] 51 pub thread_type: String, 52 pub uri: String, 53 pub not_found: bool, 54} 55 56#[derive(Debug, Clone, Serialize, Deserialize)] 57#[serde(rename_all = "camelCase")] 58pub struct ThreadBlocked { 59 #[serde(rename = "$type")] 60 pub thread_type: String, 61 pub uri: String, 62 pub blocked: bool, 63 pub author: Value, 64} 65 66#[derive(Debug, Clone, Serialize, Deserialize)] 67pub struct PostThreadOutput { 68 pub thread: ThreadNode, 69 #[serde(skip_serializing_if = "Option::is_none")] 70 pub threadgate: Option<Value>, 71} 72 73const MAX_THREAD_DEPTH: usize = 10; 74 75fn add_replies_to_thread( 76 thread: &mut ThreadViewPost, 77 local_posts: &[RecordDescript<PostRecord>], 78 author_did: &str, 79 author_handle: &str, 80 depth: usize, 81) { 82 if depth >= MAX_THREAD_DEPTH { 83 return; 84 } 85 86 let thread_uri = &thread.post.uri; 87 88 let replies: Vec<_> = local_posts 89 .iter() 90 .filter(|p| { 91 p.record 92 .reply 93 .as_ref() 94 .and_then(|r| r.get("parent")) 95 .and_then(|parent| parent.get("uri")) 96 .and_then(|u| u.as_str()) 97 == Some(thread_uri) 98 }) 99 .map(|p| { 100 let post_view = format_local_post(p, author_did, author_handle, None); 101 ThreadNode::Post(ThreadViewPost { 102 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()), 103 post: post_view, 104 parent: None, 105 replies: None, 106 extra: HashMap::new(), 107 }) 108 }) 109 .collect(); 110 111 if !replies.is_empty() { 112 match &mut thread.replies { 113 Some(existing) => existing.extend(replies), 114 None => thread.replies = Some(replies), 115 } 116 } 117 118 if let Some(ref mut existing_replies) = thread.replies { 119 for reply in existing_replies.iter_mut() { 120 if let ThreadNode::Post(reply_thread) = reply { 121 add_replies_to_thread(reply_thread, local_posts, author_did, author_handle, depth + 1); 122 } 123 } 124 } 125} 126 127pub async fn get_post_thread( 128 State(state): State<AppState>, 129 headers: axum::http::HeaderMap, 130 Query(params): Query<GetPostThreadParams>, 131) -> Response { 132 let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 133 134 let auth_did = if let Some(h) = auth_header { 135 if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { 136 match crate::auth::validate_bearer_token(&state.db, &token).await { 137 Ok(user) => Some(user.did), 138 Err(_) => None, 139 } 140 } else { 141 None 142 } 143 } else { 144 None 145 }; 146 147 let mut query_params = HashMap::new(); 148 query_params.insert("uri".to_string(), params.uri.clone()); 149 if let Some(depth) = params.depth { 150 query_params.insert("depth".to_string(), depth.to_string()); 151 } 152 if let Some(parent_height) = params.parent_height { 153 query_params.insert("parentHeight".to_string(), parent_height.to_string()); 154 } 155 156 let proxy_result = 157 match proxy_to_appview("app.bsky.feed.getPostThread", &query_params, auth_header).await { 158 Ok(r) => r, 159 Err(e) => return e, 160 }; 161 162 if proxy_result.status == StatusCode::NOT_FOUND { 163 return handle_not_found(&state, &params.uri, auth_did, &proxy_result.headers).await; 164 } 165 166 if !proxy_result.status.is_success() { 167 return (proxy_result.status, proxy_result.body).into_response(); 168 } 169 170 let rev = match extract_repo_rev(&proxy_result.headers) { 171 Some(r) => r, 172 None => return (proxy_result.status, proxy_result.body).into_response(), 173 }; 174 175 let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) { 176 Ok(t) => t, 177 Err(e) => { 178 warn!("Failed to parse post thread response: {:?}", e); 179 return (proxy_result.status, proxy_result.body).into_response(); 180 } 181 }; 182 183 let requester_did = match auth_did { 184 Some(d) => d, 185 None => return (StatusCode::OK, Json(thread_output)).into_response(), 186 }; 187 188 let local_records = match get_records_since_rev(&state, &requester_did, &rev).await { 189 Ok(r) => r, 190 Err(e) => { 191 warn!("Failed to get local records: {}", e); 192 return (proxy_result.status, proxy_result.body).into_response(); 193 } 194 }; 195 196 if local_records.posts.is_empty() { 197 return (StatusCode::OK, Json(thread_output)).into_response(); 198 } 199 200 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) 201 .fetch_optional(&state.db) 202 .await 203 { 204 Ok(Some(h)) => h, 205 Ok(None) => requester_did.clone(), 206 Err(e) => { 207 warn!("Database error fetching handle: {:?}", e); 208 requester_did.clone() 209 } 210 }; 211 212 if let ThreadNode::Post(ref mut thread_post) = thread_output.thread { 213 add_replies_to_thread(thread_post, &local_records.posts, &requester_did, &handle, 0); 214 } 215 216 let lag = get_local_lag(&local_records); 217 format_munged_response(thread_output, lag) 218} 219 220async fn handle_not_found( 221 state: &AppState, 222 uri: &str, 223 auth_did: Option<String>, 224 headers: &axum::http::HeaderMap, 225) -> Response { 226 let rev = match extract_repo_rev(headers) { 227 Some(r) => r, 228 None => { 229 return ( 230 StatusCode::NOT_FOUND, 231 Json(json!({"error": "NotFound", "message": "Post not found"})), 232 ) 233 .into_response() 234 } 235 }; 236 237 let requester_did = match auth_did { 238 Some(d) => d, 239 None => { 240 return ( 241 StatusCode::NOT_FOUND, 242 Json(json!({"error": "NotFound", "message": "Post not found"})), 243 ) 244 .into_response() 245 } 246 }; 247 248 let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect(); 249 if uri_parts.len() != 3 { 250 return ( 251 StatusCode::NOT_FOUND, 252 Json(json!({"error": "NotFound", "message": "Post not found"})), 253 ) 254 .into_response(); 255 } 256 257 let post_did = uri_parts[0]; 258 if post_did != requester_did { 259 return ( 260 StatusCode::NOT_FOUND, 261 Json(json!({"error": "NotFound", "message": "Post not found"})), 262 ) 263 .into_response(); 264 } 265 266 let local_records = match get_records_since_rev(state, &requester_did, &rev).await { 267 Ok(r) => r, 268 Err(_) => { 269 return ( 270 StatusCode::NOT_FOUND, 271 Json(json!({"error": "NotFound", "message": "Post not found"})), 272 ) 273 .into_response() 274 } 275 }; 276 277 let local_post = local_records.posts.iter().find(|p| p.uri == uri); 278 279 let local_post = match local_post { 280 Some(p) => p, 281 None => { 282 return ( 283 StatusCode::NOT_FOUND, 284 Json(json!({"error": "NotFound", "message": "Post not found"})), 285 ) 286 .into_response() 287 } 288 }; 289 290 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) 291 .fetch_optional(&state.db) 292 .await 293 { 294 Ok(Some(h)) => h, 295 Ok(None) => requester_did.clone(), 296 Err(e) => { 297 warn!("Database error fetching handle: {:?}", e); 298 requester_did.clone() 299 } 300 }; 301 302 let post_view = format_local_post( 303 local_post, 304 &requester_did, 305 &handle, 306 local_records.profile.as_ref(), 307 ); 308 309 let thread = PostThreadOutput { 310 thread: ThreadNode::Post(ThreadViewPost { 311 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()), 312 post: post_view, 313 parent: None, 314 replies: None, 315 extra: HashMap::new(), 316 }), 317 threadgate: None, 318 }; 319 320 let lag = get_local_lag(&local_records); 321 format_munged_response(thread, lag) 322}