A fork of https://github.com/teal-fm/piper
1package db
2
3import (
4 "database/sql"
5 "encoding/json"
6 "fmt"
7 "log"
8 "os"
9 "path/filepath"
10 "time"
11
12 _ "github.com/mattn/go-sqlite3"
13 "github.com/teal-fm/piper/models"
14)
15
16type DB struct {
17 *sql.DB
18 logger *log.Logger
19}
20
21func New(dbPath string) (*DB, error) {
22 dir := filepath.Dir(dbPath)
23 if dir != "." && dir != "/" {
24 os.MkdirAll(dir, 755)
25 }
26
27 db, err := sql.Open("sqlite3", dbPath)
28 if err != nil {
29 return nil, err
30 }
31
32 // Test the connection
33 if err = db.Ping(); err != nil {
34 return nil, err
35 }
36 logger := log.New(os.Stdout, "db: ", log.LstdFlags|log.Lmsgprefix)
37
38 return &DB{db, logger}, nil
39}
40
41func (db *DB) Initialize() error {
42 _, err := db.Exec(`
43 CREATE TABLE IF NOT EXISTS users (
44 id INTEGER PRIMARY KEY AUTOINCREMENT,
45 username TEXT, -- Made nullable, might not have username initially
46 email TEXT UNIQUE, -- Made nullable
47 atproto_did TEXT UNIQUE, -- Atproto DID (identifier)
48 most_recent_at_session_id TEXT, -- Most recent oAuth session id
49 spotify_id TEXT UNIQUE, -- Spotify specific ID
50 access_token TEXT, -- Spotify access token
51 refresh_token TEXT, -- Spotify refresh token
52 token_expiry TIMESTAMP, -- Spotify token expiry
53 lastfm_username TEXT, -- Last.fm username
54 created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- Use default
55 updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP -- Use default
56 )`)
57 if err != nil {
58 return err
59 }
60
61 _, err = db.Exec(`
62 CREATE TABLE IF NOT EXISTS tracks (
63 id INTEGER PRIMARY KEY AUTOINCREMENT,
64 user_id INTEGER NOT NULL,
65 name TEXT NOT NULL,
66 recording_mbid TEXT, -- Added
67 artist TEXT NOT NULL, -- should be JSONB in PostgreSQL if we ever switch
68 album TEXT NOT NULL,
69 release_mbid TEXT, -- Added
70 url TEXT NOT NULL,
71 timestamp TIMESTAMP,
72 duration_ms INTEGER,
73 progress_ms INTEGER,
74 service_base_url TEXT,
75 isrc TEXT,
76 has_stamped BOOLEAN,
77 FOREIGN KEY (user_id) REFERENCES users(id)
78 )`)
79 if err != nil {
80 return err
81 }
82
83 _, err = db.Exec(`
84 CREATE TABLE IF NOT EXISTS atproto_state (
85 id INTEGER PRIMARY KEY AUTOINCREMENT,
86 state TEXT NOT NULL,
87 authserver_url TEXT,
88 account_did TEXT,
89 scopes TEXT,
90 request_uri TEXT,
91 authserver_token_endpoint TEXT,
92 authserver_revocation_endpoint TEXT,
93 pkce_verifier TEXT,
94 dpop_authserver_nonce TEXT,
95 dpop_privatekey_multibase TEXT,
96 created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
97 );
98 CREATE INDEX IF NOT EXISTS atproto_state_state ON atproto_state(state);
99
100`)
101 if err != nil {
102 return err
103 }
104
105 _, err = db.Exec(`
106 CREATE TABLE IF NOT EXISTS atproto_sessions (
107 id INTEGER PRIMARY KEY AUTOINCREMENT,
108 look_up_key TEXT NOT NULL,
109 account_did TEXT,
110 session_id TEXT,
111 host_url TEXT,
112 authserver_url TEXT,
113 authserver_token_endpoint TEXT,
114 authserver_revocation_endpoint TEXT,
115 scopes TEXT,
116 access_token TEXT,
117 refresh_token TEXT,
118 dpop_authserver_nonce TEXT,
119 dpop_host_nonce TEXT,
120 dpop_privatekey_multibase TEXT,
121 created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
122 );
123 CREATE INDEX IF NOT EXISTS idx_atproto_sessions_look_up_key ON atproto_sessions(look_up_key);
124`)
125 if err != nil {
126 return err
127 }
128
129 // Add columns recording_mbid and release_mbid to tracks table if they don't exist
130 _, err = db.Exec(`ALTER TABLE tracks ADD COLUMN recording_mbid TEXT`)
131 if err != nil && err.Error() != "duplicate column name: recording_mbid" {
132 // Handle errors other than 'duplicate column'
133 return err
134 }
135 _, err = db.Exec(`ALTER TABLE tracks ADD COLUMN release_mbid TEXT`)
136 if err != nil && err.Error() != "duplicate column name: release_mbid" {
137 // Handle errors other than 'duplicate column'
138 return err
139 }
140
141 return nil
142}
143
144// create user without spotify id
145func (db *DB) CreateUser(user *models.User) (int64, error) {
146 now := time.Now().UTC()
147
148 result, err := db.Exec(`
149 INSERT INTO users (username, email, created_at, updated_at)
150 VALUES (?, ?, ?, ?)`,
151 user.Username, user.Email, now, now)
152
153 if err != nil {
154 return 0, err
155 }
156
157 return result.LastInsertId()
158}
159
160// add spotify session to user, returning the updated user
161func (db *DB) AddSpotifySession(userID int64, username, email, spotifyId, accessToken, refreshToken string, tokenExpiry time.Time) (*models.User, error) {
162 now := time.Now().UTC()
163
164 _, err := db.Exec(`
165 UPDATE users SET username = ?, email = ?, spotify_id = ?, access_token = ?, refresh_token = ?, token_expiry = ?, created_at = ?, updated_at = ?
166 WHERE id == ?
167 `,
168 username, email, spotifyId, accessToken, refreshToken, tokenExpiry, now, now, userID)
169 if err != nil {
170 return nil, err
171 }
172
173 user, err := db.GetUserByID(userID)
174 if err != nil {
175 return nil, err
176 }
177
178 return user, err
179}
180
181func (db *DB) GetUserByID(ID int64) (*models.User, error) {
182 user := &models.User{}
183
184 err := db.QueryRow(`
185 SELECT id,
186 username,
187 email,
188 atproto_did,
189 most_recent_at_session_id,
190 spotify_id,
191 access_token,
192 refresh_token,
193 token_expiry,
194 lastfm_username,
195 created_at,
196 updated_at
197 FROM users WHERE id = ?`, ID).Scan(
198 &user.ID, &user.Username, &user.Email, &user.ATProtoDID, &user.MostRecentAtProtoSessionID, &user.SpotifyID,
199 &user.AccessToken, &user.RefreshToken, &user.TokenExpiry,
200 &user.LastFMUsername,
201 &user.CreatedAt, &user.UpdatedAt)
202
203 if err == sql.ErrNoRows {
204 return nil, nil
205 }
206
207 if err != nil {
208 return nil, err
209 }
210
211 return user, nil
212}
213
214func (db *DB) GetUserBySpotifyID(spotifyID string) (*models.User, error) {
215 user := &models.User{}
216
217 err := db.QueryRow(`
218 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, lastfm_username, created_at, updated_at
219 FROM users WHERE spotify_id = ?`, spotifyID).Scan(
220 &user.ID, &user.Username, &user.Email, &user.SpotifyID,
221 &user.AccessToken, &user.RefreshToken, &user.TokenExpiry,
222 &user.LastFMUsername,
223 &user.CreatedAt, &user.UpdatedAt)
224
225 if err == sql.ErrNoRows {
226 return nil, nil
227 }
228
229 if err != nil {
230 return nil, err
231 }
232
233 return user, nil
234}
235
236func (db *DB) UpdateUserToken(userID int64, accessToken, refreshToken string, expiry time.Time) error {
237 now := time.Now().UTC()
238
239 _, err := db.Exec(`
240 UPDATE users
241 SET access_token = ?, refresh_token = ?, token_expiry = ?, updated_at = ?
242 WHERE id = ?`,
243 accessToken, refreshToken, expiry, now, userID)
244
245 return err
246}
247
248func (db *DB) SaveTrack(userID int64, track *models.Track) (int64, error) {
249 // marshal artist json
250 artistString := ""
251 if len(track.Artist) > 0 {
252 bytes, err := json.Marshal(track.Artist)
253 if err != nil {
254 return 0, err
255 } else {
256 artistString = string(bytes)
257 }
258 }
259
260 var trackID int64
261
262 err := db.QueryRow(`
263 INSERT INTO tracks (user_id, name, recording_mbid, artist, album, release_mbid, url, timestamp, duration_ms, progress_ms, service_base_url, isrc, has_stamped)
264 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
265 RETURNING id`,
266 userID, track.Name, track.RecordingMBID, artistString, track.Album, track.ReleaseMBID, track.URL, track.Timestamp,
267 track.DurationMs, track.ProgressMs, track.ServiceBaseUrl, track.ISRC, track.HasStamped).Scan(&trackID)
268
269 return trackID, err
270}
271
272func (db *DB) UpdateTrack(trackID int64, track *models.Track) error {
273 // marshal artist json
274 artistString := ""
275 if len(track.Artist) > 0 {
276 bytes, err := json.Marshal(track.Artist)
277 if err != nil {
278 return err
279 } else {
280 artistString = string(bytes)
281 }
282 }
283
284 _, err := db.Exec(`
285 UPDATE tracks
286 SET name = ?,
287 recording_mbid = ?,
288 artist = ?,
289 album = ?,
290 release_mbid = ?,
291 url = ?,
292 timestamp = ?,
293 duration_ms = ?,
294 progress_ms = ?,
295 service_base_url = ?,
296 isrc = ?,
297 has_stamped = ?
298 WHERE id = ?`,
299 track.Name, track.RecordingMBID, artistString, track.Album, track.ReleaseMBID, track.URL, track.Timestamp,
300 track.DurationMs, track.ProgressMs, track.ServiceBaseUrl, track.ISRC, track.HasStamped,
301 trackID)
302
303 return err
304}
305
306func (db *DB) GetRecentTracks(userID int64, limit int) ([]*models.Track, error) {
307 rows, err := db.Query(`
308 SELECT id, name, recording_mbid, artist, album, release_mbid, url, timestamp, duration_ms, progress_ms, service_base_url, isrc, has_stamped
309 FROM tracks
310 WHERE user_id = ?
311 ORDER BY timestamp DESC
312 LIMIT ?`, userID, limit)
313
314 if err != nil {
315 return nil, err
316 }
317 defer rows.Close()
318
319 var tracks []*models.Track
320
321 for rows.Next() {
322 var artistString string
323 track := &models.Track{}
324 err := rows.Scan(
325 &track.PlayID,
326 &track.Name,
327 &track.RecordingMBID, // Scan new field
328 &artistString, // scan to be unmarshaled later
329 &track.Album,
330 &track.ReleaseMBID, // Scan new field
331 &track.URL,
332 &track.Timestamp,
333 &track.DurationMs,
334 &track.ProgressMs,
335 &track.ServiceBaseUrl,
336 &track.ISRC,
337 &track.HasStamped,
338 )
339
340 if err != nil {
341 return nil, err
342 }
343
344 // unmarshal artist json
345 var artists []models.Artist
346 err = json.Unmarshal([]byte(artistString), &artists)
347 if err != nil {
348 // fallback to previous format
349 artists = []models.Artist{{Name: artistString}}
350 }
351 track.Artist = artists
352 tracks = append(tracks, track)
353 }
354
355 return tracks, nil
356}
357
358// SpotifyQueryMapping maps Spotify sql query results to user structs
359func SpotifyQueryMapping(rows *sql.Rows) ([]*models.User, error) {
360
361 var users []*models.User
362
363 for rows.Next() {
364 user := &models.User{}
365 err := rows.Scan(
366 &user.ID, &user.Username, &user.Email, &user.SpotifyID,
367 &user.AccessToken, &user.RefreshToken, &user.TokenExpiry,
368 &user.CreatedAt, &user.UpdatedAt)
369 if err != nil {
370 return nil, err
371 }
372 users = append(users, user)
373 }
374
375 return users, nil
376}
377
378func (db *DB) GetUsersWithExpiredTokens() ([]*models.User, error) {
379 rows, err := db.Query(`
380 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, created_at, updated_at
381 FROM users
382 WHERE refresh_token IS NOT NULL AND token_expiry < ?
383 ORDER BY id`, time.Now().UTC())
384
385 if err != nil {
386 return nil, err
387 }
388 defer rows.Close()
389
390 return SpotifyQueryMapping(rows)
391
392}
393
394func (db *DB) GetAllActiveUsers() ([]*models.User, error) {
395 rows, err := db.Query(`
396 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, created_at, updated_at
397 FROM users
398 WHERE access_token IS NOT NULL
399 ORDER BY id`)
400
401 if err != nil {
402 return nil, err
403 }
404 defer rows.Close()
405
406 return SpotifyQueryMapping(rows)
407}
408
409func (db *DB) GetAllActiveUsersWithUnExpiredTokens() ([]*models.User, error) {
410 rows, err := db.Query(`
411 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, created_at, updated_at
412 FROM users
413 WHERE access_token IS NOT NULL AND token_expiry > ?
414 ORDER BY id`, time.Now().UTC())
415
416 if err != nil {
417 return nil, err
418 }
419 defer rows.Close()
420
421 return SpotifyQueryMapping(rows)
422}
423
424// debug to view current user's information
425// put everything in an 'any' type
426func (db *DB) DebugViewUserInformation(userID int64) (map[string]any, error) {
427 // Use Query instead of QueryRow to get access to column names and ensure only one row is processed
428 rows, err := db.Query(`
429 SELECT *
430 FROM users
431 WHERE id = ? LIMIT 1`, userID)
432 if err != nil {
433 return nil, fmt.Errorf("query failed: %w", err)
434 }
435 defer rows.Close()
436
437 // Get column names
438 cols, err := rows.Columns()
439 if err != nil {
440 return nil, fmt.Errorf("failed to get columns: %w", err)
441 }
442
443 // Check if there's a row to process
444 if !rows.Next() {
445 if err := rows.Err(); err != nil {
446 // Error during rows.Next() or preparing the result set
447 return nil, fmt.Errorf("error checking for row: %w", err)
448 }
449 // No rows found, which is a valid outcome but might be considered an error in some contexts.
450 // Returning sql.ErrNoRows is conventional.
451 return nil, sql.ErrNoRows
452 }
453
454 // Prepare scan arguments: pointers to interface{} slices
455 values := make([]any, len(cols))
456 scanArgs := make([]any, len(cols))
457 for i := range values {
458 scanArgs[i] = &values[i]
459 }
460
461 // Scan the row values
462 err = rows.Scan(scanArgs...)
463 if err != nil {
464 return nil, fmt.Errorf("failed to scan row: %w", err)
465 }
466
467 // Check for errors that might have occurred during iteration (after Scan)
468 if err := rows.Err(); err != nil {
469 return nil, fmt.Errorf("error after scanning row: %w", err)
470 }
471
472 // Create the result map
473 resultMap := make(map[string]any, len(cols))
474 for i, colName := range cols {
475 val := values[i]
476 // SQLite often returns []byte for TEXT columns, convert to string for usability.
477 // Also handle potential nil values appropriately.
478 if b, ok := val.([]byte); ok {
479 resultMap[colName] = string(b)
480 } else {
481 resultMap[colName] = val // Keep nil as nil, numbers as numbers, etc.
482 }
483 }
484
485 return resultMap, nil
486}
487
488func (db *DB) GetLastKnownTimestamp(userID int64) (*time.Time, error) {
489 var lastTimestamp time.Time
490 err := db.QueryRow(`
491 SELECT timestamp
492 FROM tracks
493 WHERE user_id = ?
494 ORDER BY timestamp DESC
495 LIMIT 1`, userID).Scan(&lastTimestamp)
496
497 if err != nil {
498 if err == sql.ErrNoRows {
499 return nil, nil
500 }
501 return nil, fmt.Errorf("failed to query last scrobble timestamp for user %d: %w", userID, err)
502 }
503
504 return &lastTimestamp, nil
505}
506
507//