this repo has no description
1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "sort"
8 "strings"
9 "time"
10
11 "github.com/bluekeyes/go-gitdiff/gitdiff"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.sh/tangled.sh/core/api/tangled"
14 "tangled.sh/tangled.sh/core/patchutil"
15 "tangled.sh/tangled.sh/core/types"
16)
17
18type PullState int
19
20const (
21 PullClosed PullState = iota
22 PullOpen
23 PullMerged
24)
25
26func (p PullState) String() string {
27 switch p {
28 case PullOpen:
29 return "open"
30 case PullMerged:
31 return "merged"
32 case PullClosed:
33 return "closed"
34 default:
35 return "closed"
36 }
37}
38
39func (p PullState) IsOpen() bool {
40 return p == PullOpen
41}
42func (p PullState) IsMerged() bool {
43 return p == PullMerged
44}
45func (p PullState) IsClosed() bool {
46 return p == PullClosed
47}
48
49type Pull struct {
50 // ids
51 ID int
52 PullId int
53
54 // at ids
55 RepoAt syntax.ATURI
56 OwnerDid string
57 Rkey string
58
59 // content
60 Title string
61 Body string
62 TargetBranch string
63 State PullState
64 Submissions []*PullSubmission
65
66 // stacking
67 StackId string // nullable string
68 ChangeId string // nullable string
69 ParentChangeId string // nullable string
70
71 // meta
72 Created time.Time
73 PullSource *PullSource
74
75 // optionally, populate this when querying for reverse mappings
76 Repo *Repo
77}
78
79type PullSource struct {
80 Branch string
81 RepoAt *syntax.ATURI
82
83 // optionally populate this for reverse mappings
84 Repo *Repo
85}
86
87type PullSubmission struct {
88 // ids
89 ID int
90 PullId int
91
92 // at ids
93 RepoAt syntax.ATURI
94
95 // content
96 RoundNumber int
97 Patch string
98 Comments []PullComment
99 SourceRev string // include the rev that was used to create this submission: only for branch PRs
100
101 // meta
102 Created time.Time
103}
104
105type PullComment struct {
106 // ids
107 ID int
108 PullId int
109 SubmissionId int
110
111 // at ids
112 RepoAt string
113 OwnerDid string
114 CommentAt string
115
116 // content
117 Body string
118
119 // meta
120 Created time.Time
121}
122
123func (p *Pull) LatestPatch() string {
124 latestSubmission := p.Submissions[p.LastRoundNumber()]
125 return latestSubmission.Patch
126}
127
128func (p *Pull) PullAt() syntax.ATURI {
129 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey))
130}
131
132func (p *Pull) LastRoundNumber() int {
133 return len(p.Submissions) - 1
134}
135
136func (p *Pull) IsPatchBased() bool {
137 return p.PullSource == nil
138}
139
140func (p *Pull) IsBranchBased() bool {
141 if p.PullSource != nil {
142 if p.PullSource.RepoAt != nil {
143 return p.PullSource.RepoAt == &p.RepoAt
144 } else {
145 // no repo specified
146 return true
147 }
148 }
149 return false
150}
151
152func (p *Pull) IsForkBased() bool {
153 if p.PullSource != nil {
154 if p.PullSource.RepoAt != nil {
155 // make sure repos are different
156 return p.PullSource.RepoAt != &p.RepoAt
157 }
158 }
159 return false
160}
161
162func (s PullSubmission) AsDiff(targetBranch string) ([]*gitdiff.File, error) {
163 patch := s.Patch
164
165 // if format-patch; then extract each patch
166 var diffs []*gitdiff.File
167 if patchutil.IsFormatPatch(patch) {
168 patches, err := patchutil.ExtractPatches(patch)
169 if err != nil {
170 return nil, err
171 }
172 var ps [][]*gitdiff.File
173 for _, p := range patches {
174 ps = append(ps, p.Files)
175 }
176
177 diffs = patchutil.CombineDiff(ps...)
178 } else {
179 d, _, err := gitdiff.Parse(strings.NewReader(patch))
180 if err != nil {
181 return nil, err
182 }
183 diffs = d
184 }
185
186 return diffs, nil
187}
188
189func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
190 diffs, err := s.AsDiff(targetBranch)
191 if err != nil {
192 log.Println(err)
193 }
194
195 nd := types.NiceDiff{}
196 nd.Commit.Parent = targetBranch
197
198 for _, d := range diffs {
199 ndiff := types.Diff{}
200 ndiff.Name.New = d.NewName
201 ndiff.Name.Old = d.OldName
202 ndiff.IsBinary = d.IsBinary
203 ndiff.IsNew = d.IsNew
204 ndiff.IsDelete = d.IsDelete
205 ndiff.IsCopy = d.IsCopy
206 ndiff.IsRename = d.IsRename
207
208 for _, tf := range d.TextFragments {
209 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
210 for _, l := range tf.Lines {
211 switch l.Op {
212 case gitdiff.OpAdd:
213 nd.Stat.Insertions += 1
214 case gitdiff.OpDelete:
215 nd.Stat.Deletions += 1
216 }
217 }
218 }
219
220 nd.Diff = append(nd.Diff, ndiff)
221 }
222
223 nd.Stat.FilesChanged = len(diffs)
224
225 return nd
226}
227
228func (s PullSubmission) IsFormatPatch() bool {
229 return patchutil.IsFormatPatch(s.Patch)
230}
231
232func (s PullSubmission) AsFormatPatch() []patchutil.FormatPatch {
233 patches, err := patchutil.ExtractPatches(s.Patch)
234 if err != nil {
235 log.Println("error extracting patches from submission:", err)
236 return []patchutil.FormatPatch{}
237 }
238
239 return patches
240}
241
242func NewPull(tx *sql.Tx, pull *Pull) error {
243 _, err := tx.Exec(`
244 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
245 values (?, 1)
246 `, pull.RepoAt)
247 if err != nil {
248 return err
249 }
250
251 var nextId int
252 err = tx.QueryRow(`
253 update repo_pull_seqs
254 set next_pull_id = next_pull_id + 1
255 where repo_at = ?
256 returning next_pull_id - 1
257 `, pull.RepoAt).Scan(&nextId)
258 if err != nil {
259 return err
260 }
261
262 pull.PullId = nextId
263 pull.State = PullOpen
264
265 var sourceBranch, sourceRepoAt *string
266 if pull.PullSource != nil {
267 sourceBranch = &pull.PullSource.Branch
268 if pull.PullSource.RepoAt != nil {
269 x := pull.PullSource.RepoAt.String()
270 sourceRepoAt = &x
271 }
272 }
273
274 var stackId, changeId, parentChangeId *string
275 if pull.StackId != "" {
276 stackId = &pull.StackId
277 }
278 if pull.ChangeId != "" {
279 changeId = &pull.ChangeId
280 }
281 if pull.ParentChangeId != "" {
282 parentChangeId = &pull.ParentChangeId
283 }
284
285 _, err = tx.Exec(
286 `
287 insert into pulls (
288 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
289 )
290 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
291 pull.RepoAt,
292 pull.OwnerDid,
293 pull.PullId,
294 pull.Title,
295 pull.TargetBranch,
296 pull.Body,
297 pull.Rkey,
298 pull.State,
299 sourceBranch,
300 sourceRepoAt,
301 stackId,
302 changeId,
303 parentChangeId,
304 )
305 if err != nil {
306 return err
307 }
308
309 _, err = tx.Exec(`
310 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
311 values (?, ?, ?, ?, ?)
312 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
313 return err
314}
315
316func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
317 pull, err := GetPull(e, repoAt, pullId)
318 if err != nil {
319 return "", err
320 }
321 return pull.PullAt(), err
322}
323
324func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
325 var pullId int
326 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
327 return pullId - 1, err
328}
329
330func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]*Pull, error) {
331 pulls := make(map[int]*Pull)
332
333 rows, err := e.Query(`
334 select
335 owner_did,
336 pull_id,
337 created,
338 title,
339 state,
340 target_branch,
341 body,
342 rkey,
343 source_branch,
344 source_repo_at
345 from
346 pulls
347 where
348 repo_at = ? and state = ?`, repoAt, state)
349 if err != nil {
350 return nil, err
351 }
352 defer rows.Close()
353
354 for rows.Next() {
355 var pull Pull
356 var createdAt string
357 var sourceBranch, sourceRepoAt sql.NullString
358 err := rows.Scan(
359 &pull.OwnerDid,
360 &pull.PullId,
361 &createdAt,
362 &pull.Title,
363 &pull.State,
364 &pull.TargetBranch,
365 &pull.Body,
366 &pull.Rkey,
367 &sourceBranch,
368 &sourceRepoAt,
369 )
370 if err != nil {
371 return nil, err
372 }
373
374 createdTime, err := time.Parse(time.RFC3339, createdAt)
375 if err != nil {
376 return nil, err
377 }
378 pull.Created = createdTime
379
380 if sourceBranch.Valid {
381 pull.PullSource = &PullSource{
382 Branch: sourceBranch.String,
383 }
384 if sourceRepoAt.Valid {
385 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
386 if err != nil {
387 return nil, err
388 }
389 pull.PullSource.RepoAt = &sourceRepoAtParsed
390 }
391 }
392
393 pulls[pull.PullId] = &pull
394 }
395
396 // get latest round no. for each pull
397 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
398 submissionsQuery := fmt.Sprintf(`
399 select
400 id, pull_id, round_number
401 from
402 pull_submissions
403 where
404 repo_at = ? and pull_id in (%s)
405 `, inClause)
406
407 args := make([]any, len(pulls)+1)
408 args[0] = repoAt.String()
409 idx := 1
410 for _, p := range pulls {
411 args[idx] = p.PullId
412 idx += 1
413 }
414 submissionsRows, err := e.Query(submissionsQuery, args...)
415 if err != nil {
416 return nil, err
417 }
418 defer submissionsRows.Close()
419
420 for submissionsRows.Next() {
421 var s PullSubmission
422 err := submissionsRows.Scan(
423 &s.ID,
424 &s.PullId,
425 &s.RoundNumber,
426 )
427 if err != nil {
428 return nil, err
429 }
430
431 if p, ok := pulls[s.PullId]; ok {
432 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
433 p.Submissions[s.RoundNumber] = &s
434 }
435 }
436 if err := rows.Err(); err != nil {
437 return nil, err
438 }
439
440 // get comment count on latest submission on each pull
441 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
442 commentsQuery := fmt.Sprintf(`
443 select
444 count(id), pull_id
445 from
446 pull_comments
447 where
448 submission_id in (%s)
449 group by
450 submission_id
451 `, inClause)
452
453 args = []any{}
454 for _, p := range pulls {
455 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
456 }
457 commentsRows, err := e.Query(commentsQuery, args...)
458 if err != nil {
459 return nil, err
460 }
461 defer commentsRows.Close()
462
463 for commentsRows.Next() {
464 var commentCount, pullId int
465 err := commentsRows.Scan(
466 &commentCount,
467 &pullId,
468 )
469 if err != nil {
470 return nil, err
471 }
472 if p, ok := pulls[pullId]; ok {
473 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
474 }
475 }
476 if err := rows.Err(); err != nil {
477 return nil, err
478 }
479
480 orderedByDate := []*Pull{}
481 for _, p := range pulls {
482 orderedByDate = append(orderedByDate, p)
483 }
484 sort.Slice(orderedByDate, func(i, j int) bool {
485 return orderedByDate[i].Created.After(orderedByDate[j].Created)
486 })
487
488 return orderedByDate, nil
489}
490
491func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
492 query := `
493 select
494 owner_did,
495 pull_id,
496 created,
497 title,
498 state,
499 target_branch,
500 repo_at,
501 body,
502 rkey,
503 source_branch,
504 source_repo_at,
505 stack_id,
506 change_id,
507 parent_change_id
508 from
509 pulls
510 where
511 repo_at = ? and pull_id = ?
512 `
513 row := e.QueryRow(query, repoAt, pullId)
514
515 var pull Pull
516 var createdAt string
517 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
518 err := row.Scan(
519 &pull.OwnerDid,
520 &pull.PullId,
521 &createdAt,
522 &pull.Title,
523 &pull.State,
524 &pull.TargetBranch,
525 &pull.RepoAt,
526 &pull.Body,
527 &pull.Rkey,
528 &sourceBranch,
529 &sourceRepoAt,
530 &stackId,
531 &changeId,
532 &parentChangeId,
533 )
534 if err != nil {
535 return nil, err
536 }
537
538 createdTime, err := time.Parse(time.RFC3339, createdAt)
539 if err != nil {
540 return nil, err
541 }
542 pull.Created = createdTime
543
544 // populate source
545 if sourceBranch.Valid {
546 pull.PullSource = &PullSource{
547 Branch: sourceBranch.String,
548 }
549 if sourceRepoAt.Valid {
550 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
551 if err != nil {
552 return nil, err
553 }
554 pull.PullSource.RepoAt = &sourceRepoAtParsed
555 }
556 }
557
558 if stackId.Valid {
559 pull.StackId = stackId.String
560 }
561 if changeId.Valid {
562 pull.ChangeId = changeId.String
563 }
564 if parentChangeId.Valid {
565 pull.ParentChangeId = parentChangeId.String
566 }
567
568 submissionsQuery := `
569 select
570 id, pull_id, repo_at, round_number, patch, created, source_rev
571 from
572 pull_submissions
573 where
574 repo_at = ? and pull_id = ?
575 `
576 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
577 if err != nil {
578 return nil, err
579 }
580 defer submissionsRows.Close()
581
582 submissionsMap := make(map[int]*PullSubmission)
583
584 for submissionsRows.Next() {
585 var submission PullSubmission
586 var submissionCreatedStr string
587 var submissionSourceRev sql.NullString
588 err := submissionsRows.Scan(
589 &submission.ID,
590 &submission.PullId,
591 &submission.RepoAt,
592 &submission.RoundNumber,
593 &submission.Patch,
594 &submissionCreatedStr,
595 &submissionSourceRev,
596 )
597 if err != nil {
598 return nil, err
599 }
600
601 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
602 if err != nil {
603 return nil, err
604 }
605 submission.Created = submissionCreatedTime
606
607 if submissionSourceRev.Valid {
608 submission.SourceRev = submissionSourceRev.String
609 }
610
611 submissionsMap[submission.ID] = &submission
612 }
613 if err = submissionsRows.Close(); err != nil {
614 return nil, err
615 }
616 if len(submissionsMap) == 0 {
617 return &pull, nil
618 }
619
620 var args []any
621 for k := range submissionsMap {
622 args = append(args, k)
623 }
624 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
625 commentsQuery := fmt.Sprintf(`
626 select
627 id,
628 pull_id,
629 submission_id,
630 repo_at,
631 owner_did,
632 comment_at,
633 body,
634 created
635 from
636 pull_comments
637 where
638 submission_id IN (%s)
639 order by
640 created asc
641 `, inClause)
642 commentsRows, err := e.Query(commentsQuery, args...)
643 if err != nil {
644 return nil, err
645 }
646 defer commentsRows.Close()
647
648 for commentsRows.Next() {
649 var comment PullComment
650 var commentCreatedStr string
651 err := commentsRows.Scan(
652 &comment.ID,
653 &comment.PullId,
654 &comment.SubmissionId,
655 &comment.RepoAt,
656 &comment.OwnerDid,
657 &comment.CommentAt,
658 &comment.Body,
659 &commentCreatedStr,
660 )
661 if err != nil {
662 return nil, err
663 }
664
665 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
666 if err != nil {
667 return nil, err
668 }
669 comment.Created = commentCreatedTime
670
671 // Add the comment to its submission
672 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
673 submission.Comments = append(submission.Comments, comment)
674 }
675
676 }
677 if err = commentsRows.Err(); err != nil {
678 return nil, err
679 }
680
681 var pullSourceRepo *Repo
682 if pull.PullSource != nil {
683 if pull.PullSource.RepoAt != nil {
684 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
685 if err != nil {
686 log.Printf("failed to get repo by at uri: %v", err)
687 } else {
688 pull.PullSource.Repo = pullSourceRepo
689 }
690 }
691 }
692
693 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
694 for _, submission := range submissionsMap {
695 pull.Submissions[submission.RoundNumber] = submission
696 }
697
698 return &pull, nil
699}
700
701// timeframe here is directly passed into the sql query filter, and any
702// timeframe in the past should be negative; e.g.: "-3 months"
703func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
704 var pulls []Pull
705
706 rows, err := e.Query(`
707 select
708 p.owner_did,
709 p.repo_at,
710 p.pull_id,
711 p.created,
712 p.title,
713 p.state,
714 r.did,
715 r.name,
716 r.knot,
717 r.rkey,
718 r.created
719 from
720 pulls p
721 join
722 repos r on p.repo_at = r.at_uri
723 where
724 p.owner_did = ? and p.created >= date ('now', ?)
725 order by
726 p.created desc`, did, timeframe)
727 if err != nil {
728 return nil, err
729 }
730 defer rows.Close()
731
732 for rows.Next() {
733 var pull Pull
734 var repo Repo
735 var pullCreatedAt, repoCreatedAt string
736 err := rows.Scan(
737 &pull.OwnerDid,
738 &pull.RepoAt,
739 &pull.PullId,
740 &pullCreatedAt,
741 &pull.Title,
742 &pull.State,
743 &repo.Did,
744 &repo.Name,
745 &repo.Knot,
746 &repo.Rkey,
747 &repoCreatedAt,
748 )
749 if err != nil {
750 return nil, err
751 }
752
753 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
754 if err != nil {
755 return nil, err
756 }
757 pull.Created = pullCreatedTime
758
759 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
760 if err != nil {
761 return nil, err
762 }
763 repo.Created = repoCreatedTime
764
765 pull.Repo = &repo
766
767 pulls = append(pulls, pull)
768 }
769
770 if err := rows.Err(); err != nil {
771 return nil, err
772 }
773
774 return pulls, nil
775}
776
777func NewPullComment(e Execer, comment *PullComment) (int64, error) {
778 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
779 res, err := e.Exec(
780 query,
781 comment.OwnerDid,
782 comment.RepoAt,
783 comment.SubmissionId,
784 comment.CommentAt,
785 comment.PullId,
786 comment.Body,
787 )
788 if err != nil {
789 return 0, err
790 }
791
792 i, err := res.LastInsertId()
793 if err != nil {
794 return 0, err
795 }
796
797 return i, nil
798}
799
800func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
801 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
802 return err
803}
804
805func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
806 err := SetPullState(e, repoAt, pullId, PullClosed)
807 return err
808}
809
810func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
811 err := SetPullState(e, repoAt, pullId, PullOpen)
812 return err
813}
814
815func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
816 err := SetPullState(e, repoAt, pullId, PullMerged)
817 return err
818}
819
820func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
821 newRoundNumber := len(pull.Submissions)
822 _, err := e.Exec(`
823 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
824 values (?, ?, ?, ?, ?)
825 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
826
827 return err
828}
829
830type PullCount struct {
831 Open int
832 Merged int
833 Closed int
834}
835
836func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
837 row := e.QueryRow(`
838 select
839 count(case when state = ? then 1 end) as open_count,
840 count(case when state = ? then 1 end) as merged_count,
841 count(case when state = ? then 1 end) as closed_count
842 from pulls
843 where repo_at = ?`,
844 PullOpen,
845 PullMerged,
846 PullClosed,
847 repoAt,
848 )
849
850 var count PullCount
851 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
852 return PullCount{0, 0, 0}, err
853 }
854
855 return count, nil
856}