this repo has no description
1use super::super::{OAuthError, TokenData};
2use super::helpers::{from_json, to_json};
3use chrono::{DateTime, Utc};
4use sqlx::PgPool;
5
6pub async fn create_token(pool: &PgPool, data: &TokenData) -> Result<i32, OAuthError> {
7 let client_auth_json = to_json(&data.client_auth)?;
8 let parameters_json = to_json(&data.parameters)?;
9 let row = sqlx::query!(
10 r#"
11 INSERT INTO oauth_token
12 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
13 device_id, parameters, details, code, current_refresh_token, scope, controller_did)
14 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
15 RETURNING id
16 "#,
17 data.did,
18 data.token_id,
19 data.created_at,
20 data.updated_at,
21 data.expires_at,
22 data.client_id,
23 client_auth_json,
24 data.device_id,
25 parameters_json,
26 data.details,
27 data.code,
28 data.current_refresh_token,
29 data.scope,
30 data.controller_did,
31 )
32 .fetch_one(pool)
33 .await?;
34 Ok(row.id)
35}
36
37pub async fn get_token_by_id(
38 pool: &PgPool,
39 token_id: &str,
40) -> Result<Option<TokenData>, OAuthError> {
41 let row = sqlx::query!(
42 r#"
43 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
44 device_id, parameters, details, code, current_refresh_token, scope, controller_did
45 FROM oauth_token
46 WHERE token_id = $1
47 "#,
48 token_id
49 )
50 .fetch_optional(pool)
51 .await?;
52 match row {
53 Some(r) => Ok(Some(TokenData {
54 did: r.did,
55 token_id: r.token_id,
56 created_at: r.created_at,
57 updated_at: r.updated_at,
58 expires_at: r.expires_at,
59 client_id: r.client_id,
60 client_auth: from_json(r.client_auth)?,
61 device_id: r.device_id,
62 parameters: from_json(r.parameters)?,
63 details: r.details,
64 code: r.code,
65 current_refresh_token: r.current_refresh_token,
66 scope: r.scope,
67 controller_did: r.controller_did,
68 })),
69 None => Ok(None),
70 }
71}
72
73pub async fn get_token_by_refresh_token(
74 pool: &PgPool,
75 refresh_token: &str,
76) -> Result<Option<(i32, TokenData)>, OAuthError> {
77 let row = sqlx::query!(
78 r#"
79 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
80 device_id, parameters, details, code, current_refresh_token, scope, controller_did
81 FROM oauth_token
82 WHERE current_refresh_token = $1
83 "#,
84 refresh_token
85 )
86 .fetch_optional(pool)
87 .await?;
88 match row {
89 Some(r) => Ok(Some((
90 r.id,
91 TokenData {
92 did: r.did,
93 token_id: r.token_id,
94 created_at: r.created_at,
95 updated_at: r.updated_at,
96 expires_at: r.expires_at,
97 client_id: r.client_id,
98 client_auth: from_json(r.client_auth)?,
99 device_id: r.device_id,
100 parameters: from_json(r.parameters)?,
101 details: r.details,
102 code: r.code,
103 current_refresh_token: r.current_refresh_token,
104 scope: r.scope,
105 controller_did: r.controller_did,
106 },
107 ))),
108 None => Ok(None),
109 }
110}
111
112pub async fn rotate_token(
113 pool: &PgPool,
114 old_db_id: i32,
115 new_token_id: &str,
116 new_refresh_token: &str,
117 new_expires_at: DateTime<Utc>,
118) -> Result<(), OAuthError> {
119 let mut tx = pool.begin().await?;
120 let old_refresh = sqlx::query_scalar!(
121 r#"
122 SELECT current_refresh_token FROM oauth_token WHERE id = $1
123 "#,
124 old_db_id
125 )
126 .fetch_one(&mut *tx)
127 .await?;
128 if let Some(ref old_rt) = old_refresh {
129 sqlx::query!(
130 r#"
131 INSERT INTO oauth_used_refresh_token (refresh_token, token_id)
132 VALUES ($1, $2)
133 "#,
134 old_rt,
135 old_db_id
136 )
137 .execute(&mut *tx)
138 .await?;
139 }
140 sqlx::query!(
141 r#"
142 UPDATE oauth_token
143 SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW(),
144 previous_refresh_token = $5, rotated_at = NOW()
145 WHERE id = $1
146 "#,
147 old_db_id,
148 new_token_id,
149 new_refresh_token,
150 new_expires_at,
151 old_refresh
152 )
153 .execute(&mut *tx)
154 .await?;
155 tx.commit().await?;
156 Ok(())
157}
158
159pub async fn check_refresh_token_used(
160 pool: &PgPool,
161 refresh_token: &str,
162) -> Result<Option<i32>, OAuthError> {
163 let row = sqlx::query_scalar!(
164 r#"
165 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1
166 "#,
167 refresh_token
168 )
169 .fetch_optional(pool)
170 .await?;
171 Ok(row)
172}
173
174const REFRESH_GRACE_PERIOD_SECS: i64 = 60;
175
176pub async fn get_token_by_previous_refresh_token(
177 pool: &PgPool,
178 refresh_token: &str,
179) -> Result<Option<(i32, TokenData)>, OAuthError> {
180 let grace_cutoff = Utc::now() - chrono::Duration::seconds(REFRESH_GRACE_PERIOD_SECS);
181 let row = sqlx::query!(
182 r#"
183 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
184 device_id, parameters, details, code, current_refresh_token, scope, controller_did
185 FROM oauth_token
186 WHERE previous_refresh_token = $1 AND rotated_at > $2
187 "#,
188 refresh_token,
189 grace_cutoff
190 )
191 .fetch_optional(pool)
192 .await?;
193 match row {
194 Some(r) => Ok(Some((
195 r.id,
196 TokenData {
197 did: r.did,
198 token_id: r.token_id,
199 created_at: r.created_at,
200 updated_at: r.updated_at,
201 expires_at: r.expires_at,
202 client_id: r.client_id,
203 client_auth: from_json(r.client_auth)?,
204 device_id: r.device_id,
205 parameters: from_json(r.parameters)?,
206 details: r.details,
207 code: r.code,
208 current_refresh_token: r.current_refresh_token,
209 scope: r.scope,
210 controller_did: r.controller_did,
211 },
212 ))),
213 None => Ok(None),
214 }
215}
216
217pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
218 sqlx::query!(
219 r#"
220 DELETE FROM oauth_token WHERE token_id = $1
221 "#,
222 token_id
223 )
224 .execute(pool)
225 .await?;
226 Ok(())
227}
228
229pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
230 sqlx::query!(
231 r#"
232 DELETE FROM oauth_token WHERE id = $1
233 "#,
234 db_id
235 )
236 .execute(pool)
237 .await?;
238 Ok(())
239}
240
241pub async fn list_tokens_for_user(pool: &PgPool, did: &str) -> Result<Vec<TokenData>, OAuthError> {
242 let rows = sqlx::query!(
243 r#"
244 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
245 device_id, parameters, details, code, current_refresh_token, scope, controller_did
246 FROM oauth_token
247 WHERE did = $1
248 "#,
249 did
250 )
251 .fetch_all(pool)
252 .await?;
253 let mut tokens = Vec::with_capacity(rows.len());
254 for r in rows {
255 tokens.push(TokenData {
256 did: r.did,
257 token_id: r.token_id,
258 created_at: r.created_at,
259 updated_at: r.updated_at,
260 expires_at: r.expires_at,
261 client_id: r.client_id,
262 client_auth: from_json(r.client_auth)?,
263 device_id: r.device_id,
264 parameters: from_json(r.parameters)?,
265 details: r.details,
266 code: r.code,
267 current_refresh_token: r.current_refresh_token,
268 scope: r.scope,
269 controller_did: r.controller_did,
270 });
271 }
272 Ok(tokens)
273}
274
275pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
276 let count = sqlx::query_scalar!(
277 r#"
278 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1
279 "#,
280 did
281 )
282 .fetch_one(pool)
283 .await?;
284 Ok(count)
285}
286
287pub async fn delete_oldest_tokens_for_user(
288 pool: &PgPool,
289 did: &str,
290 keep_count: i64,
291) -> Result<u64, OAuthError> {
292 let result = sqlx::query!(
293 r#"
294 DELETE FROM oauth_token
295 WHERE id IN (
296 SELECT id FROM oauth_token
297 WHERE did = $1
298 ORDER BY updated_at ASC
299 OFFSET $2
300 )
301 "#,
302 did,
303 keep_count
304 )
305 .execute(pool)
306 .await?;
307 Ok(result.rows_affected())
308}
309
310const MAX_TOKENS_PER_USER: i64 = 100;
311
312pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
313 let count = count_tokens_for_user(pool, did).await?;
314 if count > MAX_TOKENS_PER_USER {
315 let to_keep = MAX_TOKENS_PER_USER - 1;
316 delete_oldest_tokens_for_user(pool, did, to_keep).await?;
317 }
318 Ok(())
319}
320
321pub async fn revoke_tokens_for_client(
322 pool: &PgPool,
323 did: &str,
324 client_id: &str,
325) -> Result<u64, OAuthError> {
326 let result = sqlx::query!(
327 "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2",
328 did,
329 client_id
330 )
331 .execute(pool)
332 .await?;
333 Ok(result.rows_affected())
334}
335
336pub async fn revoke_tokens_for_controller(
337 pool: &PgPool,
338 delegated_did: &str,
339 controller_did: &str,
340) -> Result<u64, OAuthError> {
341 let result = sqlx::query!(
342 "DELETE FROM oauth_token WHERE did = $1 AND controller_did = $2",
343 delegated_did,
344 controller_did
345 )
346 .execute(pool)
347 .await?;
348 Ok(result.rows_affected())
349}