this repo has no description
1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "strings"
7 "time"
8
9 "github.com/bluesky-social/indigo/atproto/syntax"
10)
11
12type PullState int
13
14const (
15 PullClosed PullState = iota
16 PullOpen
17 PullMerged
18)
19
20func (p PullState) String() string {
21 switch p {
22 case PullOpen:
23 return "open"
24 case PullMerged:
25 return "merged"
26 case PullClosed:
27 return "closed"
28 default:
29 return "closed"
30 }
31}
32
33func (p PullState) IsOpen() bool {
34 return p == PullOpen
35}
36func (p PullState) IsMerged() bool {
37 return p == PullMerged
38}
39func (p PullState) IsClosed() bool {
40 return p == PullClosed
41}
42
43type Pull struct {
44 // ids
45 ID int
46 PullId int
47
48 // at ids
49 RepoAt syntax.ATURI
50 OwnerDid string
51 Rkey string
52 PullAt syntax.ATURI
53
54 // content
55 Title string
56 Body string
57 TargetBranch string
58 State PullState
59 Submissions []*PullSubmission
60
61 // meta
62 Created time.Time
63}
64
65type PullSubmission struct {
66 // ids
67 ID int
68 PullId int
69
70 // at ids
71 RepoAt syntax.ATURI
72
73 // content
74 RoundNumber int
75 Patch string
76 Comments []PullComment
77
78 // meta
79 Created time.Time
80}
81
82type PullComment struct {
83 // ids
84 ID int
85 PullId int
86 SubmissionId int
87
88 // at ids
89 RepoAt string
90 OwnerDid string
91 CommentAt string
92
93 // content
94 Body string
95
96 // meta
97 Created time.Time
98}
99
100func (p *Pull) LatestPatch() string {
101 latestSubmission := p.Submissions[len(p.Submissions)-1]
102 return latestSubmission.Patch
103}
104
105func NewPull(tx *sql.Tx, pull *Pull) error {
106 defer tx.Rollback()
107
108 _, err := tx.Exec(`
109 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
110 values (?, 1)
111 `, pull.RepoAt)
112 if err != nil {
113 return err
114 }
115
116 var nextId int
117 err = tx.QueryRow(`
118 update repo_pull_seqs
119 set next_pull_id = next_pull_id + 1
120 where repo_at = ?
121 returning next_pull_id - 1
122 `, pull.RepoAt).Scan(&nextId)
123 if err != nil {
124 return err
125 }
126
127 pull.PullId = nextId
128 pull.State = PullOpen
129
130 _, err = tx.Exec(`
131 insert into pulls (repo_at, owner_did, pull_id, title, target_branch, body, rkey, state)
132 values (?, ?, ?, ?, ?, ?, ?, ?)
133 `, pull.RepoAt, pull.OwnerDid, pull.PullId, pull.Title, pull.TargetBranch, pull.Body, pull.Rkey, pull.State)
134 if err != nil {
135 return err
136 }
137
138 _, err = tx.Exec(`
139 insert into pull_submissions (pull_id, repo_at, round_number, patch)
140 values (?, ?, ?, ?)
141 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch)
142 if err != nil {
143 return err
144 }
145
146 if err := tx.Commit(); err != nil {
147 return err
148 }
149
150 return nil
151}
152
153func SetPullAt(e Execer, repoAt syntax.ATURI, pullId int, pullAt string) error {
154 _, err := e.Exec(`update pulls set pull_at = ? where repo_at = ? and pull_id = ?`, pullAt, repoAt, pullId)
155 return err
156}
157
158func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (string, error) {
159 var pullAt string
160 err := e.QueryRow(`select pull_at from pulls where repo_at = ? and pull_id = ?`, repoAt, pullId).Scan(&pullAt)
161 return pullAt, err
162}
163
164func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
165 var pullId int
166 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
167 return pullId - 1, err
168}
169
170func GetPulls(e Execer, repoAt syntax.ATURI, state PullState) ([]Pull, error) {
171 var pulls []Pull
172
173 rows, err := e.Query(`
174 select
175 owner_did,
176 pull_id,
177 created,
178 title,
179 state,
180 target_branch,
181 pull_at,
182 body,
183 rkey
184 from
185 pulls
186 where
187 repo_at = ? and state = ?
188 order by
189 created desc`, repoAt, state)
190 if err != nil {
191 return nil, err
192 }
193 defer rows.Close()
194
195 for rows.Next() {
196 var pull Pull
197 var createdAt string
198 err := rows.Scan(
199 &pull.OwnerDid,
200 &pull.PullId,
201 &createdAt,
202 &pull.Title,
203 &pull.State,
204 &pull.TargetBranch,
205 &pull.PullAt,
206 &pull.Body,
207 &pull.Rkey,
208 )
209 if err != nil {
210 return nil, err
211 }
212
213 createdTime, err := time.Parse(time.RFC3339, createdAt)
214 if err != nil {
215 return nil, err
216 }
217 pull.Created = createdTime
218
219 pulls = append(pulls, pull)
220 }
221
222 if err := rows.Err(); err != nil {
223 return nil, err
224 }
225
226 return pulls, nil
227}
228
229func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
230 query := `
231 select
232 owner_did,
233 pull_id,
234 created,
235 title,
236 state,
237 target_branch,
238 pull_at,
239 repo_at,
240 body,
241 rkey
242 from
243 pulls
244 where
245 repo_at = ? and pull_id = ?
246 `
247 row := e.QueryRow(query, repoAt, pullId)
248
249 var pull Pull
250 var createdAt string
251 err := row.Scan(
252 &pull.OwnerDid,
253 &pull.PullId,
254 &createdAt,
255 &pull.Title,
256 &pull.State,
257 &pull.TargetBranch,
258 &pull.PullAt,
259 &pull.RepoAt,
260 &pull.Body,
261 &pull.Rkey,
262 )
263 if err != nil {
264 return nil, err
265 }
266
267 createdTime, err := time.Parse(time.RFC3339, createdAt)
268 if err != nil {
269 return nil, err
270 }
271 pull.Created = createdTime
272
273 submissionsQuery := `
274 select
275 id, pull_id, repo_at, round_number, patch, created
276 from
277 pull_submissions
278 where
279 repo_at = ? and pull_id = ?
280 `
281 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
282 if err != nil {
283 return nil, err
284 }
285 defer submissionsRows.Close()
286
287 submissionsMap := make(map[int]*PullSubmission)
288
289 for submissionsRows.Next() {
290 var submission PullSubmission
291 var submissionCreatedStr string
292 err := submissionsRows.Scan(
293 &submission.ID,
294 &submission.PullId,
295 &submission.RepoAt,
296 &submission.RoundNumber,
297 &submission.Patch,
298 &submissionCreatedStr,
299 )
300 if err != nil {
301 return nil, err
302 }
303
304 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
305 if err != nil {
306 return nil, err
307 }
308 submission.Created = submissionCreatedTime
309
310 submissionsMap[submission.ID] = &submission
311 }
312 if err = submissionsRows.Close(); err != nil {
313 return nil, err
314 }
315 if len(submissionsMap) == 0 {
316 return &pull, nil
317 }
318
319 var args []any
320 for k := range submissionsMap {
321 args = append(args, k)
322 }
323 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
324 commentsQuery := fmt.Sprintf(`
325 select
326 id,
327 pull_id,
328 submission_id,
329 repo_at,
330 owner_did,
331 comment_at,
332 body,
333 created
334 from
335 pull_comments
336 where
337 submission_id IN (%s)
338 order by
339 created asc
340 `, inClause)
341 commentsRows, err := e.Query(commentsQuery, args...)
342 if err != nil {
343 return nil, err
344 }
345 defer commentsRows.Close()
346
347 for commentsRows.Next() {
348 var comment PullComment
349 var commentCreatedStr string
350 err := commentsRows.Scan(
351 &comment.ID,
352 &comment.PullId,
353 &comment.SubmissionId,
354 &comment.RepoAt,
355 &comment.OwnerDid,
356 &comment.CommentAt,
357 &comment.Body,
358 &commentCreatedStr,
359 )
360 if err != nil {
361 return nil, err
362 }
363
364 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
365 if err != nil {
366 return nil, err
367 }
368 comment.Created = commentCreatedTime
369
370 // Add the comment to its submission
371 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
372 submission.Comments = append(submission.Comments, comment)
373 }
374
375 }
376 if err = commentsRows.Err(); err != nil {
377 return nil, err
378 }
379
380 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
381 for _, submission := range submissionsMap {
382 pull.Submissions[submission.RoundNumber] = submission
383 }
384
385 return &pull, nil
386}
387
388func NewPullComment(e Execer, comment *PullComment) (int64, error) {
389 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
390 res, err := e.Exec(
391 query,
392 comment.OwnerDid,
393 comment.RepoAt,
394 comment.SubmissionId,
395 comment.CommentAt,
396 comment.PullId,
397 comment.Body,
398 )
399 if err != nil {
400 return 0, err
401 }
402
403 i, err := res.LastInsertId()
404 if err != nil {
405 return 0, err
406 }
407
408 return i, nil
409}
410
411func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
412 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
413 return err
414}
415
416func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
417 err := SetPullState(e, repoAt, pullId, PullClosed)
418 return err
419}
420
421func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
422 err := SetPullState(e, repoAt, pullId, PullOpen)
423 return err
424}
425
426func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
427 err := SetPullState(e, repoAt, pullId, PullMerged)
428 return err
429}
430
431func ResubmitPull(e Execer, pull *Pull, newPatch string) error {
432 newRoundNumber := len(pull.Submissions)
433 _, err := e.Exec(`
434 insert into pull_submissions (pull_id, repo_at, round_number, patch)
435 values (?, ?, ?, ?)
436 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch)
437
438 return err
439}
440
441type PullCount struct {
442 Open int
443 Merged int
444 Closed int
445}
446
447func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
448 row := e.QueryRow(`
449 select
450 count(case when state = ? then 1 end) as open_count,
451 count(case when state = ? then 1 end) as merged_count,
452 count(case when state = ? then 1 end) as closed_count
453 from pulls
454 where repo_at = ?`,
455 PullOpen,
456 PullMerged,
457 PullClosed,
458 repoAt,
459 )
460
461 var count PullCount
462 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
463 return PullCount{0, 0, 0}, err
464 }
465
466 return count, nil
467}