APIs for links and references in the ATmosphere

consolidate oauth handling

+258 -137
+1
Cargo.lock
··· 5024 5024 "rand 0.9.1", 5025 5025 "serde", 5026 5026 "serde_json", 5027 + "thiserror 2.0.12", 5027 5028 "tokio", 5028 5029 "tokio-util", 5029 5030 "url",
+1
who-am-i/Cargo.toml
··· 20 20 rand = "0.9.1" 21 21 serde = { version = "1.0.219", features = ["derive"] } 22 22 serde_json = "1.0.140" 23 + thiserror = "2.0.12" 23 24 tokio = { version = "1.45.1", features = ["full", "macros"] } 24 25 tokio-util = "0.7.15" 25 26 url = "2.5.4"
-34
who-am-i/src/dns_resolver.rs
··· 1 - // originally from weaver: https://github.com/rsform/weaver/blob/ee08213a85e09889b9bd66beceecee92ac025801/crates/weaver-common/src/resolver.rs 2 - // MPL 2.0: https://github.com/rsform/weaver/blob/ee08213a85e09889b9bd66beceecee92ac025801/LICENSE 3 - 4 - use atrium_identity::handle::DnsTxtResolver; 5 - use hickory_resolver::TokioResolver; 6 - 7 - pub struct HickoryDnsTxtResolver { 8 - resolver: TokioResolver, 9 - } 10 - 11 - impl Default for HickoryDnsTxtResolver { 12 - fn default() -> Self { 13 - Self { 14 - resolver: TokioResolver::builder_tokio() 15 - .expect("failed to create resolver") 16 - .build(), 17 - } 18 - } 19 - } 20 - 21 - impl DnsTxtResolver for HickoryDnsTxtResolver { 22 - async fn resolve( 23 - &self, 24 - query: &str, 25 - ) -> core::result::Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> { 26 - Ok(self 27 - .resolver 28 - .txt_lookup(query) 29 - .await? 30 - .iter() 31 - .map(|txt| txt.to_string()) 32 - .collect()) 33 - } 34 - }
+17 -2
who-am-i/src/expiring_task_map.rs
··· 6 6 use tokio::time::sleep; 7 7 use tokio_util::sync::{CancellationToken, DropGuard}; 8 8 9 - #[derive(Clone)] 10 9 pub struct ExpiringTaskMap<T>(TaskMap<T>); 10 + 11 + /// need to manually implement clone because T is allowed to not be clone 12 + impl<T> Clone for ExpiringTaskMap<T> { 13 + fn clone(&self) -> Self { 14 + Self(self.0.clone()) 15 + } 16 + } 11 17 12 18 impl<T: Send + 'static> ExpiringTaskMap<T> { 13 19 pub fn new(expiration: Duration) -> Self { ··· 58 64 } 59 65 } 60 66 61 - #[derive(Clone)] 62 67 struct TaskMap<T> { 63 68 map: Arc<DashMap<String, (DropGuard, JoinHandle<T>)>>, 64 69 expiration: Duration, 65 70 } 71 + 72 + /// need to manually implement clone because T is allowed to not be clone 73 + impl<T> Clone for TaskMap<T> { 74 + fn clone(&self) -> Self { 75 + Self { 76 + map: self.map.clone(), 77 + expiration: self.expiration, 78 + } 79 + } 80 + }
-22
who-am-i/src/identity_resolver.rs
··· 1 - use atrium_api::types::string::Did; 2 - use atrium_common::resolver::Resolver; 3 - use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}; 4 - use atrium_oauth::DefaultHttpClient; 5 - use std::sync::Arc; 6 - 7 - pub async fn resolve_identity(did: String) -> Option<String> { 8 - let http_client = Arc::new(DefaultHttpClient::default()); 9 - let resolver = CommonDidResolver::new(CommonDidResolverConfig { 10 - plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), 11 - http_client: Arc::clone(&http_client), 12 - }); 13 - let doc = resolver.resolve(&Did::new(did).unwrap()).await.unwrap(); // TODO: this is only half the resolution? or is atrium checking dns? 14 - // tokio::time::sleep(std::time::Duration::from_secs(2)).await; 15 - doc.also_known_as.and_then(|mut aka| { 16 - if aka.is_empty() { 17 - None 18 - } else { 19 - Some(aka.remove(0)) 20 - } 21 - }) 22 - }
+1 -5
who-am-i/src/lib.rs
··· 1 - mod dns_resolver; 2 1 mod expiring_task_map; 3 - mod identity_resolver; 4 2 mod oauth; 5 3 mod server; 6 4 7 - pub use dns_resolver::HickoryDnsTxtResolver; 8 5 pub use expiring_task_map::ExpiringTaskMap; 9 - pub use identity_resolver::resolve_identity; 10 - pub use oauth::{Client, authorize, client}; 6 + pub use oauth::{OAuth, OauthCallbackParams, ResolveHandleError}; 11 7 pub use server::serve;
+197 -47
who-am-i/src/oauth.rs
··· 1 - use crate::HickoryDnsTxtResolver; 1 + use atrium_api::{agent::SessionManager, types::string::Did}; 2 + use atrium_common::resolver::Resolver; 2 3 use atrium_identity::{ 3 4 did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}, 4 - handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig}, 5 + handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver}, 5 6 }; 6 7 use atrium_oauth::{ 7 - AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, KnownScope, OAuthClient, 8 - OAuthClientConfig, OAuthResolverConfig, Scope, 8 + AtprotoLocalhostClientMetadata, AuthorizeOptions, CallbackParams, DefaultHttpClient, 9 + KnownScope, OAuthClient, OAuthClientConfig, OAuthResolverConfig, Scope, 9 10 store::{session::MemorySessionStore, state::MemoryStateStore}, 10 11 }; 12 + use hickory_resolver::TokioResolver; 13 + use serde::Deserialize; 11 14 use std::sync::Arc; 15 + use thiserror::Error; 12 16 13 - pub type Client = OAuthClient< 17 + const READONLY_SCOPE: [Scope; 1] = [Scope::Known(KnownScope::Atproto)]; 18 + 19 + #[derive(Debug, Deserialize)] 20 + pub struct CallbackErrorParams { 21 + error: String, 22 + error_description: Option<String>, 23 + #[allow(dead_code)] 24 + state: Option<String>, // TODO: we _should_ use state to associate the auth request but how to do that with atrium is unclear 25 + iss: Option<String>, 26 + } 27 + 28 + #[derive(Debug, Deserialize)] 29 + #[serde(untagged)] 30 + pub enum OauthCallbackParams { 31 + Granted(CallbackParams), 32 + Failed(CallbackErrorParams), 33 + } 34 + 35 + type Client = OAuthClient< 14 36 MemoryStateStore, 15 37 MemorySessionStore, 16 38 CommonDidResolver<DefaultHttpClient>, 17 39 AtprotoHandleResolver<HickoryDnsTxtResolver, DefaultHttpClient>, 18 40 >; 19 41 20 - pub fn client() -> Client { 21 - let http_client = Arc::new(DefaultHttpClient::default()); 22 - let config = OAuthClientConfig { 23 - client_metadata: AtprotoLocalhostClientMetadata { 24 - redirect_uris: Some(vec![String::from("http://127.0.0.1:9997/authorized")]), 25 - scopes: Some(vec![Scope::Known(KnownScope::Atproto)]), 26 - }, 27 - keys: None, 28 - resolver: OAuthResolverConfig { 29 - did_resolver: CommonDidResolver::new(CommonDidResolverConfig { 30 - plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), 31 - http_client: Arc::clone(&http_client), 32 - }), 33 - handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig { 34 - dns_txt_resolver: HickoryDnsTxtResolver::default(), 35 - http_client: Arc::clone(&http_client), 36 - }), 37 - authorization_server_metadata: Default::default(), 38 - protected_resource_metadata: Default::default(), 39 - }, 40 - // A store for saving state data while the user is being redirected to the authorization server. 41 - state_store: MemoryStateStore::default(), 42 - // A store for saving session data. 43 - session_store: MemorySessionStore::default(), 44 - }; 45 - let Ok(client) = OAuthClient::new(config) else { 46 - panic!("failed to create oauth client"); 47 - }; 48 - client 42 + #[derive(Clone)] 43 + pub struct OAuth { 44 + client: Arc<Client>, 45 + did_resolver: Arc<CommonDidResolver<DefaultHttpClient>>, 49 46 } 50 47 51 - pub async fn authorize(client: &Client, handle: &str) -> String { 52 - let Ok(url) = client 53 - .authorize( 54 - handle, 55 - AuthorizeOptions { 56 - scopes: vec![Scope::Known(KnownScope::Atproto)], 57 - ..Default::default() 48 + #[derive(Debug, Error)] 49 + #[error(transparent)] 50 + pub struct AuthSetupError(#[from] atrium_oauth::Error); 51 + 52 + #[derive(Debug, Error)] 53 + #[error(transparent)] 54 + pub struct AuthStartError(#[from] atrium_oauth::Error); 55 + 56 + #[derive(Debug, Error)] 57 + pub enum AuthCompleteError { 58 + #[error("the user denied request: {description:?} (from {issuer:?})")] 59 + Denied { 60 + description: Option<String>, 61 + issuer: Option<String>, 62 + }, 63 + #[error( 64 + "the request was denied for another reason: {error}: {description:?} (from {issuer:?})" 65 + )] 66 + Failed { 67 + error: String, 68 + description: Option<String>, 69 + issuer: Option<String>, 70 + }, 71 + #[error("failed to complete oauth callback: {0}")] 72 + CallbackFailed(atrium_oauth::Error), 73 + #[error("the authorized session did not contain a DID")] 74 + NoDid, 75 + } 76 + 77 + #[derive(Debug, Error)] 78 + pub enum ResolveHandleError { 79 + #[error("failed to resolve: {0}")] 80 + ResolutionFailed(#[from] atrium_identity::Error), 81 + #[error("identity resolved but no handle found for user")] 82 + NoHandle, 83 + #[error("found handle {0:?} but it appears invalid: {1}")] 84 + InvalidHandle(String, &'static str), 85 + } 86 + 87 + impl OAuth { 88 + pub fn new() -> Result<Self, AuthSetupError> { 89 + let http_client = Arc::new(DefaultHttpClient::default()); 90 + let did_resolver = || { 91 + CommonDidResolver::new(CommonDidResolverConfig { 92 + plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), 93 + http_client: http_client.clone(), 94 + }) 95 + }; 96 + let client_config = OAuthClientConfig { 97 + client_metadata: AtprotoLocalhostClientMetadata { 98 + redirect_uris: Some(vec![String::from("http://127.0.0.1:9997/authorized")]), 99 + scopes: Some(READONLY_SCOPE.to_vec()), 58 100 }, 59 - ) 60 - .await 61 - else { 62 - panic!("failed to authorize"); 63 - }; 64 - url 101 + keys: None, 102 + resolver: OAuthResolverConfig { 103 + did_resolver: did_resolver(), 104 + handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig { 105 + dns_txt_resolver: HickoryDnsTxtResolver::default(), 106 + http_client: Arc::clone(&http_client), 107 + }), 108 + authorization_server_metadata: Default::default(), 109 + protected_resource_metadata: Default::default(), 110 + }, 111 + state_store: MemoryStateStore::default(), 112 + session_store: MemorySessionStore::default(), 113 + }; 114 + 115 + let client = OAuthClient::new(client_config)?; 116 + 117 + Ok(Self { 118 + client: Arc::new(client), 119 + did_resolver: Arc::new(did_resolver()), 120 + }) 121 + } 122 + 123 + pub async fn begin(&self, handle: &str) -> Result<String, AuthStartError> { 124 + let auth_opts = AuthorizeOptions { 125 + scopes: READONLY_SCOPE.to_vec(), 126 + ..Default::default() 127 + }; 128 + Ok(self.client.authorize(handle, auth_opts).await?) 129 + } 130 + 131 + /// Finally, resolve the oauth flow to a verified DID 132 + pub async fn complete(&self, params: OauthCallbackParams) -> Result<Did, AuthCompleteError> { 133 + let params = match params { 134 + OauthCallbackParams::Granted(params) => params, 135 + OauthCallbackParams::Failed(p) if p.error == "access_denied" => { 136 + return Err(AuthCompleteError::Denied { 137 + description: p.error_description.clone(), 138 + issuer: p.iss.clone(), 139 + }); 140 + } 141 + OauthCallbackParams::Failed(p) => { 142 + return Err(AuthCompleteError::Failed { 143 + error: p.error.clone(), 144 + description: p.error_description.clone(), 145 + issuer: p.iss.clone(), 146 + }); 147 + } 148 + }; 149 + let (session, _) = self 150 + .client 151 + .callback(params) 152 + .await 153 + .map_err(AuthCompleteError::CallbackFailed)?; 154 + let Some(did) = session.did().await else { 155 + return Err(AuthCompleteError::NoDid); 156 + }; 157 + Ok(did) 158 + } 159 + 160 + pub async fn resolve_handle(&self, did: Did) -> Result<String, ResolveHandleError> { 161 + // TODO: this is only half the resolution? or is atrium checking dns? 162 + let doc = self.did_resolver.resolve(&did).await?; 163 + let Some(aka) = doc.also_known_as else { 164 + return Err(ResolveHandleError::NoHandle); 165 + }; 166 + let Some(at_uri_handle) = aka.first() else { 167 + return Err(ResolveHandleError::NoHandle); 168 + }; 169 + if aka.len() > 1 { 170 + eprintln!("more than one handle found for {did:?}"); 171 + } 172 + let Some(bare_handle) = at_uri_handle.strip_prefix("at://") else { 173 + return Err(ResolveHandleError::InvalidHandle( 174 + at_uri_handle.to_string(), 175 + "did not start with 'at://'", 176 + )); 177 + }; 178 + if bare_handle.is_empty() { 179 + return Err(ResolveHandleError::InvalidHandle( 180 + bare_handle.to_string(), 181 + "empty handle", 182 + )); 183 + } 184 + Ok(bare_handle.to_string()) 185 + } 186 + } 187 + 188 + pub struct HickoryDnsTxtResolver { 189 + resolver: TokioResolver, 190 + } 191 + 192 + impl Default for HickoryDnsTxtResolver { 193 + fn default() -> Self { 194 + Self { 195 + resolver: TokioResolver::builder_tokio() 196 + .expect("failed to create resolver") 197 + .build(), 198 + } 199 + } 200 + } 201 + 202 + impl DnsTxtResolver for HickoryDnsTxtResolver { 203 + async fn resolve( 204 + &self, 205 + query: &str, 206 + ) -> core::result::Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> { 207 + Ok(self 208 + .resolver 209 + .txt_lookup(query) 210 + .await? 211 + .iter() 212 + .map(|txt| txt.to_string()) 213 + .collect()) 214 + } 65 215 }
+41 -27
who-am-i/src/server.rs
··· 1 - use atrium_api::agent::SessionManager; 2 - use atrium_oauth::CallbackParams; 1 + use atrium_api::types::string::Did; 3 2 use axum::{ 4 3 Router, 5 4 extract::{FromRef, Query, State}, ··· 20 19 use tokio_util::sync::CancellationToken; 21 20 use url::Url; 22 21 23 - use crate::{Client, ExpiringTaskMap, authorize, client, resolve_identity}; 22 + use crate::{ExpiringTaskMap, OAuth, OauthCallbackParams, ResolveHandleError}; 24 23 25 24 const FAVICON: &[u8] = include_bytes!("../static/favicon.ico"); 26 25 const INDEX_HTML: &str = include_str!("../static/index.html"); ··· 34 33 pub key: Key, 35 34 pub one_clicks: Arc<HashSet<String>>, 36 35 pub engine: AppEngine, 37 - pub client: Arc<Client>, 38 - pub resolving: ExpiringTaskMap<Option<String>>, 36 + pub oauth: Arc<OAuth>, 37 + pub resolving: ExpiringTaskMap<Result<String, ResolveHandleError>>, 39 38 pub shutdown: CancellationToken, 40 39 } 41 40 ··· 62 61 // clients have to pick up their identity-resolving tasks within this period 63 62 let task_pickup_expiration = Duration::from_secs(15); 64 63 64 + let oauth = OAuth::new().unwrap(); 65 + 65 66 let state = AppState { 66 67 engine: Engine::new(hbs), 67 68 key: Key::from(app_secret.as_bytes()), // TODO: via config 68 69 one_clicks: Arc::new(HashSet::from_iter(one_click)), 69 - client: Arc::new(client()), 70 + oauth: Arc::new(oauth), 70 71 resolving: ExpiringTaskMap::new(task_pickup_expiration), 71 72 shutdown: shutdown.clone(), 72 73 }; ··· 92 93 93 94 async fn prompt( 94 95 State(AppState { 96 + one_clicks, 95 97 engine, 96 - one_clicks, 98 + oauth, 97 99 resolving, 98 100 shutdown, 99 101 .. ··· 118 120 .into_response(); 119 121 } 120 122 if let Some(did) = jar.get(DID_COOKIE_KEY) { 121 - let did = did.value_trimmed().to_string(); 123 + let Ok(did) = Did::new(did.value_trimmed().to_string()) else { 124 + return "did from cookie failed to parse".into_response(); 125 + }; 122 126 123 - let task_shutdown = shutdown.child_token(); 124 - let fetch_key = resolving.dispatch(resolve_identity(did.clone()), task_shutdown); 127 + let fetch_key = resolving.dispatch( 128 + { 129 + let oauth = oauth.clone(); 130 + let did = did.clone(); 131 + async move { oauth.resolve_handle(did.clone()).await } 132 + }, 133 + shutdown.child_token(), 134 + ); 125 135 126 136 RenderHtml( 127 137 "prompt-known", ··· 157 167 let Some(task_handle) = resolving.take(&params.fetch_key) else { 158 168 return "oops, task does not exist or is gone".into_response(); 159 169 }; 160 - if let Some(handle) = task_handle.await.unwrap() { 161 - // TODO: get active state etc. 162 - // ...but also, that's a bsky thing? 163 - let Some(handle) = handle.strip_prefix("at://") else { 164 - return "hmm, handle did not start with at://".into_response(); 165 - }; 170 + if let Ok(handle) = task_handle.await.unwrap() { 166 171 Json(json!({ "handle": handle })).into_response() 167 172 } else { 168 173 "no handle?".into_response() ··· 174 179 handle: String, 175 180 } 176 181 async fn start_oauth( 177 - State(state): State<AppState>, 182 + State(AppState { oauth, .. }): State<AppState>, 178 183 Query(params): Query<BeginOauthParams>, 179 184 jar: SignedCookieJar, 180 185 ) -> (SignedCookieJar, Redirect) { 181 186 // if any existing session was active, clear it first 182 187 let jar = jar.remove(DID_COOKIE_KEY); 183 188 184 - let auth_url = authorize(&state.client, &params.handle).await; 189 + let auth_url = oauth.begin(&params.handle).await.unwrap(); 185 190 (jar, Redirect::to(&auth_url)) 186 191 } 187 192 188 193 async fn complete_oauth( 189 - State(state): State<AppState>, 190 - Query(params): Query<CallbackParams>, 194 + State(AppState { 195 + engine, 196 + resolving, 197 + oauth, 198 + shutdown, 199 + .. 200 + }): State<AppState>, 201 + Query(params): Query<OauthCallbackParams>, 191 202 jar: SignedCookieJar, 192 203 ) -> (SignedCookieJar, impl IntoResponse) { 193 - let Ok((oauth_session, _)) = state.client.callback(params).await else { 204 + let Ok(did) = oauth.complete(params).await else { 194 205 panic!("failed to do client callback"); 195 206 }; 196 - let did = oauth_session.did().await.expect("a did to be present"); 197 207 198 208 let cookie = Cookie::build((DID_COOKIE_KEY, did.to_string())) 199 209 .http_only(true) ··· 203 213 204 214 let jar = jar.add(cookie); 205 215 206 - let task_shutdown = state.shutdown.child_token(); 207 - let fetch_key = state 208 - .resolving 209 - .dispatch(resolve_identity(did.to_string()), task_shutdown); 216 + let fetch_key = resolving.dispatch( 217 + { 218 + let oauth = oauth.clone(); 219 + let did = did.clone(); 220 + async move { oauth.resolve_handle(did.clone()).await } 221 + }, 222 + shutdown.child_token(), 223 + ); 210 224 211 225 ( 212 226 jar, 213 227 RenderHtml( 214 228 "authorized", 215 - state.engine, 229 + engine, 216 230 json!({ 217 231 "did": did, 218 232 "fetch_key": fetch_key,