this repo has no description
1package db 2 3import ( 4 "database/sql" 5 "fmt" 6 "log" 7 "sort" 8 "strings" 9 "time" 10 11 "github.com/bluekeyes/go-gitdiff/gitdiff" 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 "tangled.sh/tangled.sh/core/api/tangled" 14 "tangled.sh/tangled.sh/core/patchutil" 15 "tangled.sh/tangled.sh/core/types" 16) 17 18type PullState int 19 20const ( 21 PullClosed PullState = iota 22 PullOpen 23 PullMerged 24) 25 26func (p PullState) String() string { 27 switch p { 28 case PullOpen: 29 return "open" 30 case PullMerged: 31 return "merged" 32 case PullClosed: 33 return "closed" 34 default: 35 return "closed" 36 } 37} 38 39func (p PullState) IsOpen() bool { 40 return p == PullOpen 41} 42func (p PullState) IsMerged() bool { 43 return p == PullMerged 44} 45func (p PullState) IsClosed() bool { 46 return p == PullClosed 47} 48 49type Pull struct { 50 // ids 51 ID int 52 PullId int 53 54 // at ids 55 RepoAt syntax.ATURI 56 OwnerDid string 57 Rkey string 58 59 // content 60 Title string 61 Body string 62 TargetBranch string 63 State PullState 64 Submissions []*PullSubmission 65 66 // stacking 67 StackId string // nullable string 68 ChangeId string // nullable string 69 ParentChangeId string // nullable string 70 71 // meta 72 Created time.Time 73 PullSource *PullSource 74 75 // optionally, populate this when querying for reverse mappings 76 Repo *Repo 77} 78 79type PullSource struct { 80 Branch string 81 RepoAt *syntax.ATURI 82 83 // optionally populate this for reverse mappings 84 Repo *Repo 85} 86 87type PullSubmission struct { 88 // ids 89 ID int 90 PullId int 91 92 // at ids 93 RepoAt syntax.ATURI 94 95 // content 96 RoundNumber int 97 Patch string 98 Comments []PullComment 99 SourceRev string // include the rev that was used to create this submission: only for branch PRs 100 101 // meta 102 Created time.Time 103} 104 105type PullComment struct { 106 // ids 107 ID int 108 PullId int 109 SubmissionId int 110 111 // at ids 112 RepoAt string 113 OwnerDid string 114 CommentAt string 115 116 // content 117 Body string 118 119 // meta 120 Created time.Time 121} 122 123func (p *Pull) LatestPatch() string { 124 latestSubmission := p.Submissions[p.LastRoundNumber()] 125 return latestSubmission.Patch 126} 127 128func (p *Pull) PullAt() syntax.ATURI { 129 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey)) 130} 131 132func (p *Pull) LastRoundNumber() int { 133 return len(p.Submissions) - 1 134} 135 136func (p *Pull) IsPatchBased() bool { 137 return p.PullSource == nil 138} 139 140func (p *Pull) IsBranchBased() bool { 141 if p.PullSource != nil { 142 if p.PullSource.RepoAt != nil { 143 return p.PullSource.RepoAt == &p.RepoAt 144 } else { 145 // no repo specified 146 return true 147 } 148 } 149 return false 150} 151 152func (p *Pull) IsForkBased() bool { 153 if p.PullSource != nil { 154 if p.PullSource.RepoAt != nil { 155 // make sure repos are different 156 return p.PullSource.RepoAt != &p.RepoAt 157 } 158 } 159 return false 160} 161 162func (s PullSubmission) AsDiff(targetBranch string) ([]*gitdiff.File, error) { 163 patch := s.Patch 164 165 // if format-patch; then extract each patch 166 var diffs []*gitdiff.File 167 if patchutil.IsFormatPatch(patch) { 168 patches, err := patchutil.ExtractPatches(patch) 169 if err != nil { 170 return nil, err 171 } 172 var ps [][]*gitdiff.File 173 for _, p := range patches { 174 ps = append(ps, p.Files) 175 } 176 177 diffs = patchutil.CombineDiff(ps...) 178 } else { 179 d, _, err := gitdiff.Parse(strings.NewReader(patch)) 180 if err != nil { 181 return nil, err 182 } 183 diffs = d 184 } 185 186 return diffs, nil 187} 188 189func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff { 190 diffs, err := s.AsDiff(targetBranch) 191 if err != nil { 192 log.Println(err) 193 } 194 195 nd := types.NiceDiff{} 196 nd.Commit.Parent = targetBranch 197 198 for _, d := range diffs { 199 ndiff := types.Diff{} 200 ndiff.Name.New = d.NewName 201 ndiff.Name.Old = d.OldName 202 ndiff.IsBinary = d.IsBinary 203 ndiff.IsNew = d.IsNew 204 ndiff.IsDelete = d.IsDelete 205 ndiff.IsCopy = d.IsCopy 206 ndiff.IsRename = d.IsRename 207 208 for _, tf := range d.TextFragments { 209 ndiff.TextFragments = append(ndiff.TextFragments, *tf) 210 for _, l := range tf.Lines { 211 switch l.Op { 212 case gitdiff.OpAdd: 213 nd.Stat.Insertions += 1 214 case gitdiff.OpDelete: 215 nd.Stat.Deletions += 1 216 } 217 } 218 } 219 220 nd.Diff = append(nd.Diff, ndiff) 221 } 222 223 nd.Stat.FilesChanged = len(diffs) 224 225 return nd 226} 227 228func (s PullSubmission) IsFormatPatch() bool { 229 return patchutil.IsFormatPatch(s.Patch) 230} 231 232func (s PullSubmission) AsFormatPatch() []patchutil.FormatPatch { 233 patches, err := patchutil.ExtractPatches(s.Patch) 234 if err != nil { 235 log.Println("error extracting patches from submission:", err) 236 return []patchutil.FormatPatch{} 237 } 238 239 return patches 240} 241 242func NewPull(tx *sql.Tx, pull *Pull) error { 243 _, err := tx.Exec(` 244 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 245 values (?, 1) 246 `, pull.RepoAt) 247 if err != nil { 248 return err 249 } 250 251 var nextId int 252 err = tx.QueryRow(` 253 update repo_pull_seqs 254 set next_pull_id = next_pull_id + 1 255 where repo_at = ? 256 returning next_pull_id - 1 257 `, pull.RepoAt).Scan(&nextId) 258 if err != nil { 259 return err 260 } 261 262 pull.PullId = nextId 263 pull.State = PullOpen 264 265 var sourceBranch, sourceRepoAt *string 266 if pull.PullSource != nil { 267 sourceBranch = &pull.PullSource.Branch 268 if pull.PullSource.RepoAt != nil { 269 x := pull.PullSource.RepoAt.String() 270 sourceRepoAt = &x 271 } 272 } 273 274 _, err = tx.Exec( 275 ` 276 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at) 277 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 278 pull.RepoAt, 279 pull.OwnerDid, 280 pull.PullId, 281 pull.Title, 282 pull.TargetBranch, 283 pull.Body, 284 pull.Rkey, 285 pull.State, 286 sourceBranch, 287 sourceRepoAt, 288 ) 289 if err != nil { 290 return err 291 } 292 293 _, err = tx.Exec(` 294 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 295 values (?, ?, ?, ?, ?) 296 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 297 return err 298} 299 300func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 301 pull, err := GetPull(e, repoAt, pullId) 302 if err != nil { 303 return "", err 304 } 305 return pull.PullAt(), err 306} 307 308func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 309 var pullId int 310 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 311 return pullId - 1, err 312} 313 314func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]*Pull, error) { 315 pulls := make(map[int]*Pull) 316 317 rows, err := e.Query(` 318 select 319 owner_did, 320 pull_id, 321 created, 322 title, 323 state, 324 target_branch, 325 body, 326 rkey, 327 source_branch, 328 source_repo_at 329 from 330 pulls 331 where 332 repo_at = ? and state = ?`, repoAt, state) 333 if err != nil { 334 return nil, err 335 } 336 defer rows.Close() 337 338 for rows.Next() { 339 var pull Pull 340 var createdAt string 341 var sourceBranch, sourceRepoAt sql.NullString 342 err := rows.Scan( 343 &pull.OwnerDid, 344 &pull.PullId, 345 &createdAt, 346 &pull.Title, 347 &pull.State, 348 &pull.TargetBranch, 349 &pull.Body, 350 &pull.Rkey, 351 &sourceBranch, 352 &sourceRepoAt, 353 ) 354 if err != nil { 355 return nil, err 356 } 357 358 createdTime, err := time.Parse(time.RFC3339, createdAt) 359 if err != nil { 360 return nil, err 361 } 362 pull.Created = createdTime 363 364 if sourceBranch.Valid { 365 pull.PullSource = &PullSource{ 366 Branch: sourceBranch.String, 367 } 368 if sourceRepoAt.Valid { 369 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 370 if err != nil { 371 return nil, err 372 } 373 pull.PullSource.RepoAt = &sourceRepoAtParsed 374 } 375 } 376 377 pulls[pull.PullId] = &pull 378 } 379 380 // get latest round no. for each pull 381 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 382 submissionsQuery := fmt.Sprintf(` 383 select 384 id, pull_id, round_number 385 from 386 pull_submissions 387 where 388 repo_at = ? and pull_id in (%s) 389 `, inClause) 390 391 args := make([]any, len(pulls)+1) 392 args[0] = repoAt.String() 393 idx := 1 394 for _, p := range pulls { 395 args[idx] = p.PullId 396 idx += 1 397 } 398 submissionsRows, err := e.Query(submissionsQuery, args...) 399 if err != nil { 400 return nil, err 401 } 402 defer submissionsRows.Close() 403 404 for submissionsRows.Next() { 405 var s PullSubmission 406 err := submissionsRows.Scan( 407 &s.ID, 408 &s.PullId, 409 &s.RoundNumber, 410 ) 411 if err != nil { 412 return nil, err 413 } 414 415 if p, ok := pulls[s.PullId]; ok { 416 p.Submissions = make([]*PullSubmission, s.RoundNumber+1) 417 p.Submissions[s.RoundNumber] = &s 418 } 419 } 420 if err := rows.Err(); err != nil { 421 return nil, err 422 } 423 424 // get comment count on latest submission on each pull 425 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 426 commentsQuery := fmt.Sprintf(` 427 select 428 count(id), pull_id 429 from 430 pull_comments 431 where 432 submission_id in (%s) 433 group by 434 submission_id 435 `, inClause) 436 437 args = []any{} 438 for _, p := range pulls { 439 args = append(args, p.Submissions[p.LastRoundNumber()].ID) 440 } 441 commentsRows, err := e.Query(commentsQuery, args...) 442 if err != nil { 443 return nil, err 444 } 445 defer commentsRows.Close() 446 447 for commentsRows.Next() { 448 var commentCount, pullId int 449 err := commentsRows.Scan( 450 &commentCount, 451 &pullId, 452 ) 453 if err != nil { 454 return nil, err 455 } 456 if p, ok := pulls[pullId]; ok { 457 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount) 458 } 459 } 460 if err := rows.Err(); err != nil { 461 return nil, err 462 } 463 464 orderedByDate := []*Pull{} 465 for _, p := range pulls { 466 orderedByDate = append(orderedByDate, p) 467 } 468 sort.Slice(orderedByDate, func(i, j int) bool { 469 return orderedByDate[i].Created.After(orderedByDate[j].Created) 470 }) 471 472 return orderedByDate, nil 473} 474 475func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) { 476 query := ` 477 select 478 owner_did, 479 pull_id, 480 created, 481 title, 482 state, 483 target_branch, 484 repo_at, 485 body, 486 rkey, 487 source_branch, 488 source_repo_at, 489 stack_id, 490 change_id, 491 parent_change_id 492 from 493 pulls 494 where 495 repo_at = ? and pull_id = ? 496 ` 497 row := e.QueryRow(query, repoAt, pullId) 498 499 var pull Pull 500 var createdAt string 501 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 502 err := row.Scan( 503 &pull.OwnerDid, 504 &pull.PullId, 505 &createdAt, 506 &pull.Title, 507 &pull.State, 508 &pull.TargetBranch, 509 &pull.RepoAt, 510 &pull.Body, 511 &pull.Rkey, 512 &sourceBranch, 513 &sourceRepoAt, 514 &stackId, 515 &changeId, 516 &parentChangeId, 517 ) 518 if err != nil { 519 return nil, err 520 } 521 522 createdTime, err := time.Parse(time.RFC3339, createdAt) 523 if err != nil { 524 return nil, err 525 } 526 pull.Created = createdTime 527 528 // populate source 529 if sourceBranch.Valid { 530 pull.PullSource = &PullSource{ 531 Branch: sourceBranch.String, 532 } 533 if sourceRepoAt.Valid { 534 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 535 if err != nil { 536 return nil, err 537 } 538 pull.PullSource.RepoAt = &sourceRepoAtParsed 539 } 540 } 541 542 if stackId.Valid { 543 pull.StackId = stackId.String 544 } 545 if changeId.Valid { 546 pull.ChangeId = changeId.String 547 } 548 if parentChangeId.Valid { 549 pull.ParentChangeId = parentChangeId.String 550 } 551 552 submissionsQuery := ` 553 select 554 id, pull_id, repo_at, round_number, patch, created, source_rev 555 from 556 pull_submissions 557 where 558 repo_at = ? and pull_id = ? 559 ` 560 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId) 561 if err != nil { 562 return nil, err 563 } 564 defer submissionsRows.Close() 565 566 submissionsMap := make(map[int]*PullSubmission) 567 568 for submissionsRows.Next() { 569 var submission PullSubmission 570 var submissionCreatedStr string 571 var submissionSourceRev sql.NullString 572 err := submissionsRows.Scan( 573 &submission.ID, 574 &submission.PullId, 575 &submission.RepoAt, 576 &submission.RoundNumber, 577 &submission.Patch, 578 &submissionCreatedStr, 579 &submissionSourceRev, 580 ) 581 if err != nil { 582 return nil, err 583 } 584 585 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr) 586 if err != nil { 587 return nil, err 588 } 589 submission.Created = submissionCreatedTime 590 591 if submissionSourceRev.Valid { 592 submission.SourceRev = submissionSourceRev.String 593 } 594 595 submissionsMap[submission.ID] = &submission 596 } 597 if err = submissionsRows.Close(); err != nil { 598 return nil, err 599 } 600 if len(submissionsMap) == 0 { 601 return &pull, nil 602 } 603 604 var args []any 605 for k := range submissionsMap { 606 args = append(args, k) 607 } 608 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ") 609 commentsQuery := fmt.Sprintf(` 610 select 611 id, 612 pull_id, 613 submission_id, 614 repo_at, 615 owner_did, 616 comment_at, 617 body, 618 created 619 from 620 pull_comments 621 where 622 submission_id IN (%s) 623 order by 624 created asc 625 `, inClause) 626 commentsRows, err := e.Query(commentsQuery, args...) 627 if err != nil { 628 return nil, err 629 } 630 defer commentsRows.Close() 631 632 for commentsRows.Next() { 633 var comment PullComment 634 var commentCreatedStr string 635 err := commentsRows.Scan( 636 &comment.ID, 637 &comment.PullId, 638 &comment.SubmissionId, 639 &comment.RepoAt, 640 &comment.OwnerDid, 641 &comment.CommentAt, 642 &comment.Body, 643 &commentCreatedStr, 644 ) 645 if err != nil { 646 return nil, err 647 } 648 649 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr) 650 if err != nil { 651 return nil, err 652 } 653 comment.Created = commentCreatedTime 654 655 // Add the comment to its submission 656 if submission, ok := submissionsMap[comment.SubmissionId]; ok { 657 submission.Comments = append(submission.Comments, comment) 658 } 659 660 } 661 if err = commentsRows.Err(); err != nil { 662 return nil, err 663 } 664 665 var pullSourceRepo *Repo 666 if pull.PullSource != nil { 667 if pull.PullSource.RepoAt != nil { 668 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String()) 669 if err != nil { 670 log.Printf("failed to get repo by at uri: %v", err) 671 } else { 672 pull.PullSource.Repo = pullSourceRepo 673 } 674 } 675 } 676 677 pull.Submissions = make([]*PullSubmission, len(submissionsMap)) 678 for _, submission := range submissionsMap { 679 pull.Submissions[submission.RoundNumber] = submission 680 } 681 682 return &pull, nil 683} 684 685// timeframe here is directly passed into the sql query filter, and any 686// timeframe in the past should be negative; e.g.: "-3 months" 687func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) { 688 var pulls []Pull 689 690 rows, err := e.Query(` 691 select 692 p.owner_did, 693 p.repo_at, 694 p.pull_id, 695 p.created, 696 p.title, 697 p.state, 698 r.did, 699 r.name, 700 r.knot, 701 r.rkey, 702 r.created 703 from 704 pulls p 705 join 706 repos r on p.repo_at = r.at_uri 707 where 708 p.owner_did = ? and p.created >= date ('now', ?) 709 order by 710 p.created desc`, did, timeframe) 711 if err != nil { 712 return nil, err 713 } 714 defer rows.Close() 715 716 for rows.Next() { 717 var pull Pull 718 var repo Repo 719 var pullCreatedAt, repoCreatedAt string 720 err := rows.Scan( 721 &pull.OwnerDid, 722 &pull.RepoAt, 723 &pull.PullId, 724 &pullCreatedAt, 725 &pull.Title, 726 &pull.State, 727 &repo.Did, 728 &repo.Name, 729 &repo.Knot, 730 &repo.Rkey, 731 &repoCreatedAt, 732 ) 733 if err != nil { 734 return nil, err 735 } 736 737 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 738 if err != nil { 739 return nil, err 740 } 741 pull.Created = pullCreatedTime 742 743 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 744 if err != nil { 745 return nil, err 746 } 747 repo.Created = repoCreatedTime 748 749 pull.Repo = &repo 750 751 pulls = append(pulls, pull) 752 } 753 754 if err := rows.Err(); err != nil { 755 return nil, err 756 } 757 758 return pulls, nil 759} 760 761func NewPullComment(e Execer, comment *PullComment) (int64, error) { 762 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 763 res, err := e.Exec( 764 query, 765 comment.OwnerDid, 766 comment.RepoAt, 767 comment.SubmissionId, 768 comment.CommentAt, 769 comment.PullId, 770 comment.Body, 771 ) 772 if err != nil { 773 return 0, err 774 } 775 776 i, err := res.LastInsertId() 777 if err != nil { 778 return 0, err 779 } 780 781 return i, nil 782} 783 784func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error { 785 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId) 786 return err 787} 788 789func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 790 err := SetPullState(e, repoAt, pullId, PullClosed) 791 return err 792} 793 794func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 795 err := SetPullState(e, repoAt, pullId, PullOpen) 796 return err 797} 798 799func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 800 err := SetPullState(e, repoAt, pullId, PullMerged) 801 return err 802} 803 804func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error { 805 newRoundNumber := len(pull.Submissions) 806 _, err := e.Exec(` 807 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 808 values (?, ?, ?, ?, ?) 809 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev) 810 811 return err 812} 813 814type PullCount struct { 815 Open int 816 Merged int 817 Closed int 818} 819 820func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) { 821 row := e.QueryRow(` 822 select 823 count(case when state = ? then 1 end) as open_count, 824 count(case when state = ? then 1 end) as merged_count, 825 count(case when state = ? then 1 end) as closed_count 826 from pulls 827 where repo_at = ?`, 828 PullOpen, 829 PullMerged, 830 PullClosed, 831 repoAt, 832 ) 833 834 var count PullCount 835 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil { 836 return PullCount{0, 0, 0}, err 837 } 838 839 return count, nil 840}