this repo has no description
1package db 2 3import ( 4 "database/sql" 5 "fmt" 6 "log" 7 "sort" 8 "strings" 9 "time" 10 11 "github.com/bluekeyes/go-gitdiff/gitdiff" 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 "tangled.sh/tangled.sh/core/api/tangled" 14 "tangled.sh/tangled.sh/core/patchutil" 15 "tangled.sh/tangled.sh/core/types" 16) 17 18type PullState int 19 20const ( 21 PullClosed PullState = iota 22 PullOpen 23 PullMerged 24) 25 26func (p PullState) String() string { 27 switch p { 28 case PullOpen: 29 return "open" 30 case PullMerged: 31 return "merged" 32 case PullClosed: 33 return "closed" 34 default: 35 return "closed" 36 } 37} 38 39func (p PullState) IsOpen() bool { 40 return p == PullOpen 41} 42func (p PullState) IsMerged() bool { 43 return p == PullMerged 44} 45func (p PullState) IsClosed() bool { 46 return p == PullClosed 47} 48 49type Pull struct { 50 // ids 51 ID int 52 PullId int 53 54 // at ids 55 RepoAt syntax.ATURI 56 OwnerDid string 57 Rkey string 58 59 // content 60 Title string 61 Body string 62 TargetBranch string 63 State PullState 64 Submissions []*PullSubmission 65 66 // stacking 67 StackId string // nullable string 68 ChangeId string // nullable string 69 ParentChangeId string // nullable string 70 71 // meta 72 Created time.Time 73 PullSource *PullSource 74 75 // optionally, populate this when querying for reverse mappings 76 Repo *Repo 77} 78 79type PullSource struct { 80 Branch string 81 RepoAt *syntax.ATURI 82 83 // optionally populate this for reverse mappings 84 Repo *Repo 85} 86 87type PullSubmission struct { 88 // ids 89 ID int 90 PullId int 91 92 // at ids 93 RepoAt syntax.ATURI 94 95 // content 96 RoundNumber int 97 Patch string 98 Comments []PullComment 99 SourceRev string // include the rev that was used to create this submission: only for branch PRs 100 101 // meta 102 Created time.Time 103} 104 105type PullComment struct { 106 // ids 107 ID int 108 PullId int 109 SubmissionId int 110 111 // at ids 112 RepoAt string 113 OwnerDid string 114 CommentAt string 115 116 // content 117 Body string 118 119 // meta 120 Created time.Time 121} 122 123func (p *Pull) LatestPatch() string { 124 latestSubmission := p.Submissions[p.LastRoundNumber()] 125 return latestSubmission.Patch 126} 127 128func (p *Pull) PullAt() syntax.ATURI { 129 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey)) 130} 131 132func (p *Pull) LastRoundNumber() int { 133 return len(p.Submissions) - 1 134} 135 136func (p *Pull) IsPatchBased() bool { 137 return p.PullSource == nil 138} 139 140func (p *Pull) IsBranchBased() bool { 141 if p.PullSource != nil { 142 if p.PullSource.RepoAt != nil { 143 return p.PullSource.RepoAt == &p.RepoAt 144 } else { 145 // no repo specified 146 return true 147 } 148 } 149 return false 150} 151 152func (p *Pull) IsForkBased() bool { 153 if p.PullSource != nil { 154 if p.PullSource.RepoAt != nil { 155 // make sure repos are different 156 return p.PullSource.RepoAt != &p.RepoAt 157 } 158 } 159 return false 160} 161 162func (s PullSubmission) AsDiff(targetBranch string) ([]*gitdiff.File, error) { 163 patch := s.Patch 164 165 // if format-patch; then extract each patch 166 var diffs []*gitdiff.File 167 if patchutil.IsFormatPatch(patch) { 168 patches, err := patchutil.ExtractPatches(patch) 169 if err != nil { 170 return nil, err 171 } 172 var ps [][]*gitdiff.File 173 for _, p := range patches { 174 ps = append(ps, p.Files) 175 } 176 177 diffs = patchutil.CombineDiff(ps...) 178 } else { 179 d, _, err := gitdiff.Parse(strings.NewReader(patch)) 180 if err != nil { 181 return nil, err 182 } 183 diffs = d 184 } 185 186 return diffs, nil 187} 188 189func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff { 190 diffs, err := s.AsDiff(targetBranch) 191 if err != nil { 192 log.Println(err) 193 } 194 195 nd := types.NiceDiff{} 196 nd.Commit.Parent = targetBranch 197 198 for _, d := range diffs { 199 ndiff := types.Diff{} 200 ndiff.Name.New = d.NewName 201 ndiff.Name.Old = d.OldName 202 ndiff.IsBinary = d.IsBinary 203 ndiff.IsNew = d.IsNew 204 ndiff.IsDelete = d.IsDelete 205 ndiff.IsCopy = d.IsCopy 206 ndiff.IsRename = d.IsRename 207 208 for _, tf := range d.TextFragments { 209 ndiff.TextFragments = append(ndiff.TextFragments, *tf) 210 for _, l := range tf.Lines { 211 switch l.Op { 212 case gitdiff.OpAdd: 213 nd.Stat.Insertions += 1 214 case gitdiff.OpDelete: 215 nd.Stat.Deletions += 1 216 } 217 } 218 } 219 220 nd.Diff = append(nd.Diff, ndiff) 221 } 222 223 nd.Stat.FilesChanged = len(diffs) 224 225 return nd 226} 227 228func (s PullSubmission) IsFormatPatch() bool { 229 return patchutil.IsFormatPatch(s.Patch) 230} 231 232func (s PullSubmission) AsFormatPatch() []patchutil.FormatPatch { 233 patches, err := patchutil.ExtractPatches(s.Patch) 234 if err != nil { 235 log.Println("error extracting patches from submission:", err) 236 return []patchutil.FormatPatch{} 237 } 238 239 return patches 240} 241 242func NewPull(tx *sql.Tx, pull *Pull) error { 243 _, err := tx.Exec(` 244 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 245 values (?, 1) 246 `, pull.RepoAt) 247 if err != nil { 248 return err 249 } 250 251 var nextId int 252 err = tx.QueryRow(` 253 update repo_pull_seqs 254 set next_pull_id = next_pull_id + 1 255 where repo_at = ? 256 returning next_pull_id - 1 257 `, pull.RepoAt).Scan(&nextId) 258 if err != nil { 259 return err 260 } 261 262 pull.PullId = nextId 263 pull.State = PullOpen 264 265 var sourceBranch, sourceRepoAt *string 266 if pull.PullSource != nil { 267 sourceBranch = &pull.PullSource.Branch 268 if pull.PullSource.RepoAt != nil { 269 x := pull.PullSource.RepoAt.String() 270 sourceRepoAt = &x 271 } 272 } 273 274 var stackId, changeId, parentChangeId *string 275 if pull.StackId != "" { 276 stackId = &pull.StackId 277 } 278 if pull.ChangeId != "" { 279 changeId = &pull.ChangeId 280 } 281 if pull.ParentChangeId != "" { 282 parentChangeId = &pull.ParentChangeId 283 } 284 285 _, err = tx.Exec( 286 ` 287 insert into pulls ( 288 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 289 ) 290 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 291 pull.RepoAt, 292 pull.OwnerDid, 293 pull.PullId, 294 pull.Title, 295 pull.TargetBranch, 296 pull.Body, 297 pull.Rkey, 298 pull.State, 299 sourceBranch, 300 sourceRepoAt, 301 stackId, 302 changeId, 303 parentChangeId, 304 ) 305 if err != nil { 306 return err 307 } 308 309 _, err = tx.Exec(` 310 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 311 values (?, ?, ?, ?, ?) 312 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 313 return err 314} 315 316func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 317 pull, err := GetPull(e, repoAt, pullId) 318 if err != nil { 319 return "", err 320 } 321 return pull.PullAt(), err 322} 323 324func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 325 var pullId int 326 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 327 return pullId - 1, err 328} 329 330func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]*Pull, error) { 331 pulls := make(map[int]*Pull) 332 333 rows, err := e.Query(` 334 select 335 owner_did, 336 pull_id, 337 created, 338 title, 339 state, 340 target_branch, 341 body, 342 rkey, 343 source_branch, 344 source_repo_at 345 from 346 pulls 347 where 348 repo_at = ? and state = ?`, repoAt, state) 349 if err != nil { 350 return nil, err 351 } 352 defer rows.Close() 353 354 for rows.Next() { 355 var pull Pull 356 var createdAt string 357 var sourceBranch, sourceRepoAt sql.NullString 358 err := rows.Scan( 359 &pull.OwnerDid, 360 &pull.PullId, 361 &createdAt, 362 &pull.Title, 363 &pull.State, 364 &pull.TargetBranch, 365 &pull.Body, 366 &pull.Rkey, 367 &sourceBranch, 368 &sourceRepoAt, 369 ) 370 if err != nil { 371 return nil, err 372 } 373 374 createdTime, err := time.Parse(time.RFC3339, createdAt) 375 if err != nil { 376 return nil, err 377 } 378 pull.Created = createdTime 379 380 if sourceBranch.Valid { 381 pull.PullSource = &PullSource{ 382 Branch: sourceBranch.String, 383 } 384 if sourceRepoAt.Valid { 385 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 386 if err != nil { 387 return nil, err 388 } 389 pull.PullSource.RepoAt = &sourceRepoAtParsed 390 } 391 } 392 393 pulls[pull.PullId] = &pull 394 } 395 396 // get latest round no. for each pull 397 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 398 submissionsQuery := fmt.Sprintf(` 399 select 400 id, pull_id, round_number 401 from 402 pull_submissions 403 where 404 repo_at = ? and pull_id in (%s) 405 `, inClause) 406 407 args := make([]any, len(pulls)+1) 408 args[0] = repoAt.String() 409 idx := 1 410 for _, p := range pulls { 411 args[idx] = p.PullId 412 idx += 1 413 } 414 submissionsRows, err := e.Query(submissionsQuery, args...) 415 if err != nil { 416 return nil, err 417 } 418 defer submissionsRows.Close() 419 420 for submissionsRows.Next() { 421 var s PullSubmission 422 err := submissionsRows.Scan( 423 &s.ID, 424 &s.PullId, 425 &s.RoundNumber, 426 ) 427 if err != nil { 428 return nil, err 429 } 430 431 if p, ok := pulls[s.PullId]; ok { 432 p.Submissions = make([]*PullSubmission, s.RoundNumber+1) 433 p.Submissions[s.RoundNumber] = &s 434 } 435 } 436 if err := rows.Err(); err != nil { 437 return nil, err 438 } 439 440 // get comment count on latest submission on each pull 441 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 442 commentsQuery := fmt.Sprintf(` 443 select 444 count(id), pull_id 445 from 446 pull_comments 447 where 448 submission_id in (%s) 449 group by 450 submission_id 451 `, inClause) 452 453 args = []any{} 454 for _, p := range pulls { 455 args = append(args, p.Submissions[p.LastRoundNumber()].ID) 456 } 457 commentsRows, err := e.Query(commentsQuery, args...) 458 if err != nil { 459 return nil, err 460 } 461 defer commentsRows.Close() 462 463 for commentsRows.Next() { 464 var commentCount, pullId int 465 err := commentsRows.Scan( 466 &commentCount, 467 &pullId, 468 ) 469 if err != nil { 470 return nil, err 471 } 472 if p, ok := pulls[pullId]; ok { 473 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount) 474 } 475 } 476 if err := rows.Err(); err != nil { 477 return nil, err 478 } 479 480 orderedByDate := []*Pull{} 481 for _, p := range pulls { 482 orderedByDate = append(orderedByDate, p) 483 } 484 sort.Slice(orderedByDate, func(i, j int) bool { 485 return orderedByDate[i].Created.After(orderedByDate[j].Created) 486 }) 487 488 return orderedByDate, nil 489} 490 491func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) { 492 query := ` 493 select 494 owner_did, 495 pull_id, 496 created, 497 title, 498 state, 499 target_branch, 500 repo_at, 501 body, 502 rkey, 503 source_branch, 504 source_repo_at, 505 stack_id, 506 change_id, 507 parent_change_id 508 from 509 pulls 510 where 511 repo_at = ? and pull_id = ? 512 ` 513 row := e.QueryRow(query, repoAt, pullId) 514 515 var pull Pull 516 var createdAt string 517 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 518 err := row.Scan( 519 &pull.OwnerDid, 520 &pull.PullId, 521 &createdAt, 522 &pull.Title, 523 &pull.State, 524 &pull.TargetBranch, 525 &pull.RepoAt, 526 &pull.Body, 527 &pull.Rkey, 528 &sourceBranch, 529 &sourceRepoAt, 530 &stackId, 531 &changeId, 532 &parentChangeId, 533 ) 534 if err != nil { 535 return nil, err 536 } 537 538 createdTime, err := time.Parse(time.RFC3339, createdAt) 539 if err != nil { 540 return nil, err 541 } 542 pull.Created = createdTime 543 544 // populate source 545 if sourceBranch.Valid { 546 pull.PullSource = &PullSource{ 547 Branch: sourceBranch.String, 548 } 549 if sourceRepoAt.Valid { 550 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 551 if err != nil { 552 return nil, err 553 } 554 pull.PullSource.RepoAt = &sourceRepoAtParsed 555 } 556 } 557 558 if stackId.Valid { 559 pull.StackId = stackId.String 560 } 561 if changeId.Valid { 562 pull.ChangeId = changeId.String 563 } 564 if parentChangeId.Valid { 565 pull.ParentChangeId = parentChangeId.String 566 } 567 568 submissionsQuery := ` 569 select 570 id, pull_id, repo_at, round_number, patch, created, source_rev 571 from 572 pull_submissions 573 where 574 repo_at = ? and pull_id = ? 575 ` 576 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId) 577 if err != nil { 578 return nil, err 579 } 580 defer submissionsRows.Close() 581 582 submissionsMap := make(map[int]*PullSubmission) 583 584 for submissionsRows.Next() { 585 var submission PullSubmission 586 var submissionCreatedStr string 587 var submissionSourceRev sql.NullString 588 err := submissionsRows.Scan( 589 &submission.ID, 590 &submission.PullId, 591 &submission.RepoAt, 592 &submission.RoundNumber, 593 &submission.Patch, 594 &submissionCreatedStr, 595 &submissionSourceRev, 596 ) 597 if err != nil { 598 return nil, err 599 } 600 601 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr) 602 if err != nil { 603 return nil, err 604 } 605 submission.Created = submissionCreatedTime 606 607 if submissionSourceRev.Valid { 608 submission.SourceRev = submissionSourceRev.String 609 } 610 611 submissionsMap[submission.ID] = &submission 612 } 613 if err = submissionsRows.Close(); err != nil { 614 return nil, err 615 } 616 if len(submissionsMap) == 0 { 617 return &pull, nil 618 } 619 620 var args []any 621 for k := range submissionsMap { 622 args = append(args, k) 623 } 624 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ") 625 commentsQuery := fmt.Sprintf(` 626 select 627 id, 628 pull_id, 629 submission_id, 630 repo_at, 631 owner_did, 632 comment_at, 633 body, 634 created 635 from 636 pull_comments 637 where 638 submission_id IN (%s) 639 order by 640 created asc 641 `, inClause) 642 commentsRows, err := e.Query(commentsQuery, args...) 643 if err != nil { 644 return nil, err 645 } 646 defer commentsRows.Close() 647 648 for commentsRows.Next() { 649 var comment PullComment 650 var commentCreatedStr string 651 err := commentsRows.Scan( 652 &comment.ID, 653 &comment.PullId, 654 &comment.SubmissionId, 655 &comment.RepoAt, 656 &comment.OwnerDid, 657 &comment.CommentAt, 658 &comment.Body, 659 &commentCreatedStr, 660 ) 661 if err != nil { 662 return nil, err 663 } 664 665 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr) 666 if err != nil { 667 return nil, err 668 } 669 comment.Created = commentCreatedTime 670 671 // Add the comment to its submission 672 if submission, ok := submissionsMap[comment.SubmissionId]; ok { 673 submission.Comments = append(submission.Comments, comment) 674 } 675 676 } 677 if err = commentsRows.Err(); err != nil { 678 return nil, err 679 } 680 681 var pullSourceRepo *Repo 682 if pull.PullSource != nil { 683 if pull.PullSource.RepoAt != nil { 684 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String()) 685 if err != nil { 686 log.Printf("failed to get repo by at uri: %v", err) 687 } else { 688 pull.PullSource.Repo = pullSourceRepo 689 } 690 } 691 } 692 693 pull.Submissions = make([]*PullSubmission, len(submissionsMap)) 694 for _, submission := range submissionsMap { 695 pull.Submissions[submission.RoundNumber] = submission 696 } 697 698 return &pull, nil 699} 700 701// timeframe here is directly passed into the sql query filter, and any 702// timeframe in the past should be negative; e.g.: "-3 months" 703func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) { 704 var pulls []Pull 705 706 rows, err := e.Query(` 707 select 708 p.owner_did, 709 p.repo_at, 710 p.pull_id, 711 p.created, 712 p.title, 713 p.state, 714 r.did, 715 r.name, 716 r.knot, 717 r.rkey, 718 r.created 719 from 720 pulls p 721 join 722 repos r on p.repo_at = r.at_uri 723 where 724 p.owner_did = ? and p.created >= date ('now', ?) 725 order by 726 p.created desc`, did, timeframe) 727 if err != nil { 728 return nil, err 729 } 730 defer rows.Close() 731 732 for rows.Next() { 733 var pull Pull 734 var repo Repo 735 var pullCreatedAt, repoCreatedAt string 736 err := rows.Scan( 737 &pull.OwnerDid, 738 &pull.RepoAt, 739 &pull.PullId, 740 &pullCreatedAt, 741 &pull.Title, 742 &pull.State, 743 &repo.Did, 744 &repo.Name, 745 &repo.Knot, 746 &repo.Rkey, 747 &repoCreatedAt, 748 ) 749 if err != nil { 750 return nil, err 751 } 752 753 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 754 if err != nil { 755 return nil, err 756 } 757 pull.Created = pullCreatedTime 758 759 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 760 if err != nil { 761 return nil, err 762 } 763 repo.Created = repoCreatedTime 764 765 pull.Repo = &repo 766 767 pulls = append(pulls, pull) 768 } 769 770 if err := rows.Err(); err != nil { 771 return nil, err 772 } 773 774 return pulls, nil 775} 776 777func NewPullComment(e Execer, comment *PullComment) (int64, error) { 778 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 779 res, err := e.Exec( 780 query, 781 comment.OwnerDid, 782 comment.RepoAt, 783 comment.SubmissionId, 784 comment.CommentAt, 785 comment.PullId, 786 comment.Body, 787 ) 788 if err != nil { 789 return 0, err 790 } 791 792 i, err := res.LastInsertId() 793 if err != nil { 794 return 0, err 795 } 796 797 return i, nil 798} 799 800func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error { 801 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId) 802 return err 803} 804 805func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 806 err := SetPullState(e, repoAt, pullId, PullClosed) 807 return err 808} 809 810func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 811 err := SetPullState(e, repoAt, pullId, PullOpen) 812 return err 813} 814 815func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 816 err := SetPullState(e, repoAt, pullId, PullMerged) 817 return err 818} 819 820func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error { 821 newRoundNumber := len(pull.Submissions) 822 _, err := e.Exec(` 823 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 824 values (?, ?, ?, ?, ?) 825 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev) 826 827 return err 828} 829 830type PullCount struct { 831 Open int 832 Merged int 833 Closed int 834} 835 836func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) { 837 row := e.QueryRow(` 838 select 839 count(case when state = ? then 1 end) as open_count, 840 count(case when state = ? then 1 end) as merged_count, 841 count(case when state = ? then 1 end) as closed_count 842 from pulls 843 where repo_at = ?`, 844 PullOpen, 845 PullMerged, 846 PullClosed, 847 repoAt, 848 ) 849 850 var count PullCount 851 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil { 852 return PullCount{0, 0, 0}, err 853 } 854 855 return count, nil 856}