package db import ( "database/sql" "fmt" "strings" "time" "github.com/bluesky-social/indigo/atproto/syntax" "tangled.org/core/appview/models" "tangled.org/core/appview/pagination" ) // NewPullRound creates new PR submission. // Set pullId to open a new PR. func NewPullRound(tx *sql.Tx, pullId int, round *models.PullRound) error { // Create new PR when pullId isn't set if pullId == 0 { // ensure sequence exists _, err := tx.Exec(` insert or ignore into repo_pull_seqs (repo_at, next_pull_id) values (?, 1) `, round.Target.RepoAt) if err != nil { return err } err = tx.QueryRow(` update repo_pull_seqs set next_pull_id = next_pull_id + 1 where repo_at = ? returning next_pull_id - 1 `, round.Target.RepoAt).Scan(&pullId) if err != nil { return err } _, err = tx.Exec( `insert into pulls2 (pull_id, did, rkey, state) values (?, ?, ?, ?)`, pullId, round.Did, round.Rkey, models.PullOpen, ) if err != nil { return fmt.Errorf("insert pull submission: %w", err) } } var sourceRepoAt, sourceBranch *string if round.Source != nil { x := round.Source.RepoAt.String() sourceRepoAt = &x sourceBranch = &round.Source.Branch } _, err := tx.Exec( `insert into pull_rounds ( pull_did, pull_rkey, cid, target_repo_at, target_branch, source_repo_at, source_branch, patch, title, body, created ) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, round.Did, round.Rkey, round.Cid, round.Target.RepoAt, round.Target.Branch, sourceRepoAt, sourceBranch, round.Patch, round.Title, round.Body, round.Created.Format(time.RFC3339), ) if err != nil { return fmt.Errorf("insert pull submission: %w", err) } err = tx.QueryRow( `select id, round from pull_rounds_view where did = ? and rkey = ? and cid = ?`, round.Did, round.Rkey, round.Cid, ).Scan(&round.Id, &round.Round) if err != nil { return fmt.Errorf("get id and round number: %w", err) } if err := putReferences(tx, round.AtUri(), &round.Cid, round.References); err != nil { return fmt.Errorf("put reference_links: %w", err) } return nil } func SetPullState2(e Execer, pullAt syntax.ATURI, state models.PullState) error { _, err := e.Exec( `update pulls2 set state = ? where at_uri = ? and (state <> ? or state <> ?)`, state, pullAt, models.PullDeleted, // only update state of non-deleted pulls models.PullMerged, // only update state of non-merged pulls ) return err } func ClosePull2(e Execer, pullAt syntax.ATURI) error { return SetPullState2(e, pullAt, models.PullClosed) } func ReopenPull2(e Execer, pullAt syntax.ATURI) error { return SetPullState2(e, pullAt, models.PullOpen) } func MergePull2(e Execer, pullAt syntax.ATURI) error { return SetPullState2(e, pullAt, models.PullMerged) } func DeletePull2(e Execer, pullAt syntax.ATURI) error { return SetPullState2(e, pullAt, models.PullDeleted) } func GetPulls2(e Execer, filters ...filter) ([]*models.Pull2, error) { return GetPullsPaginated(e, pagination.Page{}, filters...) } func GetPullsPaginated(e Execer, page pagination.Page, filters ...filter) ([]*models.Pull2, error) { pullsMap := make(map[syntax.ATURI]*models.Pull2) var conditions []string var args []any for _, filter := range filters { conditions = append(conditions, filter.Condition()) args = append(args, filter.Arg()...) } whereClause := "" if conditions != nil { whereClause = " where " + strings.Join(conditions, " and ") } pLower := FilterGte("row_num", page.Offset+1) pUpper := FilterLte("row_num", page.Offset+page.Limit) pageClause := "" if page.Limit > 0 { args = append(args, pLower.Arg()...) args = append(args, pUpper.Arg()...) pageClause = " where " + pLower.Condition() + " and " + pUpper.Condition() } query := fmt.Sprintf( `select * from ( select id, did, rkey, cid, target_repo_at, target_branch, source_repo_at, source_branch, patch, title, body, created, row_number() over (order by id desc) as row_num from pull_rounds %s ) ranked_pull_rounds %s`, whereClause, pageClause, ) rows, err := e.Query(query, args...) if err != nil { return nil, fmt.Errorf("query pulls table: %w", err) } defer rows.Close() for rows.Next() { var round models.PullRound var sourceBranch, sourceRepoAt sql.NullString var created string err := rows.Scan( &round.Id, &round.Did, &round.Rkey, &round.Cid, &round.Target.RepoAt, &round.Target.Branch, &sourceBranch, &sourceRepoAt, &round.Patch, &round.Title, &round.Body, &created, ) if err != nil { return nil, fmt.Errorf("scan row: %w", err) } if sourceBranch.Valid && sourceRepoAt.Valid { round.Source = &models.PullSource2{} round.Source.Branch = sourceBranch.String round.Source.RepoAt = syntax.ATURI(sourceRepoAt.String) } createdAtTime, _ := time.Parse(time.RFC3339, created) round.Created = createdAtTime pull, ok := pullsMap[round.AtUri()] if !ok { pull = &models.Pull2{ Did: round.Did, Rkey: round.Rkey, } } pull.Submissions = append(pull.Submissions, &round) pullsMap[round.AtUri()] = pull } // TODO: fetch pulls (id, pull_id, state) from (did, rkey) panic("unimplemented") }