an ORM-free SQL experience
1package norm
2
3import (
4 "database/sql"
5 "iter"
6 "reflect"
7)
8
9type Scanner[T any] struct {
10 rows *sql.Rows
11 scanFunc func(*T) []any
12 onError func(err error)
13}
14
15func NewScanner[T any](rows *sql.Rows) Scanner[T] {
16 return Scanner[T]{
17 rows: rows,
18 }
19}
20
21func (s Scanner[T]) ScanWith(fn func(*T) []any) iter.Seq2[T, error] {
22 s.scanFunc = fn
23 return s.Scan()
24}
25
26func (s Scanner[T]) Scan() iter.Seq2[T, error] {
27 // If no custom scan function provided, use reflection-based default
28 if s.scanFunc == nil {
29 s.scanFunc = func(dest *T) []any {
30 elem := reflect.ValueOf(dest).Elem()
31 numCols := elem.NumField()
32 columns := make([]any, numCols)
33
34 for i := range numCols {
35 field := elem.Field(i)
36 columns[i] = field.Addr().Interface()
37 }
38
39 return columns
40 }
41 }
42
43 return func(yield func(T, error) bool) {
44 for s.rows.Next() {
45 var data T
46 columns := s.scanFunc(&data)
47 err := s.rows.Scan(columns...)
48
49 if !yield(data, err) {
50 return
51 }
52 }
53 }
54}
55
56func (s *Scanner[T]) Close() error {
57 return s.rows.Close()
58}
59
60func ScanAll[T any](rows *sql.Rows, dest *[]T) error {
61 scanner := NewScanner[T](rows)
62 defer scanner.Close()
63
64 for elem, err := range scanner.Scan() {
65 if err != nil {
66 return err
67 }
68 *dest = append(*dest, elem)
69 }
70
71 return nil
72}
73
74func Scan[T any](row *sql.Row, dest *T) error {
75 elem := reflect.ValueOf(dest).Elem()
76 numCols := elem.NumField()
77 columns := make([]any, numCols)
78
79 for i := range numCols {
80 field := elem.Field(i)
81 columns[i] = field.Addr().Interface()
82 }
83
84 return row.Scan(columns...)
85}