this repo has no description
1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sort"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "tangled.org/core/appview/models"
16 "tangled.org/core/orm"
17)
18
19func NewPull(tx *sql.Tx, pull *models.Pull) error {
20 _, err := tx.Exec(`
21 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
22 values (?, 1)
23 `, pull.RepoAt)
24 if err != nil {
25 return err
26 }
27
28 var nextId int
29 err = tx.QueryRow(`
30 update repo_pull_seqs
31 set next_pull_id = next_pull_id + 1
32 where repo_at = ?
33 returning next_pull_id - 1
34 `, pull.RepoAt).Scan(&nextId)
35 if err != nil {
36 return err
37 }
38
39 pull.PullId = nextId
40 pull.State = models.PullOpen
41
42 var sourceBranch, sourceRepoAt *string
43 if pull.PullSource != nil {
44 sourceBranch = &pull.PullSource.Branch
45 if pull.PullSource.RepoAt != nil {
46 x := pull.PullSource.RepoAt.String()
47 sourceRepoAt = &x
48 }
49 }
50
51 var stackId, changeId, parentChangeId *string
52 if pull.StackId != "" {
53 stackId = &pull.StackId
54 }
55 if pull.ChangeId != "" {
56 changeId = &pull.ChangeId
57 }
58 if pull.ParentChangeId != "" {
59 parentChangeId = &pull.ParentChangeId
60 }
61
62 result, err := tx.Exec(
63 `
64 insert into pulls (
65 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
66 )
67 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
68 pull.RepoAt,
69 pull.OwnerDid,
70 pull.PullId,
71 pull.Title,
72 pull.TargetBranch,
73 pull.Body,
74 pull.Rkey,
75 pull.State,
76 sourceBranch,
77 sourceRepoAt,
78 stackId,
79 changeId,
80 parentChangeId,
81 )
82 if err != nil {
83 return err
84 }
85
86 // Set the database primary key ID
87 id, err := result.LastInsertId()
88 if err != nil {
89 return err
90 }
91 pull.ID = int(id)
92
93 _, err = tx.Exec(`
94 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
95 values (?, ?, ?, ?, ?)
96 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev)
97 if err != nil {
98 return err
99 }
100
101 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil {
102 return fmt.Errorf("put reference_links: %w", err)
103 }
104
105 return nil
106}
107
108func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
109 pull, err := GetPull(e, repoAt, pullId)
110 if err != nil {
111 return "", err
112 }
113 return pull.AtUri(), err
114}
115
116func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
117 var pullId int
118 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
119 return pullId - 1, err
120}
121
122func GetPullsWithLimit(e Execer, limit int, filters ...orm.Filter) ([]*models.Pull, error) {
123 pulls := make(map[syntax.ATURI]*models.Pull)
124
125 var conditions []string
126 var args []any
127 for _, filter := range filters {
128 conditions = append(conditions, filter.Condition())
129 args = append(args, filter.Arg()...)
130 }
131
132 whereClause := ""
133 if conditions != nil {
134 whereClause = " where " + strings.Join(conditions, " and ")
135 }
136 limitClause := ""
137 if limit != 0 {
138 limitClause = fmt.Sprintf(" limit %d ", limit)
139 }
140
141 query := fmt.Sprintf(`
142 select
143 id,
144 owner_did,
145 repo_at,
146 pull_id,
147 created,
148 title,
149 state,
150 target_branch,
151 body,
152 rkey,
153 source_branch,
154 source_repo_at,
155 stack_id,
156 change_id,
157 parent_change_id
158 from
159 pulls
160 %s
161 order by
162 created desc
163 %s
164 `, whereClause, limitClause)
165
166 rows, err := e.Query(query, args...)
167 if err != nil {
168 return nil, err
169 }
170 defer rows.Close()
171
172 for rows.Next() {
173 var pull models.Pull
174 var createdAt string
175 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
176 err := rows.Scan(
177 &pull.ID,
178 &pull.OwnerDid,
179 &pull.RepoAt,
180 &pull.PullId,
181 &createdAt,
182 &pull.Title,
183 &pull.State,
184 &pull.TargetBranch,
185 &pull.Body,
186 &pull.Rkey,
187 &sourceBranch,
188 &sourceRepoAt,
189 &stackId,
190 &changeId,
191 &parentChangeId,
192 )
193 if err != nil {
194 return nil, err
195 }
196
197 createdTime, err := time.Parse(time.RFC3339, createdAt)
198 if err != nil {
199 return nil, err
200 }
201 pull.Created = createdTime
202
203 if sourceBranch.Valid {
204 pull.PullSource = &models.PullSource{
205 Branch: sourceBranch.String,
206 }
207 if sourceRepoAt.Valid {
208 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
209 if err != nil {
210 return nil, err
211 }
212 pull.PullSource.RepoAt = &sourceRepoAtParsed
213 }
214 }
215
216 if stackId.Valid {
217 pull.StackId = stackId.String
218 }
219 if changeId.Valid {
220 pull.ChangeId = changeId.String
221 }
222 if parentChangeId.Valid {
223 pull.ParentChangeId = parentChangeId.String
224 }
225
226 pulls[pull.AtUri()] = &pull
227 }
228
229 var pullAts []syntax.ATURI
230 for _, p := range pulls {
231 pullAts = append(pullAts, p.AtUri())
232 }
233 submissionsMap, err := GetPullSubmissions(e, orm.FilterIn("pull_at", pullAts))
234 if err != nil {
235 return nil, fmt.Errorf("failed to get submissions: %w", err)
236 }
237
238 for pullAt, submissions := range submissionsMap {
239 if p, ok := pulls[pullAt]; ok {
240 p.Submissions = submissions
241 }
242 }
243
244 // collect allLabels for each issue
245 allLabels, err := GetLabels(e, orm.FilterIn("subject", pullAts))
246 if err != nil {
247 return nil, fmt.Errorf("failed to query labels: %w", err)
248 }
249 for pullAt, labels := range allLabels {
250 if p, ok := pulls[pullAt]; ok {
251 p.Labels = labels
252 }
253 }
254
255 // collect pull source for all pulls that need it
256 var sourceAts []syntax.ATURI
257 for _, p := range pulls {
258 if p.PullSource != nil && p.PullSource.RepoAt != nil {
259 sourceAts = append(sourceAts, *p.PullSource.RepoAt)
260 }
261 }
262 sourceRepos, err := GetRepos(e, 0, orm.FilterIn("at_uri", sourceAts))
263 if err != nil && !errors.Is(err, sql.ErrNoRows) {
264 return nil, fmt.Errorf("failed to get source repos: %w", err)
265 }
266 sourceRepoMap := make(map[syntax.ATURI]*models.Repo)
267 for _, r := range sourceRepos {
268 sourceRepoMap[r.RepoAt()] = &r
269 }
270 for _, p := range pulls {
271 if p.PullSource != nil && p.PullSource.RepoAt != nil {
272 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok {
273 p.PullSource.Repo = sourceRepo
274 }
275 }
276 }
277
278 allReferences, err := GetReferencesAll(e, orm.FilterIn("from_at", pullAts))
279 if err != nil {
280 return nil, fmt.Errorf("failed to query reference_links: %w", err)
281 }
282 for pullAt, references := range allReferences {
283 if pull, ok := pulls[pullAt]; ok {
284 pull.References = references
285 }
286 }
287
288 orderedByPullId := []*models.Pull{}
289 for _, p := range pulls {
290 orderedByPullId = append(orderedByPullId, p)
291 }
292 sort.Slice(orderedByPullId, func(i, j int) bool {
293 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
294 })
295
296 return orderedByPullId, nil
297}
298
299func GetPulls(e Execer, filters ...orm.Filter) ([]*models.Pull, error) {
300 return GetPullsWithLimit(e, 0, filters...)
301}
302
303func GetPullIDs(e Execer, opts models.PullSearchOptions) ([]int64, error) {
304 var ids []int64
305
306 var filters []orm.Filter
307 filters = append(filters, orm.FilterEq("state", opts.State))
308 if opts.RepoAt != "" {
309 filters = append(filters, orm.FilterEq("repo_at", opts.RepoAt))
310 }
311
312 var conditions []string
313 var args []any
314
315 for _, filter := range filters {
316 conditions = append(conditions, filter.Condition())
317 args = append(args, filter.Arg()...)
318 }
319
320 whereClause := ""
321 if conditions != nil {
322 whereClause = " where " + strings.Join(conditions, " and ")
323 }
324 pageClause := ""
325 if opts.Page.Limit != 0 {
326 pageClause = fmt.Sprintf(
327 " limit %d offset %d ",
328 opts.Page.Limit,
329 opts.Page.Offset,
330 )
331 }
332
333 query := fmt.Sprintf(
334 `
335 select
336 id
337 from
338 pulls
339 %s
340 %s`,
341 whereClause,
342 pageClause,
343 )
344 args = append(args, opts.Page.Limit, opts.Page.Offset)
345 rows, err := e.Query(query, args...)
346 if err != nil {
347 return nil, err
348 }
349 defer rows.Close()
350
351 for rows.Next() {
352 var id int64
353 err := rows.Scan(&id)
354 if err != nil {
355 return nil, err
356 }
357
358 ids = append(ids, id)
359 }
360
361 return ids, nil
362}
363
364func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
365 pulls, err := GetPullsWithLimit(e, 1, orm.FilterEq("repo_at", repoAt), orm.FilterEq("pull_id", pullId))
366 if err != nil {
367 return nil, err
368 }
369 if len(pulls) == 0 {
370 return nil, sql.ErrNoRows
371 }
372
373 return pulls[0], nil
374}
375
376// mapping from pull -> pull submissions
377func GetPullSubmissions(e Execer, filters ...orm.Filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
378 var conditions []string
379 var args []any
380 for _, filter := range filters {
381 conditions = append(conditions, filter.Condition())
382 args = append(args, filter.Arg()...)
383 }
384
385 whereClause := ""
386 if conditions != nil {
387 whereClause = " where " + strings.Join(conditions, " and ")
388 }
389
390 query := fmt.Sprintf(`
391 select
392 id,
393 pull_at,
394 round_number,
395 patch,
396 combined,
397 created,
398 source_rev
399 from
400 pull_submissions
401 %s
402 order by
403 round_number asc
404 `, whereClause)
405
406 rows, err := e.Query(query, args...)
407 if err != nil {
408 return nil, err
409 }
410 defer rows.Close()
411
412 submissionMap := make(map[int]*models.PullSubmission)
413
414 for rows.Next() {
415 var submission models.PullSubmission
416 var submissionCreatedStr string
417 var submissionSourceRev, submissionCombined sql.NullString
418 err := rows.Scan(
419 &submission.ID,
420 &submission.PullAt,
421 &submission.RoundNumber,
422 &submission.Patch,
423 &submissionCombined,
424 &submissionCreatedStr,
425 &submissionSourceRev,
426 )
427 if err != nil {
428 return nil, err
429 }
430
431 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil {
432 submission.Created = t
433 }
434
435 if submissionSourceRev.Valid {
436 submission.SourceRev = submissionSourceRev.String
437 }
438
439 if submissionCombined.Valid {
440 submission.Combined = submissionCombined.String
441 }
442
443 submissionMap[submission.ID] = &submission
444 }
445
446 if err := rows.Err(); err != nil {
447 return nil, err
448 }
449
450 // Get comments for all submissions using GetComments
451 submissionIds := slices.Collect(maps.Keys(submissionMap))
452 comments, err := GetComments(e, orm.FilterIn("pull_submission_id", submissionIds))
453 if err != nil {
454 return nil, fmt.Errorf("failed to get pull comments: %w", err)
455 }
456 for _, comment := range comments {
457 if comment.PullSubmissionId != nil {
458 if submission, ok := submissionMap[*comment.PullSubmissionId]; ok {
459 submission.Comments = append(submission.Comments, comment)
460 }
461 }
462 }
463
464 // group the submissions by pull_at
465 m := make(map[syntax.ATURI][]*models.PullSubmission)
466 for _, s := range submissionMap {
467 m[s.PullAt] = append(m[s.PullAt], s)
468 }
469
470 // sort each one by round number
471 for _, s := range m {
472 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
473 return cmp.Compare(a.RoundNumber, b.RoundNumber)
474 })
475 }
476
477 return m, nil
478}
479
480// timeframe here is directly passed into the sql query filter, and any
481// timeframe in the past should be negative; e.g.: "-3 months"
482func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
483 var pulls []models.Pull
484
485 rows, err := e.Query(`
486 select
487 p.owner_did,
488 p.repo_at,
489 p.pull_id,
490 p.created,
491 p.title,
492 p.state,
493 r.did,
494 r.name,
495 r.knot,
496 r.rkey,
497 r.created
498 from
499 pulls p
500 join
501 repos r on p.repo_at = r.at_uri
502 where
503 p.owner_did = ? and p.created >= date ('now', ?)
504 order by
505 p.created desc`, did, timeframe)
506 if err != nil {
507 return nil, err
508 }
509 defer rows.Close()
510
511 for rows.Next() {
512 var pull models.Pull
513 var repo models.Repo
514 var pullCreatedAt, repoCreatedAt string
515 err := rows.Scan(
516 &pull.OwnerDid,
517 &pull.RepoAt,
518 &pull.PullId,
519 &pullCreatedAt,
520 &pull.Title,
521 &pull.State,
522 &repo.Did,
523 &repo.Name,
524 &repo.Knot,
525 &repo.Rkey,
526 &repoCreatedAt,
527 )
528 if err != nil {
529 return nil, err
530 }
531
532 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
533 if err != nil {
534 return nil, err
535 }
536 pull.Created = pullCreatedTime
537
538 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
539 if err != nil {
540 return nil, err
541 }
542 repo.Created = repoCreatedTime
543
544 pull.Repo = &repo
545
546 pulls = append(pulls, pull)
547 }
548
549 if err := rows.Err(); err != nil {
550 return nil, err
551 }
552
553 return pulls, nil
554}
555
556func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
557 _, err := e.Exec(
558 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
559 pullState,
560 repoAt,
561 pullId,
562 models.PullDeleted, // only update state of non-deleted pulls
563 models.PullMerged, // only update state of non-merged pulls
564 )
565 return err
566}
567
568func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
569 err := SetPullState(e, repoAt, pullId, models.PullClosed)
570 return err
571}
572
573func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
574 err := SetPullState(e, repoAt, pullId, models.PullOpen)
575 return err
576}
577
578func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
579 err := SetPullState(e, repoAt, pullId, models.PullMerged)
580 return err
581}
582
583func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
584 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
585 return err
586}
587
588func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error {
589 _, err := e.Exec(`
590 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
591 values (?, ?, ?, ?, ?)
592 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev)
593
594 return err
595}
596
597func SetPullParentChangeId(e Execer, parentChangeId string, filters ...orm.Filter) error {
598 var conditions []string
599 var args []any
600
601 args = append(args, parentChangeId)
602
603 for _, filter := range filters {
604 conditions = append(conditions, filter.Condition())
605 args = append(args, filter.Arg()...)
606 }
607
608 whereClause := ""
609 if conditions != nil {
610 whereClause = " where " + strings.Join(conditions, " and ")
611 }
612
613 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
614 _, err := e.Exec(query, args...)
615
616 return err
617}
618
619// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
620// otherwise submissions are immutable
621func UpdatePull(e Execer, newPatch, sourceRev string, filters ...orm.Filter) error {
622 var conditions []string
623 var args []any
624
625 args = append(args, sourceRev)
626 args = append(args, newPatch)
627
628 for _, filter := range filters {
629 conditions = append(conditions, filter.Condition())
630 args = append(args, filter.Arg()...)
631 }
632
633 whereClause := ""
634 if conditions != nil {
635 whereClause = " where " + strings.Join(conditions, " and ")
636 }
637
638 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
639 _, err := e.Exec(query, args...)
640
641 return err
642}
643
644func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
645 row := e.QueryRow(`
646 select
647 count(case when state = ? then 1 end) as open_count,
648 count(case when state = ? then 1 end) as merged_count,
649 count(case when state = ? then 1 end) as closed_count,
650 count(case when state = ? then 1 end) as deleted_count
651 from pulls
652 where repo_at = ?`,
653 models.PullOpen,
654 models.PullMerged,
655 models.PullClosed,
656 models.PullDeleted,
657 repoAt,
658 )
659
660 var count models.PullCount
661 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
662 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
663 }
664
665 return count, nil
666}
667
668// change-id parent-change-id
669//
670// 4 w ,-------- z (TOP)
671// 3 z <----',------- y
672// 2 y <-----',------ x
673// 1 x <------' nil (BOT)
674//
675// `w` is parent of none, so it is the top of the stack
676func GetStack(e Execer, stackId string) (models.Stack, error) {
677 unorderedPulls, err := GetPulls(
678 e,
679 orm.FilterEq("stack_id", stackId),
680 orm.FilterNotEq("state", models.PullDeleted),
681 )
682 if err != nil {
683 return nil, err
684 }
685 // map of parent-change-id to pull
686 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
687 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
688 for _, p := range unorderedPulls {
689 changeIdMap[p.ChangeId] = p
690 if p.ParentChangeId != "" {
691 parentMap[p.ParentChangeId] = p
692 }
693 }
694
695 // the top of the stack is the pull that is not a parent of any pull
696 var topPull *models.Pull
697 for _, maybeTop := range unorderedPulls {
698 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
699 topPull = maybeTop
700 break
701 }
702 }
703
704 pulls := []*models.Pull{}
705 for {
706 pulls = append(pulls, topPull)
707 if topPull.ParentChangeId != "" {
708 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
709 topPull = next
710 } else {
711 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
712 }
713 } else {
714 break
715 }
716 }
717
718 return pulls, nil
719}
720
721func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
722 pulls, err := GetPulls(
723 e,
724 orm.FilterEq("stack_id", stackId),
725 orm.FilterEq("state", models.PullDeleted),
726 )
727 if err != nil {
728 return nil, err
729 }
730
731 return pulls, nil
732}