this repo has no description
1package tapclient
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "net/url"
11 "sync"
12 "time"
13
14 "github.com/gorilla/websocket"
15)
16
17var (
18 initialBackoff = 500 * time.Millisecond
19)
20
21// A thin error wrapper that indicates to the tap client consumer loop that a message
22// should not be retried (i.e. invalid user input that will surely fail again on retry).
23type NonRetryableError struct {
24 err error
25}
26
27func NewNonRetryableError(err error) *NonRetryableError {
28 return &NonRetryableError{err: err}
29}
30
31func (err *NonRetryableError) Error() string {
32 if err.err != nil {
33 return err.err.Error()
34 }
35 return ""
36}
37
38// Websocket implements a tap consumer that reads via a websocket
39type Websocket struct {
40 log *slog.Logger
41
42 addr string
43 sendAcks bool
44 maxErrs int
45
46 connectTimeout time.Duration
47 readTimeout time.Duration
48 writeTimeout time.Duration
49
50 handler WebsocketHandlerFunc
51}
52
53// Defines an option for the tap websocket consumer
54type WebsocketOption func(*Websocket)
55
56// Defines the log/slog logger to use throughout the lifecycle of the websocket
57// consumer. Pass nil to disable logging.
58func WithLogger(logger *slog.Logger) func(*Websocket) {
59 return func(ws *Websocket) {
60 ws.log = logger
61
62 if ws.log == nil {
63 // write to io.Discard if a nil logger is passed
64 ws.log = slog.New(slog.NewTextHandler(io.Discard, nil))
65 }
66 }
67}
68
69// Sets the connect timeout for connecting to the websocket
70func WithConnectTimeout(timeout time.Duration) func(*Websocket) {
71 return func(ws *Websocket) {
72 ws.connectTimeout = timeout
73 }
74}
75
76// Sets the read timeout for reading data from the websocket
77func WithReadTimeout(timeout time.Duration) func(*Websocket) {
78 return func(ws *Websocket) {
79 ws.readTimeout = timeout
80 }
81}
82
83// Sets the write timeout for writing data to the websocket
84func WithWriteTimeout(timeout time.Duration) func(*Websocket) {
85 return func(ws *Websocket) {
86 ws.writeTimeout = timeout
87 }
88}
89
90// Controls how many times the loop will attempt to reconnect to the websocket in a row before giving up
91func WithMaxConsecutiveErrors(numErrs int) func(*Websocket) {
92 return func(ws *Websocket) {
93 ws.maxErrs = numErrs
94 }
95}
96
97// Turns on message acknowledgements
98func WithAcks() func(*Websocket) {
99 return func(ws *Websocket) {
100 ws.sendAcks = true
101 }
102}
103
104// Defines an option for the tap websocket consumer. A nil error indicates that an ACK will be sent to tap
105// if WithAcks() is provided.
106type WebsocketHandlerFunc func(context.Context, *Event) error
107
108// Initializes a tap websocket consumer
109func NewWebsocket(addr string, handler WebsocketHandlerFunc, opts ...WebsocketOption) (*Websocket, error) {
110 u, err := url.Parse(addr)
111 if err != nil {
112 return nil, fmt.Errorf("failed to parse websocket url %q: %w", addr, err)
113 }
114
115 switch u.Scheme {
116 case "ws", "wss": // ok
117 default:
118 return nil, fmt.Errorf("invalid websocket protocol scheme: wanted ws:// or wss://, got %q", u.Scheme)
119 }
120
121 if handler == nil {
122 return nil, fmt.Errorf("a websocket message handler func is required")
123 }
124
125 ws := &Websocket{
126 log: slog.Default().WithGroup("tap"),
127
128 addr: addr,
129 sendAcks: false,
130 maxErrs: 10,
131
132 connectTimeout: 30 * time.Second,
133 readTimeout: 30 * time.Second,
134 writeTimeout: 30 * time.Second,
135
136 handler: handler,
137 }
138
139 for _, opt := range opts {
140 opt(ws)
141 }
142
143 return ws, nil
144}
145
146// Connects to and beings the main tap websocket consumer loop
147func (ws *Websocket) Run(ctx context.Context) error {
148 for errCount := 0; ; {
149 select {
150 case <-ctx.Done():
151 ws.log.Debug("websocket ingester shutting down")
152 return nil
153 default:
154 }
155
156 err := ws.runOnce(ctx)
157 if errors.Is(err, context.Canceled) {
158 ws.log.Debug("websocket ingester shutting down")
159 return nil
160 }
161
162 if err == nil {
163 errCount = 0
164 ws.log.Debug("websocket connection closed normally, reconnecting")
165 continue
166 }
167
168 errCount++
169 ws.log.Error("websocket connection failed", "err", err, "consecutive_errors", errCount)
170
171 if errCount >= ws.maxErrs {
172 return fmt.Errorf("websocket connection failed %d consecutive times: %w", errCount, err)
173 }
174
175 ws.log.Warn("retrying websocket connection", "consecutive_errors", errCount)
176 if sleepMaybeExit(ctx, errCount) {
177 return nil
178 }
179 }
180}
181
182func (ws *Websocket) close(conn *websocket.Conn) {
183 if err := conn.Close(); err != nil {
184 ws.log.Error("failed to close websocket connection", "err", err)
185 }
186}
187
188func (ws *Websocket) runOnce(ctx context.Context) error {
189 dialer := websocket.Dialer{HandshakeTimeout: ws.connectTimeout}
190 conn, _, err := dialer.DialContext(ctx, ws.addr, nil)
191 if err != nil {
192 return fmt.Errorf("failed to connect to websocket at %q: %w", ws.addr, err)
193 }
194
195 var closeOnce sync.Once
196 closeConn := func() {
197 closeOnce.Do(func() {
198 if err := conn.Close(); err != nil {
199 ws.log.Error("failed to close websocket connection", "err", err)
200 }
201 })
202 }
203 defer closeConn()
204
205 ws.log.Debug("connected to websocket", "addr", ws.addr)
206
207 go func() {
208 <-ctx.Done()
209 closeConn()
210 }()
211
212 for {
213 if done(ctx) {
214 return nil
215 }
216
217 if err := conn.SetReadDeadline(time.Now().Add(ws.readTimeout)); err != nil {
218 return fmt.Errorf("failed to set websocket read deadline: %w", err)
219 }
220
221 _, buf, err := conn.ReadMessage()
222 if err != nil {
223 if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
224 return nil // normal remote closure
225 }
226
227 if ctx.Err() != nil {
228 return ctx.Err()
229 }
230
231 return fmt.Errorf("failed to read websocket message: %w", err)
232 }
233
234 var ev Event
235 if err := json.Unmarshal(buf, &ev); err != nil {
236 ws.log.Warn("failed to unmarshal event json", "err", err)
237 continue
238 }
239
240 // indefinitely retry messages that failed to process unless a non-retryable error occurrs
241 for errCount := 0; ; errCount++ {
242 if done(ctx) {
243 break
244 }
245
246 err := ws.handler(ctx, &ev)
247 if err == nil {
248 break
249 }
250
251 ws.log.Error("failed to process event", "err", err)
252 if sleepMaybeExit(ctx, errCount) {
253 return nil
254 }
255
256 var nr *NonRetryableError
257 if errors.As(err, &nr) {
258 ws.log.Error("handled non-retryable error", "id", ev.ID, "err", err)
259 break
260 }
261 }
262
263 if ws.sendAcks {
264 ws.ack(ctx, conn, &ev)
265 }
266 }
267}
268
269// Indefinitely tries acking the message with the tap server
270func (ws *Websocket) ack(ctx context.Context, conn *websocket.Conn, ev *Event) {
271 for errCount := 0; ; errCount++ {
272 if done(ctx) {
273 return
274 }
275
276 if err := conn.SetWriteDeadline(time.Now().Add(ws.writeTimeout)); err != nil {
277 ws.log.Warn("failed to set write deadline on ack", "err", err)
278 }
279
280 err := conn.WriteJSON(NewACKPayload(ev.ID))
281 if err == nil {
282 return
283 }
284
285 if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
286 return // normal remote closure
287 }
288
289 ws.log.Error("failed to send ack", "err", err)
290
291 if sleepMaybeExit(ctx, errCount) {
292 return
293 }
294 }
295}
296
297func done(ctx context.Context) bool {
298 select {
299 case <-ctx.Done():
300 return true
301 default:
302 return false
303 }
304}
305
306func sleepMaybeExit(ctx context.Context, errCount int) bool {
307 select {
308 case <-ctx.Done():
309 return true // shutdown received during a backoff sleep means that we're done
310 case <-time.After(backoffDuration(errCount)):
311 return false
312 }
313}
314
315func backoffDuration(errCount int) time.Duration {
316 multiplier := 1 << errCount
317 waitFor := initialBackoff * time.Duration(multiplier)
318
319 return min(waitFor, 10*time.Second)
320}