1package server
2
3import (
4 "errors"
5 "fmt"
6 "strings"
7 "time"
8
9 "github.com/Azure/go-autorest/autorest/to"
10 "github.com/bluesky-social/indigo/atproto/syntax"
11 "github.com/haileyok/cocoon/internal/helpers"
12 "github.com/haileyok/cocoon/models"
13 "github.com/labstack/echo/v4"
14 "golang.org/x/crypto/bcrypt"
15 "gorm.io/gorm"
16)
17
18type ComAtprotoServerCreateSessionRequest struct {
19 Identifier string `json:"identifier" validate:"required"`
20 Password string `json:"password" validate:"required"`
21 AuthFactorToken *string `json:"authFactorToken,omitempty"`
22}
23
24type ComAtprotoServerCreateSessionResponse struct {
25 AccessJwt string `json:"accessJwt"`
26 RefreshJwt string `json:"refreshJwt"`
27 Handle string `json:"handle"`
28 Did string `json:"did"`
29 Email string `json:"email"`
30 EmailConfirmed bool `json:"emailConfirmed"`
31 EmailAuthFactor bool `json:"emailAuthFactor"`
32 Active bool `json:"active"`
33 Status *string `json:"status,omitempty"`
34}
35
36func (s *Server) handleCreateSession(e echo.Context) error {
37 ctx := e.Request().Context()
38 logger := s.logger.With("name", "handleServerCreateSession")
39
40 var req ComAtprotoServerCreateSessionRequest
41 if err := e.Bind(&req); err != nil {
42 logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err)
43 return helpers.ServerError(e, nil)
44 }
45
46 if err := e.Validate(req); err != nil {
47 var verr ValidationError
48 if errors.As(err, &verr) {
49 if verr.Field == "Identifier" {
50 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
51 }
52
53 if verr.Field == "Password" {
54 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
55 }
56 }
57 }
58
59 req.Identifier = strings.ToLower(req.Identifier)
60 var idtype string
61 if _, err := syntax.ParseDID(req.Identifier); err == nil {
62 idtype = "did"
63 } else if _, err := syntax.ParseHandle(req.Identifier); err == nil {
64 idtype = "handle"
65 } else {
66 idtype = "email"
67 }
68
69 var repo models.RepoActor
70 var err error
71 switch idtype {
72 case "did":
73 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error
74 case "handle":
75 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error
76 case "email":
77 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error
78 }
79
80 if err != nil {
81 if err == gorm.ErrRecordNotFound {
82 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
83 }
84
85 logger.Error("erorr looking up repo", "endpoint", "com.atproto.server.createSession", "error", err)
86 return helpers.ServerError(e, nil)
87 }
88
89 // if repo requires auth factor token and one hasn't been provided, return error prompting for one
90 if repo.EmailAuthFactor && (req.AuthFactorToken == nil || *req.AuthFactorToken == "") {
91 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5))
92 eat := time.Now().Add(10 * time.Minute).UTC()
93
94 if err := s.db.Exec(ctx, "UPDATE repos SET email_update_code = ?, email_update_code_expires_at = ? WHERE did = ?", nil, code, eat, repo.Repo.Did).Error; err != nil {
95 s.logger.Error("error updating repo", "error", err)
96 return helpers.ServerError(e, nil)
97 }
98
99 if err := s.sendEmailUpdate(repo.Email, repo.Handle, code); err != nil {
100 s.logger.Error("error sending email", "error", err)
101 return helpers.ServerError(e, nil)
102 }
103
104 return helpers.InputError(e, to.StringPtr("AuthFactorTokenRequired"))
105 }
106
107 // if auth factor is required, now check that the one provided is valid
108 if repo.EmailAuthFactor {
109 if repo.EmailUpdateCode == nil || repo.EmailUpdateCodeExpiresAt == nil {
110 return helpers.InvalidTokenError(e)
111 }
112
113 if *repo.EmailUpdateCode != *req.AuthFactorToken {
114 return helpers.InvalidTokenError(e)
115 }
116
117 if time.Now().UTC().After(*repo.EmailUpdateCodeExpiresAt) {
118 return helpers.ExpiredTokenError(e)
119 }
120 }
121
122 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil {
123 if err != bcrypt.ErrMismatchedHashAndPassword {
124 logger.Error("erorr comparing hash and password", "error", err)
125 }
126 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
127 }
128
129 sess, err := s.createSession(ctx, &repo.Repo)
130 if err != nil {
131 logger.Error("error creating session", "error", err)
132 return helpers.ServerError(e, nil)
133 }
134
135 return e.JSON(200, ComAtprotoServerCreateSessionResponse{
136 AccessJwt: sess.AccessToken,
137 RefreshJwt: sess.RefreshToken,
138 Handle: repo.Handle,
139 Did: repo.Repo.Did,
140 Email: repo.Email,
141 EmailConfirmed: repo.EmailConfirmedAt != nil,
142 EmailAuthFactor: repo.EmailAuthFactor,
143 Active: repo.Active(),
144 Status: repo.Status(),
145 })
146}