1package server
2
3import (
4 "bytes"
5 "crypto/sha256"
6 "encoding/base64"
7 "errors"
8 "fmt"
9 "slices"
10 "time"
11
12 "github.com/Azure/go-autorest/autorest/to"
13 "github.com/golang-jwt/jwt/v4"
14 "github.com/haileyok/cocoon/internal/helpers"
15 "github.com/haileyok/cocoon/oauth"
16 "github.com/haileyok/cocoon/oauth/constants"
17 "github.com/haileyok/cocoon/oauth/dpop"
18 "github.com/haileyok/cocoon/oauth/provider"
19 "github.com/labstack/echo/v4"
20)
21
22type OauthTokenRequest struct {
23 provider.AuthenticateClientRequestBase
24 GrantType string `form:"grant_type" json:"grant_type"`
25 Code *string `form:"code" json:"code,omitempty"`
26 CodeVerifier *string `form:"code_verifier" json:"code_verifier,omitempty"`
27 RedirectURI *string `form:"redirect_uri" json:"redirect_uri,omitempty"`
28 RefreshToken *string `form:"refresh_token" json:"refresh_token,omitempty"`
29}
30
31type OauthTokenResponse struct {
32 AccessToken string `json:"access_token"`
33 TokenType string `json:"token_type"`
34 RefreshToken string `json:"refresh_token"`
35 Scope string `json:"scope"`
36 ExpiresIn int64 `json:"expires_in"`
37 Sub string `json:"sub"`
38}
39
40func (s *Server) handleOauthToken(e echo.Context) error {
41 ctx := e.Request().Context()
42
43 var req OauthTokenRequest
44 if err := e.Bind(&req); err != nil {
45 s.logger.Error("error binding token request", "error", err)
46 return helpers.ServerError(e, nil)
47 }
48
49 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, e.Request().URL.String(), e.Request().Header, nil)
50 if err != nil {
51 if errors.Is(err, dpop.ErrUseDpopNonce) {
52 nonce := s.oauthProvider.NextNonce()
53 if nonce != "" {
54 e.Response().Header().Set("DPoP-Nonce", nonce)
55 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
56 }
57 return e.JSON(400, map[string]string{
58 "error": "use_dpop_nonce",
59 })
60 }
61 s.logger.Error("error getting dpop proof", "error", err)
62 return helpers.InputError(e, nil)
63 }
64
65 client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), req.AuthenticateClientRequestBase, proof, &provider.AuthenticateClientOptions{
66 AllowMissingDpopProof: true,
67 })
68 if err != nil {
69 s.logger.Error("error authenticating client", "client_id", req.ClientID, "error", err)
70 return helpers.InputError(e, to.StringPtr(err.Error()))
71 }
72
73 // TODO: this should come from an oauth provier config
74 if !slices.Contains([]string{"authorization_code", "refresh_token"}, req.GrantType) {
75 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the server`, req.GrantType)))
76 }
77
78 if !slices.Contains(client.Metadata.GrantTypes, req.GrantType) {
79 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`"%s" grant type is not supported by the client`, req.GrantType)))
80 }
81
82 if req.GrantType == "authorization_code" {
83 if req.Code == nil {
84 return helpers.InputError(e, to.StringPtr(`"code" is required"`))
85 }
86
87 var authReq provider.OauthAuthorizationRequest
88 // get the lil guy and delete him
89 if err := s.db.Raw(ctx, "DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil {
90 s.logger.Error("error finding authorization request", "error", err)
91 return helpers.ServerError(e, nil)
92 }
93
94 if req.RedirectURI == nil || *req.RedirectURI != authReq.Parameters.RedirectURI {
95 return helpers.InputError(e, to.StringPtr(`"redirect_uri" mismatch`))
96 }
97
98 if authReq.Parameters.CodeChallenge != nil {
99 if req.CodeVerifier == nil {
100 return helpers.InputError(e, to.StringPtr(`"code_verifier" is required`))
101 }
102
103 if len(*req.CodeVerifier) < 43 {
104 return helpers.InputError(e, to.StringPtr(`"code_verifier" is too short`))
105 }
106
107 switch *&authReq.Parameters.CodeChallengeMethod {
108 case "", "plain":
109 if authReq.Parameters.CodeChallenge != req.CodeVerifier {
110 return helpers.InputError(e, to.StringPtr("invalid code_verifier"))
111 }
112 case "S256":
113 inputChal, err := base64.RawURLEncoding.DecodeString(*authReq.Parameters.CodeChallenge)
114 if err != nil {
115 s.logger.Error("error decoding code challenge", "error", err)
116 return helpers.ServerError(e, nil)
117 }
118
119 h := sha256.New()
120 h.Write([]byte(*req.CodeVerifier))
121 compdChal := h.Sum(nil)
122
123 if !bytes.Equal(inputChal, compdChal) {
124 return helpers.InputError(e, to.StringPtr("invalid code_verifier"))
125 }
126 default:
127 return helpers.InputError(e, to.StringPtr("unsupported code_challenge_method "+*&authReq.Parameters.CodeChallengeMethod))
128 }
129 } else if req.CodeVerifier != nil {
130 return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided"))
131 }
132
133 repo, err := s.getRepoActorByDid(ctx, *authReq.Sub)
134 if err != nil {
135 helpers.InputError(e, to.StringPtr("unable to find actor"))
136 }
137
138 now := time.Now()
139 eat := now.Add(constants.TokenMaxAge)
140 id := oauth.GenerateTokenId()
141
142 refreshToken := oauth.GenerateRefreshToken()
143
144 accessClaims := jwt.MapClaims{
145 "scope": authReq.Parameters.Scope,
146 "aud": s.config.Did,
147 "sub": repo.Repo.Did,
148 "iat": now.Unix(),
149 "exp": eat.Unix(),
150 "jti": id,
151 "client_id": authReq.ClientId,
152 }
153
154 if authReq.Parameters.DpopJkt != nil {
155 accessClaims["cnf"] = *authReq.Parameters.DpopJkt
156 }
157
158 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims)
159 accessString, err := accessToken.SignedString(s.privateKey)
160 if err != nil {
161 return err
162 }
163
164 if err := s.db.Create(ctx, &provider.OauthToken{
165 ClientId: authReq.ClientId,
166 ClientAuth: *clientAuth,
167 Parameters: authReq.Parameters,
168 ExpiresAt: eat,
169 DeviceId: "",
170 Sub: repo.Repo.Did,
171 Code: *authReq.Code,
172 Token: accessString,
173 RefreshToken: refreshToken,
174 Ip: authReq.Ip,
175 }, nil).Error; err != nil {
176 s.logger.Error("error creating token in db", "error", err)
177 return helpers.ServerError(e, nil)
178 }
179
180 // prob not needed
181 tokenType := "Bearer"
182 if authReq.Parameters.DpopJkt != nil {
183 tokenType = "DPoP"
184 }
185
186 e.Response().Header().Set("content-type", "application/json")
187
188 return e.JSON(200, OauthTokenResponse{
189 AccessToken: accessString,
190 RefreshToken: refreshToken,
191 TokenType: tokenType,
192 Scope: authReq.Parameters.Scope,
193 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()),
194 Sub: repo.Repo.Did,
195 })
196 }
197
198 if req.GrantType == "refresh_token" {
199 if req.RefreshToken == nil {
200 return helpers.InputError(e, to.StringPtr(`"refresh_token" is required`))
201 }
202
203 var oauthToken provider.OauthToken
204 if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil {
205 s.logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken)
206 return helpers.ServerError(e, nil)
207 }
208
209 if client.Metadata.ClientID != oauthToken.ClientId {
210 return helpers.InputError(e, to.StringPtr(`"client_id" mismatch`))
211 }
212
213 if clientAuth.Method != oauthToken.ClientAuth.Method {
214 return helpers.InputError(e, to.StringPtr(`"client authentication method mismatch`))
215 }
216
217 if *oauthToken.Parameters.DpopJkt != proof.JKT {
218 return helpers.InputError(e, to.StringPtr("dpop proof does not match expected jkt"))
219 }
220
221 ageRes := oauth.GetSessionAgeFromToken(oauthToken)
222
223 if ageRes.SessionExpired {
224 return helpers.InputError(e, to.StringPtr("Session expired"))
225 }
226
227 if ageRes.RefreshExpired {
228 return helpers.InputError(e, to.StringPtr("Refresh token expired"))
229 }
230
231 if client.Metadata.DpopBoundAccessTokens && oauthToken.Parameters.DpopJkt == nil {
232 // why? ref impl
233 return helpers.InputError(e, to.StringPtr("dpop jkt is required for dpop bound access tokens"))
234 }
235
236 nextTokenId := oauth.GenerateTokenId()
237 nextRefreshToken := oauth.GenerateRefreshToken()
238
239 now := time.Now()
240 eat := now.Add(constants.TokenMaxAge)
241
242 accessClaims := jwt.MapClaims{
243 "scope": oauthToken.Parameters.Scope,
244 "aud": s.config.Did,
245 "sub": oauthToken.Sub,
246 "iat": now.Unix(),
247 "exp": eat.Unix(),
248 "jti": nextTokenId,
249 "client_id": oauthToken.ClientId,
250 }
251
252 if oauthToken.Parameters.DpopJkt != nil {
253 accessClaims["cnf"] = *&oauthToken.Parameters.DpopJkt
254 }
255
256 accessToken := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims)
257 accessString, err := accessToken.SignedString(s.privateKey)
258 if err != nil {
259 return err
260 }
261
262 if err := s.db.Exec(ctx, "UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil {
263 s.logger.Error("error updating token", "error", err)
264 return helpers.ServerError(e, nil)
265 }
266
267 // prob not needed
268 tokenType := "Bearer"
269 if oauthToken.Parameters.DpopJkt != nil {
270 tokenType = "DPoP"
271 }
272
273 return e.JSON(200, OauthTokenResponse{
274 AccessToken: accessString,
275 RefreshToken: nextRefreshToken,
276 TokenType: tokenType,
277 Scope: oauthToken.Parameters.Scope,
278 ExpiresIn: int64(eat.Sub(time.Now()).Seconds()),
279 Sub: oauthToken.Sub,
280 })
281 }
282
283 return helpers.InputError(e, to.StringPtr(fmt.Sprintf(`grant type "%s" is not supported`, req.GrantType)))
284}