this repo has no description
1use chrono::{DateTime, Utc};
2use sqlx::PgPool;
3use super::super::{OAuthError, TokenData};
4use super::helpers::{from_json, to_json};
5
6pub async fn create_token(
7 pool: &PgPool,
8 data: &TokenData,
9) -> Result<i32, OAuthError> {
10 let client_auth_json = to_json(&data.client_auth)?;
11 let parameters_json = to_json(&data.parameters)?;
12 let row = sqlx::query!(
13 r#"
14 INSERT INTO oauth_token
15 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
16 device_id, parameters, details, code, current_refresh_token, scope)
17 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
18 RETURNING id
19 "#,
20 data.did,
21 data.token_id,
22 data.created_at,
23 data.updated_at,
24 data.expires_at,
25 data.client_id,
26 client_auth_json,
27 data.device_id,
28 parameters_json,
29 data.details,
30 data.code,
31 data.current_refresh_token,
32 data.scope,
33 )
34 .fetch_one(pool)
35 .await?;
36 Ok(row.id)
37}
38
39pub async fn get_token_by_id(
40 pool: &PgPool,
41 token_id: &str,
42) -> Result<Option<TokenData>, OAuthError> {
43 let row = sqlx::query!(
44 r#"
45 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
46 device_id, parameters, details, code, current_refresh_token, scope
47 FROM oauth_token
48 WHERE token_id = $1
49 "#,
50 token_id
51 )
52 .fetch_optional(pool)
53 .await?;
54 match row {
55 Some(r) => Ok(Some(TokenData {
56 did: r.did,
57 token_id: r.token_id,
58 created_at: r.created_at,
59 updated_at: r.updated_at,
60 expires_at: r.expires_at,
61 client_id: r.client_id,
62 client_auth: from_json(r.client_auth)?,
63 device_id: r.device_id,
64 parameters: from_json(r.parameters)?,
65 details: r.details,
66 code: r.code,
67 current_refresh_token: r.current_refresh_token,
68 scope: r.scope,
69 })),
70 None => Ok(None),
71 }
72}
73
74pub async fn get_token_by_refresh_token(
75 pool: &PgPool,
76 refresh_token: &str,
77) -> Result<Option<(i32, TokenData)>, OAuthError> {
78 let row = sqlx::query!(
79 r#"
80 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
81 device_id, parameters, details, code, current_refresh_token, scope
82 FROM oauth_token
83 WHERE current_refresh_token = $1
84 "#,
85 refresh_token
86 )
87 .fetch_optional(pool)
88 .await?;
89 match row {
90 Some(r) => Ok(Some((
91 r.id,
92 TokenData {
93 did: r.did,
94 token_id: r.token_id,
95 created_at: r.created_at,
96 updated_at: r.updated_at,
97 expires_at: r.expires_at,
98 client_id: r.client_id,
99 client_auth: from_json(r.client_auth)?,
100 device_id: r.device_id,
101 parameters: from_json(r.parameters)?,
102 details: r.details,
103 code: r.code,
104 current_refresh_token: r.current_refresh_token,
105 scope: r.scope,
106 },
107 ))),
108 None => Ok(None),
109 }
110}
111
112pub async fn rotate_token(
113 pool: &PgPool,
114 old_db_id: i32,
115 new_token_id: &str,
116 new_refresh_token: &str,
117 new_expires_at: DateTime<Utc>,
118) -> Result<(), OAuthError> {
119 let mut tx = pool.begin().await?;
120 let old_refresh = sqlx::query_scalar!(
121 r#"
122 SELECT current_refresh_token FROM oauth_token WHERE id = $1
123 "#,
124 old_db_id
125 )
126 .fetch_one(&mut *tx)
127 .await?;
128 if let Some(old_rt) = old_refresh {
129 sqlx::query!(
130 r#"
131 INSERT INTO oauth_used_refresh_token (refresh_token, token_id)
132 VALUES ($1, $2)
133 "#,
134 old_rt,
135 old_db_id
136 )
137 .execute(&mut *tx)
138 .await?;
139 }
140 sqlx::query!(
141 r#"
142 UPDATE oauth_token
143 SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW()
144 WHERE id = $1
145 "#,
146 old_db_id,
147 new_token_id,
148 new_refresh_token,
149 new_expires_at
150 )
151 .execute(&mut *tx)
152 .await?;
153 tx.commit().await?;
154 Ok(())
155}
156
157pub async fn check_refresh_token_used(
158 pool: &PgPool,
159 refresh_token: &str,
160) -> Result<Option<i32>, OAuthError> {
161 let row = sqlx::query_scalar!(
162 r#"
163 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1
164 "#,
165 refresh_token
166 )
167 .fetch_optional(pool)
168 .await?;
169 Ok(row)
170}
171
172pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
173 sqlx::query!(
174 r#"
175 DELETE FROM oauth_token WHERE token_id = $1
176 "#,
177 token_id
178 )
179 .execute(pool)
180 .await?;
181 Ok(())
182}
183
184pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
185 sqlx::query!(
186 r#"
187 DELETE FROM oauth_token WHERE id = $1
188 "#,
189 db_id
190 )
191 .execute(pool)
192 .await?;
193 Ok(())
194}
195
196pub async fn list_tokens_for_user(
197 pool: &PgPool,
198 did: &str,
199) -> Result<Vec<TokenData>, OAuthError> {
200 let rows = sqlx::query!(
201 r#"
202 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
203 device_id, parameters, details, code, current_refresh_token, scope
204 FROM oauth_token
205 WHERE did = $1
206 "#,
207 did
208 )
209 .fetch_all(pool)
210 .await?;
211 let mut tokens = Vec::with_capacity(rows.len());
212 for r in rows {
213 tokens.push(TokenData {
214 did: r.did,
215 token_id: r.token_id,
216 created_at: r.created_at,
217 updated_at: r.updated_at,
218 expires_at: r.expires_at,
219 client_id: r.client_id,
220 client_auth: from_json(r.client_auth)?,
221 device_id: r.device_id,
222 parameters: from_json(r.parameters)?,
223 details: r.details,
224 code: r.code,
225 current_refresh_token: r.current_refresh_token,
226 scope: r.scope,
227 });
228 }
229 Ok(tokens)
230}
231
232pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
233 let count = sqlx::query_scalar!(
234 r#"
235 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1
236 "#,
237 did
238 )
239 .fetch_one(pool)
240 .await?;
241 Ok(count)
242}
243
244pub async fn delete_oldest_tokens_for_user(
245 pool: &PgPool,
246 did: &str,
247 keep_count: i64,
248) -> Result<u64, OAuthError> {
249 let result = sqlx::query!(
250 r#"
251 DELETE FROM oauth_token
252 WHERE id IN (
253 SELECT id FROM oauth_token
254 WHERE did = $1
255 ORDER BY updated_at ASC
256 OFFSET $2
257 )
258 "#,
259 did,
260 keep_count
261 )
262 .execute(pool)
263 .await?;
264 Ok(result.rows_affected())
265}
266
267const MAX_TOKENS_PER_USER: i64 = 100;
268
269pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
270 let count = count_tokens_for_user(pool, did).await?;
271 if count > MAX_TOKENS_PER_USER {
272 let to_keep = MAX_TOKENS_PER_USER - 1;
273 delete_oldest_tokens_for_user(pool, did, to_keep).await?;
274 }
275 Ok(())
276}