this repo has no description
1use chrono::{DateTime, Utc};
2use sqlx::PgPool;
3
4use super::super::{OAuthError, TokenData};
5use super::helpers::{from_json, to_json};
6
7pub async fn create_token(
8 pool: &PgPool,
9 data: &TokenData,
10) -> Result<i32, OAuthError> {
11 let client_auth_json = to_json(&data.client_auth)?;
12 let parameters_json = to_json(&data.parameters)?;
13
14 let row = sqlx::query!(
15 r#"
16 INSERT INTO oauth_token
17 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
18 device_id, parameters, details, code, current_refresh_token, scope)
19 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
20 RETURNING id
21 "#,
22 data.did,
23 data.token_id,
24 data.created_at,
25 data.updated_at,
26 data.expires_at,
27 data.client_id,
28 client_auth_json,
29 data.device_id,
30 parameters_json,
31 data.details,
32 data.code,
33 data.current_refresh_token,
34 data.scope,
35 )
36 .fetch_one(pool)
37 .await?;
38
39 Ok(row.id)
40}
41
42pub async fn get_token_by_id(
43 pool: &PgPool,
44 token_id: &str,
45) -> Result<Option<TokenData>, OAuthError> {
46 let row = sqlx::query!(
47 r#"
48 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
49 device_id, parameters, details, code, current_refresh_token, scope
50 FROM oauth_token
51 WHERE token_id = $1
52 "#,
53 token_id
54 )
55 .fetch_optional(pool)
56 .await?;
57
58 match row {
59 Some(r) => Ok(Some(TokenData {
60 did: r.did,
61 token_id: r.token_id,
62 created_at: r.created_at,
63 updated_at: r.updated_at,
64 expires_at: r.expires_at,
65 client_id: r.client_id,
66 client_auth: from_json(r.client_auth)?,
67 device_id: r.device_id,
68 parameters: from_json(r.parameters)?,
69 details: r.details,
70 code: r.code,
71 current_refresh_token: r.current_refresh_token,
72 scope: r.scope,
73 })),
74 None => Ok(None),
75 }
76}
77
78pub async fn get_token_by_refresh_token(
79 pool: &PgPool,
80 refresh_token: &str,
81) -> Result<Option<(i32, TokenData)>, OAuthError> {
82 let row = sqlx::query!(
83 r#"
84 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
85 device_id, parameters, details, code, current_refresh_token, scope
86 FROM oauth_token
87 WHERE current_refresh_token = $1
88 "#,
89 refresh_token
90 )
91 .fetch_optional(pool)
92 .await?;
93
94 match row {
95 Some(r) => Ok(Some((
96 r.id,
97 TokenData {
98 did: r.did,
99 token_id: r.token_id,
100 created_at: r.created_at,
101 updated_at: r.updated_at,
102 expires_at: r.expires_at,
103 client_id: r.client_id,
104 client_auth: from_json(r.client_auth)?,
105 device_id: r.device_id,
106 parameters: from_json(r.parameters)?,
107 details: r.details,
108 code: r.code,
109 current_refresh_token: r.current_refresh_token,
110 scope: r.scope,
111 },
112 ))),
113 None => Ok(None),
114 }
115}
116
117pub async fn rotate_token(
118 pool: &PgPool,
119 old_db_id: i32,
120 new_token_id: &str,
121 new_refresh_token: &str,
122 new_expires_at: DateTime<Utc>,
123) -> Result<(), OAuthError> {
124 let mut tx = pool.begin().await?;
125
126 let old_refresh = sqlx::query_scalar!(
127 r#"
128 SELECT current_refresh_token FROM oauth_token WHERE id = $1
129 "#,
130 old_db_id
131 )
132 .fetch_one(&mut *tx)
133 .await?;
134
135 if let Some(old_rt) = old_refresh {
136 sqlx::query!(
137 r#"
138 INSERT INTO oauth_used_refresh_token (refresh_token, token_id)
139 VALUES ($1, $2)
140 "#,
141 old_rt,
142 old_db_id
143 )
144 .execute(&mut *tx)
145 .await?;
146 }
147
148 sqlx::query!(
149 r#"
150 UPDATE oauth_token
151 SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW()
152 WHERE id = $1
153 "#,
154 old_db_id,
155 new_token_id,
156 new_refresh_token,
157 new_expires_at
158 )
159 .execute(&mut *tx)
160 .await?;
161
162 tx.commit().await?;
163 Ok(())
164}
165
166pub async fn check_refresh_token_used(
167 pool: &PgPool,
168 refresh_token: &str,
169) -> Result<Option<i32>, OAuthError> {
170 let row = sqlx::query_scalar!(
171 r#"
172 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1
173 "#,
174 refresh_token
175 )
176 .fetch_optional(pool)
177 .await?;
178
179 Ok(row)
180}
181
182pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
183 sqlx::query!(
184 r#"
185 DELETE FROM oauth_token WHERE token_id = $1
186 "#,
187 token_id
188 )
189 .execute(pool)
190 .await?;
191
192 Ok(())
193}
194
195pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
196 sqlx::query!(
197 r#"
198 DELETE FROM oauth_token WHERE id = $1
199 "#,
200 db_id
201 )
202 .execute(pool)
203 .await?;
204
205 Ok(())
206}
207
208pub async fn list_tokens_for_user(
209 pool: &PgPool,
210 did: &str,
211) -> Result<Vec<TokenData>, OAuthError> {
212 let rows = sqlx::query!(
213 r#"
214 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
215 device_id, parameters, details, code, current_refresh_token, scope
216 FROM oauth_token
217 WHERE did = $1
218 "#,
219 did
220 )
221 .fetch_all(pool)
222 .await?;
223
224 let mut tokens = Vec::with_capacity(rows.len());
225 for r in rows {
226 tokens.push(TokenData {
227 did: r.did,
228 token_id: r.token_id,
229 created_at: r.created_at,
230 updated_at: r.updated_at,
231 expires_at: r.expires_at,
232 client_id: r.client_id,
233 client_auth: from_json(r.client_auth)?,
234 device_id: r.device_id,
235 parameters: from_json(r.parameters)?,
236 details: r.details,
237 code: r.code,
238 current_refresh_token: r.current_refresh_token,
239 scope: r.scope,
240 });
241 }
242 Ok(tokens)
243}
244
245pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
246 let count = sqlx::query_scalar!(
247 r#"
248 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1
249 "#,
250 did
251 )
252 .fetch_one(pool)
253 .await?;
254
255 Ok(count)
256}
257
258pub async fn delete_oldest_tokens_for_user(
259 pool: &PgPool,
260 did: &str,
261 keep_count: i64,
262) -> Result<u64, OAuthError> {
263 let result = sqlx::query!(
264 r#"
265 DELETE FROM oauth_token
266 WHERE id IN (
267 SELECT id FROM oauth_token
268 WHERE did = $1
269 ORDER BY updated_at ASC
270 OFFSET $2
271 )
272 "#,
273 did,
274 keep_count
275 )
276 .execute(pool)
277 .await?;
278
279 Ok(result.rows_affected())
280}
281
282const MAX_TOKENS_PER_USER: i64 = 100;
283
284pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
285 let count = count_tokens_for_user(pool, did).await?;
286 if count > MAX_TOKENS_PER_USER {
287 let to_keep = MAX_TOKENS_PER_USER - 1;
288 delete_oldest_tokens_for_user(pool, did, to_keep).await?;
289 }
290 Ok(())
291}