this repo has no description
1use super::super::{OAuthError, TokenData};
2use super::helpers::{from_json, to_json};
3use chrono::{DateTime, Utc};
4use sqlx::PgPool;
5
6pub async fn create_token(pool: &PgPool, data: &TokenData) -> Result<i32, OAuthError> {
7 let client_auth_json = to_json(&data.client_auth)?;
8 let parameters_json = to_json(&data.parameters)?;
9 let row = sqlx::query!(
10 r#"
11 INSERT INTO oauth_token
12 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
13 device_id, parameters, details, code, current_refresh_token, scope)
14 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
15 RETURNING id
16 "#,
17 data.did,
18 data.token_id,
19 data.created_at,
20 data.updated_at,
21 data.expires_at,
22 data.client_id,
23 client_auth_json,
24 data.device_id,
25 parameters_json,
26 data.details,
27 data.code,
28 data.current_refresh_token,
29 data.scope,
30 )
31 .fetch_one(pool)
32 .await?;
33 Ok(row.id)
34}
35
36pub async fn get_token_by_id(
37 pool: &PgPool,
38 token_id: &str,
39) -> Result<Option<TokenData>, OAuthError> {
40 let row = sqlx::query!(
41 r#"
42 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
43 device_id, parameters, details, code, current_refresh_token, scope
44 FROM oauth_token
45 WHERE token_id = $1
46 "#,
47 token_id
48 )
49 .fetch_optional(pool)
50 .await?;
51 match row {
52 Some(r) => Ok(Some(TokenData {
53 did: r.did,
54 token_id: r.token_id,
55 created_at: r.created_at,
56 updated_at: r.updated_at,
57 expires_at: r.expires_at,
58 client_id: r.client_id,
59 client_auth: from_json(r.client_auth)?,
60 device_id: r.device_id,
61 parameters: from_json(r.parameters)?,
62 details: r.details,
63 code: r.code,
64 current_refresh_token: r.current_refresh_token,
65 scope: r.scope,
66 })),
67 None => Ok(None),
68 }
69}
70
71pub async fn get_token_by_refresh_token(
72 pool: &PgPool,
73 refresh_token: &str,
74) -> Result<Option<(i32, TokenData)>, OAuthError> {
75 let row = sqlx::query!(
76 r#"
77 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
78 device_id, parameters, details, code, current_refresh_token, scope
79 FROM oauth_token
80 WHERE current_refresh_token = $1
81 "#,
82 refresh_token
83 )
84 .fetch_optional(pool)
85 .await?;
86 match row {
87 Some(r) => Ok(Some((
88 r.id,
89 TokenData {
90 did: r.did,
91 token_id: r.token_id,
92 created_at: r.created_at,
93 updated_at: r.updated_at,
94 expires_at: r.expires_at,
95 client_id: r.client_id,
96 client_auth: from_json(r.client_auth)?,
97 device_id: r.device_id,
98 parameters: from_json(r.parameters)?,
99 details: r.details,
100 code: r.code,
101 current_refresh_token: r.current_refresh_token,
102 scope: r.scope,
103 },
104 ))),
105 None => Ok(None),
106 }
107}
108
109pub async fn rotate_token(
110 pool: &PgPool,
111 old_db_id: i32,
112 new_token_id: &str,
113 new_refresh_token: &str,
114 new_expires_at: DateTime<Utc>,
115) -> Result<(), OAuthError> {
116 let mut tx = pool.begin().await?;
117 let old_refresh = sqlx::query_scalar!(
118 r#"
119 SELECT current_refresh_token FROM oauth_token WHERE id = $1
120 "#,
121 old_db_id
122 )
123 .fetch_one(&mut *tx)
124 .await?;
125 if let Some(ref old_rt) = old_refresh {
126 sqlx::query!(
127 r#"
128 INSERT INTO oauth_used_refresh_token (refresh_token, token_id)
129 VALUES ($1, $2)
130 "#,
131 old_rt,
132 old_db_id
133 )
134 .execute(&mut *tx)
135 .await?;
136 }
137 sqlx::query!(
138 r#"
139 UPDATE oauth_token
140 SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW(),
141 previous_refresh_token = $5, rotated_at = NOW()
142 WHERE id = $1
143 "#,
144 old_db_id,
145 new_token_id,
146 new_refresh_token,
147 new_expires_at,
148 old_refresh
149 )
150 .execute(&mut *tx)
151 .await?;
152 tx.commit().await?;
153 Ok(())
154}
155
156pub async fn check_refresh_token_used(
157 pool: &PgPool,
158 refresh_token: &str,
159) -> Result<Option<i32>, OAuthError> {
160 let row = sqlx::query_scalar!(
161 r#"
162 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1
163 "#,
164 refresh_token
165 )
166 .fetch_optional(pool)
167 .await?;
168 Ok(row)
169}
170
171const REFRESH_GRACE_PERIOD_SECS: i64 = 60;
172
173pub async fn get_token_by_previous_refresh_token(
174 pool: &PgPool,
175 refresh_token: &str,
176) -> Result<Option<(i32, TokenData)>, OAuthError> {
177 let grace_cutoff = Utc::now() - chrono::Duration::seconds(REFRESH_GRACE_PERIOD_SECS);
178 let row = sqlx::query!(
179 r#"
180 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
181 device_id, parameters, details, code, current_refresh_token, scope
182 FROM oauth_token
183 WHERE previous_refresh_token = $1 AND rotated_at > $2
184 "#,
185 refresh_token,
186 grace_cutoff
187 )
188 .fetch_optional(pool)
189 .await?;
190 match row {
191 Some(r) => Ok(Some((
192 r.id,
193 TokenData {
194 did: r.did,
195 token_id: r.token_id,
196 created_at: r.created_at,
197 updated_at: r.updated_at,
198 expires_at: r.expires_at,
199 client_id: r.client_id,
200 client_auth: from_json(r.client_auth)?,
201 device_id: r.device_id,
202 parameters: from_json(r.parameters)?,
203 details: r.details,
204 code: r.code,
205 current_refresh_token: r.current_refresh_token,
206 scope: r.scope,
207 },
208 ))),
209 None => Ok(None),
210 }
211}
212
213pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
214 sqlx::query!(
215 r#"
216 DELETE FROM oauth_token WHERE token_id = $1
217 "#,
218 token_id
219 )
220 .execute(pool)
221 .await?;
222 Ok(())
223}
224
225pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
226 sqlx::query!(
227 r#"
228 DELETE FROM oauth_token WHERE id = $1
229 "#,
230 db_id
231 )
232 .execute(pool)
233 .await?;
234 Ok(())
235}
236
237pub async fn list_tokens_for_user(pool: &PgPool, did: &str) -> Result<Vec<TokenData>, OAuthError> {
238 let rows = sqlx::query!(
239 r#"
240 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
241 device_id, parameters, details, code, current_refresh_token, scope
242 FROM oauth_token
243 WHERE did = $1
244 "#,
245 did
246 )
247 .fetch_all(pool)
248 .await?;
249 let mut tokens = Vec::with_capacity(rows.len());
250 for r in rows {
251 tokens.push(TokenData {
252 did: r.did,
253 token_id: r.token_id,
254 created_at: r.created_at,
255 updated_at: r.updated_at,
256 expires_at: r.expires_at,
257 client_id: r.client_id,
258 client_auth: from_json(r.client_auth)?,
259 device_id: r.device_id,
260 parameters: from_json(r.parameters)?,
261 details: r.details,
262 code: r.code,
263 current_refresh_token: r.current_refresh_token,
264 scope: r.scope,
265 });
266 }
267 Ok(tokens)
268}
269
270pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
271 let count = sqlx::query_scalar!(
272 r#"
273 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1
274 "#,
275 did
276 )
277 .fetch_one(pool)
278 .await?;
279 Ok(count)
280}
281
282pub async fn delete_oldest_tokens_for_user(
283 pool: &PgPool,
284 did: &str,
285 keep_count: i64,
286) -> Result<u64, OAuthError> {
287 let result = sqlx::query!(
288 r#"
289 DELETE FROM oauth_token
290 WHERE id IN (
291 SELECT id FROM oauth_token
292 WHERE did = $1
293 ORDER BY updated_at ASC
294 OFFSET $2
295 )
296 "#,
297 did,
298 keep_count
299 )
300 .execute(pool)
301 .await?;
302 Ok(result.rows_affected())
303}
304
305const MAX_TOKENS_PER_USER: i64 = 100;
306
307pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
308 let count = count_tokens_for_user(pool, did).await?;
309 if count > MAX_TOKENS_PER_USER {
310 let to_keep = MAX_TOKENS_PER_USER - 1;
311 delete_oldest_tokens_for_user(pool, did, to_keep).await?;
312 }
313 Ok(())
314}
315
316pub async fn revoke_tokens_for_client(
317 pool: &PgPool,
318 did: &str,
319 client_id: &str,
320) -> Result<u64, OAuthError> {
321 let result = sqlx::query!(
322 "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2",
323 did,
324 client_id
325 )
326 .execute(pool)
327 .await?;
328 Ok(result.rows_affected())
329}