this repo has no description
1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "sort"
8 "strings"
9 "time"
10
11 "github.com/bluekeyes/go-gitdiff/gitdiff"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.sh/tangled.sh/core/api/tangled"
14 "tangled.sh/tangled.sh/core/patchutil"
15 "tangled.sh/tangled.sh/core/types"
16)
17
18type PullState int
19
20const (
21 PullClosed PullState = iota
22 PullOpen
23 PullMerged
24)
25
26func (p PullState) String() string {
27 switch p {
28 case PullOpen:
29 return "open"
30 case PullMerged:
31 return "merged"
32 case PullClosed:
33 return "closed"
34 default:
35 return "closed"
36 }
37}
38
39func (p PullState) IsOpen() bool {
40 return p == PullOpen
41}
42func (p PullState) IsMerged() bool {
43 return p == PullMerged
44}
45func (p PullState) IsClosed() bool {
46 return p == PullClosed
47}
48
49type Pull struct {
50 // ids
51 ID int
52 PullId int
53
54 // at ids
55 RepoAt syntax.ATURI
56 OwnerDid string
57 Rkey string
58
59 // content
60 Title string
61 Body string
62 TargetBranch string
63 State PullState
64 Submissions []*PullSubmission
65
66 // stacking
67 StackId string // nullable string
68 ChangeId string // nullable string
69 ParentChangeId string // nullable string
70
71 // meta
72 Created time.Time
73 PullSource *PullSource
74
75 // optionally, populate this when querying for reverse mappings
76 Repo *Repo
77}
78
79type PullSource struct {
80 Branch string
81 RepoAt *syntax.ATURI
82
83 // optionally populate this for reverse mappings
84 Repo *Repo
85}
86
87type PullSubmission struct {
88 // ids
89 ID int
90 PullId int
91
92 // at ids
93 RepoAt syntax.ATURI
94
95 // content
96 RoundNumber int
97 Patch string
98 Comments []PullComment
99 SourceRev string // include the rev that was used to create this submission: only for branch PRs
100
101 // meta
102 Created time.Time
103}
104
105type PullComment struct {
106 // ids
107 ID int
108 PullId int
109 SubmissionId int
110
111 // at ids
112 RepoAt string
113 OwnerDid string
114 CommentAt string
115
116 // content
117 Body string
118
119 // meta
120 Created time.Time
121}
122
123func (p *Pull) LatestPatch() string {
124 latestSubmission := p.Submissions[p.LastRoundNumber()]
125 return latestSubmission.Patch
126}
127
128func (p *Pull) PullAt() syntax.ATURI {
129 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey))
130}
131
132func (p *Pull) LastRoundNumber() int {
133 return len(p.Submissions) - 1
134}
135
136func (p *Pull) IsPatchBased() bool {
137 return p.PullSource == nil
138}
139
140func (p *Pull) IsBranchBased() bool {
141 if p.PullSource != nil {
142 if p.PullSource.RepoAt != nil {
143 return p.PullSource.RepoAt == &p.RepoAt
144 } else {
145 // no repo specified
146 return true
147 }
148 }
149 return false
150}
151
152func (p *Pull) IsForkBased() bool {
153 if p.PullSource != nil {
154 if p.PullSource.RepoAt != nil {
155 // make sure repos are different
156 return p.PullSource.RepoAt != &p.RepoAt
157 }
158 }
159 return false
160}
161
162func (s PullSubmission) AsDiff(targetBranch string) ([]*gitdiff.File, error) {
163 patch := s.Patch
164
165 // if format-patch; then extract each patch
166 var diffs []*gitdiff.File
167 if patchutil.IsFormatPatch(patch) {
168 patches, err := patchutil.ExtractPatches(patch)
169 if err != nil {
170 return nil, err
171 }
172 var ps [][]*gitdiff.File
173 for _, p := range patches {
174 ps = append(ps, p.Files)
175 }
176
177 diffs = patchutil.CombineDiff(ps...)
178 } else {
179 d, _, err := gitdiff.Parse(strings.NewReader(patch))
180 if err != nil {
181 return nil, err
182 }
183 diffs = d
184 }
185
186 return diffs, nil
187}
188
189func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
190 diffs, err := s.AsDiff(targetBranch)
191 if err != nil {
192 log.Println(err)
193 }
194
195 nd := types.NiceDiff{}
196 nd.Commit.Parent = targetBranch
197
198 for _, d := range diffs {
199 ndiff := types.Diff{}
200 ndiff.Name.New = d.NewName
201 ndiff.Name.Old = d.OldName
202 ndiff.IsBinary = d.IsBinary
203 ndiff.IsNew = d.IsNew
204 ndiff.IsDelete = d.IsDelete
205 ndiff.IsCopy = d.IsCopy
206 ndiff.IsRename = d.IsRename
207
208 for _, tf := range d.TextFragments {
209 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
210 for _, l := range tf.Lines {
211 switch l.Op {
212 case gitdiff.OpAdd:
213 nd.Stat.Insertions += 1
214 case gitdiff.OpDelete:
215 nd.Stat.Deletions += 1
216 }
217 }
218 }
219
220 nd.Diff = append(nd.Diff, ndiff)
221 }
222
223 nd.Stat.FilesChanged = len(diffs)
224
225 return nd
226}
227
228func (s PullSubmission) IsFormatPatch() bool {
229 return patchutil.IsFormatPatch(s.Patch)
230}
231
232func (s PullSubmission) AsFormatPatch() []patchutil.FormatPatch {
233 patches, err := patchutil.ExtractPatches(s.Patch)
234 if err != nil {
235 log.Println("error extracting patches from submission:", err)
236 return []patchutil.FormatPatch{}
237 }
238
239 return patches
240}
241
242func NewPull(tx *sql.Tx, pull *Pull) error {
243 _, err := tx.Exec(`
244 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
245 values (?, 1)
246 `, pull.RepoAt)
247 if err != nil {
248 return err
249 }
250
251 var nextId int
252 err = tx.QueryRow(`
253 update repo_pull_seqs
254 set next_pull_id = next_pull_id + 1
255 where repo_at = ?
256 returning next_pull_id - 1
257 `, pull.RepoAt).Scan(&nextId)
258 if err != nil {
259 return err
260 }
261
262 pull.PullId = nextId
263 pull.State = PullOpen
264
265 var sourceBranch, sourceRepoAt *string
266 if pull.PullSource != nil {
267 sourceBranch = &pull.PullSource.Branch
268 if pull.PullSource.RepoAt != nil {
269 x := pull.PullSource.RepoAt.String()
270 sourceRepoAt = &x
271 }
272 }
273
274 _, err = tx.Exec(
275 `
276 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at)
277 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
278 pull.RepoAt,
279 pull.OwnerDid,
280 pull.PullId,
281 pull.Title,
282 pull.TargetBranch,
283 pull.Body,
284 pull.Rkey,
285 pull.State,
286 sourceBranch,
287 sourceRepoAt,
288 )
289 if err != nil {
290 return err
291 }
292
293 _, err = tx.Exec(`
294 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
295 values (?, ?, ?, ?, ?)
296 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
297 return err
298}
299
300func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
301 pull, err := GetPull(e, repoAt, pullId)
302 if err != nil {
303 return "", err
304 }
305 return pull.PullAt(), err
306}
307
308func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
309 var pullId int
310 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
311 return pullId - 1, err
312}
313
314func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]*Pull, error) {
315 pulls := make(map[int]*Pull)
316
317 rows, err := e.Query(`
318 select
319 owner_did,
320 pull_id,
321 created,
322 title,
323 state,
324 target_branch,
325 body,
326 rkey,
327 source_branch,
328 source_repo_at
329 from
330 pulls
331 where
332 repo_at = ? and state = ?`, repoAt, state)
333 if err != nil {
334 return nil, err
335 }
336 defer rows.Close()
337
338 for rows.Next() {
339 var pull Pull
340 var createdAt string
341 var sourceBranch, sourceRepoAt sql.NullString
342 err := rows.Scan(
343 &pull.OwnerDid,
344 &pull.PullId,
345 &createdAt,
346 &pull.Title,
347 &pull.State,
348 &pull.TargetBranch,
349 &pull.Body,
350 &pull.Rkey,
351 &sourceBranch,
352 &sourceRepoAt,
353 )
354 if err != nil {
355 return nil, err
356 }
357
358 createdTime, err := time.Parse(time.RFC3339, createdAt)
359 if err != nil {
360 return nil, err
361 }
362 pull.Created = createdTime
363
364 if sourceBranch.Valid {
365 pull.PullSource = &PullSource{
366 Branch: sourceBranch.String,
367 }
368 if sourceRepoAt.Valid {
369 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
370 if err != nil {
371 return nil, err
372 }
373 pull.PullSource.RepoAt = &sourceRepoAtParsed
374 }
375 }
376
377 pulls[pull.PullId] = &pull
378 }
379
380 // get latest round no. for each pull
381 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
382 submissionsQuery := fmt.Sprintf(`
383 select
384 id, pull_id, round_number
385 from
386 pull_submissions
387 where
388 repo_at = ? and pull_id in (%s)
389 `, inClause)
390
391 args := make([]any, len(pulls)+1)
392 args[0] = repoAt.String()
393 idx := 1
394 for _, p := range pulls {
395 args[idx] = p.PullId
396 idx += 1
397 }
398 submissionsRows, err := e.Query(submissionsQuery, args...)
399 if err != nil {
400 return nil, err
401 }
402 defer submissionsRows.Close()
403
404 for submissionsRows.Next() {
405 var s PullSubmission
406 err := submissionsRows.Scan(
407 &s.ID,
408 &s.PullId,
409 &s.RoundNumber,
410 )
411 if err != nil {
412 return nil, err
413 }
414
415 if p, ok := pulls[s.PullId]; ok {
416 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
417 p.Submissions[s.RoundNumber] = &s
418 }
419 }
420 if err := rows.Err(); err != nil {
421 return nil, err
422 }
423
424 // get comment count on latest submission on each pull
425 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
426 commentsQuery := fmt.Sprintf(`
427 select
428 count(id), pull_id
429 from
430 pull_comments
431 where
432 submission_id in (%s)
433 group by
434 submission_id
435 `, inClause)
436
437 args = []any{}
438 for _, p := range pulls {
439 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
440 }
441 commentsRows, err := e.Query(commentsQuery, args...)
442 if err != nil {
443 return nil, err
444 }
445 defer commentsRows.Close()
446
447 for commentsRows.Next() {
448 var commentCount, pullId int
449 err := commentsRows.Scan(
450 &commentCount,
451 &pullId,
452 )
453 if err != nil {
454 return nil, err
455 }
456 if p, ok := pulls[pullId]; ok {
457 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
458 }
459 }
460 if err := rows.Err(); err != nil {
461 return nil, err
462 }
463
464 orderedByDate := []*Pull{}
465 for _, p := range pulls {
466 orderedByDate = append(orderedByDate, p)
467 }
468 sort.Slice(orderedByDate, func(i, j int) bool {
469 return orderedByDate[i].Created.After(orderedByDate[j].Created)
470 })
471
472 return orderedByDate, nil
473}
474
475func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
476 query := `
477 select
478 owner_did,
479 pull_id,
480 created,
481 title,
482 state,
483 target_branch,
484 repo_at,
485 body,
486 rkey,
487 source_branch,
488 source_repo_at,
489 stack_id,
490 change_id,
491 parent_change_id
492 from
493 pulls
494 where
495 repo_at = ? and pull_id = ?
496 `
497 row := e.QueryRow(query, repoAt, pullId)
498
499 var pull Pull
500 var createdAt string
501 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
502 err := row.Scan(
503 &pull.OwnerDid,
504 &pull.PullId,
505 &createdAt,
506 &pull.Title,
507 &pull.State,
508 &pull.TargetBranch,
509 &pull.RepoAt,
510 &pull.Body,
511 &pull.Rkey,
512 &sourceBranch,
513 &sourceRepoAt,
514 &stackId,
515 &changeId,
516 &parentChangeId,
517 )
518 if err != nil {
519 return nil, err
520 }
521
522 createdTime, err := time.Parse(time.RFC3339, createdAt)
523 if err != nil {
524 return nil, err
525 }
526 pull.Created = createdTime
527
528 // populate source
529 if sourceBranch.Valid {
530 pull.PullSource = &PullSource{
531 Branch: sourceBranch.String,
532 }
533 if sourceRepoAt.Valid {
534 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
535 if err != nil {
536 return nil, err
537 }
538 pull.PullSource.RepoAt = &sourceRepoAtParsed
539 }
540 }
541
542 if stackId.Valid {
543 pull.StackId = stackId.String
544 }
545 if changeId.Valid {
546 pull.ChangeId = changeId.String
547 }
548 if parentChangeId.Valid {
549 pull.ParentChangeId = parentChangeId.String
550 }
551
552 submissionsQuery := `
553 select
554 id, pull_id, repo_at, round_number, patch, created, source_rev
555 from
556 pull_submissions
557 where
558 repo_at = ? and pull_id = ?
559 `
560 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
561 if err != nil {
562 return nil, err
563 }
564 defer submissionsRows.Close()
565
566 submissionsMap := make(map[int]*PullSubmission)
567
568 for submissionsRows.Next() {
569 var submission PullSubmission
570 var submissionCreatedStr string
571 var submissionSourceRev sql.NullString
572 err := submissionsRows.Scan(
573 &submission.ID,
574 &submission.PullId,
575 &submission.RepoAt,
576 &submission.RoundNumber,
577 &submission.Patch,
578 &submissionCreatedStr,
579 &submissionSourceRev,
580 )
581 if err != nil {
582 return nil, err
583 }
584
585 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
586 if err != nil {
587 return nil, err
588 }
589 submission.Created = submissionCreatedTime
590
591 if submissionSourceRev.Valid {
592 submission.SourceRev = submissionSourceRev.String
593 }
594
595 submissionsMap[submission.ID] = &submission
596 }
597 if err = submissionsRows.Close(); err != nil {
598 return nil, err
599 }
600 if len(submissionsMap) == 0 {
601 return &pull, nil
602 }
603
604 var args []any
605 for k := range submissionsMap {
606 args = append(args, k)
607 }
608 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
609 commentsQuery := fmt.Sprintf(`
610 select
611 id,
612 pull_id,
613 submission_id,
614 repo_at,
615 owner_did,
616 comment_at,
617 body,
618 created
619 from
620 pull_comments
621 where
622 submission_id IN (%s)
623 order by
624 created asc
625 `, inClause)
626 commentsRows, err := e.Query(commentsQuery, args...)
627 if err != nil {
628 return nil, err
629 }
630 defer commentsRows.Close()
631
632 for commentsRows.Next() {
633 var comment PullComment
634 var commentCreatedStr string
635 err := commentsRows.Scan(
636 &comment.ID,
637 &comment.PullId,
638 &comment.SubmissionId,
639 &comment.RepoAt,
640 &comment.OwnerDid,
641 &comment.CommentAt,
642 &comment.Body,
643 &commentCreatedStr,
644 )
645 if err != nil {
646 return nil, err
647 }
648
649 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
650 if err != nil {
651 return nil, err
652 }
653 comment.Created = commentCreatedTime
654
655 // Add the comment to its submission
656 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
657 submission.Comments = append(submission.Comments, comment)
658 }
659
660 }
661 if err = commentsRows.Err(); err != nil {
662 return nil, err
663 }
664
665 var pullSourceRepo *Repo
666 if pull.PullSource != nil {
667 if pull.PullSource.RepoAt != nil {
668 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
669 if err != nil {
670 log.Printf("failed to get repo by at uri: %v", err)
671 } else {
672 pull.PullSource.Repo = pullSourceRepo
673 }
674 }
675 }
676
677 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
678 for _, submission := range submissionsMap {
679 pull.Submissions[submission.RoundNumber] = submission
680 }
681
682 return &pull, nil
683}
684
685// timeframe here is directly passed into the sql query filter, and any
686// timeframe in the past should be negative; e.g.: "-3 months"
687func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
688 var pulls []Pull
689
690 rows, err := e.Query(`
691 select
692 p.owner_did,
693 p.repo_at,
694 p.pull_id,
695 p.created,
696 p.title,
697 p.state,
698 r.did,
699 r.name,
700 r.knot,
701 r.rkey,
702 r.created
703 from
704 pulls p
705 join
706 repos r on p.repo_at = r.at_uri
707 where
708 p.owner_did = ? and p.created >= date ('now', ?)
709 order by
710 p.created desc`, did, timeframe)
711 if err != nil {
712 return nil, err
713 }
714 defer rows.Close()
715
716 for rows.Next() {
717 var pull Pull
718 var repo Repo
719 var pullCreatedAt, repoCreatedAt string
720 err := rows.Scan(
721 &pull.OwnerDid,
722 &pull.RepoAt,
723 &pull.PullId,
724 &pullCreatedAt,
725 &pull.Title,
726 &pull.State,
727 &repo.Did,
728 &repo.Name,
729 &repo.Knot,
730 &repo.Rkey,
731 &repoCreatedAt,
732 )
733 if err != nil {
734 return nil, err
735 }
736
737 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
738 if err != nil {
739 return nil, err
740 }
741 pull.Created = pullCreatedTime
742
743 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
744 if err != nil {
745 return nil, err
746 }
747 repo.Created = repoCreatedTime
748
749 pull.Repo = &repo
750
751 pulls = append(pulls, pull)
752 }
753
754 if err := rows.Err(); err != nil {
755 return nil, err
756 }
757
758 return pulls, nil
759}
760
761func NewPullComment(e Execer, comment *PullComment) (int64, error) {
762 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
763 res, err := e.Exec(
764 query,
765 comment.OwnerDid,
766 comment.RepoAt,
767 comment.SubmissionId,
768 comment.CommentAt,
769 comment.PullId,
770 comment.Body,
771 )
772 if err != nil {
773 return 0, err
774 }
775
776 i, err := res.LastInsertId()
777 if err != nil {
778 return 0, err
779 }
780
781 return i, nil
782}
783
784func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
785 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
786 return err
787}
788
789func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
790 err := SetPullState(e, repoAt, pullId, PullClosed)
791 return err
792}
793
794func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
795 err := SetPullState(e, repoAt, pullId, PullOpen)
796 return err
797}
798
799func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
800 err := SetPullState(e, repoAt, pullId, PullMerged)
801 return err
802}
803
804func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
805 newRoundNumber := len(pull.Submissions)
806 _, err := e.Exec(`
807 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
808 values (?, ?, ?, ?, ?)
809 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
810
811 return err
812}
813
814type PullCount struct {
815 Open int
816 Merged int
817 Closed int
818}
819
820func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
821 row := e.QueryRow(`
822 select
823 count(case when state = ? then 1 end) as open_count,
824 count(case when state = ? then 1 end) as merged_count,
825 count(case when state = ? then 1 end) as closed_count
826 from pulls
827 where repo_at = ?`,
828 PullOpen,
829 PullMerged,
830 PullClosed,
831 repoAt,
832 )
833
834 var count PullCount
835 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
836 return PullCount{0, 0, 0}, err
837 }
838
839 return count, nil
840}