A fork of https://github.com/teal-fm/piper

clean up comments + code

Natalie B 5399a52d 130c8755

+91 -165
+1 -9
config/config.go
··· 10 11 // Load initializes the configuration with viper 12 func Load() { 13 - // Load .env file if it exists 14 if err := godotenv.Load(); err != nil { 15 log.Println("No .env file found or error loading it. Using default values and environment variables.") 16 } 17 18 - // Set default configurations 19 viper.SetDefault("server.port", "8080") 20 viper.SetDefault("server.host", "localhost") 21 viper.SetDefault("callback.spotify", "http://localhost:8080/callback/spotify") ··· 30 viper.SetDefault("atproto.metadata_url", "http://localhost:8080/metadata") 31 viper.SetDefault("atproto.callback_url", "/metadata") 32 33 - // Configure Viper to read environment variables 34 viper.AutomaticEnv() 35 36 - // Replace dots with underscores for environment variables 37 viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) 38 39 - // Set the config name and paths 40 viper.SetConfigName("config") 41 viper.SetConfigType("yaml") 42 viper.AddConfigPath("./config") 43 viper.AddConfigPath(".") 44 45 - // Try to read the config file 46 if err := viper.ReadInConfig(); err != nil { 47 if _, ok := err.(viper.ConfigFileNotFoundError); !ok { 48 - // It's not a "file not found" error, so it's a real error 49 log.Fatalf("Error reading config file: %v", err) 50 } 51 - // Config file not found, using defaults and environment variables 52 log.Println("Config file not found, using default values and environment variables") 53 } else { 54 log.Println("Using config file:", viper.ConfigFileUsed()) 55 } 56 57 - // Check if required values are present 58 requiredVars := []string{"spotify.client_id", "spotify.client_secret"} 59 missingVars := []string{} 60
··· 10 11 // Load initializes the configuration with viper 12 func Load() { 13 if err := godotenv.Load(); err != nil { 14 log.Println("No .env file found or error loading it. Using default values and environment variables.") 15 } 16 17 viper.SetDefault("server.port", "8080") 18 viper.SetDefault("server.host", "localhost") 19 viper.SetDefault("callback.spotify", "http://localhost:8080/callback/spotify") ··· 28 viper.SetDefault("atproto.metadata_url", "http://localhost:8080/metadata") 29 viper.SetDefault("atproto.callback_url", "/metadata") 30 31 viper.AutomaticEnv() 32 33 viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) 34 35 viper.SetConfigName("config") 36 viper.SetConfigType("yaml") 37 viper.AddConfigPath("./config") 38 viper.AddConfigPath(".") 39 40 if err := viper.ReadInConfig(); err != nil { 41 if _, ok := err.(viper.ConfigFileNotFoundError); !ok { 42 log.Fatalf("Error reading config file: %v", err) 43 } 44 log.Println("Config file not found, using default values and environment variables") 45 } else { 46 log.Println("Using config file:", viper.ConfigFileUsed()) 47 } 48 49 + // check for required settings 50 requiredVars := []string{"spotify.client_id", "spotify.client_secret"} 51 missingVars := []string{} 52
+6 -10
db/atproto.go
··· 39 40 func (db *DB) GetATprotoAuthData(state string) (*models.ATprotoAuthData, error) { 41 var data models.ATprotoAuthData 42 - var dpopPrivateJWKString string // Temporary variable to hold the JSON string 43 44 err := db.QueryRow(` 45 SELECT state, did, pds_url, authserver_issuer, pkce_verifier, dpop_authserver_nonce, dpop_private_jwk ··· 52 &data.AuthServerIssuer, 53 &data.PKCEVerifier, 54 &data.DPoPAuthServerNonce, 55 - &dpopPrivateJWKString, // Scan into the temporary string 56 ) 57 if err != nil { 58 - // Return the original scan error if it occurred 59 if err == sql.ErrNoRows { 60 return nil, fmt.Errorf("no auth data found for state %s: %w", state, err) 61 } ··· 64 65 key, err := helpers.ParseJWKFromBytes([]byte(dpopPrivateJWKString)) 66 if err != nil { 67 - // Return an error if parsing fails 68 return nil, fmt.Errorf("failed to parse DPoPPrivateJWK for state %s: %w", state, err) 69 } 70 data.DPoPPrivateJWK = key 71 72 - return &data, nil // Return nil error on success 73 } 74 75 func (db *DB) FindOrCreateUserByDID(did string) (*models.User, error) { ··· 97 if idErr != nil { 98 return nil, fmt.Errorf("failed to get last insert id: %w", idErr) 99 } 100 - // Populate the user struct with the newly created user's data 101 user.ID = lastID 102 user.ATProtoDID = &did 103 user.CreatedAt = now 104 user.UpdatedAt = now 105 - return &user, nil // Return the created user and nil error 106 } else if err != nil { 107 - // Handle other potential errors from QueryRow 108 return nil, fmt.Errorf("failed to find user by DID: %w", err) 109 } 110 111 return &user, err 112 } 113 114 - // Create or update the current user's ATproto session data. 115 func (db *DB) SaveATprotoSession(tokenResp *oauth.TokenResponse) error { 116 117 expiryTime := time.Now().Add(time.Second * time.Duration(tokenResp.ExpiresIn)) ··· 141 142 rowsAffected, err := result.RowsAffected() 143 if err != nil { 144 - // Error checking RowsAffected, but the update might have succeeded 145 return fmt.Errorf("failed to check rows affected after updating atproto session for did %s: %w", tokenResp.Sub, err) 146 } 147
··· 39 40 func (db *DB) GetATprotoAuthData(state string) (*models.ATprotoAuthData, error) { 41 var data models.ATprotoAuthData 42 + var dpopPrivateJWKString string 43 44 err := db.QueryRow(` 45 SELECT state, did, pds_url, authserver_issuer, pkce_verifier, dpop_authserver_nonce, dpop_private_jwk ··· 52 &data.AuthServerIssuer, 53 &data.PKCEVerifier, 54 &data.DPoPAuthServerNonce, 55 + &dpopPrivateJWKString, 56 ) 57 if err != nil { 58 if err == sql.ErrNoRows { 59 return nil, fmt.Errorf("no auth data found for state %s: %w", state, err) 60 } ··· 63 64 key, err := helpers.ParseJWKFromBytes([]byte(dpopPrivateJWKString)) 65 if err != nil { 66 return nil, fmt.Errorf("failed to parse DPoPPrivateJWK for state %s: %w", state, err) 67 } 68 data.DPoPPrivateJWK = key 69 70 + return &data, nil 71 } 72 73 func (db *DB) FindOrCreateUserByDID(did string) (*models.User, error) { ··· 95 if idErr != nil { 96 return nil, fmt.Errorf("failed to get last insert id: %w", idErr) 97 } 98 user.ID = lastID 99 user.ATProtoDID = &did 100 user.CreatedAt = now 101 user.UpdatedAt = now 102 + return &user, nil 103 } else if err != nil { 104 return nil, fmt.Errorf("failed to find user by DID: %w", err) 105 } 106 107 return &user, err 108 } 109 110 + // create or update the current user's ATproto session data. 111 func (db *DB) SaveATprotoSession(tokenResp *oauth.TokenResponse) error { 112 113 expiryTime := time.Now().Add(time.Second * time.Duration(tokenResp.ExpiresIn)) ··· 137 138 rowsAffected, err := result.RowsAffected() 139 if err != nil { 140 + // it's possible the update succeeded here? 141 return fmt.Errorf("failed to check rows affected after updating atproto session for did %s: %w", tokenResp.Sub, err) 142 } 143
+5 -11
db/db.go
··· 11 "github.com/teal-fm/piper/models" 12 ) 13 14 - // DB is a wrapper around sql.DB 15 type DB struct { 16 *sql.DB 17 } ··· 36 } 37 38 func (db *DB) Initialize() error { 39 - // Create users table 40 _, err := db.Exec(` 41 CREATE TABLE IF NOT EXISTS users ( 42 id INTEGER PRIMARY KEY AUTOINCREMENT, ··· 59 return err 60 } 61 62 - // Create tracks table 63 _, err = db.Exec(` 64 CREATE TABLE IF NOT EXISTS tracks ( 65 id INTEGER PRIMARY KEY AUTOINCREMENT, ··· 115 return result.LastInsertId() 116 } 117 118 - // Add spotify session to user, returning the updated user 119 func (db *DB) AddSpotifySession(userID int64, username, email, spotifyId, accessToken, refreshToken string, tokenExpiry time.Time) (*models.User, error) { 120 now := time.Now() 121 ··· 191 } 192 193 func (db *DB) SaveTrack(userID int64, track *models.Track) (int64, error) { 194 - // Convert the Artist array to a string for storage 195 artistString := "" 196 if len(track.Artist) > 0 { 197 bytes, err := json.Marshal(track.Artist) ··· 215 } 216 217 func (db *DB) UpdateTrack(trackID int64, track *models.Track) error { 218 - // Convert the Artist array to a string for storage 219 - // In a production environment, you'd want to use proper JSON serialization 220 artistString := "" 221 if len(track.Artist) > 0 { 222 bytes, err := json.Marshal(track.Artist) ··· 248 } 249 250 func (db *DB) GetRecentTracks(userID int64, limit int) ([]*models.Track, error) { 251 - // convert previous-format artist strings to current-format 252 - 253 rows, err := db.Query(` 254 SELECT id, name, artist, album, url, timestamp, duration_ms, progress_ms, service_base_url, isrc, has_stamped 255 FROM tracks ··· 270 err := rows.Scan( 271 &track.PlayID, 272 &track.Name, 273 - &artistString, // Scan into a string first 274 &track.Album, 275 &track.URL, 276 &track.Timestamp, ··· 285 return nil, err 286 } 287 288 - // Convert the artist string to the Artist array structure 289 var artists []models.Artist 290 err = json.Unmarshal([]byte(artistString), &artists) 291 if err != nil {
··· 11 "github.com/teal-fm/piper/models" 12 ) 13 14 type DB struct { 15 *sql.DB 16 } ··· 35 } 36 37 func (db *DB) Initialize() error { 38 _, err := db.Exec(` 39 CREATE TABLE IF NOT EXISTS users ( 40 id INTEGER PRIMARY KEY AUTOINCREMENT, ··· 57 return err 58 } 59 60 _, err = db.Exec(` 61 CREATE TABLE IF NOT EXISTS tracks ( 62 id INTEGER PRIMARY KEY AUTOINCREMENT, ··· 112 return result.LastInsertId() 113 } 114 115 + // add spotify session to user, returning the updated user 116 func (db *DB) AddSpotifySession(userID int64, username, email, spotifyId, accessToken, refreshToken string, tokenExpiry time.Time) (*models.User, error) { 117 now := time.Now() 118 ··· 188 } 189 190 func (db *DB) SaveTrack(userID int64, track *models.Track) (int64, error) { 191 + // marshal artist json 192 artistString := "" 193 if len(track.Artist) > 0 { 194 bytes, err := json.Marshal(track.Artist) ··· 212 } 213 214 func (db *DB) UpdateTrack(trackID int64, track *models.Track) error { 215 + // marshal artist json 216 artistString := "" 217 if len(track.Artist) > 0 { 218 bytes, err := json.Marshal(track.Artist) ··· 244 } 245 246 func (db *DB) GetRecentTracks(userID int64, limit int) ([]*models.Track, error) { 247 rows, err := db.Query(` 248 SELECT id, name, artist, album, url, timestamp, duration_ms, progress_ms, service_base_url, isrc, has_stamped 249 FROM tracks ··· 264 err := rows.Scan( 265 &track.PlayID, 266 &track.Name, 267 + &artistString, // scan to be unmarshaled later 268 &track.Album, 269 &track.URL, 270 &track.Timestamp, ··· 279 return nil, err 280 } 281 282 + // unmarshal artist json 283 var artists []models.Artist 284 err = json.Unmarshal([]byte(artistString), &artists) 285 if err != nil {
+3 -7
main.go
··· 21 func home(w http.ResponseWriter, r *http.Request) { 22 w.Header().Set("Content-Type", "text/html") 23 24 - // Check if user has an active session cookie 25 cookie, err := r.Cookie("session") 26 isLoggedIn := err == nil && cookie != nil 27 - // TODO: Add logic here to fetch user details from DB using session ID 28 - // to check if Spotify is already connected, if desired for finer control. 29 - // For now, we'll just check if *any* session exists. 30 31 html := ` 32 <html> ··· 106 107 // JSON API handlers 108 109 - // jsonResponse returns a JSON response 110 func jsonResponse(w http.ResponseWriter, statusCode int, data any) { 111 w.Header().Set("Content-Type", "application/json") 112 w.WriteHeader(statusCode) ··· 115 } 116 } 117 118 - // API endpoint for current track 119 func apiCurrentTrack(spotifyService *spotify.SpotifyService) http.HandlerFunc { 120 return func(w http.ResponseWriter, r *http.Request) { 121 userID, ok := session.GetUserID(r.Context()) ··· 134 } 135 } 136 137 - // API endpoint for history 138 func apiTrackHistory(spotifyService *spotify.SpotifyService) http.HandlerFunc { 139 return func(w http.ResponseWriter, r *http.Request) { 140 userID, ok := session.GetUserID(r.Context())
··· 21 func home(w http.ResponseWriter, r *http.Request) { 22 w.Header().Set("Content-Type", "text/html") 23 24 + // check if user has an active session cookie 25 cookie, err := r.Cookie("session") 26 isLoggedIn := err == nil && cookie != nil 27 + // TODO: add logic here to fetch user details from DB using session ID 28 + // to check if Spotify is already connected 29 30 html := ` 31 <html> ··· 105 106 // JSON API handlers 107 108 func jsonResponse(w http.ResponseWriter, statusCode int, data any) { 109 w.Header().Set("Content-Type", "application/json") 110 w.WriteHeader(statusCode) ··· 113 } 114 } 115 116 func apiCurrentTrack(spotifyService *spotify.SpotifyService) http.HandlerFunc { 117 return func(w http.ResponseWriter, r *http.Request) { 118 userID, ok := session.GetUserID(r.Context()) ··· 131 } 132 } 133 134 func apiTrackHistory(spotifyService *spotify.SpotifyService) http.HandlerFunc { 135 return func(w http.ResponseWriter, r *http.Request) { 136 userID, ok := session.GetUserID(r.Context())
-1
models/atproto.go
··· 1 - // Add this struct definition to piper/models/atproto.go 2 package models 3 4 import (
··· 1 package models 2 3 import (
+19 -14
models/user.go
··· 2 3 import "time" 4 5 - // User represents a user of the application 6 type User struct { 7 - ID int64 8 - Username string 9 - Email *string // Use pointer for nullable fields 10 - SpotifyID *string // Use pointer for nullable fields 11 - AccessToken *string // Spotify Access Token 12 - RefreshToken *string // Spotify Refresh Token 13 - TokenExpiry *time.Time // Spotify Token Expiry 14 - CreatedAt time.Time 15 - UpdatedAt time.Time 16 - ATProtoDID *string // ATProto DID 17 - ATProtoAccessToken *string // ATProto Access Token 18 - ATProtoRefreshToken *string // ATProto Refresh Token 19 - ATProtoTokenExpiry *time.Time // ATProto Token Expiry 20 }
··· 2 3 import "time" 4 5 + // an end user of piper 6 type User struct { 7 + ID int64 8 + Username string 9 + Email *string 10 + 11 + // spotify information 12 + SpotifyID *string 13 + AccessToken *string 14 + RefreshToken *string 15 + TokenExpiry *time.Time 16 + 17 + // atp info 18 + ATProtoDID *string 19 + ATProtoAccessToken *string 20 + ATProtoRefreshToken *string 21 + ATProtoTokenExpiry *time.Time 22 + 23 + CreatedAt time.Time 24 + UpdatedAt time.Time 25 }
+2 -3
oauth/atproto/atproto.go
··· 1 - // Modify piper/oauth/atproto/atproto.go 2 package atproto 3 4 import ( ··· 88 return nil, fmt.Errorf("failed PAR request to %s: %w", ui.AuthServer, err) 89 } 90 91 - // Save state including generated PKCE verifier and DPoP key 92 data := &models.ATprotoAuthData{ 93 State: parResp.State, 94 DID: ui.DID, ··· 171 } 172 173 log.Printf("ATProto Callback Success: User %d (DID: %s) authenticated.", userID.ID, data.DID) 174 - return userID.ID, nil // Return the piper user ID 175 }
··· 1 package atproto 2 3 import ( ··· 87 return nil, fmt.Errorf("failed PAR request to %s: %w", ui.AuthServer, err) 88 } 89 90 + // Save state 91 data := &models.ATprotoAuthData{ 92 State: parResp.State, 93 DID: ui.DID, ··· 170 } 171 172 log.Printf("ATProto Callback Success: User %d (DID: %s) authenticated.", userID.ID, data.DID) 173 + return userID.ID, nil 174 }
+6 -14
oauth/oauth2.go
··· 1 - // Modify piper/oauth/oauth2.go 2 package oauth 3 4 import ( ··· 22 state string 23 codeVerifier string 24 codeChallenge string 25 - // Added TokenReceiver field to handle user lookup/creation based on token 26 tokenReceiver TokenReceiver 27 } 28 ··· 38 switch strings.ToLower(provider) { 39 case "spotify": 40 endpoint = spotify.Endpoint 41 - // Add other providers like Last.fm here 42 default: 43 - // Placeholder for unconfigured providers 44 log.Printf("Warning: OAuth2 provider '%s' not explicitly configured. Using placeholder endpoints.", provider) 45 endpoint = oauth2.Endpoint{ 46 - AuthURL: "https://example.com/auth", // Replace with actual endpoints if needed 47 TokenURL: "https://example.com/token", 48 } 49 } ··· 62 state: GenerateRandomState(), 63 codeVerifier: codeVerifier, 64 codeChallenge: codeChallenge, 65 - tokenReceiver: tokenReceiver, // Store the token receiver 66 } 67 } 68 69 - // generateCodeVerifier creates a random code verifier for PKCE 70 func GenerateCodeVerifier() string { 71 b := make([]byte, 64) 72 rand.Read(b) 73 return base64.RawURLEncoding.EncodeToString(b) 74 } 75 76 - // generateCodeChallenge creates a code challenge from the code verifier using S256 method 77 func GenerateCodeChallenge(verifier string) string { 78 h := sha256.New() 79 h.Write([]byte(verifier)) 80 return base64.RawURLEncoding.EncodeToString(h.Sum(nil)) 81 } 82 83 - // HandleLogin implements the AuthService interface method. 84 func (o *OAuth2Service) HandleLogin(w http.ResponseWriter, r *http.Request) { 85 opts := []oauth2.AuthCodeOption{ 86 oauth2.SetAuthURLParam("code_challenge", o.codeChallenge), ··· 128 129 userId, hasSession := session.GetUserID(r.Context()) 130 131 - // Use the token receiver to store the token and get the user ID 132 userID, err := o.tokenReceiver.SetAccessToken(token.AccessToken, userId, hasSession) 133 if err != nil { 134 log.Printf("OAuth2 Callback Info: TokenReceiver did not return a valid user ID for token: %s...", token.AccessToken[:min(10, len(token.AccessToken))]) ··· 138 return userID, nil 139 } 140 141 - // GetToken remains unchanged 142 func (o *OAuth2Service) GetToken(code string) (*oauth2.Token, error) { 143 opts := []oauth2.AuthCodeOption{ 144 oauth2.SetAuthURLParam("code_verifier", o.codeVerifier), ··· 146 return o.config.Exchange(context.Background(), code, opts...) 147 } 148 149 - // GetClient remains unchanged 150 func (o *OAuth2Service) GetClient(token *oauth2.Token) *http.Client { 151 return o.config.Client(context.Background(), token) 152 } 153 154 - // RefreshToken remains unchanged 155 func (o *OAuth2Service) RefreshToken(token *oauth2.Token) (*oauth2.Token, error) { 156 source := o.config.TokenSource(context.Background(), token) 157 return oauth2.ReuseTokenSource(token, source).Token() 158 } 159 160 - // Helper function 161 func min(a, b int) int { 162 if a < b { 163 return a
··· 1 package oauth 2 3 import ( ··· 21 state string 22 codeVerifier string 23 codeChallenge string 24 tokenReceiver TokenReceiver 25 } 26 ··· 36 switch strings.ToLower(provider) { 37 case "spotify": 38 endpoint = spotify.Endpoint 39 default: 40 + // placeholder 41 log.Printf("Warning: OAuth2 provider '%s' not explicitly configured. Using placeholder endpoints.", provider) 42 endpoint = oauth2.Endpoint{ 43 + AuthURL: "https://example.com/auth", 44 TokenURL: "https://example.com/token", 45 } 46 } ··· 59 state: GenerateRandomState(), 60 codeVerifier: codeVerifier, 61 codeChallenge: codeChallenge, 62 + tokenReceiver: tokenReceiver, 63 } 64 } 65 66 + // generate a random code verifier, for PKCE 67 func GenerateCodeVerifier() string { 68 b := make([]byte, 64) 69 rand.Read(b) 70 return base64.RawURLEncoding.EncodeToString(b) 71 } 72 73 + // generate a code challenge for verification later 74 func GenerateCodeChallenge(verifier string) string { 75 h := sha256.New() 76 h.Write([]byte(verifier)) 77 return base64.RawURLEncoding.EncodeToString(h.Sum(nil)) 78 } 79 80 func (o *OAuth2Service) HandleLogin(w http.ResponseWriter, r *http.Request) { 81 opts := []oauth2.AuthCodeOption{ 82 oauth2.SetAuthURLParam("code_challenge", o.codeChallenge), ··· 124 125 userId, hasSession := session.GetUserID(r.Context()) 126 127 + // store token and get uid 128 userID, err := o.tokenReceiver.SetAccessToken(token.AccessToken, userId, hasSession) 129 if err != nil { 130 log.Printf("OAuth2 Callback Info: TokenReceiver did not return a valid user ID for token: %s...", token.AccessToken[:min(10, len(token.AccessToken))]) ··· 134 return userID, nil 135 } 136 137 func (o *OAuth2Service) GetToken(code string) (*oauth2.Token, error) { 138 opts := []oauth2.AuthCodeOption{ 139 oauth2.SetAuthURLParam("code_verifier", o.codeVerifier), ··· 141 return o.config.Exchange(context.Background(), code, opts...) 142 } 143 144 func (o *OAuth2Service) GetClient(token *oauth2.Token) *http.Client { 145 return o.config.Client(context.Background(), token) 146 } 147 148 func (o *OAuth2Service) RefreshToken(token *oauth2.Token) (*oauth2.Token, error) { 149 source := o.config.TokenSource(context.Background(), token) 150 return oauth2.ReuseTokenSource(token, source).Token() 151 } 152 153 func min(a, b int) int { 154 if a < b { 155 return a
+9 -14
oauth/oauth_manager.go
··· 12 13 // manages multiple oauth client services 14 type OAuthServiceManager struct { 15 - services map[string]AuthService // Changed from *OAuth2Service to AuthService interface 16 sessionManager *session.SessionManager 17 mu sync.RWMutex 18 } 19 20 func NewOAuthServiceManager() *OAuthServiceManager { 21 return &OAuthServiceManager{ 22 - services: make(map[string]AuthService), // Initialize the new map 23 sessionManager: session.NewSessionManager(), 24 } 25 } 26 27 - // RegisterService registers any service that implements the AuthService interface. 28 func (m *OAuthServiceManager) RegisterService(name string, service AuthService) { 29 m.mu.Lock() 30 defer m.mu.Unlock() ··· 32 log.Printf("Registered auth service: %s", name) 33 } 34 35 - // GetService retrieves a registered AuthService by name. 36 func (m *OAuthServiceManager) GetService(name string) (AuthService, bool) { 37 m.mu.RLock() 38 defer m.mu.RUnlock() ··· 47 m.mu.RUnlock() 48 49 if exists { 50 - service.HandleLogin(w, r) // Call interface method 51 return 52 } 53 ··· 70 return 71 } 72 73 - // Call the service's HandleCallback, which now returns the user ID 74 - userID, err := service.HandleCallback(w, r) // Call interface method 75 76 if err != nil { 77 log.Printf("Error handling callback for service '%s': %v", serviceName, err) ··· 80 } 81 82 if userID > 0 { 83 - // Create session for the user 84 session := m.sessionManager.CreateSession(userID) 85 86 - // Set session cookie 87 m.sessionManager.SetSessionCookie(w, session) 88 89 log.Printf("Created session for user %d via service %s", userID, serviceName) 90 91 - // Redirect to homepage after successful login and session creation 92 http.Redirect(w, r, "/", http.StatusSeeOther) 93 } else { 94 log.Printf("Callback for service '%s' did not result in a valid user ID.", serviceName) 95 - // Optionally redirect to an error page or show an error message 96 - // For now, just redirecting home, but this might hide errors. 97 - // Consider adding error handling based on why userID might be 0. 98 - http.Redirect(w, r, "/", http.StatusSeeOther) // Or redirect to a login/error page 99 } 100 } 101 }
··· 12 13 // manages multiple oauth client services 14 type OAuthServiceManager struct { 15 + services map[string]AuthService 16 sessionManager *session.SessionManager 17 mu sync.RWMutex 18 } 19 20 func NewOAuthServiceManager() *OAuthServiceManager { 21 return &OAuthServiceManager{ 22 + services: make(map[string]AuthService), 23 sessionManager: session.NewSessionManager(), 24 } 25 } 26 27 + // registers any service that impls AuthService 28 func (m *OAuthServiceManager) RegisterService(name string, service AuthService) { 29 m.mu.Lock() 30 defer m.mu.Unlock() ··· 32 log.Printf("Registered auth service: %s", name) 33 } 34 35 + // get an AuthService by registered name 36 func (m *OAuthServiceManager) GetService(name string) (AuthService, bool) { 37 m.mu.RLock() 38 defer m.mu.RUnlock() ··· 47 m.mu.RUnlock() 48 49 if exists { 50 + service.HandleLogin(w, r) 51 return 52 } 53 ··· 70 return 71 } 72 73 + userID, err := service.HandleCallback(w, r) 74 75 if err != nil { 76 log.Printf("Error handling callback for service '%s': %v", serviceName, err) ··· 79 } 80 81 if userID > 0 { 82 session := m.sessionManager.CreateSession(userID) 83 84 m.sessionManager.SetSessionCookie(w, session) 85 86 log.Printf("Created session for user %d via service %s", userID, serviceName) 87 88 http.Redirect(w, r, "/", http.StatusSeeOther) 89 } else { 90 log.Printf("Callback for service '%s' did not result in a valid user ID.", serviceName) 91 + // todo: redirect to an error page 92 + // right now this just redirects home but we don't want this behaviour ideally 93 + http.Redirect(w, r, "/", http.StatusSeeOther) 94 } 95 } 96 }
+6 -10
oauth/service.go
··· 1 - // Create piper/oauth/auth_service.go 2 package oauth 3 4 import ( 5 "net/http" 6 ) 7 8 - // AuthService defines the interface for different authentication services 9 - // that can be managed by the OAuthServiceManager. 10 type AuthService interface { 11 - // HandleLogin initiates the login flow for the specific service. 12 HandleLogin(w http.ResponseWriter, r *http.Request) 13 - // HandleCallback handles the callback from the authentication provider, 14 - // processes the response (e.g., exchanges code for token), finds or creates 15 - // the user in the local system, and returns the user ID. 16 - // Returns 0 if authentication failed or user could not be determined. 17 HandleCallback(w http.ResponseWriter, r *http.Request) (int64, error) 18 } 19 20 type TokenReceiver interface { 21 - // SetAccessToken stores the access token for the user and returns the user ID. 22 - // If the user is already logged in, the current ID is provided. 23 SetAccessToken(token string, currentId int64, hasSession bool) (int64, error) 24 }
··· 1 package oauth 2 3 import ( 4 "net/http" 5 ) 6 7 type AuthService interface { 8 + // inits the login flow for the service 9 HandleLogin(w http.ResponseWriter, r *http.Request) 10 + // handles the callback for the provider. is responsible for inserting 11 + // sessions in the db 12 HandleCallback(w http.ResponseWriter, r *http.Request) (int64, error) 13 } 14 15 + // optional but recommended 16 type TokenReceiver interface { 17 + // stores the access token in the db 18 + // if there is a session, will associate the token with the session 19 SetAccessToken(token string, currentId int64, hasSession bool) (int64, error) 20 }
+24 -39
service/spotify/spotify.go
··· 2 3 import ( 4 "encoding/json" 5 "fmt" 6 "io" 7 "log" ··· 31 } 32 33 func (s *SpotifyService) SetAccessToken(token string, userId int64, hasSession bool) (int64, error) { 34 - // Identify the user synchronously instead of in a goroutine 35 userID, err := s.identifyAndStoreUser(token, userId, hasSession) 36 if err != nil { 37 log.Printf("Error identifying and storing user: %v", err) ··· 41 } 42 43 func (s *SpotifyService) identifyAndStoreUser(token string, userId int64, hasSession bool) (int64, error) { 44 - // Get Spotify user profile 45 userProfile, err := s.fetchSpotifyProfile(token) 46 if err != nil { 47 log.Printf("Error fetching Spotify profile: %v", err) ··· 50 51 fmt.Printf("uid: %d hasSession: %t", userId, hasSession) 52 53 - // Check if user exists 54 user, err := s.DB.GetUserBySpotifyID(userProfile.ID) 55 if err != nil { 56 // This error might mean DB connection issue, not just user not found. ··· 74 } 75 } 76 } else { 77 - // Update existing user's token and expiry 78 err = s.DB.UpdateUserToken(user.ID, token, "", tokenExpiryTime) 79 if err != nil { 80 log.Printf("Error updating user token for user ID %d: %v", user.ID, err) 81 - // Consider if we should return 0 or the user ID even if update fails 82 - // Sticking to original behavior: log and continue 83 } else { 84 log.Printf("Updated token for existing user: %s (ID: %d)", user.Username, user.ID) 85 } 86 } 87 - // Keep the local 'user' object consistent (optional but good practice) 88 user.AccessToken = &token 89 user.TokenExpiry = &tokenExpiryTime 90 91 - // Store token in memory cache regardless of new/existing user 92 s.mu.Lock() 93 s.userTokens[user.ID] = token 94 s.mu.Unlock() ··· 103 Email string `json:"email"` 104 } 105 106 - // LoadAllUsers loads all active users from the database into memory 107 func (s *SpotifyService) LoadAllUsers() error { 108 users, err := s.DB.GetAllActiveUsers() 109 if err != nil { ··· 115 116 count := 0 117 for _, user := range users { 118 - // Only load users with valid tokens 119 if user.AccessToken != nil && user.TokenExpiry.After(time.Now()) { 120 s.userTokens[user.ID] = *user.AccessToken 121 count++ ··· 126 return nil 127 } 128 129 func (s *SpotifyService) RefreshToken(userID string) error { 130 s.mu.Lock() 131 defer s.mu.Unlock() ··· 139 return fmt.Errorf("no refresh token for user %s", userID) 140 } 141 142 - // Implement token refresh logic here using Spotify's token refresh endpoint 143 - // This would make a request to Spotify's token endpoint with grant_type=refresh_token 144 - 145 - // If successful, update the database and in-memory cache 146 - // we won't be now so just error out 147 - return fmt.Errorf("token refresh not implemented") 148 - // 149 - //s.userTokens[user.ID] = newToken 150 - //return nil 151 } 152 153 - // RefreshExpiredTokens attempts to refresh expired tokens 154 func (s *SpotifyService) RefreshExpiredTokens() { 155 users, err := s.DB.GetUsersWithExpiredTokens() 156 if err != nil { ··· 160 161 refreshed := 0 162 for _, user := range users { 163 - // Skip users without refresh tokens 164 if user.RefreshToken == nil { 165 continue 166 } 167 168 - // Implement token refresh logic here using Spotify's token refresh endpoint 169 - // This would make a request to Spotify's token endpoint with grant_type=refresh_token 170 171 - // If successful, update the database and in-memory cache 172 refreshed++ 173 } 174 ··· 231 return 232 } 233 234 - // Get recent tracks from database 235 tracks, err := s.DB.GetRecentTracks(userID, 20) 236 if err != nil { 237 http.Error(w, "Error retrieving track history", http.StatusInternalServerError) ··· 252 return nil, fmt.Errorf("no access token for user %d", userID) 253 } 254 255 - // Call Spotify API to get currently playing track 256 req, err := http.NewRequest("GET", "https://api.spotify.com/v1/me/player/currently-playing", nil) 257 if err != nil { 258 return nil, err ··· 266 } 267 defer resp.Body.Close() 268 269 - // No track playing 270 if resp.StatusCode == 204 { 271 return nil, nil 272 } 273 274 - // Token expired 275 if resp.StatusCode == 401 { 276 // attempt to refresh token 277 if err := s.RefreshToken(strconv.FormatInt(userID, 10)); err != nil { ··· 282 } 283 } 284 285 - // Error response 286 if resp.StatusCode != 200 { 287 body, _ := io.ReadAll(resp.Body) 288 return nil, fmt.Errorf("spotify API error: %s", body) 289 } 290 291 - // Parse response 292 var response struct { 293 Item struct { 294 Name string `json:"name"` ··· 320 return nil, err 321 } 322 323 - // Extract artist names/ids 324 var artists []models.Artist 325 for _, artist := range response.Item.Artists { 326 artists = append(artists, models.Artist{ ··· 329 }) 330 } 331 332 - // Create Track model 333 track := &models.Track{ 334 Name: response.Item.Name, 335 Artist: artists, ··· 351 defer ticker.Stop() 352 353 for range ticker.C { 354 - // Copy userIDs to avoid holding the lock too long 355 s.mu.RLock() 356 userIDs := make([]int64, 0, len(s.userTokens)) 357 for userID := range s.userTokens { ··· 359 } 360 s.mu.RUnlock() 361 362 - // Check each user's currently playing track 363 for _, userID := range userIDs { 364 track, err := s.FetchCurrentTrack(userID) 365 if err != nil { ··· 367 continue 368 } 369 370 - // No change if no track is playing 371 if track == nil { 372 continue 373 } 374 375 - // Check if this is a new track 376 s.mu.RLock() 377 currentTrack := s.userTracks[userID] 378 s.mu.RUnlock() ··· 384 } 385 } 386 387 - // If track is different or we've played more than either half of the track or 30 seconds since the start 388 - // whichever is greater 389 isNewTrack := currentTrack == nil || 390 currentTrack.Name != track.Name || 391 // just check the first one for now ··· 426 } 427 428 if isNewTrack { 429 - // Save to database 430 id, err := s.DB.SaveTrack(userID, track) 431 if err != nil { 432 log.Printf("Error saving track for user %d: %v", userID, err)
··· 2 3 import ( 4 "encoding/json" 5 + "errors" 6 "fmt" 7 "io" 8 "log" ··· 32 } 33 34 func (s *SpotifyService) SetAccessToken(token string, userId int64, hasSession bool) (int64, error) { 35 userID, err := s.identifyAndStoreUser(token, userId, hasSession) 36 if err != nil { 37 log.Printf("Error identifying and storing user: %v", err) ··· 41 } 42 43 func (s *SpotifyService) identifyAndStoreUser(token string, userId int64, hasSession bool) (int64, error) { 44 userProfile, err := s.fetchSpotifyProfile(token) 45 if err != nil { 46 log.Printf("Error fetching Spotify profile: %v", err) ··· 49 50 fmt.Printf("uid: %d hasSession: %t", userId, hasSession) 51 52 user, err := s.DB.GetUserBySpotifyID(userProfile.ID) 53 if err != nil { 54 // This error might mean DB connection issue, not just user not found. ··· 72 } 73 } 74 } else { 75 err = s.DB.UpdateUserToken(user.ID, token, "", tokenExpiryTime) 76 if err != nil { 77 + // for now log and continue 78 log.Printf("Error updating user token for user ID %d: %v", user.ID, err) 79 } else { 80 log.Printf("Updated token for existing user: %s (ID: %d)", user.Username, user.ID) 81 } 82 } 83 user.AccessToken = &token 84 user.TokenExpiry = &tokenExpiryTime 85 86 s.mu.Lock() 87 s.userTokens[user.ID] = token 88 s.mu.Unlock() ··· 97 Email string `json:"email"` 98 } 99 100 func (s *SpotifyService) LoadAllUsers() error { 101 users, err := s.DB.GetAllActiveUsers() 102 if err != nil { ··· 108 109 count := 0 110 for _, user := range users { 111 + // load users with valid tokens 112 if user.AccessToken != nil && user.TokenExpiry.After(time.Now()) { 113 s.userTokens[user.ID] = *user.AccessToken 114 count++ ··· 119 return nil 120 } 121 122 + func (s *SpotifyService) refreshTokenInner(user models.User) error { 123 + // implement token refresh logic here using Spotify's token refresh endpoint 124 + // this would make a request to Spotify's token endpoint with grant_type=refresh_token 125 + return errors.New("Not implemented yet") 126 + // if successful, update the database and in-memory cache 127 + } 128 + 129 func (s *SpotifyService) RefreshToken(userID string) error { 130 s.mu.Lock() 131 defer s.mu.Unlock() ··· 139 return fmt.Errorf("no refresh token for user %s", userID) 140 } 141 142 + return s.refreshTokenInner(*user) 143 } 144 145 + // attempt to refresh expired tokens 146 func (s *SpotifyService) RefreshExpiredTokens() { 147 users, err := s.DB.GetUsersWithExpiredTokens() 148 if err != nil { ··· 152 153 refreshed := 0 154 for _, user := range users { 155 + // skip users without refresh tokens 156 if user.RefreshToken == nil { 157 continue 158 } 159 160 + err := s.refreshTokenInner(*user) 161 + 162 + if err != nil { 163 + // just print out errors here for now 164 + log.Printf("Error from service/spotify/spotify.go when refreshing tokens: %s", err.Error()) 165 + } 166 167 refreshed++ 168 } 169 ··· 226 return 227 } 228 229 tracks, err := s.DB.GetRecentTracks(userID, 20) 230 if err != nil { 231 http.Error(w, "Error retrieving track history", http.StatusInternalServerError) ··· 246 return nil, fmt.Errorf("no access token for user %d", userID) 247 } 248 249 req, err := http.NewRequest("GET", "https://api.spotify.com/v1/me/player/currently-playing", nil) 250 if err != nil { 251 return nil, err ··· 259 } 260 defer resp.Body.Close() 261 262 + // nothing playing 263 if resp.StatusCode == 204 { 264 return nil, nil 265 } 266 267 + // oops, token expired 268 if resp.StatusCode == 401 { 269 // attempt to refresh token 270 if err := s.RefreshToken(strconv.FormatInt(userID, 10)); err != nil { ··· 275 } 276 } 277 278 if resp.StatusCode != 200 { 279 body, _ := io.ReadAll(resp.Body) 280 return nil, fmt.Errorf("spotify API error: %s", body) 281 } 282 283 var response struct { 284 Item struct { 285 Name string `json:"name"` ··· 311 return nil, err 312 } 313 314 var artists []models.Artist 315 for _, artist := range response.Item.Artists { 316 artists = append(artists, models.Artist{ ··· 319 }) 320 } 321 322 + // assemble Track 323 track := &models.Track{ 324 Name: response.Item.Name, 325 Artist: artists, ··· 341 defer ticker.Stop() 342 343 for range ticker.C { 344 + // copy userIDs to avoid holding the lock too long 345 s.mu.RLock() 346 userIDs := make([]int64, 0, len(s.userTokens)) 347 for userID := range s.userTokens { ··· 349 } 350 s.mu.RUnlock() 351 352 for _, userID := range userIDs { 353 track, err := s.FetchCurrentTrack(userID) 354 if err != nil { ··· 356 continue 357 } 358 359 if track == nil { 360 continue 361 } 362 363 s.mu.RLock() 364 currentTrack := s.userTracks[userID] 365 s.mu.RUnlock() ··· 371 } 372 } 373 374 + // if flagged true, we have a new track 375 isNewTrack := currentTrack == nil || 376 currentTrack.Name != track.Name || 377 // just check the first one for now ··· 412 } 413 414 if isNewTrack { 415 id, err := s.DB.SaveTrack(userID, track) 416 if err != nil { 417 log.Printf("Error saving track for user %d: %v", userID, err)
+10 -33
session/session.go
··· 31 mu sync.RWMutex 32 } 33 34 - // NewSessionManager creates a new session manager 35 func NewSessionManager() *SessionManager { 36 - // Initialize session table if it doesn't exist 37 database, err := db.New("./data/piper.db") 38 if err != nil { 39 log.Printf("Error connecting to database for sessions, falling back to in memory only: %v", err) ··· 56 log.Printf("Error creating sessions table: %v", err) 57 } 58 59 - // Create API key manager 60 apiKeyMgr := apikey.NewApiKeyManager(database) 61 62 return &SessionManager{ ··· 120 return session, true 121 } 122 123 - // If not in memory and we have a database, check there 124 if sm.db != nil { 125 session = &Session{ID: sessionID} 126 ··· 189 http.SetCookie(w, cookie) 190 } 191 192 - // HandleLogout handles user logout 193 func (sm *SessionManager) HandleLogout(w http.ResponseWriter, r *http.Request) { 194 cookie, err := r.Cookie("session") 195 if err == nil { ··· 201 http.Redirect(w, r, "/", http.StatusSeeOther) 202 } 203 204 - // GetAPIKeyManager returns the API key manager 205 func (sm *SessionManager) GetAPIKeyManager() *apikey.ApiKeyManager { 206 return sm.apiKeyMgr 207 } 208 209 - // CreateAPIKey creates a new API key for a user 210 func (sm *SessionManager) CreateAPIKey(userID int64, name string, validityDays int) (*apikey.ApiKey, error) { 211 return sm.apiKeyMgr.CreateApiKey(userID, name, validityDays) 212 } 213 214 - // WithAuth is a middleware that checks if a user is authenticated via cookies or API key 215 func WithAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc { 216 return func(w http.ResponseWriter, r *http.Request) { 217 - // First try API key authentication (for API requests) 218 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r) 219 if apiKeyErr == nil && apiKeyStr != "" { 220 - // Validate API key 221 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr) 222 if valid { 223 - // Add user ID to context 224 ctx := WithUserID(r.Context(), apiKey.UserID) 225 r = r.WithContext(ctx) 226 227 - // Set a flag in the context that this is an API request 228 ctx = WithAPIRequest(r.Context(), true) 229 r = r.WithContext(ctx) 230 ··· 233 } 234 } 235 236 - // Fall back to cookie authentication (for browser requests) 237 cookie, err := r.Cookie("session") 238 if err != nil { 239 http.Redirect(w, r, "/login/spotify", http.StatusSeeOther) 240 return 241 } 242 243 - // Verify cookie session 244 session, exists := sm.GetSession(cookie.Value) 245 if !exists { 246 http.Redirect(w, r, "/login/spotify", http.StatusSeeOther) 247 return 248 } 249 250 - // Add session information to request context 251 ctx := WithUserID(r.Context(), session.UserID) 252 r = r.WithContext(ctx) 253 ··· 255 } 256 } 257 258 func WithPossibleAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc { 259 return func(w http.ResponseWriter, r *http.Request) { 260 ctx := r.Context() 261 - authenticated := false // Default to not authenticated 262 263 - // 1. Try API key authentication 264 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r) 265 if apiKeyErr == nil && apiKeyStr != "" { 266 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr) 267 if valid { 268 - // API Key valid: Add UserID, API flag, and set auth status 269 ctx = WithUserID(ctx, apiKey.UserID) 270 ctx = WithAPIRequest(ctx, true) 271 authenticated = true 272 - // Update request context and call handler 273 r = r.WithContext(WithAuthStatus(ctx, authenticated)) 274 handler(w, r) 275 return 276 } 277 - // If API key was provided but invalid, we still proceed without auth 278 } 279 280 - // 2. If no valid API key, try cookie authentication 281 - if !authenticated { // Only check cookies if API key didn't authenticate 282 cookie, err := r.Cookie("session") 283 - if err == nil { // Cookie exists 284 session, exists := sm.GetSession(cookie.Value) 285 if exists { 286 - // Session valid: Add UserID and set auth status 287 ctx = WithUserID(ctx, session.UserID) 288 - // ctx = WithAPIRequest(ctx, false) // Not strictly needed, default is false 289 authenticated = true 290 } 291 - // If session cookie exists but is invalid/expired, we proceed without auth 292 } 293 } 294 295 - // 3. Set final auth status (could be true or false) and call handler 296 r = r.WithContext(WithAuthStatus(ctx, authenticated)) 297 handler(w, r) 298 } 299 } 300 301 - // WithAPIAuth is a middleware specifically for API-only endpoints (no cookie fallback, returns 401 instead of redirect) 302 func WithAPIAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc { 303 return func(w http.ResponseWriter, r *http.Request) { 304 - // Try API key authentication 305 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r) 306 if apiKeyErr != nil || apiKeyStr == "" { 307 w.Header().Set("Content-Type", "application/json") ··· 310 return 311 } 312 313 - // Validate API key 314 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr) 315 if !valid { 316 w.Header().Set("Content-Type", "application/json") ··· 319 return 320 } 321 322 - // Add user ID to context 323 ctx := WithUserID(r.Context(), apiKey.UserID) 324 - // Mark as API request 325 ctx = WithAPIRequest(ctx, true) 326 r = r.WithContext(ctx) 327 ··· 329 } 330 } 331 332 - // Context keys 333 type contextKey int 334 335 const (
··· 31 mu sync.RWMutex 32 } 33 34 func NewSessionManager() *SessionManager { 35 database, err := db.New("./data/piper.db") 36 if err != nil { 37 log.Printf("Error connecting to database for sessions, falling back to in memory only: %v", err) ··· 54 log.Printf("Error creating sessions table: %v", err) 55 } 56 57 apiKeyMgr := apikey.NewApiKeyManager(database) 58 59 return &SessionManager{ ··· 117 return session, true 118 } 119 120 + // if not in memory and we have a database, check there 121 if sm.db != nil { 122 session = &Session{ID: sessionID} 123 ··· 186 http.SetCookie(w, cookie) 187 } 188 189 func (sm *SessionManager) HandleLogout(w http.ResponseWriter, r *http.Request) { 190 cookie, err := r.Cookie("session") 191 if err == nil { ··· 197 http.Redirect(w, r, "/", http.StatusSeeOther) 198 } 199 200 func (sm *SessionManager) GetAPIKeyManager() *apikey.ApiKeyManager { 201 return sm.apiKeyMgr 202 } 203 204 func (sm *SessionManager) CreateAPIKey(userID int64, name string, validityDays int) (*apikey.ApiKey, error) { 205 return sm.apiKeyMgr.CreateApiKey(userID, name, validityDays) 206 } 207 208 + // middleware that checks if a user is authenticated via cookies or API key 209 func WithAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc { 210 return func(w http.ResponseWriter, r *http.Request) { 211 + // first: check API keys 212 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r) 213 if apiKeyErr == nil && apiKeyStr != "" { 214 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr) 215 if valid { 216 ctx := WithUserID(r.Context(), apiKey.UserID) 217 r = r.WithContext(ctx) 218 219 + // set a flag for api requests 220 ctx = WithAPIRequest(r.Context(), true) 221 r = r.WithContext(ctx) 222 ··· 225 } 226 } 227 228 + // if not found, check cookies for session value 229 cookie, err := r.Cookie("session") 230 if err != nil { 231 http.Redirect(w, r, "/login/spotify", http.StatusSeeOther) 232 return 233 } 234 235 session, exists := sm.GetSession(cookie.Value) 236 if !exists { 237 http.Redirect(w, r, "/login/spotify", http.StatusSeeOther) 238 return 239 } 240 241 ctx := WithUserID(r.Context(), session.UserID) 242 r = r.WithContext(ctx) 243 ··· 245 } 246 } 247 248 + // middleware that checks if a user is authenticated but doesn't error out if not 249 func WithPossibleAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc { 250 return func(w http.ResponseWriter, r *http.Request) { 251 ctx := r.Context() 252 + authenticated := false 253 254 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r) 255 if apiKeyErr == nil && apiKeyStr != "" { 256 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr) 257 if valid { 258 ctx = WithUserID(ctx, apiKey.UserID) 259 ctx = WithAPIRequest(ctx, true) 260 authenticated = true 261 r = r.WithContext(WithAuthStatus(ctx, authenticated)) 262 handler(w, r) 263 return 264 } 265 } 266 267 + if !authenticated { 268 cookie, err := r.Cookie("session") 269 + if err == nil { 270 session, exists := sm.GetSession(cookie.Value) 271 if exists { 272 ctx = WithUserID(ctx, session.UserID) 273 authenticated = true 274 } 275 } 276 } 277 278 r = r.WithContext(WithAuthStatus(ctx, authenticated)) 279 handler(w, r) 280 } 281 } 282 283 + // middleware that only accepts API keys 284 func WithAPIAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc { 285 return func(w http.ResponseWriter, r *http.Request) { 286 apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r) 287 if apiKeyErr != nil || apiKeyStr == "" { 288 w.Header().Set("Content-Type", "application/json") ··· 291 return 292 } 293 294 apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr) 295 if !valid { 296 w.Header().Set("Content-Type", "application/json") ··· 299 return 300 } 301 302 ctx := WithUserID(r.Context(), apiKey.UserID) 303 ctx = WithAPIRequest(ctx, true) 304 r = r.WithContext(ctx) 305 ··· 307 } 308 } 309 310 type contextKey int 311 312 const (