this repo has no description
1use super::super::{OAuthError, TokenData};
2use super::helpers::{from_json, to_json};
3use chrono::{DateTime, Utc};
4use sqlx::PgPool;
5
6pub async fn create_token(pool: &PgPool, data: &TokenData) -> Result<i32, OAuthError> {
7 let client_auth_json = to_json(&data.client_auth)?;
8 let parameters_json = to_json(&data.parameters)?;
9 let row = sqlx::query!(
10 r#"
11 INSERT INTO oauth_token
12 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
13 device_id, parameters, details, code, current_refresh_token, scope)
14 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
15 RETURNING id
16 "#,
17 data.did,
18 data.token_id,
19 data.created_at,
20 data.updated_at,
21 data.expires_at,
22 data.client_id,
23 client_auth_json,
24 data.device_id,
25 parameters_json,
26 data.details,
27 data.code,
28 data.current_refresh_token,
29 data.scope,
30 )
31 .fetch_one(pool)
32 .await?;
33 Ok(row.id)
34}
35
36pub async fn get_token_by_id(
37 pool: &PgPool,
38 token_id: &str,
39) -> Result<Option<TokenData>, OAuthError> {
40 let row = sqlx::query!(
41 r#"
42 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
43 device_id, parameters, details, code, current_refresh_token, scope
44 FROM oauth_token
45 WHERE token_id = $1
46 "#,
47 token_id
48 )
49 .fetch_optional(pool)
50 .await?;
51 match row {
52 Some(r) => Ok(Some(TokenData {
53 did: r.did,
54 token_id: r.token_id,
55 created_at: r.created_at,
56 updated_at: r.updated_at,
57 expires_at: r.expires_at,
58 client_id: r.client_id,
59 client_auth: from_json(r.client_auth)?,
60 device_id: r.device_id,
61 parameters: from_json(r.parameters)?,
62 details: r.details,
63 code: r.code,
64 current_refresh_token: r.current_refresh_token,
65 scope: r.scope,
66 })),
67 None => Ok(None),
68 }
69}
70
71pub async fn get_token_by_refresh_token(
72 pool: &PgPool,
73 refresh_token: &str,
74) -> Result<Option<(i32, TokenData)>, OAuthError> {
75 let row = sqlx::query!(
76 r#"
77 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
78 device_id, parameters, details, code, current_refresh_token, scope
79 FROM oauth_token
80 WHERE current_refresh_token = $1
81 "#,
82 refresh_token
83 )
84 .fetch_optional(pool)
85 .await?;
86 match row {
87 Some(r) => Ok(Some((
88 r.id,
89 TokenData {
90 did: r.did,
91 token_id: r.token_id,
92 created_at: r.created_at,
93 updated_at: r.updated_at,
94 expires_at: r.expires_at,
95 client_id: r.client_id,
96 client_auth: from_json(r.client_auth)?,
97 device_id: r.device_id,
98 parameters: from_json(r.parameters)?,
99 details: r.details,
100 code: r.code,
101 current_refresh_token: r.current_refresh_token,
102 scope: r.scope,
103 },
104 ))),
105 None => Ok(None),
106 }
107}
108
109pub async fn rotate_token(
110 pool: &PgPool,
111 old_db_id: i32,
112 new_token_id: &str,
113 new_refresh_token: &str,
114 new_expires_at: DateTime<Utc>,
115) -> Result<(), OAuthError> {
116 let mut tx = pool.begin().await?;
117 let old_refresh = sqlx::query_scalar!(
118 r#"
119 SELECT current_refresh_token FROM oauth_token WHERE id = $1
120 "#,
121 old_db_id
122 )
123 .fetch_one(&mut *tx)
124 .await?;
125 if let Some(old_rt) = old_refresh {
126 sqlx::query!(
127 r#"
128 INSERT INTO oauth_used_refresh_token (refresh_token, token_id)
129 VALUES ($1, $2)
130 "#,
131 old_rt,
132 old_db_id
133 )
134 .execute(&mut *tx)
135 .await?;
136 }
137 sqlx::query!(
138 r#"
139 UPDATE oauth_token
140 SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW()
141 WHERE id = $1
142 "#,
143 old_db_id,
144 new_token_id,
145 new_refresh_token,
146 new_expires_at
147 )
148 .execute(&mut *tx)
149 .await?;
150 tx.commit().await?;
151 Ok(())
152}
153
154pub async fn check_refresh_token_used(
155 pool: &PgPool,
156 refresh_token: &str,
157) -> Result<Option<i32>, OAuthError> {
158 let row = sqlx::query_scalar!(
159 r#"
160 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1
161 "#,
162 refresh_token
163 )
164 .fetch_optional(pool)
165 .await?;
166 Ok(row)
167}
168
169pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
170 sqlx::query!(
171 r#"
172 DELETE FROM oauth_token WHERE token_id = $1
173 "#,
174 token_id
175 )
176 .execute(pool)
177 .await?;
178 Ok(())
179}
180
181pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
182 sqlx::query!(
183 r#"
184 DELETE FROM oauth_token WHERE id = $1
185 "#,
186 db_id
187 )
188 .execute(pool)
189 .await?;
190 Ok(())
191}
192
193pub async fn list_tokens_for_user(pool: &PgPool, did: &str) -> Result<Vec<TokenData>, OAuthError> {
194 let rows = sqlx::query!(
195 r#"
196 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
197 device_id, parameters, details, code, current_refresh_token, scope
198 FROM oauth_token
199 WHERE did = $1
200 "#,
201 did
202 )
203 .fetch_all(pool)
204 .await?;
205 let mut tokens = Vec::with_capacity(rows.len());
206 for r in rows {
207 tokens.push(TokenData {
208 did: r.did,
209 token_id: r.token_id,
210 created_at: r.created_at,
211 updated_at: r.updated_at,
212 expires_at: r.expires_at,
213 client_id: r.client_id,
214 client_auth: from_json(r.client_auth)?,
215 device_id: r.device_id,
216 parameters: from_json(r.parameters)?,
217 details: r.details,
218 code: r.code,
219 current_refresh_token: r.current_refresh_token,
220 scope: r.scope,
221 });
222 }
223 Ok(tokens)
224}
225
226pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
227 let count = sqlx::query_scalar!(
228 r#"
229 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1
230 "#,
231 did
232 )
233 .fetch_one(pool)
234 .await?;
235 Ok(count)
236}
237
238pub async fn delete_oldest_tokens_for_user(
239 pool: &PgPool,
240 did: &str,
241 keep_count: i64,
242) -> Result<u64, OAuthError> {
243 let result = sqlx::query!(
244 r#"
245 DELETE FROM oauth_token
246 WHERE id IN (
247 SELECT id FROM oauth_token
248 WHERE did = $1
249 ORDER BY updated_at ASC
250 OFFSET $2
251 )
252 "#,
253 did,
254 keep_count
255 )
256 .execute(pool)
257 .await?;
258 Ok(result.rows_affected())
259}
260
261const MAX_TOKENS_PER_USER: i64 = 100;
262
263pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
264 let count = count_tokens_for_user(pool, did).await?;
265 if count > MAX_TOKENS_PER_USER {
266 let to_keep = MAX_TOKENS_PER_USER - 1;
267 delete_oldest_tokens_for_user(pool, did, to_keep).await?;
268 }
269 Ok(())
270}
271
272pub async fn revoke_tokens_for_client(
273 pool: &PgPool,
274 did: &str,
275 client_id: &str,
276) -> Result<u64, OAuthError> {
277 let result = sqlx::query!(
278 "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2",
279 did,
280 client_id
281 )
282 .execute(pool)
283 .await?;
284 Ok(result.rows_affected())
285}