1package server
2
3import (
4 "crypto/sha256"
5 "encoding/base64"
6 "errors"
7 "fmt"
8 "strings"
9 "time"
10
11 "github.com/Azure/go-autorest/autorest/to"
12 "github.com/golang-jwt/jwt/v4"
13 "github.com/haileyok/cocoon/internal/helpers"
14 "github.com/haileyok/cocoon/models"
15 "github.com/haileyok/cocoon/oauth/dpop"
16 "github.com/haileyok/cocoon/oauth/provider"
17 "github.com/labstack/echo/v4"
18 "gitlab.com/yawning/secp256k1-voi"
19 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
20 "gorm.io/gorm"
21)
22
23func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
24 return func(e echo.Context) error {
25 username, password, ok := e.Request().BasicAuth()
26 if !ok || username != "admin" || password != s.config.AdminPassword {
27 return helpers.InputError(e, to.StringPtr("Unauthorized"))
28 }
29
30 if err := next(e); err != nil {
31 e.Error(err)
32 }
33
34 return nil
35 }
36}
37
38func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
39 return func(e echo.Context) error {
40 ctx := e.Request().Context()
41
42 authheader := e.Request().Header.Get("authorization")
43 if authheader == "" {
44 return e.JSON(401, map[string]string{"error": "Unauthorized"})
45 }
46
47 pts := strings.Split(authheader, " ")
48 if len(pts) != 2 {
49 return helpers.ServerError(e, nil)
50 }
51
52 // move on to oauth session middleware if this is a dpop token
53 if pts[0] == "DPoP" {
54 return next(e)
55 }
56
57 tokenstr := pts[1]
58 token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{})
59 claims, ok := token.Claims.(jwt.MapClaims)
60 if !ok {
61 return helpers.InvalidTokenError(e)
62 }
63
64 var did string
65 var repo *models.RepoActor
66
67 // service auth tokens
68 lxm, hasLxm := claims["lxm"]
69 if hasLxm {
70 pts := strings.Split(e.Request().URL.String(), "/")
71 if lxm != pts[len(pts)-1] {
72 s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err)
73 return helpers.InputError(e, nil)
74 }
75
76 maybeDid, ok := claims["iss"].(string)
77 if !ok {
78 s.logger.Error("no iss in service auth token", "error", err)
79 return helpers.InputError(e, nil)
80 }
81 did = maybeDid
82
83 maybeRepo, err := s.getRepoActorByDid(ctx, did)
84 if err != nil {
85 s.logger.Error("error fetching repo", "error", err)
86 return helpers.ServerError(e, nil)
87 }
88 repo = maybeRepo
89 }
90
91 if token.Header["alg"] != "ES256K" {
92 token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
93 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
94 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
95 }
96 return s.privateKey.Public(), nil
97 })
98 if err != nil {
99 s.logger.Error("error parsing jwt", "error", err)
100 return helpers.ExpiredTokenError(e)
101 }
102
103 if !token.Valid {
104 return helpers.InvalidTokenError(e)
105 }
106 } else {
107 kpts := strings.Split(tokenstr, ".")
108 signingInput := kpts[0] + "." + kpts[1]
109 hash := sha256.Sum256([]byte(signingInput))
110 sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2])
111 if err != nil {
112 s.logger.Error("error decoding signature bytes", "error", err)
113 return helpers.ServerError(e, nil)
114 }
115
116 if len(sigBytes) != 64 {
117 s.logger.Error("incorrect sigbytes length", "length", len(sigBytes))
118 return helpers.ServerError(e, nil)
119 }
120
121 rBytes := sigBytes[:32]
122 sBytes := sigBytes[32:]
123 rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes))
124 ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes))
125
126 if repo == nil {
127 sub, ok := claims["sub"].(string)
128 if !ok {
129 s.logger.Error("no sub claim in ES256K token and repo not set")
130 return helpers.InvalidTokenError(e)
131 }
132 maybeRepo, err := s.getRepoActorByDid(ctx, sub)
133 if err != nil {
134 s.logger.Error("error fetching repo for ES256K verification", "error", err)
135 return helpers.ServerError(e, nil)
136 }
137 repo = maybeRepo
138 did = sub
139 }
140
141 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
142 if err != nil {
143 s.logger.Error("can't load private key", "error", err)
144 return err
145 }
146
147 pubKey, ok := sk.Public().(*secp256k1secec.PublicKey)
148 if !ok {
149 s.logger.Error("error getting public key from sk")
150 return helpers.ServerError(e, nil)
151 }
152
153 verified := pubKey.VerifyRaw(hash[:], rr, ss)
154 if !verified {
155 s.logger.Error("error verifying", "error", err)
156 return helpers.ServerError(e, nil)
157 }
158 }
159
160 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
161 scope, _ := claims["scope"].(string)
162
163 if isRefresh && scope != "com.atproto.refresh" {
164 return helpers.InvalidTokenError(e)
165 } else if !hasLxm && !isRefresh && scope != "com.atproto.access" {
166 return helpers.InvalidTokenError(e)
167 }
168
169 table := "tokens"
170 if isRefresh {
171 table = "refresh_tokens"
172 }
173
174 if isRefresh {
175 type Result struct {
176 Found bool
177 }
178 var result Result
179 if err := s.db.Raw(ctx, "SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
180 if err == gorm.ErrRecordNotFound {
181 return helpers.InvalidTokenError(e)
182 }
183
184 s.logger.Error("error getting token from db", "error", err)
185 return helpers.ServerError(e, nil)
186 }
187
188 if !result.Found {
189 return helpers.InvalidTokenError(e)
190 }
191 }
192
193 exp, ok := claims["exp"].(float64)
194 if !ok {
195 s.logger.Error("error getting iat from token")
196 return helpers.ServerError(e, nil)
197 }
198
199 if exp < float64(time.Now().UTC().Unix()) {
200 return helpers.ExpiredTokenError(e)
201 }
202
203 if repo == nil {
204 maybeRepo, err := s.getRepoActorByDid(ctx, claims["sub"].(string))
205 if err != nil {
206 s.logger.Error("error fetching repo", "error", err)
207 return helpers.ServerError(e, nil)
208 }
209 repo = maybeRepo
210 did = repo.Repo.Did
211 }
212
213 e.Set("repo", repo)
214 e.Set("did", did)
215 e.Set("token", tokenstr)
216
217 if err := next(e); err != nil {
218 return helpers.InvalidTokenError(e)
219 }
220
221 return nil
222 }
223}
224
225func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
226 return func(e echo.Context) error {
227 ctx := e.Request().Context()
228
229 authheader := e.Request().Header.Get("authorization")
230 if authheader == "" {
231 return e.JSON(401, map[string]string{"error": "Unauthorized"})
232 }
233
234 pts := strings.Split(authheader, " ")
235 if len(pts) != 2 {
236 return helpers.ServerError(e, nil)
237 }
238
239 if pts[0] != "DPoP" {
240 return next(e)
241 }
242
243 accessToken := pts[1]
244
245 nonce := s.oauthProvider.NextNonce()
246 if nonce != "" {
247 e.Response().Header().Set("DPoP-Nonce", nonce)
248 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
249 }
250
251 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken))
252 if err != nil {
253 if errors.Is(err, dpop.ErrUseDpopNonce) {
254 e.Response().Header().Set("WWW-Authenticate", `DPoP error="use_dpop_nonce"`)
255 e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate")
256 return e.JSON(401, map[string]string{
257 "error": "use_dpop_nonce",
258 })
259 }
260 s.logger.Error("invalid dpop proof", "error", err)
261 return helpers.InputError(e, nil)
262 }
263
264 var oauthToken provider.OauthToken
265 if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil {
266 s.logger.Error("error finding access token in db", "error", err)
267 return helpers.InputError(e, nil)
268 }
269
270 if oauthToken.Token == "" {
271 return helpers.InvalidTokenError(e)
272 }
273
274 if *oauthToken.Parameters.DpopJkt != proof.JKT {
275 s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT)
276 return helpers.InputError(e, to.StringPtr("dpop jkt mismatch"))
277 }
278
279 if time.Now().After(oauthToken.ExpiresAt) {
280 e.Response().Header().Set("WWW-Authenticate", `DPoP error="invalid_token", error_description="Token expired"`)
281 e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate")
282 return e.JSON(401, map[string]string{
283 "error": "invalid_token",
284 "error_description": "Token expired",
285 })
286 }
287
288 repo, err := s.getRepoActorByDid(ctx, oauthToken.Sub)
289 if err != nil {
290 s.logger.Error("could not find actor in db", "error", err)
291 return helpers.ServerError(e, nil)
292 }
293
294 e.Set("repo", repo)
295 e.Set("did", repo.Repo.Did)
296 e.Set("token", accessToken)
297 e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " "))
298
299 return next(e)
300 }
301}