1package server
2
3import (
4 "errors"
5 "strings"
6
7 "github.com/bluesky-social/indigo/atproto/syntax"
8 "github.com/gorilla/sessions"
9 "github.com/haileyok/cocoon/internal/helpers"
10 "github.com/haileyok/cocoon/models"
11 "github.com/labstack/echo-contrib/session"
12 "github.com/labstack/echo/v4"
13 "golang.org/x/crypto/bcrypt"
14 "gorm.io/gorm"
15)
16
17type OauthSigninInput struct {
18 Username string `form:"username"`
19 Password string `form:"password"`
20 QueryParams string `form:"query_params"`
21}
22
23func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) {
24 ctx := e.Request().Context()
25
26 sess, err := session.Get("session", e)
27 if err != nil {
28 return nil, nil, err
29 }
30
31 did, ok := sess.Values["did"].(string)
32 if !ok {
33 return nil, sess, errors.New("did was not set in session")
34 }
35
36 repo, err := s.getRepoActorByDid(ctx, did)
37 if err != nil {
38 return nil, sess, err
39 }
40
41 return repo, sess, nil
42}
43
44func getFlashesFromSession(e echo.Context, sess *sessions.Session) map[string]any {
45 defer sess.Save(e.Request(), e.Response())
46 return map[string]any{
47 "errors": sess.Flashes("error"),
48 "successes": sess.Flashes("success"),
49 }
50}
51
52func (s *Server) handleAccountSigninGet(e echo.Context) error {
53 _, sess, err := s.getSessionRepoOrErr(e)
54 if err == nil {
55 return e.Redirect(303, "/account")
56 }
57
58 return e.Render(200, "signin.html", map[string]any{
59 "flashes": getFlashesFromSession(e, sess),
60 "QueryParams": e.QueryParams().Encode(),
61 })
62}
63
64func (s *Server) handleAccountSigninPost(e echo.Context) error {
65 ctx := e.Request().Context()
66 logger := s.logger.With("name", "handleAccountSigninPost")
67
68 var req OauthSigninInput
69 if err := e.Bind(&req); err != nil {
70 logger.Error("error binding sign in req", "error", err)
71 return helpers.ServerError(e, nil)
72 }
73
74 sess, _ := session.Get("session", e)
75
76 req.Username = strings.ToLower(req.Username)
77 var idtype string
78 if _, err := syntax.ParseDID(req.Username); err == nil {
79 idtype = "did"
80 } else if _, err := syntax.ParseHandle(req.Username); err == nil {
81 idtype = "handle"
82 } else {
83 idtype = "email"
84 }
85
86 // TODO: we should make this a helper since we do it for the base create_session as well
87 var repo models.RepoActor
88 var err error
89 switch idtype {
90 case "did":
91 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error
92 case "handle":
93 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error
94 case "email":
95 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error
96 }
97 if err != nil {
98 if err == gorm.ErrRecordNotFound {
99 sess.AddFlash("Handle or password is incorrect", "error")
100 } else {
101 sess.AddFlash("Something went wrong!", "error")
102 }
103 sess.Save(e.Request(), e.Response())
104 return e.Redirect(303, "/account/signin")
105 }
106
107 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil {
108 if err != bcrypt.ErrMismatchedHashAndPassword {
109 sess.AddFlash("Handle or password is incorrect", "error")
110 } else {
111 sess.AddFlash("Something went wrong!", "error")
112 }
113 sess.Save(e.Request(), e.Response())
114 return e.Redirect(303, "/account/signin")
115 }
116
117 sess.Options = &sessions.Options{
118 Path: "/",
119 MaxAge: int(AccountSessionMaxAge.Seconds()),
120 HttpOnly: true,
121 }
122
123 sess.Values = map[any]any{}
124 sess.Values["did"] = repo.Repo.Did
125
126 if err := sess.Save(e.Request(), e.Response()); err != nil {
127 return err
128 }
129
130 if req.QueryParams != "" {
131 return e.Redirect(303, "/oauth/authorize?"+req.QueryParams)
132 } else {
133 return e.Redirect(303, "/account")
134 }
135}