this repo has no description
1use crate::state::AppState;
2use axum::{
3 body::Bytes,
4 extract::{Path, Query, State},
5 http::{HeaderMap, Method, StatusCode},
6 response::{IntoResponse, Response},
7};
8use reqwest::Client;
9use sqlx::Row;
10use std::collections::HashMap;
11use tracing::{error, info};
12
13pub async fn proxy_handler(
14 State(state): State<AppState>,
15 Path(method): Path<String>,
16 method_verb: Method,
17 headers: HeaderMap,
18 Query(params): Query<HashMap<String, String>>,
19 body: Bytes,
20) -> Response {
21 let proxy_header = headers
22 .get("atproto-proxy")
23 .and_then(|h| h.to_str().ok())
24 .map(|s| s.to_string());
25
26 let appview_url = match &proxy_header {
27 Some(url) => url.clone(),
28 None => match std::env::var("APPVIEW_URL") {
29 Ok(url) => url,
30 Err(_) => {
31 return (StatusCode::BAD_GATEWAY, "No upstream AppView configured").into_response();
32 }
33 },
34 };
35
36 let target_url = format!("{}/xrpc/{}", appview_url, method);
37
38 info!("Proxying {} request to {}", method_verb, target_url);
39
40 let client = Client::new();
41
42 let mut request_builder = client.request(method_verb, &target_url).query(¶ms);
43
44 let mut auth_header_val = headers.get("Authorization").map(|h| h.clone());
45
46 if let Some(aud) = &proxy_header {
47 if let Some(auth_val) = &auth_header_val {
48 if let Ok(token) = auth_val.to_str() {
49 let token = token.replace("Bearer ", "");
50 if let Ok(did) = crate::auth::get_did_from_token(&token) {
51 let key_row = sqlx::query("SELECT k.key_bytes FROM user_keys k JOIN users u ON k.user_id = u.id WHERE u.did = $1")
52 .bind(&did)
53 .fetch_optional(&state.db)
54 .await;
55
56 if let Ok(Some(row)) = key_row {
57 let key_bytes: Vec<u8> = row.get("key_bytes");
58 if let Ok(new_token) =
59 crate::auth::create_service_token(&did, aud, &method, &key_bytes)
60 {
61 if let Ok(val) =
62 axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token))
63 {
64 auth_header_val = Some(val);
65 }
66 }
67 }
68 }
69 }
70 }
71 }
72
73 if let Some(val) = auth_header_val {
74 request_builder = request_builder.header("Authorization", val);
75 }
76
77 for (key, value) in headers.iter() {
78 if key != "host" && key != "content-length" && key != "authorization" {
79 request_builder = request_builder.header(key, value);
80 }
81 }
82
83 request_builder = request_builder.body(body);
84
85 match request_builder.send().await {
86 Ok(resp) => {
87 let status = resp.status();
88 let headers = resp.headers().clone();
89 let body = match resp.bytes().await {
90 Ok(b) => b,
91 Err(e) => {
92 error!("Error reading proxy response body: {:?}", e);
93 return (StatusCode::BAD_GATEWAY, "Error reading upstream response")
94 .into_response();
95 }
96 };
97
98 let mut response_builder = Response::builder().status(status);
99
100 for (key, value) in headers.iter() {
101 response_builder = response_builder.header(key, value);
102 }
103
104 match response_builder.body(axum::body::Body::from(body)) {
105 Ok(r) => r,
106 Err(e) => {
107 error!("Error building proxy response: {:?}", e);
108 (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response()
109 }
110 }
111 }
112 Err(e) => {
113 error!("Error sending proxy request: {:?}", e);
114 if e.is_timeout() {
115 (StatusCode::GATEWAY_TIMEOUT, "Upstream Timeout").into_response()
116 } else {
117 (StatusCode::BAD_GATEWAY, "Upstream Error").into_response()
118 }
119 }
120 }
121}