this repo has no description
1package server 2 3import ( 4 "crypto/sha256" 5 "encoding/base64" 6 "errors" 7 "fmt" 8 "strings" 9 "time" 10 11 "github.com/Azure/go-autorest/autorest/to" 12 "github.com/golang-jwt/jwt/v4" 13 "github.com/haileyok/cocoon/internal/helpers" 14 "github.com/haileyok/cocoon/models" 15 "github.com/haileyok/cocoon/oauth/dpop" 16 "github.com/haileyok/cocoon/oauth/provider" 17 "github.com/labstack/echo/v4" 18 "gitlab.com/yawning/secp256k1-voi" 19 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 20 "gorm.io/gorm" 21) 22 23func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 24 return func(e echo.Context) error { 25 username, password, ok := e.Request().BasicAuth() 26 if !ok || username != "admin" || password != s.config.AdminPassword { 27 return helpers.InputError(e, to.StringPtr("Unauthorized")) 28 } 29 30 if err := next(e); err != nil { 31 e.Error(err) 32 } 33 34 return nil 35 } 36} 37 38func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 39 return func(e echo.Context) error { 40 ctx := e.Request().Context() 41 42 authheader := e.Request().Header.Get("authorization") 43 if authheader == "" { 44 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 45 } 46 47 pts := strings.Split(authheader, " ") 48 if len(pts) != 2 { 49 return helpers.ServerError(e, nil) 50 } 51 52 // move on to oauth session middleware if this is a dpop token 53 if pts[0] == "DPoP" { 54 return next(e) 55 } 56 57 tokenstr := pts[1] 58 token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{}) 59 claims, ok := token.Claims.(jwt.MapClaims) 60 if !ok { 61 return helpers.InvalidTokenError(e) 62 } 63 64 var did string 65 var repo *models.RepoActor 66 67 // service auth tokens 68 lxm, hasLxm := claims["lxm"] 69 if hasLxm { 70 pts := strings.Split(e.Request().URL.String(), "/") 71 if lxm != pts[len(pts)-1] { 72 s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err) 73 return helpers.InputError(e, nil) 74 } 75 76 maybeDid, ok := claims["iss"].(string) 77 if !ok { 78 s.logger.Error("no iss in service auth token", "error", err) 79 return helpers.InputError(e, nil) 80 } 81 did = maybeDid 82 83 maybeRepo, err := s.getRepoActorByDid(ctx, did) 84 if err != nil { 85 s.logger.Error("error fetching repo", "error", err) 86 return helpers.ServerError(e, nil) 87 } 88 repo = maybeRepo 89 } 90 91 if token.Header["alg"] != "ES256K" { 92 token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 93 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 94 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 95 } 96 return s.privateKey.Public(), nil 97 }) 98 if err != nil { 99 s.logger.Error("error parsing jwt", "error", err) 100 return helpers.ExpiredTokenError(e) 101 } 102 103 if !token.Valid { 104 return helpers.InvalidTokenError(e) 105 } 106 } else { 107 kpts := strings.Split(tokenstr, ".") 108 signingInput := kpts[0] + "." + kpts[1] 109 hash := sha256.Sum256([]byte(signingInput)) 110 sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2]) 111 if err != nil { 112 s.logger.Error("error decoding signature bytes", "error", err) 113 return helpers.ServerError(e, nil) 114 } 115 116 if len(sigBytes) != 64 { 117 s.logger.Error("incorrect sigbytes length", "length", len(sigBytes)) 118 return helpers.ServerError(e, nil) 119 } 120 121 rBytes := sigBytes[:32] 122 sBytes := sigBytes[32:] 123 rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes)) 124 ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes)) 125 126 if repo == nil { 127 sub, ok := claims["sub"].(string) 128 if !ok { 129 s.logger.Error("no sub claim in ES256K token and repo not set") 130 return helpers.InvalidTokenError(e) 131 } 132 maybeRepo, err := s.getRepoActorByDid(ctx, sub) 133 if err != nil { 134 s.logger.Error("error fetching repo for ES256K verification", "error", err) 135 return helpers.ServerError(e, nil) 136 } 137 repo = maybeRepo 138 did = sub 139 } 140 141 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 142 if err != nil { 143 s.logger.Error("can't load private key", "error", err) 144 return err 145 } 146 147 pubKey, ok := sk.Public().(*secp256k1secec.PublicKey) 148 if !ok { 149 s.logger.Error("error getting public key from sk") 150 return helpers.ServerError(e, nil) 151 } 152 153 verified := pubKey.VerifyRaw(hash[:], rr, ss) 154 if !verified { 155 s.logger.Error("error verifying", "error", err) 156 return helpers.ServerError(e, nil) 157 } 158 } 159 160 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 161 scope, _ := claims["scope"].(string) 162 163 if isRefresh && scope != "com.atproto.refresh" { 164 return helpers.InvalidTokenError(e) 165 } else if !hasLxm && !isRefresh && scope != "com.atproto.access" { 166 return helpers.InvalidTokenError(e) 167 } 168 169 table := "tokens" 170 if isRefresh { 171 table = "refresh_tokens" 172 } 173 174 if isRefresh { 175 type Result struct { 176 Found bool 177 } 178 var result Result 179 if err := s.db.Raw(ctx, "SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 180 if err == gorm.ErrRecordNotFound { 181 return helpers.InvalidTokenError(e) 182 } 183 184 s.logger.Error("error getting token from db", "error", err) 185 return helpers.ServerError(e, nil) 186 } 187 188 if !result.Found { 189 return helpers.InvalidTokenError(e) 190 } 191 } 192 193 exp, ok := claims["exp"].(float64) 194 if !ok { 195 s.logger.Error("error getting iat from token") 196 return helpers.ServerError(e, nil) 197 } 198 199 if exp < float64(time.Now().UTC().Unix()) { 200 return helpers.ExpiredTokenError(e) 201 } 202 203 if repo == nil { 204 maybeRepo, err := s.getRepoActorByDid(ctx, claims["sub"].(string)) 205 if err != nil { 206 s.logger.Error("error fetching repo", "error", err) 207 return helpers.ServerError(e, nil) 208 } 209 repo = maybeRepo 210 did = repo.Repo.Did 211 } 212 213 e.Set("repo", repo) 214 e.Set("did", did) 215 e.Set("token", tokenstr) 216 217 if err := next(e); err != nil { 218 return helpers.InvalidTokenError(e) 219 } 220 221 return nil 222 } 223} 224 225func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 226 return func(e echo.Context) error { 227 ctx := e.Request().Context() 228 229 authheader := e.Request().Header.Get("authorization") 230 if authheader == "" { 231 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 232 } 233 234 pts := strings.Split(authheader, " ") 235 if len(pts) != 2 { 236 return helpers.ServerError(e, nil) 237 } 238 239 if pts[0] != "DPoP" { 240 return next(e) 241 } 242 243 accessToken := pts[1] 244 245 nonce := s.oauthProvider.NextNonce() 246 if nonce != "" { 247 e.Response().Header().Set("DPoP-Nonce", nonce) 248 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 249 } 250 251 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken)) 252 if err != nil { 253 if errors.Is(err, dpop.ErrUseDpopNonce) { 254 e.Response().Header().Set("WWW-Authenticate", `DPoP error="use_dpop_nonce"`) 255 e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate") 256 return e.JSON(401, map[string]string{ 257 "error": "use_dpop_nonce", 258 }) 259 } 260 s.logger.Error("invalid dpop proof", "error", err) 261 return helpers.InputError(e, nil) 262 } 263 264 var oauthToken provider.OauthToken 265 if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 266 s.logger.Error("error finding access token in db", "error", err) 267 return helpers.InputError(e, nil) 268 } 269 270 if oauthToken.Token == "" { 271 return helpers.InvalidTokenError(e) 272 } 273 274 if *oauthToken.Parameters.DpopJkt != proof.JKT { 275 s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT) 276 return helpers.InputError(e, to.StringPtr("dpop jkt mismatch")) 277 } 278 279 if time.Now().After(oauthToken.ExpiresAt) { 280 e.Response().Header().Set("WWW-Authenticate", `DPoP error="invalid_token", error_description="Token expired"`) 281 e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate") 282 return e.JSON(401, map[string]string{ 283 "error": "invalid_token", 284 "error_description": "Token expired", 285 }) 286 } 287 288 repo, err := s.getRepoActorByDid(ctx, oauthToken.Sub) 289 if err != nil { 290 s.logger.Error("could not find actor in db", "error", err) 291 return helpers.ServerError(e, nil) 292 } 293 294 e.Set("repo", repo) 295 e.Set("did", repo.Repo.Did) 296 e.Set("token", accessToken) 297 e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " ")) 298 299 return next(e) 300 } 301}