1package server
2
3import (
4 "errors"
5 "strings"
6
7 "github.com/Azure/go-autorest/autorest/to"
8 "github.com/bluesky-social/indigo/atproto/syntax"
9 "github.com/haileyok/cocoon/internal/helpers"
10 "github.com/haileyok/cocoon/models"
11 "github.com/labstack/echo/v4"
12 "golang.org/x/crypto/bcrypt"
13 "gorm.io/gorm"
14)
15
16type ComAtprotoServerCreateSessionRequest struct {
17 Identifier string `json:"identifier" validate:"required"`
18 Password string `json:"password" validate:"required"`
19 AuthFactorToken *string `json:"authFactorToken,omitempty"`
20}
21
22type ComAtprotoServerCreateSessionResponse struct {
23 AccessJwt string `json:"accessJwt"`
24 RefreshJwt string `json:"refreshJwt"`
25 Handle string `json:"handle"`
26 Did string `json:"did"`
27 Email string `json:"email"`
28 EmailConfirmed bool `json:"emailConfirmed"`
29 EmailAuthFactor bool `json:"emailAuthFactor"`
30 Active bool `json:"active"`
31 Status *string `json:"status,omitempty"`
32}
33
34func (s *Server) handleCreateSession(e echo.Context) error {
35 ctx := e.Request().Context()
36
37 var req ComAtprotoServerCreateSessionRequest
38 if err := e.Bind(&req); err != nil {
39 s.logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err)
40 return helpers.ServerError(e, nil)
41 }
42
43 if err := e.Validate(req); err != nil {
44 var verr ValidationError
45 if errors.As(err, &verr) {
46 if verr.Field == "Identifier" {
47 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
48 }
49
50 if verr.Field == "Password" {
51 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
52 }
53 }
54 }
55
56 req.Identifier = strings.ToLower(req.Identifier)
57 var idtype string
58 if _, err := syntax.ParseDID(req.Identifier); err == nil {
59 idtype = "did"
60 } else if _, err := syntax.ParseHandle(req.Identifier); err == nil {
61 idtype = "handle"
62 } else {
63 idtype = "email"
64 }
65
66 var repo models.RepoActor
67 var err error
68 switch idtype {
69 case "did":
70 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error
71 case "handle":
72 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error
73 case "email":
74 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error
75 }
76
77 if err != nil {
78 if err == gorm.ErrRecordNotFound {
79 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
80 }
81
82 s.logger.Error("erorr looking up repo", "endpoint", "com.atproto.server.createSession", "error", err)
83 return helpers.ServerError(e, nil)
84 }
85
86 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil {
87 if err != bcrypt.ErrMismatchedHashAndPassword {
88 s.logger.Error("erorr comparing hash and password", "error", err)
89 }
90 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
91 }
92
93 sess, err := s.createSession(ctx, &repo.Repo)
94 if err != nil {
95 s.logger.Error("error creating session", "error", err)
96 return helpers.ServerError(e, nil)
97 }
98
99 return e.JSON(200, ComAtprotoServerCreateSessionResponse{
100 AccessJwt: sess.AccessToken,
101 RefreshJwt: sess.RefreshToken,
102 Handle: repo.Handle,
103 Did: repo.Repo.Did,
104 Email: repo.Email,
105 EmailConfirmed: repo.EmailConfirmedAt != nil,
106 EmailAuthFactor: false,
107 Active: repo.Active(),
108 Status: repo.Status(),
109 })
110}