this repo has no description
1package server 2 3import ( 4 "crypto/sha256" 5 "encoding/base64" 6 "errors" 7 "fmt" 8 "strings" 9 "time" 10 11 "github.com/Azure/go-autorest/autorest/to" 12 "github.com/golang-jwt/jwt/v4" 13 "github.com/haileyok/cocoon/internal/helpers" 14 "github.com/haileyok/cocoon/models" 15 "github.com/haileyok/cocoon/oauth/dpop" 16 "github.com/haileyok/cocoon/oauth/provider" 17 "github.com/labstack/echo/v4" 18 "gitlab.com/yawning/secp256k1-voi" 19 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 20 "gorm.io/gorm" 21) 22 23func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 24 return func(e echo.Context) error { 25 username, password, ok := e.Request().BasicAuth() 26 if !ok || username != "admin" || password != s.config.AdminPassword { 27 return helpers.InputError(e, to.StringPtr("Unauthorized")) 28 } 29 30 if err := next(e); err != nil { 31 e.Error(err) 32 } 33 34 return nil 35 } 36} 37 38func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 39 return func(e echo.Context) error { 40 ctx := e.Request().Context() 41 42 authheader := e.Request().Header.Get("authorization") 43 if authheader == "" { 44 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 45 } 46 47 pts := strings.Split(authheader, " ") 48 if len(pts) != 2 { 49 return helpers.ServerError(e, nil) 50 } 51 52 // move on to oauth session middleware if this is a dpop token 53 if pts[0] == "DPoP" { 54 return next(e) 55 } 56 57 tokenstr := pts[1] 58 token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{}) 59 claims, ok := token.Claims.(jwt.MapClaims) 60 if !ok { 61 return helpers.InvalidTokenError(e) 62 } 63 64 var did string 65 var repo *models.RepoActor 66 67 // service auth tokens 68 lxm, hasLxm := claims["lxm"] 69 if hasLxm { 70 pts := strings.Split(e.Request().URL.String(), "/") 71 if lxm != pts[len(pts)-1] { 72 s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err) 73 return helpers.InputError(e, nil) 74 } 75 76 maybeDid, ok := claims["iss"].(string) 77 if !ok { 78 s.logger.Error("no iss in service auth token", "error", err) 79 return helpers.InputError(e, nil) 80 } 81 did = maybeDid 82 83 maybeRepo, err := s.getRepoActorByDid(ctx, did) 84 if err != nil { 85 s.logger.Error("error fetching repo", "error", err) 86 return helpers.ServerError(e, nil) 87 } 88 repo = maybeRepo 89 } 90 91 if token.Header["alg"] != "ES256K" { 92 token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 93 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 94 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 95 } 96 return s.privateKey.Public(), nil 97 }) 98 if err != nil { 99 s.logger.Error("error parsing jwt", "error", err) 100 return helpers.ExpiredTokenError(e) 101 } 102 103 if !token.Valid { 104 return helpers.InvalidTokenError(e) 105 } 106 } else { 107 kpts := strings.Split(tokenstr, ".") 108 signingInput := kpts[0] + "." + kpts[1] 109 hash := sha256.Sum256([]byte(signingInput)) 110 sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2]) 111 if err != nil { 112 s.logger.Error("error decoding signature bytes", "error", err) 113 return helpers.ServerError(e, nil) 114 } 115 116 if len(sigBytes) != 64 { 117 s.logger.Error("incorrect sigbytes length", "length", len(sigBytes)) 118 return helpers.ServerError(e, nil) 119 } 120 121 rBytes := sigBytes[:32] 122 sBytes := sigBytes[32:] 123 rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes)) 124 ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes)) 125 126 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 127 if err != nil { 128 s.logger.Error("can't load private key", "error", err) 129 return err 130 } 131 132 pubKey, ok := sk.Public().(*secp256k1secec.PublicKey) 133 if !ok { 134 s.logger.Error("error getting public key from sk") 135 return helpers.ServerError(e, nil) 136 } 137 138 verified := pubKey.VerifyRaw(hash[:], rr, ss) 139 if !verified { 140 s.logger.Error("error verifying", "error", err) 141 return helpers.ServerError(e, nil) 142 } 143 } 144 145 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 146 scope, _ := claims["scope"].(string) 147 148 if isRefresh && scope != "com.atproto.refresh" { 149 return helpers.InvalidTokenError(e) 150 } else if !hasLxm && !isRefresh && scope != "com.atproto.access" { 151 return helpers.InvalidTokenError(e) 152 } 153 154 table := "tokens" 155 if isRefresh { 156 table = "refresh_tokens" 157 } 158 159 if isRefresh { 160 type Result struct { 161 Found bool 162 } 163 var result Result 164 if err := s.db.Raw(ctx, "SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 165 if err == gorm.ErrRecordNotFound { 166 return helpers.InvalidTokenError(e) 167 } 168 169 s.logger.Error("error getting token from db", "error", err) 170 return helpers.ServerError(e, nil) 171 } 172 173 if !result.Found { 174 return helpers.InvalidTokenError(e) 175 } 176 } 177 178 exp, ok := claims["exp"].(float64) 179 if !ok { 180 s.logger.Error("error getting iat from token") 181 return helpers.ServerError(e, nil) 182 } 183 184 if exp < float64(time.Now().UTC().Unix()) { 185 return helpers.ExpiredTokenError(e) 186 } 187 188 if repo == nil { 189 maybeRepo, err := s.getRepoActorByDid(ctx, claims["sub"].(string)) 190 if err != nil { 191 s.logger.Error("error fetching repo", "error", err) 192 return helpers.ServerError(e, nil) 193 } 194 repo = maybeRepo 195 did = repo.Repo.Did 196 } 197 198 e.Set("repo", repo) 199 e.Set("did", did) 200 e.Set("token", tokenstr) 201 202 if err := next(e); err != nil { 203 return helpers.InvalidTokenError(e) 204 } 205 206 return nil 207 } 208} 209 210func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 211 return func(e echo.Context) error { 212 ctx := e.Request().Context() 213 214 authheader := e.Request().Header.Get("authorization") 215 if authheader == "" { 216 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 217 } 218 219 pts := strings.Split(authheader, " ") 220 if len(pts) != 2 { 221 return helpers.ServerError(e, nil) 222 } 223 224 if pts[0] != "DPoP" { 225 return next(e) 226 } 227 228 accessToken := pts[1] 229 230 nonce := s.oauthProvider.NextNonce() 231 if nonce != "" { 232 e.Response().Header().Set("DPoP-Nonce", nonce) 233 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 234 } 235 236 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken)) 237 if err != nil { 238 if errors.Is(err, dpop.ErrUseDpopNonce) { 239 e.Response().Header().Set("WWW-Authenticate", `DPoP error="use_dpop_nonce"`) 240 e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate") 241 return e.JSON(401, map[string]string{ 242 "error": "use_dpop_nonce", 243 }) 244 } 245 s.logger.Error("invalid dpop proof", "error", err) 246 return helpers.InputError(e, nil) 247 } 248 249 var oauthToken provider.OauthToken 250 if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 251 s.logger.Error("error finding access token in db", "error", err) 252 return helpers.InputError(e, nil) 253 } 254 255 if oauthToken.Token == "" { 256 return helpers.InvalidTokenError(e) 257 } 258 259 if *oauthToken.Parameters.DpopJkt != proof.JKT { 260 s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT) 261 return helpers.InputError(e, to.StringPtr("dpop jkt mismatch")) 262 } 263 264 if time.Now().After(oauthToken.ExpiresAt) { 265 e.Response().Header().Set("WWW-Authenticate", `DPoP error="invalid_token", error_description="Token expired"`) 266 e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate") 267 return e.JSON(401, map[string]string{ 268 "error": "invalid_token", 269 "error_description": "Token expired", 270 }) 271 } 272 273 repo, err := s.getRepoActorByDid(ctx, oauthToken.Sub) 274 if err != nil { 275 s.logger.Error("could not find actor in db", "error", err) 276 return helpers.ServerError(e, nil) 277 } 278 279 e.Set("repo", repo) 280 e.Set("did", repo.Repo.Did) 281 e.Set("token", accessToken) 282 e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " ")) 283 284 return next(e) 285 } 286}