this repo has no description
1use chrono::{DateTime, Utc};
2use sqlx::PgPool;
3use super::super::{OAuthError, TokenData};
4use super::helpers::{from_json, to_json};
5pub async fn create_token(
6 pool: &PgPool,
7 data: &TokenData,
8) -> Result<i32, OAuthError> {
9 let client_auth_json = to_json(&data.client_auth)?;
10 let parameters_json = to_json(&data.parameters)?;
11 let row = sqlx::query!(
12 r#"
13 INSERT INTO oauth_token
14 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
15 device_id, parameters, details, code, current_refresh_token, scope)
16 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
17 RETURNING id
18 "#,
19 data.did,
20 data.token_id,
21 data.created_at,
22 data.updated_at,
23 data.expires_at,
24 data.client_id,
25 client_auth_json,
26 data.device_id,
27 parameters_json,
28 data.details,
29 data.code,
30 data.current_refresh_token,
31 data.scope,
32 )
33 .fetch_one(pool)
34 .await?;
35 Ok(row.id)
36}
37pub async fn get_token_by_id(
38 pool: &PgPool,
39 token_id: &str,
40) -> Result<Option<TokenData>, OAuthError> {
41 let row = sqlx::query!(
42 r#"
43 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
44 device_id, parameters, details, code, current_refresh_token, scope
45 FROM oauth_token
46 WHERE token_id = $1
47 "#,
48 token_id
49 )
50 .fetch_optional(pool)
51 .await?;
52 match row {
53 Some(r) => Ok(Some(TokenData {
54 did: r.did,
55 token_id: r.token_id,
56 created_at: r.created_at,
57 updated_at: r.updated_at,
58 expires_at: r.expires_at,
59 client_id: r.client_id,
60 client_auth: from_json(r.client_auth)?,
61 device_id: r.device_id,
62 parameters: from_json(r.parameters)?,
63 details: r.details,
64 code: r.code,
65 current_refresh_token: r.current_refresh_token,
66 scope: r.scope,
67 })),
68 None => Ok(None),
69 }
70}
71pub async fn get_token_by_refresh_token(
72 pool: &PgPool,
73 refresh_token: &str,
74) -> Result<Option<(i32, TokenData)>, OAuthError> {
75 let row = sqlx::query!(
76 r#"
77 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
78 device_id, parameters, details, code, current_refresh_token, scope
79 FROM oauth_token
80 WHERE current_refresh_token = $1
81 "#,
82 refresh_token
83 )
84 .fetch_optional(pool)
85 .await?;
86 match row {
87 Some(r) => Ok(Some((
88 r.id,
89 TokenData {
90 did: r.did,
91 token_id: r.token_id,
92 created_at: r.created_at,
93 updated_at: r.updated_at,
94 expires_at: r.expires_at,
95 client_id: r.client_id,
96 client_auth: from_json(r.client_auth)?,
97 device_id: r.device_id,
98 parameters: from_json(r.parameters)?,
99 details: r.details,
100 code: r.code,
101 current_refresh_token: r.current_refresh_token,
102 scope: r.scope,
103 },
104 ))),
105 None => Ok(None),
106 }
107}
108pub async fn rotate_token(
109 pool: &PgPool,
110 old_db_id: i32,
111 new_token_id: &str,
112 new_refresh_token: &str,
113 new_expires_at: DateTime<Utc>,
114) -> Result<(), OAuthError> {
115 let mut tx = pool.begin().await?;
116 let old_refresh = sqlx::query_scalar!(
117 r#"
118 SELECT current_refresh_token FROM oauth_token WHERE id = $1
119 "#,
120 old_db_id
121 )
122 .fetch_one(&mut *tx)
123 .await?;
124 if let Some(old_rt) = old_refresh {
125 sqlx::query!(
126 r#"
127 INSERT INTO oauth_used_refresh_token (refresh_token, token_id)
128 VALUES ($1, $2)
129 "#,
130 old_rt,
131 old_db_id
132 )
133 .execute(&mut *tx)
134 .await?;
135 }
136 sqlx::query!(
137 r#"
138 UPDATE oauth_token
139 SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW()
140 WHERE id = $1
141 "#,
142 old_db_id,
143 new_token_id,
144 new_refresh_token,
145 new_expires_at
146 )
147 .execute(&mut *tx)
148 .await?;
149 tx.commit().await?;
150 Ok(())
151}
152pub async fn check_refresh_token_used(
153 pool: &PgPool,
154 refresh_token: &str,
155) -> Result<Option<i32>, OAuthError> {
156 let row = sqlx::query_scalar!(
157 r#"
158 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1
159 "#,
160 refresh_token
161 )
162 .fetch_optional(pool)
163 .await?;
164 Ok(row)
165}
166pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> {
167 sqlx::query!(
168 r#"
169 DELETE FROM oauth_token WHERE token_id = $1
170 "#,
171 token_id
172 )
173 .execute(pool)
174 .await?;
175 Ok(())
176}
177pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> {
178 sqlx::query!(
179 r#"
180 DELETE FROM oauth_token WHERE id = $1
181 "#,
182 db_id
183 )
184 .execute(pool)
185 .await?;
186 Ok(())
187}
188pub async fn list_tokens_for_user(
189 pool: &PgPool,
190 did: &str,
191) -> Result<Vec<TokenData>, OAuthError> {
192 let rows = sqlx::query!(
193 r#"
194 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth,
195 device_id, parameters, details, code, current_refresh_token, scope
196 FROM oauth_token
197 WHERE did = $1
198 "#,
199 did
200 )
201 .fetch_all(pool)
202 .await?;
203 let mut tokens = Vec::with_capacity(rows.len());
204 for r in rows {
205 tokens.push(TokenData {
206 did: r.did,
207 token_id: r.token_id,
208 created_at: r.created_at,
209 updated_at: r.updated_at,
210 expires_at: r.expires_at,
211 client_id: r.client_id,
212 client_auth: from_json(r.client_auth)?,
213 device_id: r.device_id,
214 parameters: from_json(r.parameters)?,
215 details: r.details,
216 code: r.code,
217 current_refresh_token: r.current_refresh_token,
218 scope: r.scope,
219 });
220 }
221 Ok(tokens)
222}
223pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> {
224 let count = sqlx::query_scalar!(
225 r#"
226 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1
227 "#,
228 did
229 )
230 .fetch_one(pool)
231 .await?;
232 Ok(count)
233}
234pub async fn delete_oldest_tokens_for_user(
235 pool: &PgPool,
236 did: &str,
237 keep_count: i64,
238) -> Result<u64, OAuthError> {
239 let result = sqlx::query!(
240 r#"
241 DELETE FROM oauth_token
242 WHERE id IN (
243 SELECT id FROM oauth_token
244 WHERE did = $1
245 ORDER BY updated_at ASC
246 OFFSET $2
247 )
248 "#,
249 did,
250 keep_count
251 )
252 .execute(pool)
253 .await?;
254 Ok(result.rows_affected())
255}
256const MAX_TOKENS_PER_USER: i64 = 100;
257pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> {
258 let count = count_tokens_for_user(pool, did).await?;
259 if count > MAX_TOKENS_PER_USER {
260 let to_keep = MAX_TOKENS_PER_USER - 1;
261 delete_oldest_tokens_for_user(pool, did, to_keep).await?;
262 }
263 Ok(())
264}