this repo has no description
1package knotclient
2
3import (
4 "context"
5 "encoding/json"
6 "log/slog"
7 "math/rand"
8 "net/url"
9 "sync"
10 "time"
11
12 "tangled.sh/tangled.sh/core/log"
13
14 "github.com/gorilla/websocket"
15)
16
17type ProcessFunc func(source EventSource, message Message) error
18
19type Message struct {
20 Rkey string
21 Nsid string
22 // do not full deserialize this portion of the message, processFunc can do that
23 EventJson json.RawMessage
24}
25
26type ConsumerConfig struct {
27 Sources map[EventSource]struct{}
28 ProcessFunc ProcessFunc
29 RetryInterval time.Duration
30 MaxRetryInterval time.Duration
31 ConnectionTimeout time.Duration
32 WorkerCount int
33 QueueSize int
34 Logger *slog.Logger
35 Dev bool
36}
37
38type EventSource struct {
39 Knot string
40}
41
42func NewEventSource(knot string) EventSource {
43 return EventSource{
44 Knot: knot,
45 }
46}
47
48type EventConsumer struct {
49 cfg ConsumerConfig
50 wg sync.WaitGroup
51 dialer *websocket.Dialer
52 connMap sync.Map
53 jobQueue chan job
54 logger *slog.Logger
55 randSource *rand.Rand
56
57 // rw lock over edits to consumer config
58 mu sync.RWMutex
59}
60
61func (e *EventConsumer) buildUrl(s EventSource, cursor string) (*url.URL, error) {
62 scheme := "wss"
63 if e.cfg.Dev {
64 scheme = "ws"
65 }
66
67 u, err := url.Parse(scheme + "://" + s.Knot + "/events")
68 if err != nil {
69 return nil, err
70 }
71
72 if cursor != "" {
73 query := url.Values{}
74 query.Add("cursor", cursor)
75 u.RawQuery = query.Encode()
76 }
77 return u, nil
78}
79
80type job struct {
81 source EventSource
82 message []byte
83}
84
85func NewEventConsumer(cfg ConsumerConfig) *EventConsumer {
86 if cfg.RetryInterval == 0 {
87 cfg.RetryInterval = 15 * time.Minute
88 }
89 if cfg.ConnectionTimeout == 0 {
90 cfg.ConnectionTimeout = 10 * time.Second
91 }
92 if cfg.WorkerCount <= 0 {
93 cfg.WorkerCount = 5
94 }
95 if cfg.MaxRetryInterval == 0 {
96 cfg.MaxRetryInterval = 1 * time.Hour
97 }
98 if cfg.Logger == nil {
99 cfg.Logger = log.New("eventconsumer")
100 }
101 if cfg.QueueSize == 0 {
102 cfg.QueueSize = 100
103 }
104 return &EventConsumer{
105 cfg: cfg,
106 dialer: websocket.DefaultDialer,
107 jobQueue: make(chan job, cfg.QueueSize), // buffered job queue
108 logger: cfg.Logger,
109 randSource: rand.New(rand.NewSource(time.Now().UnixNano())),
110 }
111}
112
113func (c *EventConsumer) Start(ctx context.Context) {
114 // start workers
115 for range c.cfg.WorkerCount {
116 c.wg.Add(1)
117 go c.worker(ctx)
118 }
119
120 // start streaming
121 for source := range c.cfg.Sources {
122 c.wg.Add(1)
123 go c.startConnectionLoop(ctx, source)
124 }
125}
126
127func (c *EventConsumer) Stop() {
128 c.connMap.Range(func(_, val any) bool {
129 if conn, ok := val.(*websocket.Conn); ok {
130 conn.Close()
131 }
132 return true
133 })
134 c.wg.Wait()
135 close(c.jobQueue)
136}
137
138func (c *EventConsumer) AddSource(ctx context.Context, s EventSource) {
139 c.mu.Lock()
140 c.cfg.Sources[s] = struct{}{}
141 c.wg.Add(1)
142 go c.startConnectionLoop(ctx, s)
143 c.mu.Unlock()
144}
145
146func (c *EventConsumer) worker(ctx context.Context) {
147 defer c.wg.Done()
148 for {
149 select {
150 case <-ctx.Done():
151 return
152 case j, ok := <-c.jobQueue:
153 if !ok {
154 return
155 }
156
157 var msg Message
158 err := json.Unmarshal(j.message, &msg)
159 if err != nil {
160 c.logger.Error("error deserializing message", "source", j.source.Knot, "err", err)
161 return
162 }
163 if err := c.cfg.ProcessFunc(j.source, msg); err != nil {
164 c.logger.Error("error processing message", "source", j.source, "err", err)
165 }
166 }
167 }
168}
169
170func (c *EventConsumer) startConnectionLoop(ctx context.Context, source EventSource) {
171 defer c.wg.Done()
172 retryInterval := c.cfg.RetryInterval
173 for {
174 select {
175 case <-ctx.Done():
176 return
177 default:
178 err := c.runConnection(ctx, source)
179 if err != nil {
180 c.logger.Error("connection failed", "source", source, "err", err)
181 }
182
183 // apply jitter
184 jitter := time.Duration(c.randSource.Int63n(int64(retryInterval) / 5))
185 delay := retryInterval + jitter
186
187 if retryInterval < c.cfg.MaxRetryInterval {
188 retryInterval *= 2
189 if retryInterval > c.cfg.MaxRetryInterval {
190 retryInterval = c.cfg.MaxRetryInterval
191 }
192 }
193 c.logger.Info("retrying connection", "source", source, "delay", delay)
194 select {
195 case <-time.After(delay):
196 case <-ctx.Done():
197 return
198 }
199 }
200 }
201}
202
203func (c *EventConsumer) runConnection(ctx context.Context, source EventSource) error {
204 connCtx, cancel := context.WithTimeout(ctx, c.cfg.ConnectionTimeout)
205 defer cancel()
206
207 u, err := url.Parse(source)
208
209 u, err := c.buildUrl(source, cursor)
210 if err != nil {
211 return err
212 }
213
214 c.logger.Info("connecting", "url", u.String())
215 conn, _, err := c.dialer.DialContext(connCtx, u.String(), nil)
216 if err != nil {
217 return err
218 }
219 defer conn.Close()
220 c.connMap.Store(source, conn)
221 defer c.connMap.Delete(source)
222
223 c.logger.Info("connected", "source", source)
224
225 for {
226 select {
227 case <-ctx.Done():
228 return nil
229 default:
230 msgType, msg, err := conn.ReadMessage()
231 if err != nil {
232 return err
233 }
234 if msgType != websocket.TextMessage {
235 continue
236 }
237 select {
238 case c.jobQueue <- job{source: source, message: msg}:
239 case <-ctx.Done():
240 return nil
241 }
242 }
243 }
244}