tuiter 2006
1package main
2
3import (
4 "context"
5 "crypto/aes"
6 "crypto/cipher"
7 "crypto/rand"
8 "crypto/sha256"
9 "database/sql"
10 "encoding/base64"
11 "encoding/json"
12 "errors"
13 "fmt"
14 "io"
15 "time"
16
17 _ "modernc.org/sqlite"
18
19 "github.com/bluesky-social/indigo/atproto/auth/oauth"
20 "github.com/bluesky-social/indigo/atproto/syntax"
21)
22
23type sqliteStore struct {
24 db *sql.DB
25 key []byte
26}
27
28func deriveKeyFromEnv(raw string) ([]byte, error) {
29 if raw == "" {
30 return nil, errors.New("SESSION_DB_KEY is required")
31 }
32 if decoded, err := base64.StdEncoding.DecodeString(raw); err == nil && len(decoded) >= 16 {
33 if len(decoded) == 32 {
34 return decoded, nil
35 }
36 h := sha256.Sum256(decoded)
37 return h[:], nil
38 }
39 h := sha256.Sum256([]byte(raw))
40 return h[:], nil
41}
42
43func encryptBlob(key, plaintext []byte) ([]byte, error) {
44 block, err := aes.NewCipher(key)
45 if err != nil {
46 return nil, err
47 }
48 gcm, err := cipher.NewGCM(block)
49 if err != nil {
50 return nil, err
51 }
52 nonce := make([]byte, gcm.NonceSize())
53 if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
54 return nil, err
55 }
56 ct := gcm.Seal(nil, nonce, plaintext, nil)
57 out := make([]byte, 0, len(nonce)+len(ct))
58 out = append(out, nonce...)
59 out = append(out, ct...)
60 return out, nil
61}
62
63func decryptBlob(key, blob []byte) ([]byte, error) {
64 block, err := aes.NewCipher(key)
65 if err != nil {
66 return nil, err
67 }
68 gcm, err := cipher.NewGCM(block)
69 if err != nil {
70 return nil, err
71 }
72 n := gcm.NonceSize()
73 if len(blob) < n {
74 return nil, errors.New("ciphertext too short")
75 }
76 nonce := blob[:n]
77 ct := blob[n:]
78 pt, err := gcm.Open(nil, nonce, ct, nil)
79 if err != nil {
80 return nil, err
81 }
82 return pt, nil
83}
84
85func NewSQLiteStore(dbPath string, key []byte) (*sqliteStore, error) {
86 db, err := sql.Open("sqlite", dbPath)
87 if err != nil {
88 return nil, err
89 }
90 if _, err := db.Exec("PRAGMA busy_timeout = 5000"); err != nil {
91 _ = db.Close()
92 return nil, err
93 }
94 schema := []string{`CREATE TABLE IF NOT EXISTS sessions(
95 session_id TEXT PRIMARY KEY,
96 did TEXT,
97 data BLOB,
98 updated_at INTEGER
99 );`, `CREATE TABLE IF NOT EXISTS auth_requests(
100 state TEXT PRIMARY KEY,
101 data BLOB,
102 updated_at INTEGER
103 );`}
104 for _, s := range schema {
105 if _, err := db.Exec(s); err != nil {
106 _ = db.Close()
107 return nil, err
108 }
109 }
110 return &sqliteStore{db: db, key: key}, nil
111}
112
113func (s *sqliteStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
114 data, err := json.Marshal(sess)
115 if err != nil {
116 return err
117 }
118 enc, err := encryptBlob(s.key, data)
119 if err != nil {
120 return err
121 }
122 _, err = s.db.ExecContext(ctx, `INSERT INTO sessions(session_id, did, data, updated_at) VALUES (?, ?, ?, ?) ON CONFLICT(session_id) DO UPDATE SET did=excluded.did, data=excluded.data, updated_at=excluded.updated_at`, sess.SessionID, sess.AccountDID.String(), enc, time.Now().Unix())
123 return err
124}
125
126func (s *sqliteStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
127 row := s.db.QueryRowContext(ctx, `SELECT data FROM sessions WHERE session_id = ? AND did = ?`, sessionID, did.String())
128 var blob []byte
129 if err := row.Scan(&blob); err != nil {
130 if errors.Is(err, sql.ErrNoRows) {
131 return nil, fmt.Errorf("session not found")
132 }
133 return nil, err
134 }
135 pt, err := decryptBlob(s.key, blob)
136 if err != nil {
137 return nil, err
138 }
139 var out oauth.ClientSessionData
140 if err := json.Unmarshal(pt, &out); err != nil {
141 return nil, err
142 }
143 return &out, nil
144}
145
146func (s *sqliteStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
147 _, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE session_id = ? AND did = ?`, sessionID, did.String())
148 return err
149}
150
151func (s *sqliteStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
152 data, err := json.Marshal(info)
153 if err != nil {
154 return err
155 }
156 enc, err := encryptBlob(s.key, data)
157 if err != nil {
158 return err
159 }
160 _, err = s.db.ExecContext(ctx, `INSERT INTO auth_requests(state, data, updated_at) VALUES (?, ?, ?) ON CONFLICT(state) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at`, info.State, enc, time.Now().Unix())
161 return err
162}
163
164func (s *sqliteStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
165 row := s.db.QueryRowContext(ctx, `SELECT data FROM auth_requests WHERE state = ?`, state)
166 var blob []byte
167 if err := row.Scan(&blob); err != nil {
168 if errors.Is(err, sql.ErrNoRows) {
169 return nil, fmt.Errorf("auth request not found")
170 }
171 return nil, err
172 }
173 pt, err := decryptBlob(s.key, blob)
174 if err != nil {
175 return nil, err
176 }
177 var out oauth.AuthRequestData
178 if err := json.Unmarshal(pt, &out); err != nil {
179 return nil, err
180 }
181 return &out, nil
182}
183
184func (s *sqliteStore) DeleteAuthRequestInfo(ctx context.Context, state string) error {
185 _, err := s.db.ExecContext(ctx, `DELETE FROM auth_requests WHERE state = ?`, state)
186 return err
187}