···11package db
2233import (
44+ "context"
45 "sync"
5667 "gorm.io/gorm"
···1920 }
2021}
21222222-func (db *DB) Create(value any, clauses []clause.Expression) *gorm.DB {
2323+func (db *DB) Create(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB {
2324 db.mu.Lock()
2425 defer db.mu.Unlock()
2525- return db.cli.Clauses(clauses...).Create(value)
2626+ return db.cli.WithContext(ctx).Clauses(clauses...).Create(value)
2627}
27282828-func (db *DB) Save(value any, clauses []clause.Expression) *gorm.DB {
2929+func (db *DB) Save(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB {
2930 db.mu.Lock()
3031 defer db.mu.Unlock()
3131- return db.cli.Clauses(clauses...).Save(value)
3232+ return db.cli.WithContext(ctx).Clauses(clauses...).Save(value)
3233}
33343434-func (db *DB) Exec(sql string, clauses []clause.Expression, values ...any) *gorm.DB {
3535+func (db *DB) Exec(ctx context.Context, sql string, clauses []clause.Expression, values ...any) *gorm.DB {
3536 db.mu.Lock()
3637 defer db.mu.Unlock()
3737- return db.cli.Clauses(clauses...).Exec(sql, values...)
3838+ return db.cli.WithContext(ctx).Clauses(clauses...).Exec(sql, values...)
3839}
39404040-func (db *DB) Raw(sql string, clauses []clause.Expression, values ...any) *gorm.DB {
4141- return db.cli.Clauses(clauses...).Raw(sql, values...)
4141+func (db *DB) Raw(ctx context.Context, sql string, clauses []clause.Expression, values ...any) *gorm.DB {
4242+ return db.cli.WithContext(ctx).Clauses(clauses...).Raw(sql, values...)
4243}
43444445func (db *DB) AutoMigrate(models ...any) error {
4546 return db.cli.AutoMigrate(models...)
4647}
47484848-func (db *DB) Delete(value any, clauses []clause.Expression) *gorm.DB {
4949+func (db *DB) Delete(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB {
4950 db.mu.Lock()
5051 defer db.mu.Unlock()
5151- return db.cli.Clauses(clauses...).Delete(value)
5252+ return db.cli.WithContext(ctx).Clauses(clauses...).Delete(value)
5253}
53545454-func (db *DB) First(dest any, conds ...any) *gorm.DB {
5555- return db.cli.First(dest, conds...)
5555+func (db *DB) First(ctx context.Context, dest any, conds ...any) *gorm.DB {
5656+ return db.cli.WithContext(ctx).First(dest, conds...)
5657}
57585859// TODO: this isn't actually good. we can commit even if the db is locked here. this is probably okay for the time being, but need to figure
5960// out a better solution. right now we only do this whenever we're importing a repo though so i'm mostly not worried, but it's still bad.
6061// e.g. when we do apply writes we should also be using a transcation but we don't right now
6161-func (db *DB) BeginDangerously() *gorm.DB {
6262- return db.cli.Begin()
6262+func (db *DB) BeginDangerously(ctx context.Context) *gorm.DB {
6363+ return db.cli.WithContext(ctx).Begin()
6364}
64656566func (db *DB) Lock() {
+10-8
server/common.go
···11package server
2233import (
44+ "context"
55+46 "github.com/haileyok/cocoon/models"
57)
6877-func (s *Server) getActorByHandle(handle string) (*models.Actor, error) {
99+func (s *Server) getActorByHandle(ctx context.Context, handle string) (*models.Actor, error) {
810 var actor models.Actor
99- if err := s.db.First(&actor, models.Actor{Handle: handle}).Error; err != nil {
1111+ if err := s.db.First(ctx, &actor, models.Actor{Handle: handle}).Error; err != nil {
1012 return nil, err
1113 }
1214 return &actor, nil
1315}
14161515-func (s *Server) getRepoByEmail(email string) (*models.Repo, error) {
1717+func (s *Server) getRepoByEmail(ctx context.Context, email string) (*models.Repo, error) {
1618 var repo models.Repo
1717- if err := s.db.First(&repo, models.Repo{Email: email}).Error; err != nil {
1919+ if err := s.db.First(ctx, &repo, models.Repo{Email: email}).Error; err != nil {
1820 return nil, err
1921 }
2022 return &repo, nil
2123}
22242323-func (s *Server) getRepoActorByEmail(email string) (*models.RepoActor, error) {
2525+func (s *Server) getRepoActorByEmail(ctx context.Context, email string) (*models.RepoActor, error) {
2426 var repo models.RepoActor
2525- if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", nil, email).Scan(&repo).Error; err != nil {
2727+ if err := s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", nil, email).Scan(&repo).Error; err != nil {
2628 return nil, err
2729 }
2830 return &repo, nil
2931}
30323131-func (s *Server) getRepoActorByDid(did string) (*models.RepoActor, error) {
3333+func (s *Server) getRepoActorByDid(ctx context.Context, did string) (*models.RepoActor, error) {
3234 var repo models.RepoActor
3333- if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, did).Scan(&repo).Error; err != nil {
3535+ if err := s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, did).Scan(&repo).Error; err != nil {
3436 return nil, err
3537 }
3638 return &repo, nil
+2-1
server/handle_account.go
···12121313func (s *Server) handleAccount(e echo.Context) error {
1414 ctx := e.Request().Context()
1515+1516 repo, sess, err := s.getSessionRepoOrErr(e)
1617 if err != nil {
1718 return e.Redirect(303, "/account/signin")
···2021 oldestPossibleSession := time.Now().Add(constants.ConfidentialClientSessionLifetime)
21222223 var tokens []provider.OauthToken
2323- if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil {
2424+ if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil {
2425 s.logger.Error("couldnt fetch oauth sessions for account", "did", repo.Repo.Did, "error", err)
2526 sess.AddFlash("Unable to fetch sessions. See server logs for more details.", "error")
2627 sess.Save(e.Request(), e.Response())
+5-3
server/handle_account_revoke.go
···55 "github.com/labstack/echo/v4"
66)
7788-type AccountRevokeRequest struct {
88+type AccountRevokeInput struct {
99 Token string `form:"token"`
1010}
11111212func (s *Server) handleAccountRevoke(e echo.Context) error {
1313- var req AccountRevokeRequest
1313+ ctx := e.Request().Context()
1414+1515+ var req AccountRevokeInput
1416 if err := e.Bind(&req); err != nil {
1517 s.logger.Error("could not bind account revoke request", "error", err)
1618 return helpers.ServerError(e, nil)
···2123 return e.Redirect(303, "/account/signin")
2224 }
23252424- if err := s.db.Exec("DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil {
2626+ if err := s.db.Exec(ctx, "DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil {
2527 s.logger.Error("couldnt delete oauth session for account", "did", repo.Repo.Did, "token", req.Token, "error", err)
2628 sess.AddFlash("Unable to revoke session. See server logs for more details.", "error")
2729 sess.Save(e.Request(), e.Response())
+10-6
server/handle_account_signin.go
···1414 "gorm.io/gorm"
1515)
16161717-type OauthSigninRequest struct {
1717+type OauthSigninInput struct {
1818 Username string `form:"username"`
1919 Password string `form:"password"`
2020 QueryParams string `form:"query_params"`
2121}
22222323func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) {
2424+ ctx := e.Request().Context()
2525+2426 sess, err := session.Get("session", e)
2527 if err != nil {
2628 return nil, nil, err
···3133 return nil, sess, errors.New("did was not set in session")
3234 }
33353434- repo, err := s.getRepoActorByDid(did)
3636+ repo, err := s.getRepoActorByDid(ctx, did)
3537 if err != nil {
3638 return nil, sess, err
3739 }
···6062}
61636264func (s *Server) handleAccountSigninPost(e echo.Context) error {
6363- var req OauthSigninRequest
6565+ ctx := e.Request().Context()
6666+6767+ var req OauthSigninInput
6468 if err := e.Bind(&req); err != nil {
6569 s.logger.Error("error binding sign in req", "error", err)
6670 return helpers.ServerError(e, nil)
···8387 var err error
8488 switch idtype {
8589 case "did":
8686- err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error
9090+ err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error
8791 case "handle":
8888- err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error
9292+ err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error
8993 case "email":
9090- err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error
9494+ err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error
9195 }
9296 if err != nil {
9397 if err == gorm.ErrRecordNotFound {
+3-1
server/handle_actor_put_preferences.go
···1010// This is kinda lame. Not great to implement app.bsky in the pds, but alas
11111212func (s *Server) handleActorPutPreferences(e echo.Context) error {
1313+ ctx := e.Request().Context()
1414+1315 repo := e.Get("repo").(*models.RepoActor)
14161517 var prefs map[string]any
···2224 return err
2325 }
24262525- if err := s.db.Exec("UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil {
2727+ if err := s.db.Exec(ctx, "UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil {
2628 return err
2729 }
2830
···1313)
14141515func (s *Server) handleOauthAuthorizeGet(e echo.Context) error {
1616+ ctx := e.Request().Context()
1717+1618 reqUri := e.QueryParam("request_uri")
1719 if reqUri == "" {
1820 // render page for logged out dev
···3840 }
39414042 var req provider.OauthAuthorizationRequest
4141- if err := s.db.Raw("SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil {
4343+ if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil {
4244 return helpers.ServerError(e, to.StringPtr(err.Error()))
4345 }
4446···7274}
73757476func (s *Server) handleOauthAuthorizePost(e echo.Context) error {
7777+ ctx := e.Request().Context()
7878+7579 repo, _, err := s.getSessionRepoOrErr(e)
7680 if err != nil {
7781 return e.Redirect(303, "/account/signin")
···8993 }
90949195 var authReq provider.OauthAuthorizationRequest
9292- if err := s.db.Raw("SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil {
9696+ if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil {
9397 return helpers.ServerError(e, to.StringPtr(err.Error()))
9498 }
9599···113117114118 code := oauth.GenerateCode()
115119116116- if err := s.db.Exec("UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil {
120120+ if err := s.db.Exec(ctx, "UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil {
117121 s.logger.Error("error updating authorization request", "error", err)
118122 return helpers.ServerError(e, nil)
119123 }
+3-1
server/handle_oauth_par.go
···1919}
20202121func (s *Server) handleOauthPar(e echo.Context) error {
2222+ ctx := e.Request().Context()
2323+2224 var parRequest provider.ParRequest
2325 if err := e.Bind(&parRequest); err != nil {
2426 s.logger.Error("error binding for par request", "error", err)
···8688 ExpiresAt: eat,
8789 }
88908989- if err := s.db.Create(authRequest, nil).Error; err != nil {
9191+ if err := s.db.Create(ctx, authRequest, nil).Error; err != nil {
9092 s.logger.Error("error creating auth request in db", "error", err)
9193 return helpers.ServerError(e, nil)
9294 }
+7-5
server/handle_oauth_token.go
···3838}
39394040func (s *Server) handleOauthToken(e echo.Context) error {
4141+ ctx := e.Request().Context()
4242+4143 var req OauthTokenRequest
4244 if err := e.Bind(&req); err != nil {
4345 s.logger.Error("error binding token request", "error", err)
···84868587 var authReq provider.OauthAuthorizationRequest
8688 // get the lil guy and delete him
8787- if err := s.db.Raw("DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil {
8989+ if err := s.db.Raw(ctx, "DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil {
8890 s.logger.Error("error finding authorization request", "error", err)
8991 return helpers.ServerError(e, nil)
9092 }
···128130 return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided"))
129131 }
130132131131- repo, err := s.getRepoActorByDid(*authReq.Sub)
133133+ repo, err := s.getRepoActorByDid(ctx, *authReq.Sub)
132134 if err != nil {
133135 helpers.InputError(e, to.StringPtr("unable to find actor"))
134136 }
···159161 return err
160162 }
161163162162- if err := s.db.Create(&provider.OauthToken{
164164+ if err := s.db.Create(ctx, &provider.OauthToken{
163165 ClientId: authReq.ClientId,
164166 ClientAuth: *clientAuth,
165167 Parameters: authReq.Parameters,
···199201 }
200202201203 var oauthToken provider.OauthToken
202202- if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil {
204204+ if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil {
203205 s.logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken)
204206 return helpers.ServerError(e, nil)
205207 }
···257259 return err
258260 }
259261260260- if err := s.db.Exec("UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil {
262262+ if err := s.db.Exec(ctx, "UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil {
261263 s.logger.Error("error updating token", "error", err)
262264 return helpers.ServerError(e, nil)
263265 }
+4-2
server/handle_repo_describe_repo.go
···2020}
21212222func (s *Server) handleDescribeRepo(e echo.Context) error {
2323+ ctx := e.Request().Context()
2424+2325 did := e.QueryParam("repo")
2424- repo, err := s.getRepoActorByDid(did)
2626+ repo, err := s.getRepoActorByDid(ctx, did)
2527 if err != nil {
2628 if err == gorm.ErrRecordNotFound {
2729 return helpers.InputError(e, to.StringPtr("RepoNotFound"))
···6466 }
65676668 var records []models.Record
6767- if err := s.db.Raw("SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil {
6969+ if err := s.db.Raw(ctx, "SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil {
6870 s.logger.Error("error getting collections", "error", err)
6971 return helpers.ServerError(e, nil)
7072 }
+3-1
server/handle_repo_get_record.go
···1414}
15151616func (s *Server) handleRepoGetRecord(e echo.Context) error {
1717+ ctx := e.Request().Context()
1818+1719 repo := e.QueryParam("repo")
1820 collection := e.QueryParam("collection")
1921 rkey := e.QueryParam("rkey")
···3234 }
33353436 var record models.Record
3535- if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil {
3737+ if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil {
3638 // TODO: handle error nicely
3739 return err
3840 }
+4-2
server/handle_repo_list_missing_blobs.go
···2222}
23232424func (s *Server) handleListMissingBlobs(e echo.Context) error {
2525+ ctx := e.Request().Context()
2626+2527 urepo := e.Get("repo").(*models.RepoActor)
26282729 limitStr := e.QueryParam("limit")
···3537 }
36383739 var records []models.Record
3838- if err := s.db.Raw("SELECT * FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&records).Error; err != nil {
4040+ if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&records).Error; err != nil {
3941 s.logger.Error("failed to get records for listMissingBlobs", "error", err)
4042 return helpers.ServerError(e, nil)
4143 }
···6971 }
70727173 var count int64
7272- if err := s.db.Raw("SELECT COUNT(*) FROM blobs WHERE did = ? AND cid = ?", nil, urepo.Repo.Did, ref.cid.Bytes()).Scan(&count).Error; err != nil {
7474+ if err := s.db.Raw(ctx, "SELECT COUNT(*) FROM blobs WHERE did = ? AND cid = ?", nil, urepo.Repo.Did, ref.cid.Bytes()).Scan(&count).Error; err != nil {
7375 continue
7476 }
7577
+4-2
server/handle_repo_list_records.go
···4646}
47474848func (s *Server) handleListRecords(e echo.Context) error {
4949+ ctx := e.Request().Context()
5050+4951 var req ComAtprotoRepoListRecordsRequest
5052 if err := e.Bind(&req); err != nil {
5153 s.logger.Error("could not bind list records request", "error", err)
···78807981 did := req.Repo
8082 if _, err := syntax.ParseDID(did); err != nil {
8181- actor, err := s.getActorByHandle(req.Repo)
8383+ actor, err := s.getActorByHandle(ctx, req.Repo)
8284 if err != nil {
8385 return helpers.InputError(e, to.StringPtr("RepoNotFound"))
8486 }
···9395 params = append(params, limit)
94969597 var records []models.Record
9696- if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil {
9898+ if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil {
9799 s.logger.Error("error getting records", "error", err)
98100 return helpers.ServerError(e, nil)
99101 }
+3-1
server/handle_repo_list_repos.go
···21212222// TODO: paginate this bitch
2323func (s *Server) handleListRepos(e echo.Context) error {
2424+ ctx := e.Request().Context()
2525+2426 var repos []models.Repo
2525- if err := s.db.Raw("SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil {
2727+ if err := s.db.Raw(ctx, "SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil {
2628 return err
2729 }
2830
+5-3
server/handle_repo_upload_blob.go
···3232}
33333434func (s *Server) handleRepoUploadBlob(e echo.Context) error {
3535+ ctx := e.Request().Context()
3636+3537 urepo := e.Get("repo").(*models.RepoActor)
36383739 mime := e.Request().Header.Get("content-type")
···5153 Storage: storage,
5254 }
53555454- if err := s.db.Create(&blob, nil).Error; err != nil {
5656+ if err := s.db.Create(ctx, &blob, nil).Error; err != nil {
5557 s.logger.Error("error creating new blob in db", "error", err)
5658 return helpers.ServerError(e, nil)
5759 }
···8486 Data: data,
8587 }
86888787- if err := s.db.Create(&blobPart, nil).Error; err != nil {
8989+ if err := s.db.Create(ctx, &blobPart, nil).Error; err != nil {
8890 s.logger.Error("error adding blob part to db", "error", err)
8991 return helpers.ServerError(e, nil)
9092 }
···131133 }
132134 }
133135134134- if err := s.db.Exec("UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil {
136136+ if err := s.db.Exec(ctx, "UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil {
135137 // there should probably be somme handling here if this fails...
136138 s.logger.Error("error updating blob", "error", err)
137139 return helpers.ServerError(e, nil)
+3-1
server/handle_server_activate_account.go
···1818}
19192020func (s *Server) handleServerActivateAccount(e echo.Context) error {
2121+ ctx := e.Request().Context()
2222+2123 var req ComAtprotoServerDeactivateAccountRequest
2224 if err := e.Bind(&req); err != nil {
2325 s.logger.Error("error binding", "error", err)
···26282729 urepo := e.Get("repo").(*models.RepoActor)
28302929- if err := s.db.Exec("UPDATE repos SET deactivated = ? WHERE did = ?", nil, false, urepo.Repo.Did).Error; err != nil {
3131+ if err := s.db.Exec(ctx, "UPDATE repos SET deactivated = ? WHERE did = ?", nil, false, urepo.Repo.Did).Error; err != nil {
3032 s.logger.Error("error updating account status to deactivated", "error", err)
3133 return helpers.ServerError(e, nil)
3234 }
+5-3
server/handle_server_check_account_status.go
···2020}
21212222func (s *Server) handleServerCheckAccountStatus(e echo.Context) error {
2323+ ctx := e.Request().Context()
2424+2325 urepo := e.Get("repo").(*models.RepoActor)
24262527 resp := ComAtprotoServerCheckAccountStatusResponse{
···4143 }
42444345 var blockCtResp CountResp
4444- if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil {
4646+ if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil {
4547 s.logger.Error("error getting block count", "error", err)
4648 return helpers.ServerError(e, nil)
4749 }
4850 resp.RepoBlocks = blockCtResp.Ct
49515052 var recCtResp CountResp
5151- if err := s.db.Raw("SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil {
5353+ if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil {
5254 s.logger.Error("error getting record count", "error", err)
5355 return helpers.ServerError(e, nil)
5456 }
5557 resp.IndexedRecords = recCtResp.Ct
56585759 var blobCtResp CountResp
5858- if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil {
6060+ if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil {
5961 s.logger.Error("error getting record count", "error", err)
6062 return helpers.ServerError(e, nil)
6163 }
+3-1
server/handle_server_confirm_email.go
···1515}
16161717func (s *Server) handleServerConfirmEmail(e echo.Context) error {
1818+ ctx := e.Request().Context()
1919+1820 urepo := e.Get("repo").(*models.RepoActor)
19212022 var req ComAtprotoServerConfirmEmailRequest
···41434244 now := time.Now().UTC()
43454444- if err := s.db.Exec("UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", nil, now, urepo.Repo.Did).Error; err != nil {
4646+ if err := s.db.Exec(ctx, "UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", nil, now, urepo.Repo.Did).Error; err != nil {
4547 s.logger.Error("error updating user", "error", err)
4648 return helpers.ServerError(e, nil)
4749 }
+16-14
server/handle_server_create_account.go
···3636}
37373838func (s *Server) handleCreateAccount(e echo.Context) error {
3939+ ctx := e.Request().Context()
4040+3941 var request ComAtprotoServerCreateAccountRequest
40424143 if err := e.Bind(&request); err != nil {
···6870 }
6971 }
7072 }
7171-7373+7274 var signupDid string
7375 if request.Did != nil {
7474- signupDid = *request.Did;
7575-7676+ signupDid = *request.Did
7777+7678 token := strings.TrimSpace(strings.Replace(e.Request().Header.Get("authorization"), "Bearer ", "", 1))
7779 if token == "" {
7880 return helpers.UnauthorizedError(e, to.StringPtr("must authenticate to use an existing did"))
···9092 }
91939294 // see if the handle is already taken
9393- actor, err := s.getActorByHandle(request.Handle)
9595+ actor, err := s.getActorByHandle(ctx, request.Handle)
9496 if err != nil && err != gorm.ErrRecordNotFound {
9597 s.logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err)
9698 return helpers.ServerError(e, nil)
···109111 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
110112 }
111113112112- if err := s.db.Raw("SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
114114+ if err := s.db.Raw(ctx, "SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
113115 if err == gorm.ErrRecordNotFound {
114116 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
115117 }
···123125 }
124126125127 // see if the email is already taken
126126- existingRepo, err := s.getRepoByEmail(request.Email)
128128+ existingRepo, err := s.getRepoByEmail(ctx, request.Email)
127129 if err != nil && err != gorm.ErrRecordNotFound {
128130 s.logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err)
129131 return helpers.ServerError(e, nil)
···137139 var k *atcrypto.PrivateKeyK256
138140139141 if signupDid != "" {
140140- reservedKey, err := s.getReservedKey(signupDid)
142142+ reservedKey, err := s.getReservedKey(ctx, signupDid)
141143 if err != nil {
142144 s.logger.Error("error looking up reserved key", "error", err)
143145 }
···148150 k = nil
149151 } else {
150152 defer func() {
151151- if delErr := s.deleteReservedKey(reservedKey.KeyDid, reservedKey.Did); delErr != nil {
153153+ if delErr := s.deleteReservedKey(ctx, reservedKey.KeyDid, reservedKey.Did); delErr != nil {
152154 s.logger.Error("error deleting reserved key", "error", delErr)
153155 }
154156 }()
···199201 Handle: request.Handle,
200202 }
201203202202- if err := s.db.Create(&urepo, nil).Error; err != nil {
204204+ if err := s.db.Create(ctx, &urepo, nil).Error; err != nil {
203205 s.logger.Error("error inserting new repo", "error", err)
204206 return helpers.ServerError(e, nil)
205207 }
206206-207207- if err := s.db.Create(&actor, nil).Error; err != nil {
208208+209209+ if err := s.db.Create(ctx, &actor, nil).Error; err != nil {
208210 s.logger.Error("error inserting new actor", "error", err)
209211 return helpers.ServerError(e, nil)
210212 }
211213 } else {
212212- if err := s.db.Save(&actor, nil).Error; err != nil {
214214+ if err := s.db.Save(ctx, &actor, nil).Error; err != nil {
213215 s.logger.Error("error inserting new actor", "error", err)
214216 return helpers.ServerError(e, nil)
215217 }
···241243 }
242244243245 if s.config.RequireInvite {
244244- if err := s.db.Raw("UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
246246+ if err := s.db.Raw(ctx, "UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
245247 s.logger.Error("error decrementing use count", "error", err)
246248 return helpers.ServerError(e, nil)
247249 }
248250 }
249251250250- sess, err := s.createSession(&urepo)
252252+ sess, err := s.createSession(ctx, &urepo)
251253 if err != nil {
252254 s.logger.Error("error creating new session", "error", err)
253255 return helpers.ServerError(e, nil)
···37373838func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
3939 return func(e echo.Context) error {
4040+ ctx := e.Request().Context()
4141+4042 authheader := e.Request().Header.Get("authorization")
4143 if authheader == "" {
4244 return e.JSON(401, map[string]string{"error": "Unauthorized"})
···7880 }
7981 did = maybeDid
80828181- maybeRepo, err := s.getRepoActorByDid(did)
8383+ maybeRepo, err := s.getRepoActorByDid(ctx, did)
8284 if err != nil {
8385 s.logger.Error("error fetching repo", "error", err)
8486 return helpers.ServerError(e, nil)
···159161 Found bool
160162 }
161163 var result Result
162162- if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
164164+ if err := s.db.Raw(ctx, "SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
163165 if err == gorm.ErrRecordNotFound {
164166 return helpers.InvalidTokenError(e)
165167 }
···184186 }
185187186188 if repo == nil {
187187- maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string))
189189+ maybeRepo, err := s.getRepoActorByDid(ctx, claims["sub"].(string))
188190 if err != nil {
189191 s.logger.Error("error fetching repo", "error", err)
190192 return helpers.ServerError(e, nil)
···207209208210func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
209211 return func(e echo.Context) error {
212212+ ctx := e.Request().Context()
213213+210214 authheader := e.Request().Header.Get("authorization")
211215 if authheader == "" {
212216 return e.JSON(401, map[string]string{"error": "Unauthorized"})
···243247 }
244248245249 var oauthToken provider.OauthToken
246246- if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil {
250250+ if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil {
247251 s.logger.Error("error finding access token in db", "error", err)
248252 return helpers.InputError(e, nil)
249253 }
···266270 })
267271 }
268272269269- repo, err := s.getRepoActorByDid(oauthToken.Sub)
273273+ repo, err := s.getRepoActorByDid(ctx, oauthToken.Sub)
270274 if err != nil {
271275 s.logger.Error("could not find actor in db", "error", err)
272276 return helpers.ServerError(e, nil)
+11-11
server/repo.go
···181181 case OpTypeDelete:
182182 // try to find the old record in the database
183183 var old models.Record
184184- if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil {
184184+ if err := rm.db.Raw(ctx, "SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil {
185185 return nil, err
186186 }
187187···323323 var cids []cid.Cid
324324 // whenever there is cid present, we know it's a create (dumb)
325325 if entry.Cid != "" {
326326- if err := rm.s.db.Create(&entry, []clause.Expression{clause.OnConflict{
326326+ if err := rm.s.db.Create(ctx, &entry, []clause.Expression{clause.OnConflict{
327327 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}},
328328 UpdateAll: true,
329329 }}).Error; err != nil {
···331331 }
332332333333 // increment the given blob refs, yay
334334- cids, err = rm.incrementBlobRefs(urepo, entry.Value)
334334+ cids, err = rm.incrementBlobRefs(ctx, urepo, entry.Value)
335335 if err != nil {
336336 return nil, err
337337 }
···339339 // as i noted above this is dumb. but we delete whenever the cid is nil. it works solely becaue the pkey
340340 // is did + collection + rkey. i still really want to separate that out, or use a different type to make
341341 // this less confusing/easy to read. alas, its 2 am and yea no
342342- if err := rm.s.db.Delete(&entry, nil).Error; err != nil {
342342+ if err := rm.s.db.Delete(ctx, &entry, nil).Error; err != nil {
343343 return nil, err
344344 }
345345346346 // TODO:
347347- cids, err = rm.decrementBlobRefs(urepo, entry.Value)
347347+ cids, err = rm.decrementBlobRefs(ctx, urepo, entry.Value)
348348 if err != nil {
349349 return nil, err
350350 }
···411411 return c, bs.GetReadLog(), nil
412412}
413413414414-func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
414414+func (rm *RepoMan) incrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
415415 cids, err := getBlobCidsFromCbor(cbor)
416416 if err != nil {
417417 return nil, err
418418 }
419419420420 for _, c := range cids {
421421- if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil {
421421+ if err := rm.db.Exec(ctx, "UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil {
422422 return nil, err
423423 }
424424 }
···426426 return cids, nil
427427}
428428429429-func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
429429+func (rm *RepoMan) decrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
430430 cids, err := getBlobCidsFromCbor(cbor)
431431 if err != nil {
432432 return nil, err
···437437 ID uint
438438 Count int
439439 }
440440- if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil {
440440+ if err := rm.db.Raw(ctx, "UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil {
441441 return nil, err
442442 }
443443444444 // TODO: this does _not_ handle deletions of blobs that are on s3 storage!!!! we need to get the blob, see what
445445 // storage it is in, and clean up s3!!!!
446446 if res.Count == 0 {
447447- if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil {
447447+ if err := rm.db.Exec(ctx, "DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil {
448448 return nil, err
449449 }
450450- if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil {
450450+ if err := rm.db.Exec(ctx, "DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil {
451451 return nil, err
452452 }
453453 }
+1-1
server/server.go
···729729}
730730731731func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
732732- if err := s.db.Exec("UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
732732+ if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
733733 return err
734734 }
735735