1package server
2
3import (
4 "errors"
5 "strings"
6
7 "github.com/bluesky-social/indigo/atproto/syntax"
8 "github.com/gorilla/sessions"
9 "github.com/haileyok/cocoon/internal/helpers"
10 "github.com/haileyok/cocoon/models"
11 "github.com/labstack/echo-contrib/session"
12 "github.com/labstack/echo/v4"
13 "golang.org/x/crypto/bcrypt"
14 "gorm.io/gorm"
15)
16
17type OauthSigninInput struct {
18 Username string `form:"username"`
19 Password string `form:"password"`
20 QueryParams string `form:"query_params"`
21}
22
23func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) {
24 ctx := e.Request().Context()
25
26 sess, err := session.Get("session", e)
27 if err != nil {
28 return nil, nil, err
29 }
30
31 did, ok := sess.Values["did"].(string)
32 if !ok {
33 return nil, sess, errors.New("did was not set in session")
34 }
35
36 repo, err := s.getRepoActorByDid(ctx, did)
37 if err != nil {
38 return nil, sess, err
39 }
40
41 return repo, sess, nil
42}
43
44func getFlashesFromSession(e echo.Context, sess *sessions.Session) map[string]any {
45 defer sess.Save(e.Request(), e.Response())
46 return map[string]any{
47 "errors": sess.Flashes("error"),
48 "successes": sess.Flashes("success"),
49 }
50}
51
52func (s *Server) handleAccountSigninGet(e echo.Context) error {
53 _, sess, err := s.getSessionRepoOrErr(e)
54 if err == nil {
55 return e.Redirect(303, "/account")
56 }
57
58 return e.Render(200, "signin.html", map[string]any{
59 "flashes": getFlashesFromSession(e, sess),
60 "QueryParams": e.QueryParams().Encode(),
61 })
62}
63
64func (s *Server) handleAccountSigninPost(e echo.Context) error {
65 ctx := e.Request().Context()
66
67 var req OauthSigninInput
68 if err := e.Bind(&req); err != nil {
69 s.logger.Error("error binding sign in req", "error", err)
70 return helpers.ServerError(e, nil)
71 }
72
73 sess, _ := session.Get("session", e)
74
75 req.Username = strings.ToLower(req.Username)
76 var idtype string
77 if _, err := syntax.ParseDID(req.Username); err == nil {
78 idtype = "did"
79 } else if _, err := syntax.ParseHandle(req.Username); err == nil {
80 idtype = "handle"
81 } else {
82 idtype = "email"
83 }
84
85 // TODO: we should make this a helper since we do it for the base create_session as well
86 var repo models.RepoActor
87 var err error
88 switch idtype {
89 case "did":
90 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error
91 case "handle":
92 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error
93 case "email":
94 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error
95 }
96 if err != nil {
97 if err == gorm.ErrRecordNotFound {
98 sess.AddFlash("Handle or password is incorrect", "error")
99 } else {
100 sess.AddFlash("Something went wrong!", "error")
101 }
102 sess.Save(e.Request(), e.Response())
103 return e.Redirect(303, "/account/signin")
104 }
105
106 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil {
107 if err != bcrypt.ErrMismatchedHashAndPassword {
108 sess.AddFlash("Handle or password is incorrect", "error")
109 } else {
110 sess.AddFlash("Something went wrong!", "error")
111 }
112 sess.Save(e.Request(), e.Response())
113 return e.Redirect(303, "/account/signin")
114 }
115
116 sess.Options = &sessions.Options{
117 Path: "/",
118 MaxAge: int(AccountSessionMaxAge.Seconds()),
119 HttpOnly: true,
120 }
121
122 sess.Values = map[any]any{}
123 sess.Values["did"] = repo.Repo.Did
124
125 if err := sess.Save(e.Request(), e.Response()); err != nil {
126 return err
127 }
128
129 if req.QueryParams != "" {
130 return e.Redirect(303, "/oauth/authorize?"+req.QueryParams)
131 } else {
132 return e.Redirect(303, "/account")
133 }
134}