tuiter 2006
at main 187 lines 5.0 kB view raw
1package main 2 3import ( 4 "context" 5 "crypto/aes" 6 "crypto/cipher" 7 "crypto/rand" 8 "crypto/sha256" 9 "database/sql" 10 "encoding/base64" 11 "encoding/json" 12 "errors" 13 "fmt" 14 "io" 15 "time" 16 17 _ "modernc.org/sqlite" 18 19 "github.com/bluesky-social/indigo/atproto/auth/oauth" 20 "github.com/bluesky-social/indigo/atproto/syntax" 21) 22 23type sqliteStore struct { 24 db *sql.DB 25 key []byte 26} 27 28func deriveKeyFromEnv(raw string) ([]byte, error) { 29 if raw == "" { 30 return nil, errors.New("SESSION_DB_KEY is required") 31 } 32 if decoded, err := base64.StdEncoding.DecodeString(raw); err == nil && len(decoded) >= 16 { 33 if len(decoded) == 32 { 34 return decoded, nil 35 } 36 h := sha256.Sum256(decoded) 37 return h[:], nil 38 } 39 h := sha256.Sum256([]byte(raw)) 40 return h[:], nil 41} 42 43func encryptBlob(key, plaintext []byte) ([]byte, error) { 44 block, err := aes.NewCipher(key) 45 if err != nil { 46 return nil, err 47 } 48 gcm, err := cipher.NewGCM(block) 49 if err != nil { 50 return nil, err 51 } 52 nonce := make([]byte, gcm.NonceSize()) 53 if _, err := io.ReadFull(rand.Reader, nonce); err != nil { 54 return nil, err 55 } 56 ct := gcm.Seal(nil, nonce, plaintext, nil) 57 out := make([]byte, 0, len(nonce)+len(ct)) 58 out = append(out, nonce...) 59 out = append(out, ct...) 60 return out, nil 61} 62 63func decryptBlob(key, blob []byte) ([]byte, error) { 64 block, err := aes.NewCipher(key) 65 if err != nil { 66 return nil, err 67 } 68 gcm, err := cipher.NewGCM(block) 69 if err != nil { 70 return nil, err 71 } 72 n := gcm.NonceSize() 73 if len(blob) < n { 74 return nil, errors.New("ciphertext too short") 75 } 76 nonce := blob[:n] 77 ct := blob[n:] 78 pt, err := gcm.Open(nil, nonce, ct, nil) 79 if err != nil { 80 return nil, err 81 } 82 return pt, nil 83} 84 85func NewSQLiteStore(dbPath string, key []byte) (*sqliteStore, error) { 86 db, err := sql.Open("sqlite", dbPath) 87 if err != nil { 88 return nil, err 89 } 90 if _, err := db.Exec("PRAGMA busy_timeout = 5000"); err != nil { 91 _ = db.Close() 92 return nil, err 93 } 94 schema := []string{`CREATE TABLE IF NOT EXISTS sessions( 95 session_id TEXT PRIMARY KEY, 96 did TEXT, 97 data BLOB, 98 updated_at INTEGER 99 );`, `CREATE TABLE IF NOT EXISTS auth_requests( 100 state TEXT PRIMARY KEY, 101 data BLOB, 102 updated_at INTEGER 103 );`} 104 for _, s := range schema { 105 if _, err := db.Exec(s); err != nil { 106 _ = db.Close() 107 return nil, err 108 } 109 } 110 return &sqliteStore{db: db, key: key}, nil 111} 112 113func (s *sqliteStore) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error { 114 data, err := json.Marshal(sess) 115 if err != nil { 116 return err 117 } 118 enc, err := encryptBlob(s.key, data) 119 if err != nil { 120 return err 121 } 122 _, err = s.db.ExecContext(ctx, `INSERT INTO sessions(session_id, did, data, updated_at) VALUES (?, ?, ?, ?) ON CONFLICT(session_id) DO UPDATE SET did=excluded.did, data=excluded.data, updated_at=excluded.updated_at`, sess.SessionID, sess.AccountDID.String(), enc, time.Now().Unix()) 123 return err 124} 125 126func (s *sqliteStore) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) { 127 row := s.db.QueryRowContext(ctx, `SELECT data FROM sessions WHERE session_id = ? AND did = ?`, sessionID, did.String()) 128 var blob []byte 129 if err := row.Scan(&blob); err != nil { 130 if errors.Is(err, sql.ErrNoRows) { 131 return nil, fmt.Errorf("session not found") 132 } 133 return nil, err 134 } 135 pt, err := decryptBlob(s.key, blob) 136 if err != nil { 137 return nil, err 138 } 139 var out oauth.ClientSessionData 140 if err := json.Unmarshal(pt, &out); err != nil { 141 return nil, err 142 } 143 return &out, nil 144} 145 146func (s *sqliteStore) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { 147 _, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE session_id = ? AND did = ?`, sessionID, did.String()) 148 return err 149} 150 151func (s *sqliteStore) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error { 152 data, err := json.Marshal(info) 153 if err != nil { 154 return err 155 } 156 enc, err := encryptBlob(s.key, data) 157 if err != nil { 158 return err 159 } 160 _, err = s.db.ExecContext(ctx, `INSERT INTO auth_requests(state, data, updated_at) VALUES (?, ?, ?) ON CONFLICT(state) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at`, info.State, enc, time.Now().Unix()) 161 return err 162} 163 164func (s *sqliteStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) { 165 row := s.db.QueryRowContext(ctx, `SELECT data FROM auth_requests WHERE state = ?`, state) 166 var blob []byte 167 if err := row.Scan(&blob); err != nil { 168 if errors.Is(err, sql.ErrNoRows) { 169 return nil, fmt.Errorf("auth request not found") 170 } 171 return nil, err 172 } 173 pt, err := decryptBlob(s.key, blob) 174 if err != nil { 175 return nil, err 176 } 177 var out oauth.AuthRequestData 178 if err := json.Unmarshal(pt, &out); err != nil { 179 return nil, err 180 } 181 return &out, nil 182} 183 184func (s *sqliteStore) DeleteAuthRequestInfo(ctx context.Context, state string) error { 185 _, err := s.db.ExecContext(ctx, `DELETE FROM auth_requests WHERE state = ?`, state) 186 return err 187}