1package server
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "time"
9
10 "github.com/Azure/go-autorest/autorest/to"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "github.com/haileyok/cocoon/internal/helpers"
13 "github.com/haileyok/cocoon/models"
14 "github.com/labstack/echo/v4"
15 "golang.org/x/crypto/bcrypt"
16 "gorm.io/gorm"
17)
18
19type ComAtprotoServerCreateSessionRequest struct {
20 Identifier string `json:"identifier" validate:"required"`
21 Password string `json:"password" validate:"required"`
22 AuthFactorToken *string `json:"authFactorToken,omitempty"`
23}
24
25type ComAtprotoServerCreateSessionResponse struct {
26 AccessJwt string `json:"accessJwt"`
27 RefreshJwt string `json:"refreshJwt"`
28 Handle string `json:"handle"`
29 Did string `json:"did"`
30 Email string `json:"email"`
31 EmailConfirmed bool `json:"emailConfirmed"`
32 EmailAuthFactor bool `json:"emailAuthFactor"`
33 Active bool `json:"active"`
34 Status *string `json:"status,omitempty"`
35}
36
37func (s *Server) handleCreateSession(e echo.Context) error {
38 ctx := e.Request().Context()
39 logger := s.logger.With("name", "handleServerCreateSession")
40
41 var req ComAtprotoServerCreateSessionRequest
42 if err := e.Bind(&req); err != nil {
43 logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err)
44 return helpers.ServerError(e, nil)
45 }
46
47 if err := e.Validate(req); err != nil {
48 var verr ValidationError
49 if errors.As(err, &verr) {
50 if verr.Field == "Identifier" {
51 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
52 }
53
54 if verr.Field == "Password" {
55 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
56 }
57 }
58 }
59
60 req.Identifier = strings.ToLower(req.Identifier)
61 var idtype string
62 if _, err := syntax.ParseDID(req.Identifier); err == nil {
63 idtype = "did"
64 } else if _, err := syntax.ParseHandle(req.Identifier); err == nil {
65 idtype = "handle"
66 } else {
67 idtype = "email"
68 }
69
70 var repo models.RepoActor
71 var err error
72 switch idtype {
73 case "did":
74 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error
75 case "handle":
76 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error
77 case "email":
78 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error
79 }
80
81 if err != nil {
82 if err == gorm.ErrRecordNotFound {
83 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
84 }
85
86 logger.Error("erorr looking up repo", "endpoint", "com.atproto.server.createSession", "error", err)
87 return helpers.ServerError(e, nil)
88 }
89
90 // if repo requires auth factor token and one hasn't been provided, return error prompting for one
91 if repo.EmailAuthFactor && (req.AuthFactorToken == nil || *req.AuthFactorToken == "") {
92 err = s.createAndSendAuthCode(ctx, repo)
93 if err != nil {
94 s.logger.Error("sending auth code", "error", err)
95 return helpers.ServerError(e, nil)
96 }
97
98 return helpers.InputError(e, to.StringPtr("AuthFactorTokenRequired"))
99 }
100
101 // if auth factor is required, now check that the one provided is valid
102 if repo.EmailAuthFactor {
103 if repo.AuthCode == nil || repo.AuthCodeExpiresAt == nil {
104 err = s.createAndSendAuthCode(ctx, repo)
105 if err != nil {
106 s.logger.Error("sending auth code", "error", err)
107 return helpers.ServerError(e, nil)
108 }
109
110 return helpers.InputError(e, to.StringPtr("AuthFactorTokenRequired"))
111 }
112
113 if *repo.AuthCode != *req.AuthFactorToken {
114 return helpers.InvalidTokenError(e)
115 }
116
117 if time.Now().UTC().After(*repo.AuthCodeExpiresAt) {
118 return helpers.ExpiredTokenError(e)
119 }
120 }
121
122 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil {
123 if err != bcrypt.ErrMismatchedHashAndPassword {
124 logger.Error("erorr comparing hash and password", "error", err)
125 }
126 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
127 }
128
129 sess, err := s.createSession(ctx, &repo.Repo)
130 if err != nil {
131 logger.Error("error creating session", "error", err)
132 return helpers.ServerError(e, nil)
133 }
134
135 return e.JSON(200, ComAtprotoServerCreateSessionResponse{
136 AccessJwt: sess.AccessToken,
137 RefreshJwt: sess.RefreshToken,
138 Handle: repo.Handle,
139 Did: repo.Repo.Did,
140 Email: repo.Email,
141 EmailConfirmed: repo.EmailConfirmedAt != nil,
142 EmailAuthFactor: repo.EmailAuthFactor,
143 Active: repo.Active(),
144 Status: repo.Status(),
145 })
146}
147
148func (s *Server) createAndSendAuthCode(ctx context.Context, repo models.RepoActor) error {
149 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5))
150 eat := time.Now().Add(10 * time.Minute).UTC()
151
152 if err := s.db.Exec(ctx, "UPDATE repos SET auth_code = ?, auth_code_expires_at = ? WHERE did = ?", nil, code, eat, repo.Repo.Did).Error; err != nil {
153 return fmt.Errorf("updating repo: %w", err)
154 }
155
156 if err := s.sendAuthCode(repo.Email, repo.Handle, code); err != nil {
157 return fmt.Errorf("sending email: %w", err)
158 }
159
160 return nil
161}