tangled
alpha
login
or
join now
bad-example.com
/
microcosm-links
7
fork
atom
APIs for links and references in the ATmosphere
7
fork
atom
overview
issues
pulls
pipelines
consolidate oauth handling
bad-example.com
8 months ago
95c60ce3
5340be12
+258
-137
8 changed files
expand all
collapse all
unified
split
Cargo.lock
who-am-i
Cargo.toml
src
dns_resolver.rs
expiring_task_map.rs
identity_resolver.rs
lib.rs
oauth.rs
server.rs
+1
Cargo.lock
···
5024
5024
"rand 0.9.1",
5025
5025
"serde",
5026
5026
"serde_json",
5027
5027
+
"thiserror 2.0.12",
5027
5028
"tokio",
5028
5029
"tokio-util",
5029
5030
"url",
+1
who-am-i/Cargo.toml
···
20
20
rand = "0.9.1"
21
21
serde = { version = "1.0.219", features = ["derive"] }
22
22
serde_json = "1.0.140"
23
23
+
thiserror = "2.0.12"
23
24
tokio = { version = "1.45.1", features = ["full", "macros"] }
24
25
tokio-util = "0.7.15"
25
26
url = "2.5.4"
-34
who-am-i/src/dns_resolver.rs
···
1
1
-
// originally from weaver: https://github.com/rsform/weaver/blob/ee08213a85e09889b9bd66beceecee92ac025801/crates/weaver-common/src/resolver.rs
2
2
-
// MPL 2.0: https://github.com/rsform/weaver/blob/ee08213a85e09889b9bd66beceecee92ac025801/LICENSE
3
3
-
4
4
-
use atrium_identity::handle::DnsTxtResolver;
5
5
-
use hickory_resolver::TokioResolver;
6
6
-
7
7
-
pub struct HickoryDnsTxtResolver {
8
8
-
resolver: TokioResolver,
9
9
-
}
10
10
-
11
11
-
impl Default for HickoryDnsTxtResolver {
12
12
-
fn default() -> Self {
13
13
-
Self {
14
14
-
resolver: TokioResolver::builder_tokio()
15
15
-
.expect("failed to create resolver")
16
16
-
.build(),
17
17
-
}
18
18
-
}
19
19
-
}
20
20
-
21
21
-
impl DnsTxtResolver for HickoryDnsTxtResolver {
22
22
-
async fn resolve(
23
23
-
&self,
24
24
-
query: &str,
25
25
-
) -> core::result::Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
26
26
-
Ok(self
27
27
-
.resolver
28
28
-
.txt_lookup(query)
29
29
-
.await?
30
30
-
.iter()
31
31
-
.map(|txt| txt.to_string())
32
32
-
.collect())
33
33
-
}
34
34
-
}
+17
-2
who-am-i/src/expiring_task_map.rs
···
6
6
use tokio::time::sleep;
7
7
use tokio_util::sync::{CancellationToken, DropGuard};
8
8
9
9
-
#[derive(Clone)]
10
9
pub struct ExpiringTaskMap<T>(TaskMap<T>);
10
10
+
11
11
+
/// need to manually implement clone because T is allowed to not be clone
12
12
+
impl<T> Clone for ExpiringTaskMap<T> {
13
13
+
fn clone(&self) -> Self {
14
14
+
Self(self.0.clone())
15
15
+
}
16
16
+
}
11
17
12
18
impl<T: Send + 'static> ExpiringTaskMap<T> {
13
19
pub fn new(expiration: Duration) -> Self {
···
58
64
}
59
65
}
60
66
61
61
-
#[derive(Clone)]
62
67
struct TaskMap<T> {
63
68
map: Arc<DashMap<String, (DropGuard, JoinHandle<T>)>>,
64
69
expiration: Duration,
65
70
}
71
71
+
72
72
+
/// need to manually implement clone because T is allowed to not be clone
73
73
+
impl<T> Clone for TaskMap<T> {
74
74
+
fn clone(&self) -> Self {
75
75
+
Self {
76
76
+
map: self.map.clone(),
77
77
+
expiration: self.expiration,
78
78
+
}
79
79
+
}
80
80
+
}
-22
who-am-i/src/identity_resolver.rs
···
1
1
-
use atrium_api::types::string::Did;
2
2
-
use atrium_common::resolver::Resolver;
3
3
-
use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL};
4
4
-
use atrium_oauth::DefaultHttpClient;
5
5
-
use std::sync::Arc;
6
6
-
7
7
-
pub async fn resolve_identity(did: String) -> Option<String> {
8
8
-
let http_client = Arc::new(DefaultHttpClient::default());
9
9
-
let resolver = CommonDidResolver::new(CommonDidResolverConfig {
10
10
-
plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(),
11
11
-
http_client: Arc::clone(&http_client),
12
12
-
});
13
13
-
let doc = resolver.resolve(&Did::new(did).unwrap()).await.unwrap(); // TODO: this is only half the resolution? or is atrium checking dns?
14
14
-
// tokio::time::sleep(std::time::Duration::from_secs(2)).await;
15
15
-
doc.also_known_as.and_then(|mut aka| {
16
16
-
if aka.is_empty() {
17
17
-
None
18
18
-
} else {
19
19
-
Some(aka.remove(0))
20
20
-
}
21
21
-
})
22
22
-
}
+1
-5
who-am-i/src/lib.rs
···
1
1
-
mod dns_resolver;
2
1
mod expiring_task_map;
3
3
-
mod identity_resolver;
4
2
mod oauth;
5
3
mod server;
6
4
7
7
-
pub use dns_resolver::HickoryDnsTxtResolver;
8
5
pub use expiring_task_map::ExpiringTaskMap;
9
9
-
pub use identity_resolver::resolve_identity;
10
10
-
pub use oauth::{Client, authorize, client};
6
6
+
pub use oauth::{OAuth, OauthCallbackParams, ResolveHandleError};
11
7
pub use server::serve;
+197
-47
who-am-i/src/oauth.rs
···
1
1
-
use crate::HickoryDnsTxtResolver;
1
1
+
use atrium_api::{agent::SessionManager, types::string::Did};
2
2
+
use atrium_common::resolver::Resolver;
2
3
use atrium_identity::{
3
4
did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL},
4
4
-
handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig},
5
5
+
handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver},
5
6
};
6
7
use atrium_oauth::{
7
7
-
AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, KnownScope, OAuthClient,
8
8
-
OAuthClientConfig, OAuthResolverConfig, Scope,
8
8
+
AtprotoLocalhostClientMetadata, AuthorizeOptions, CallbackParams, DefaultHttpClient,
9
9
+
KnownScope, OAuthClient, OAuthClientConfig, OAuthResolverConfig, Scope,
9
10
store::{session::MemorySessionStore, state::MemoryStateStore},
10
11
};
12
12
+
use hickory_resolver::TokioResolver;
13
13
+
use serde::Deserialize;
11
14
use std::sync::Arc;
15
15
+
use thiserror::Error;
12
16
13
13
-
pub type Client = OAuthClient<
17
17
+
const READONLY_SCOPE: [Scope; 1] = [Scope::Known(KnownScope::Atproto)];
18
18
+
19
19
+
#[derive(Debug, Deserialize)]
20
20
+
pub struct CallbackErrorParams {
21
21
+
error: String,
22
22
+
error_description: Option<String>,
23
23
+
#[allow(dead_code)]
24
24
+
state: Option<String>, // TODO: we _should_ use state to associate the auth request but how to do that with atrium is unclear
25
25
+
iss: Option<String>,
26
26
+
}
27
27
+
28
28
+
#[derive(Debug, Deserialize)]
29
29
+
#[serde(untagged)]
30
30
+
pub enum OauthCallbackParams {
31
31
+
Granted(CallbackParams),
32
32
+
Failed(CallbackErrorParams),
33
33
+
}
34
34
+
35
35
+
type Client = OAuthClient<
14
36
MemoryStateStore,
15
37
MemorySessionStore,
16
38
CommonDidResolver<DefaultHttpClient>,
17
39
AtprotoHandleResolver<HickoryDnsTxtResolver, DefaultHttpClient>,
18
40
>;
19
41
20
20
-
pub fn client() -> Client {
21
21
-
let http_client = Arc::new(DefaultHttpClient::default());
22
22
-
let config = OAuthClientConfig {
23
23
-
client_metadata: AtprotoLocalhostClientMetadata {
24
24
-
redirect_uris: Some(vec![String::from("http://127.0.0.1:9997/authorized")]),
25
25
-
scopes: Some(vec![Scope::Known(KnownScope::Atproto)]),
26
26
-
},
27
27
-
keys: None,
28
28
-
resolver: OAuthResolverConfig {
29
29
-
did_resolver: CommonDidResolver::new(CommonDidResolverConfig {
30
30
-
plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(),
31
31
-
http_client: Arc::clone(&http_client),
32
32
-
}),
33
33
-
handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig {
34
34
-
dns_txt_resolver: HickoryDnsTxtResolver::default(),
35
35
-
http_client: Arc::clone(&http_client),
36
36
-
}),
37
37
-
authorization_server_metadata: Default::default(),
38
38
-
protected_resource_metadata: Default::default(),
39
39
-
},
40
40
-
// A store for saving state data while the user is being redirected to the authorization server.
41
41
-
state_store: MemoryStateStore::default(),
42
42
-
// A store for saving session data.
43
43
-
session_store: MemorySessionStore::default(),
44
44
-
};
45
45
-
let Ok(client) = OAuthClient::new(config) else {
46
46
-
panic!("failed to create oauth client");
47
47
-
};
48
48
-
client
42
42
+
#[derive(Clone)]
43
43
+
pub struct OAuth {
44
44
+
client: Arc<Client>,
45
45
+
did_resolver: Arc<CommonDidResolver<DefaultHttpClient>>,
49
46
}
50
47
51
51
-
pub async fn authorize(client: &Client, handle: &str) -> String {
52
52
-
let Ok(url) = client
53
53
-
.authorize(
54
54
-
handle,
55
55
-
AuthorizeOptions {
56
56
-
scopes: vec![Scope::Known(KnownScope::Atproto)],
57
57
-
..Default::default()
48
48
+
#[derive(Debug, Error)]
49
49
+
#[error(transparent)]
50
50
+
pub struct AuthSetupError(#[from] atrium_oauth::Error);
51
51
+
52
52
+
#[derive(Debug, Error)]
53
53
+
#[error(transparent)]
54
54
+
pub struct AuthStartError(#[from] atrium_oauth::Error);
55
55
+
56
56
+
#[derive(Debug, Error)]
57
57
+
pub enum AuthCompleteError {
58
58
+
#[error("the user denied request: {description:?} (from {issuer:?})")]
59
59
+
Denied {
60
60
+
description: Option<String>,
61
61
+
issuer: Option<String>,
62
62
+
},
63
63
+
#[error(
64
64
+
"the request was denied for another reason: {error}: {description:?} (from {issuer:?})"
65
65
+
)]
66
66
+
Failed {
67
67
+
error: String,
68
68
+
description: Option<String>,
69
69
+
issuer: Option<String>,
70
70
+
},
71
71
+
#[error("failed to complete oauth callback: {0}")]
72
72
+
CallbackFailed(atrium_oauth::Error),
73
73
+
#[error("the authorized session did not contain a DID")]
74
74
+
NoDid,
75
75
+
}
76
76
+
77
77
+
#[derive(Debug, Error)]
78
78
+
pub enum ResolveHandleError {
79
79
+
#[error("failed to resolve: {0}")]
80
80
+
ResolutionFailed(#[from] atrium_identity::Error),
81
81
+
#[error("identity resolved but no handle found for user")]
82
82
+
NoHandle,
83
83
+
#[error("found handle {0:?} but it appears invalid: {1}")]
84
84
+
InvalidHandle(String, &'static str),
85
85
+
}
86
86
+
87
87
+
impl OAuth {
88
88
+
pub fn new() -> Result<Self, AuthSetupError> {
89
89
+
let http_client = Arc::new(DefaultHttpClient::default());
90
90
+
let did_resolver = || {
91
91
+
CommonDidResolver::new(CommonDidResolverConfig {
92
92
+
plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(),
93
93
+
http_client: http_client.clone(),
94
94
+
})
95
95
+
};
96
96
+
let client_config = OAuthClientConfig {
97
97
+
client_metadata: AtprotoLocalhostClientMetadata {
98
98
+
redirect_uris: Some(vec![String::from("http://127.0.0.1:9997/authorized")]),
99
99
+
scopes: Some(READONLY_SCOPE.to_vec()),
58
100
},
59
59
-
)
60
60
-
.await
61
61
-
else {
62
62
-
panic!("failed to authorize");
63
63
-
};
64
64
-
url
101
101
+
keys: None,
102
102
+
resolver: OAuthResolverConfig {
103
103
+
did_resolver: did_resolver(),
104
104
+
handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig {
105
105
+
dns_txt_resolver: HickoryDnsTxtResolver::default(),
106
106
+
http_client: Arc::clone(&http_client),
107
107
+
}),
108
108
+
authorization_server_metadata: Default::default(),
109
109
+
protected_resource_metadata: Default::default(),
110
110
+
},
111
111
+
state_store: MemoryStateStore::default(),
112
112
+
session_store: MemorySessionStore::default(),
113
113
+
};
114
114
+
115
115
+
let client = OAuthClient::new(client_config)?;
116
116
+
117
117
+
Ok(Self {
118
118
+
client: Arc::new(client),
119
119
+
did_resolver: Arc::new(did_resolver()),
120
120
+
})
121
121
+
}
122
122
+
123
123
+
pub async fn begin(&self, handle: &str) -> Result<String, AuthStartError> {
124
124
+
let auth_opts = AuthorizeOptions {
125
125
+
scopes: READONLY_SCOPE.to_vec(),
126
126
+
..Default::default()
127
127
+
};
128
128
+
Ok(self.client.authorize(handle, auth_opts).await?)
129
129
+
}
130
130
+
131
131
+
/// Finally, resolve the oauth flow to a verified DID
132
132
+
pub async fn complete(&self, params: OauthCallbackParams) -> Result<Did, AuthCompleteError> {
133
133
+
let params = match params {
134
134
+
OauthCallbackParams::Granted(params) => params,
135
135
+
OauthCallbackParams::Failed(p) if p.error == "access_denied" => {
136
136
+
return Err(AuthCompleteError::Denied {
137
137
+
description: p.error_description.clone(),
138
138
+
issuer: p.iss.clone(),
139
139
+
});
140
140
+
}
141
141
+
OauthCallbackParams::Failed(p) => {
142
142
+
return Err(AuthCompleteError::Failed {
143
143
+
error: p.error.clone(),
144
144
+
description: p.error_description.clone(),
145
145
+
issuer: p.iss.clone(),
146
146
+
});
147
147
+
}
148
148
+
};
149
149
+
let (session, _) = self
150
150
+
.client
151
151
+
.callback(params)
152
152
+
.await
153
153
+
.map_err(AuthCompleteError::CallbackFailed)?;
154
154
+
let Some(did) = session.did().await else {
155
155
+
return Err(AuthCompleteError::NoDid);
156
156
+
};
157
157
+
Ok(did)
158
158
+
}
159
159
+
160
160
+
pub async fn resolve_handle(&self, did: Did) -> Result<String, ResolveHandleError> {
161
161
+
// TODO: this is only half the resolution? or is atrium checking dns?
162
162
+
let doc = self.did_resolver.resolve(&did).await?;
163
163
+
let Some(aka) = doc.also_known_as else {
164
164
+
return Err(ResolveHandleError::NoHandle);
165
165
+
};
166
166
+
let Some(at_uri_handle) = aka.first() else {
167
167
+
return Err(ResolveHandleError::NoHandle);
168
168
+
};
169
169
+
if aka.len() > 1 {
170
170
+
eprintln!("more than one handle found for {did:?}");
171
171
+
}
172
172
+
let Some(bare_handle) = at_uri_handle.strip_prefix("at://") else {
173
173
+
return Err(ResolveHandleError::InvalidHandle(
174
174
+
at_uri_handle.to_string(),
175
175
+
"did not start with 'at://'",
176
176
+
));
177
177
+
};
178
178
+
if bare_handle.is_empty() {
179
179
+
return Err(ResolveHandleError::InvalidHandle(
180
180
+
bare_handle.to_string(),
181
181
+
"empty handle",
182
182
+
));
183
183
+
}
184
184
+
Ok(bare_handle.to_string())
185
185
+
}
186
186
+
}
187
187
+
188
188
+
pub struct HickoryDnsTxtResolver {
189
189
+
resolver: TokioResolver,
190
190
+
}
191
191
+
192
192
+
impl Default for HickoryDnsTxtResolver {
193
193
+
fn default() -> Self {
194
194
+
Self {
195
195
+
resolver: TokioResolver::builder_tokio()
196
196
+
.expect("failed to create resolver")
197
197
+
.build(),
198
198
+
}
199
199
+
}
200
200
+
}
201
201
+
202
202
+
impl DnsTxtResolver for HickoryDnsTxtResolver {
203
203
+
async fn resolve(
204
204
+
&self,
205
205
+
query: &str,
206
206
+
) -> core::result::Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
207
207
+
Ok(self
208
208
+
.resolver
209
209
+
.txt_lookup(query)
210
210
+
.await?
211
211
+
.iter()
212
212
+
.map(|txt| txt.to_string())
213
213
+
.collect())
214
214
+
}
65
215
}
+41
-27
who-am-i/src/server.rs
···
1
1
-
use atrium_api::agent::SessionManager;
2
2
-
use atrium_oauth::CallbackParams;
1
1
+
use atrium_api::types::string::Did;
3
2
use axum::{
4
3
Router,
5
4
extract::{FromRef, Query, State},
···
20
19
use tokio_util::sync::CancellationToken;
21
20
use url::Url;
22
21
23
23
-
use crate::{Client, ExpiringTaskMap, authorize, client, resolve_identity};
22
22
+
use crate::{ExpiringTaskMap, OAuth, OauthCallbackParams, ResolveHandleError};
24
23
25
24
const FAVICON: &[u8] = include_bytes!("../static/favicon.ico");
26
25
const INDEX_HTML: &str = include_str!("../static/index.html");
···
34
33
pub key: Key,
35
34
pub one_clicks: Arc<HashSet<String>>,
36
35
pub engine: AppEngine,
37
37
-
pub client: Arc<Client>,
38
38
-
pub resolving: ExpiringTaskMap<Option<String>>,
36
36
+
pub oauth: Arc<OAuth>,
37
37
+
pub resolving: ExpiringTaskMap<Result<String, ResolveHandleError>>,
39
38
pub shutdown: CancellationToken,
40
39
}
41
40
···
62
61
// clients have to pick up their identity-resolving tasks within this period
63
62
let task_pickup_expiration = Duration::from_secs(15);
64
63
64
64
+
let oauth = OAuth::new().unwrap();
65
65
+
65
66
let state = AppState {
66
67
engine: Engine::new(hbs),
67
68
key: Key::from(app_secret.as_bytes()), // TODO: via config
68
69
one_clicks: Arc::new(HashSet::from_iter(one_click)),
69
69
-
client: Arc::new(client()),
70
70
+
oauth: Arc::new(oauth),
70
71
resolving: ExpiringTaskMap::new(task_pickup_expiration),
71
72
shutdown: shutdown.clone(),
72
73
};
···
92
93
93
94
async fn prompt(
94
95
State(AppState {
96
96
+
one_clicks,
95
97
engine,
96
96
-
one_clicks,
98
98
+
oauth,
97
99
resolving,
98
100
shutdown,
99
101
..
···
118
120
.into_response();
119
121
}
120
122
if let Some(did) = jar.get(DID_COOKIE_KEY) {
121
121
-
let did = did.value_trimmed().to_string();
123
123
+
let Ok(did) = Did::new(did.value_trimmed().to_string()) else {
124
124
+
return "did from cookie failed to parse".into_response();
125
125
+
};
122
126
123
123
-
let task_shutdown = shutdown.child_token();
124
124
-
let fetch_key = resolving.dispatch(resolve_identity(did.clone()), task_shutdown);
127
127
+
let fetch_key = resolving.dispatch(
128
128
+
{
129
129
+
let oauth = oauth.clone();
130
130
+
let did = did.clone();
131
131
+
async move { oauth.resolve_handle(did.clone()).await }
132
132
+
},
133
133
+
shutdown.child_token(),
134
134
+
);
125
135
126
136
RenderHtml(
127
137
"prompt-known",
···
157
167
let Some(task_handle) = resolving.take(¶ms.fetch_key) else {
158
168
return "oops, task does not exist or is gone".into_response();
159
169
};
160
160
-
if let Some(handle) = task_handle.await.unwrap() {
161
161
-
// TODO: get active state etc.
162
162
-
// ...but also, that's a bsky thing?
163
163
-
let Some(handle) = handle.strip_prefix("at://") else {
164
164
-
return "hmm, handle did not start with at://".into_response();
165
165
-
};
170
170
+
if let Ok(handle) = task_handle.await.unwrap() {
166
171
Json(json!({ "handle": handle })).into_response()
167
172
} else {
168
173
"no handle?".into_response()
···
174
179
handle: String,
175
180
}
176
181
async fn start_oauth(
177
177
-
State(state): State<AppState>,
182
182
+
State(AppState { oauth, .. }): State<AppState>,
178
183
Query(params): Query<BeginOauthParams>,
179
184
jar: SignedCookieJar,
180
185
) -> (SignedCookieJar, Redirect) {
181
186
// if any existing session was active, clear it first
182
187
let jar = jar.remove(DID_COOKIE_KEY);
183
188
184
184
-
let auth_url = authorize(&state.client, ¶ms.handle).await;
189
189
+
let auth_url = oauth.begin(¶ms.handle).await.unwrap();
185
190
(jar, Redirect::to(&auth_url))
186
191
}
187
192
188
193
async fn complete_oauth(
189
189
-
State(state): State<AppState>,
190
190
-
Query(params): Query<CallbackParams>,
194
194
+
State(AppState {
195
195
+
engine,
196
196
+
resolving,
197
197
+
oauth,
198
198
+
shutdown,
199
199
+
..
200
200
+
}): State<AppState>,
201
201
+
Query(params): Query<OauthCallbackParams>,
191
202
jar: SignedCookieJar,
192
203
) -> (SignedCookieJar, impl IntoResponse) {
193
193
-
let Ok((oauth_session, _)) = state.client.callback(params).await else {
204
204
+
let Ok(did) = oauth.complete(params).await else {
194
205
panic!("failed to do client callback");
195
206
};
196
196
-
let did = oauth_session.did().await.expect("a did to be present");
197
207
198
208
let cookie = Cookie::build((DID_COOKIE_KEY, did.to_string()))
199
209
.http_only(true)
···
203
213
204
214
let jar = jar.add(cookie);
205
215
206
206
-
let task_shutdown = state.shutdown.child_token();
207
207
-
let fetch_key = state
208
208
-
.resolving
209
209
-
.dispatch(resolve_identity(did.to_string()), task_shutdown);
216
216
+
let fetch_key = resolving.dispatch(
217
217
+
{
218
218
+
let oauth = oauth.clone();
219
219
+
let did = did.clone();
220
220
+
async move { oauth.resolve_handle(did.clone()).await }
221
221
+
},
222
222
+
shutdown.child_token(),
223
223
+
);
210
224
211
225
(
212
226
jar,
213
227
RenderHtml(
214
228
"authorized",
215
215
-
state.engine,
229
229
+
engine,
216
230
json!({
217
231
"did": did,
218
232
"fetch_key": fetch_key,