···11+package db
22+33+import (
44+ "sync"
55+66+ "gorm.io/gorm"
77+ "gorm.io/gorm/clause"
88+)
99+1010+type DB struct {
1111+ cli *gorm.DB
1212+ mu sync.Mutex
1313+}
1414+1515+func NewDB(cli *gorm.DB) *DB {
1616+ return &DB{
1717+ cli: cli,
1818+ mu: sync.Mutex{},
1919+ }
2020+}
2121+2222+func (db *DB) Create(value any, clauses []clause.Expression) *gorm.DB {
2323+ db.mu.Lock()
2424+ defer db.mu.Unlock()
2525+ return db.cli.Clauses(clauses...).Create(value)
2626+}
2727+2828+func (db *DB) Exec(sql string, clauses []clause.Expression, values ...any) *gorm.DB {
2929+ db.mu.Lock()
3030+ defer db.mu.Unlock()
3131+ return db.cli.Clauses(clauses...).Exec(sql, values...)
3232+}
3333+3434+func (db *DB) Raw(sql string, clauses []clause.Expression, values ...any) *gorm.DB {
3535+ return db.cli.Clauses(clauses...).Raw(sql, values...)
3636+}
3737+3838+func (db *DB) AutoMigrate(models ...any) error {
3939+ return db.cli.AutoMigrate(models...)
4040+}
4141+4242+func (db *DB) Delete(value any, clauses []clause.Expression) *gorm.DB {
4343+ db.mu.Lock()
4444+ defer db.mu.Unlock()
4545+ return db.cli.Clauses(clauses...).Delete(value)
4646+}
4747+4848+func (db *DB) First(dest any, conds ...any) *gorm.DB {
4949+ return db.cli.First(dest, conds...)
5050+}
5151+5252+// TODO: this isn't actually good. we can commit even if the db is locked here. this is probably okay for the time being, but need to figure
5353+// out a better solution. right now we only do this whenever we're importing a repo though so i'm mostly not worried, but it's still bad.
5454+// e.g. when we do apply writes we should also be using a transcation but we don't right now
5555+func (db *DB) BeginDangerously() *gorm.DB {
5656+ return db.cli.Begin()
5757+}
5858+5959+func (db *DB) Lock() {
6060+ db.mu.Lock()
6161+}
6262+6363+func (db *DB) Unlock() {
6464+ db.mu.Unlock()
6565+}
+2-2
server/common.go
···22222323func (s *Server) getRepoActorByEmail(email string) (*models.RepoActor, error) {
2424 var repo models.RepoActor
2525- if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", email).Scan(&repo).Error; err != nil {
2525+ if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", nil, email).Scan(&repo).Error; err != nil {
2626 return nil, err
2727 }
2828 return &repo, nil
···30303131func (s *Server) getRepoActorByDid(did string) (*models.RepoActor, error) {
3232 var repo models.RepoActor
3333- if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", did).Scan(&repo).Error; err != nil {
3333+ if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, did).Scan(&repo).Error; err != nil {
3434 return nil, err
3535 }
3636 return &repo, nil
+1-1
server/handle_actor_put_preferences.go
···2222 return err
2323 }
24242525- if err := s.db.Exec("UPDATE repos SET preferences = ? WHERE did = ?", b, repo.Repo.Did).Error; err != nil {
2525+ if err := s.db.Exec("UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil {
2626 return err
2727 }
2828
+1-1
server/handle_identity_update_handle.go
···103103 },
104104 })
105105106106- if err := s.db.Exec("UPDATE actors SET handle = ? WHERE did = ?", req.Handle, repo.Repo.Did).Error; err != nil {
106106+ if err := s.db.Exec("UPDATE actors SET handle = ? WHERE did = ?", nil, req.Handle, repo.Repo.Did).Error; err != nil {
107107 s.logger.Error("error updating handle in db", "error", err)
108108 return helpers.ServerError(e, nil)
109109 }
···6464 }
65656666 var records []models.Record
6767- if err := s.db.Raw("SELECT DISTINCT(nsid) FROM records WHERE did = ?", repo.Repo.Did).Scan(&records).Error; err != nil {
6767+ if err := s.db.Raw("SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil {
6868 s.logger.Error("error getting collections", "error", err)
6969 return helpers.ServerError(e, nil)
7070 }
+1-1
server/handle_repo_get_record.go
···3232 }
33333434 var record models.Record
3535- if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, params...).Scan(&record).Error; err != nil {
3535+ if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil {
3636 // TODO: handle error nicely
3737 return err
3838 }
+1-1
server/handle_repo_list_records.go
···6464 params = append(params, limit)
65656666 var records []models.Record
6767- if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", params...).Scan(&records).Error; err != nil {
6767+ if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil {
6868 s.logger.Error("error getting records", "error", err)
6969 return helpers.ServerError(e, nil)
7070 }
+1-1
server/handle_repo_list_repos.go
···2222// TODO: paginate this bitch
2323func (s *Server) handleListRepos(e echo.Context) error {
2424 var repos []models.Repo
2525- if err := s.db.Raw("SELECT * FROM repos ORDER BY created_at DESC LIMIT 500").Scan(&repos).Error; err != nil {
2525+ if err := s.db.Raw("SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil {
2626 return err
2727 }
2828
+3-3
server/handle_repo_upload_blob.go
···4040 CreatedAt: s.repoman.clock.Next().String(),
4141 }
42424343- if err := s.db.Create(&blob).Error; err != nil {
4343+ if err := s.db.Create(&blob, nil).Error; err != nil {
4444 s.logger.Error("error creating new blob in db", "error", err)
4545 return helpers.ServerError(e, nil)
4646 }
···7272 Data: data,
7373 }
74747575- if err := s.db.Create(&blobPart).Error; err != nil {
7575+ if err := s.db.Create(&blobPart, nil).Error; err != nil {
7676 s.logger.Error("error adding blob part to db", "error", err)
7777 return helpers.ServerError(e, nil)
7878 }
···8989 return helpers.ServerError(e, nil)
9090 }
91919292- if err := s.db.Exec("UPDATE blobs SET cid = ? WHERE id = ?", c.Bytes(), blob.ID).Error; err != nil {
9292+ if err := s.db.Exec("UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil {
9393 // there should probably be somme handling here if this fails...
9494 s.logger.Error("error updating blob", "error", err)
9595 return helpers.ServerError(e, nil)
+1-1
server/handle_server_confirm_email.go
···41414242 now := time.Now().UTC()
43434444- if err := s.db.Exec("UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", now, urepo.Repo.Did).Error; err != nil {
4444+ if err := s.db.Exec("UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", nil, now, urepo.Repo.Did).Error; err != nil {
4545 s.logger.Error("error updating user", "error", err)
4646 return helpers.ServerError(e, nil)
4747 }
+4-4
server/handle_server_create_account.go
···102102 }
103103104104 var ic models.InviteCode
105105- if err := s.db.Raw("SELECT * FROM invite_codes WHERE code = ?", request.InviteCode).Scan(&ic).Error; err != nil {
105105+ if err := s.db.Raw("SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
106106 if err == gorm.ErrRecordNotFound {
107107 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
108108 }
···166166 Handle: request.Handle,
167167 }
168168169169- if err := s.db.Create(&urepo).Error; err != nil {
169169+ if err := s.db.Create(&urepo, nil).Error; err != nil {
170170 s.logger.Error("error inserting new repo", "error", err)
171171 return helpers.ServerError(e, nil)
172172 }
173173174174- if err := s.db.Create(&actor).Error; err != nil {
174174+ if err := s.db.Create(&actor, nil).Error; err != nil {
175175 s.logger.Error("error inserting new actor", "error", err)
176176 return helpers.ServerError(e, nil)
177177 }
···210210 })
211211 }
212212213213- if err := s.db.Raw("UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", request.InviteCode).Scan(&ic).Error; err != nil {
213213+ if err := s.db.Raw("UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
214214 s.logger.Error("error decrementing use count", "error", err)
215215 return helpers.ServerError(e, nil)
216216 }
···6565 var err error
6666 switch idtype {
6767 case "did":
6868- err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", req.Identifier).Scan(&repo).Error
6868+ err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error
6969 case "handle":
7070- err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", req.Identifier).Scan(&repo).Error
7070+ err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error
7171 case "email":
7272- err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", req.Identifier).Scan(&repo).Error
7272+ err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error
7373 }
74747575 if err != nil {
+2-2
server/handle_server_delete_session.go
···1010 token := e.Get("token").(string)
11111212 var acctok models.Token
1313- if err := s.db.Raw("DELETE FROM tokens WHERE token = ? RETURNING *", token).Scan(&acctok).Error; err != nil {
1313+ if err := s.db.Raw("DELETE FROM tokens WHERE token = ? RETURNING *", nil, token).Scan(&acctok).Error; err != nil {
1414 s.logger.Error("error deleting access token from db", "error", err)
1515 return helpers.ServerError(e, nil)
1616 }
17171818- if err := s.db.Exec("DELETE FROM refresh_tokens WHERE token = ?", acctok.RefreshToken).Error; err != nil {
1818+ if err := s.db.Exec("DELETE FROM refresh_tokens WHERE token = ?", nil, acctok.RefreshToken).Error; err != nil {
1919 s.logger.Error("error deleting refresh token from db", "error", err)
2020 return helpers.ServerError(e, nil)
2121 }
+2-2
server/handle_server_refresh_session.go
···1919 token := e.Get("token").(string)
2020 repo := e.Get("repo").(*models.RepoActor)
21212222- if err := s.db.Exec("DELETE FROM refresh_tokens WHERE token = ?", token).Error; err != nil {
2222+ if err := s.db.Exec("DELETE FROM refresh_tokens WHERE token = ?", nil, token).Error; err != nil {
2323 s.logger.Error("error getting refresh token from db", "error", err)
2424 return helpers.ServerError(e, nil)
2525 }
26262727- if err := s.db.Exec("DELETE FROM tokens WHERE refresh_token = ?", token).Error; err != nil {
2727+ if err := s.db.Exec("DELETE FROM tokens WHERE refresh_token = ?", nil, token).Error; err != nil {
2828 s.logger.Error("error deleting access token from db", "error", err)
2929 return helpers.ServerError(e, nil)
3030 }
···2020 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5))
2121 eat := time.Now().Add(10 * time.Minute).UTC()
22222323- if err := s.db.Exec("UPDATE repos SET email_verification_code = ?, email_verification_code_expires_at = ? WHERE did = ?", code, eat, urepo.Repo.Did).Error; err != nil {
2323+ if err := s.db.Exec("UPDATE repos SET email_verification_code = ?, email_verification_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil {
2424 s.logger.Error("error updating user", "error", err)
2525 return helpers.ServerError(e, nil)
2626 }
+1-1
server/handle_server_request_email_update.go
···2020 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5))
2121 eat := time.Now().Add(10 * time.Minute).UTC()
22222323- if err := s.db.Exec("UPDATE repos SET email_update_code = ?, email_update_code_expires_at = ? WHERE did = ?", code, eat, urepo.Repo.Did).Error; err != nil {
2323+ if err := s.db.Exec("UPDATE repos SET email_update_code = ?, email_update_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil {
2424 s.logger.Error("error updating repo", "error", err)
2525 return helpers.ServerError(e, nil)
2626 }
+1-1
server/handle_server_request_password_reset.go
···3636 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5))
3737 eat := time.Now().Add(10 * time.Minute).UTC()
38383939- if err := s.db.Exec("UPDATE repos SET password_reset_code = ?, password_reset_code_expires_at = ? WHERE did = ?", code, eat, urepo.Repo.Did).Error; err != nil {
3939+ if err := s.db.Exec("UPDATE repos SET password_reset_code = ?, password_reset_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil {
4040 s.logger.Error("error updating repo", "error", err)
4141 return helpers.ServerError(e, nil)
4242 }
+1-1
server/handle_server_reset_password.go
···4646 return helpers.ServerError(e, nil)
4747 }
48484949- if err := s.db.Exec("UPDATE repos SET password_reset_code = NULL, password_reset_code_expires_at = NULL, password = ? WHERE did = ?", hash, urepo.Repo.Did).Error; err != nil {
4949+ if err := s.db.Exec("UPDATE repos SET password_reset_code = NULL, password_reset_code_expires_at = NULL, password = ? WHERE did = ?", nil, hash, urepo.Repo.Did).Error; err != nil {
5050 s.logger.Error("error updating repo", "error", err)
5151 return helpers.ServerError(e, nil)
5252 }
+1-1
server/handle_server_update_email.go
···4040 return helpers.InputError(e, to.StringPtr("ExpiredToken"))
4141 }
42424343- if err := s.db.Exec("UPDATE repos SET email_update_code = NULL, email_update_code_expires_at = NULL, email_confirmed_at = NULL, email = ? WHERE did = ?", req.Email, urepo.Repo.Did).Error; err != nil {
4343+ if err := s.db.Exec("UPDATE repos SET email_update_code = NULL, email_update_code_expires_at = NULL, email_confirmed_at = NULL, email = ? WHERE did = ?", nil, req.Email, urepo.Repo.Did).Error; err != nil {
4444 s.logger.Error("error updating repo", "error", err)
4545 return helpers.ServerError(e, nil)
4646 }
+3-3
server/handle_sync_get_blob.go
···2626 }
27272828 var blob models.Blob
2929- if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? AND cid = ?", did, c.Bytes()).Scan(&blob).Error; err != nil {
2929+ if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? AND cid = ?", nil, did, c.Bytes()).Scan(&blob).Error; err != nil {
3030 s.logger.Error("error looking up blob", "error", err)
3131 return helpers.ServerError(e, nil)
3232 }
···3434 buf := new(bytes.Buffer)
35353636 var parts []models.BlobPart
3737- if err := s.db.Raw("SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", blob.ID).Scan(&parts).Error; err != nil {
3737+ if err := s.db.Raw("SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", nil, blob.ID).Scan(&parts).Error; err != nil {
3838 s.logger.Error("error getting blob parts", "error", err)
3939 return helpers.ServerError(e, nil)
4040 }
···4444 buf.Write(p.Data)
4545 }
46464747- e.Response().Header().Set(echo.HeaderContentDisposition, "attachment; filename=" + c.String())
4747+ e.Response().Header().Set(echo.HeaderContentDisposition, "attachment; filename="+c.String())
48484949 return e.Stream(200, "application/octet-stream", buf)
5050}
+1-1
server/handle_sync_get_record.go
···1818 rkey := e.QueryParam("rkey")
19192020 var urepo models.Repo
2121- if err := s.db.Raw("SELECT * FROM repos WHERE did = ?", did).Scan(&urepo).Error; err != nil {
2121+ if err := s.db.Raw("SELECT * FROM repos WHERE did = ?", nil, did).Scan(&urepo).Error; err != nil {
2222 s.logger.Error("error getting repo", "error", err)
2323 return helpers.ServerError(e, nil)
2424 }
+1-1
server/handle_sync_get_repo.go
···4141 }
42424343 var blocks []models.Block
4444- if err := s.db.Raw("SELECT * FROM blocks WHERE did = ? ORDER BY rev ASC", urepo.Repo.Did).Scan(&blocks).Error; err != nil {
4444+ if err := s.db.Raw("SELECT * FROM blocks WHERE did = ? ORDER BY rev ASC", nil, urepo.Repo.Did).Scan(&blocks).Error; err != nil {
4545 return err
4646 }
4747
+1-1
server/handle_sync_list_blobs.go
···3535 params = append(params, limit)
36363737 var blobs []models.Blob
3838- if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? "+cursorquery+" ORDER BY created_at DESC LIMIT ?", params...).Scan(&blobs).Error; err != nil {
3838+ if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? "+cursorquery+" ORDER BY created_at DESC LIMIT ?", nil, params...).Scan(&blobs).Error; err != nil {
3939 s.logger.Error("error getting records", "error", err)
4040 return helpers.ServerError(e, nil)
4141 }
+10-10
server/repo.go
···1818 "github.com/bluesky-social/indigo/repo"
1919 "github.com/bluesky-social/indigo/util"
2020 "github.com/haileyok/cocoon/blockstore"
2121+ "github.com/haileyok/cocoon/internal/db"
2122 "github.com/haileyok/cocoon/models"
2223 blocks "github.com/ipfs/go-block-format"
2324 "github.com/ipfs/go-cid"
2425 cbor "github.com/ipfs/go-ipld-cbor"
2526 "github.com/ipld/go-car"
2626- "gorm.io/gorm"
2727 "gorm.io/gorm/clause"
2828)
29293030type RepoMan struct {
3131- db *gorm.DB
3131+ db *db.DB
3232 s *Server
3333 clock *syntax.TIDClock
3434}
···162162 })
163163 case OpTypeDelete:
164164 var old models.Record
165165- if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil {
165165+ if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil {
166166 return nil, err
167167 }
168168 entries = append(entries, models.Record{
···284284 for _, entry := range entries {
285285 var cids []cid.Cid
286286 if entry.Cid != "" {
287287- if err := rm.s.db.Clauses(clause.OnConflict{
287287+ if err := rm.s.db.Create(&entry, []clause.Expression{clause.OnConflict{
288288 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}},
289289 UpdateAll: true,
290290- }).Create(&entry).Error; err != nil {
290290+ }}).Error; err != nil {
291291 return nil, err
292292 }
293293···296296 return nil, err
297297 }
298298 } else {
299299- if err := rm.s.db.Delete(&entry).Error; err != nil {
299299+ if err := rm.s.db.Delete(&entry, nil).Error; err != nil {
300300 return nil, err
301301 }
302302 cids, err = rm.decrementBlobRefs(urepo, entry.Value)
···368368 }
369369370370 for _, c := range cids {
371371- if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", urepo.Did, c.Bytes()).Error; err != nil {
371371+ if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil {
372372 return nil, err
373373 }
374374 }
···387387 ID uint
388388 Count int
389389 }
390390- if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", urepo.Did, c.Bytes()).Scan(&res).Error; err != nil {
390390+ if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil {
391391 return nil, err
392392 }
393393394394 if res.Count == 0 {
395395- if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", res.ID).Error; err != nil {
395395+ if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil {
396396 return nil, err
397397 }
398398- if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", res.ID).Error; err != nil {
398398+ if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil {
399399 return nil, err
400400 }
401401 }