this repo has no description
at main 320 lines 7.4 kB view raw
1package tapclient 2 3import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "io" 9 "log/slog" 10 "net/url" 11 "sync" 12 "time" 13 14 "github.com/gorilla/websocket" 15) 16 17var ( 18 initialBackoff = 500 * time.Millisecond 19) 20 21// A thin error wrapper that indicates to the tap client consumer loop that a message 22// should not be retried (i.e. invalid user input that will surely fail again on retry). 23type NonRetryableError struct { 24 err error 25} 26 27func NewNonRetryableError(err error) *NonRetryableError { 28 return &NonRetryableError{err: err} 29} 30 31func (err *NonRetryableError) Error() string { 32 if err.err != nil { 33 return err.err.Error() 34 } 35 return "" 36} 37 38// Websocket implements a tap consumer that reads via a websocket 39type Websocket struct { 40 log *slog.Logger 41 42 addr string 43 sendAcks bool 44 maxErrs int 45 46 connectTimeout time.Duration 47 readTimeout time.Duration 48 writeTimeout time.Duration 49 50 handler WebsocketHandlerFunc 51} 52 53// Defines an option for the tap websocket consumer 54type WebsocketOption func(*Websocket) 55 56// Defines the log/slog logger to use throughout the lifecycle of the websocket 57// consumer. Pass nil to disable logging. 58func WithLogger(logger *slog.Logger) func(*Websocket) { 59 return func(ws *Websocket) { 60 ws.log = logger 61 62 if ws.log == nil { 63 // write to io.Discard if a nil logger is passed 64 ws.log = slog.New(slog.NewTextHandler(io.Discard, nil)) 65 } 66 } 67} 68 69// Sets the connect timeout for connecting to the websocket 70func WithConnectTimeout(timeout time.Duration) func(*Websocket) { 71 return func(ws *Websocket) { 72 ws.connectTimeout = timeout 73 } 74} 75 76// Sets the read timeout for reading data from the websocket 77func WithReadTimeout(timeout time.Duration) func(*Websocket) { 78 return func(ws *Websocket) { 79 ws.readTimeout = timeout 80 } 81} 82 83// Sets the write timeout for writing data to the websocket 84func WithWriteTimeout(timeout time.Duration) func(*Websocket) { 85 return func(ws *Websocket) { 86 ws.writeTimeout = timeout 87 } 88} 89 90// Controls how many times the loop will attempt to reconnect to the websocket in a row before giving up 91func WithMaxConsecutiveErrors(numErrs int) func(*Websocket) { 92 return func(ws *Websocket) { 93 ws.maxErrs = numErrs 94 } 95} 96 97// Turns on message acknowledgements 98func WithAcks() func(*Websocket) { 99 return func(ws *Websocket) { 100 ws.sendAcks = true 101 } 102} 103 104// Defines an option for the tap websocket consumer. A nil error indicates that an ACK will be sent to tap 105// if WithAcks() is provided. 106type WebsocketHandlerFunc func(context.Context, *Event) error 107 108// Initializes a tap websocket consumer 109func NewWebsocket(addr string, handler WebsocketHandlerFunc, opts ...WebsocketOption) (*Websocket, error) { 110 u, err := url.Parse(addr) 111 if err != nil { 112 return nil, fmt.Errorf("failed to parse websocket url %q: %w", addr, err) 113 } 114 115 switch u.Scheme { 116 case "ws", "wss": // ok 117 default: 118 return nil, fmt.Errorf("invalid websocket protocol scheme: wanted ws:// or wss://, got %q", u.Scheme) 119 } 120 121 if handler == nil { 122 return nil, fmt.Errorf("a websocket message handler func is required") 123 } 124 125 ws := &Websocket{ 126 log: slog.Default().WithGroup("tap"), 127 128 addr: addr, 129 sendAcks: false, 130 maxErrs: 10, 131 132 connectTimeout: 30 * time.Second, 133 readTimeout: 30 * time.Second, 134 writeTimeout: 30 * time.Second, 135 136 handler: handler, 137 } 138 139 for _, opt := range opts { 140 opt(ws) 141 } 142 143 return ws, nil 144} 145 146// Connects to and beings the main tap websocket consumer loop 147func (ws *Websocket) Run(ctx context.Context) error { 148 for errCount := 0; ; { 149 select { 150 case <-ctx.Done(): 151 ws.log.Debug("websocket ingester shutting down") 152 return nil 153 default: 154 } 155 156 err := ws.runOnce(ctx) 157 if errors.Is(err, context.Canceled) { 158 ws.log.Debug("websocket ingester shutting down") 159 return nil 160 } 161 162 if err == nil { 163 errCount = 0 164 ws.log.Debug("websocket connection closed normally, reconnecting") 165 continue 166 } 167 168 errCount++ 169 ws.log.Error("websocket connection failed", "err", err, "consecutive_errors", errCount) 170 171 if errCount >= ws.maxErrs { 172 return fmt.Errorf("websocket connection failed %d consecutive times: %w", errCount, err) 173 } 174 175 ws.log.Warn("retrying websocket connection", "consecutive_errors", errCount) 176 if sleepMaybeExit(ctx, errCount) { 177 return nil 178 } 179 } 180} 181 182func (ws *Websocket) close(conn *websocket.Conn) { 183 if err := conn.Close(); err != nil { 184 ws.log.Error("failed to close websocket connection", "err", err) 185 } 186} 187 188func (ws *Websocket) runOnce(ctx context.Context) error { 189 dialer := websocket.Dialer{HandshakeTimeout: ws.connectTimeout} 190 conn, _, err := dialer.DialContext(ctx, ws.addr, nil) 191 if err != nil { 192 return fmt.Errorf("failed to connect to websocket at %q: %w", ws.addr, err) 193 } 194 195 var closeOnce sync.Once 196 closeConn := func() { 197 closeOnce.Do(func() { 198 if err := conn.Close(); err != nil { 199 ws.log.Error("failed to close websocket connection", "err", err) 200 } 201 }) 202 } 203 defer closeConn() 204 205 ws.log.Debug("connected to websocket", "addr", ws.addr) 206 207 go func() { 208 <-ctx.Done() 209 closeConn() 210 }() 211 212 for { 213 if done(ctx) { 214 return nil 215 } 216 217 if err := conn.SetReadDeadline(time.Now().Add(ws.readTimeout)); err != nil { 218 return fmt.Errorf("failed to set websocket read deadline: %w", err) 219 } 220 221 _, buf, err := conn.ReadMessage() 222 if err != nil { 223 if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { 224 return nil // normal remote closure 225 } 226 227 if ctx.Err() != nil { 228 return ctx.Err() 229 } 230 231 return fmt.Errorf("failed to read websocket message: %w", err) 232 } 233 234 var ev Event 235 if err := json.Unmarshal(buf, &ev); err != nil { 236 ws.log.Warn("failed to unmarshal event json", "err", err) 237 continue 238 } 239 240 // indefinitely retry messages that failed to process unless a non-retryable error occurrs 241 for errCount := 0; ; errCount++ { 242 if done(ctx) { 243 break 244 } 245 246 err := ws.handler(ctx, &ev) 247 if err == nil { 248 break 249 } 250 251 ws.log.Error("failed to process event", "err", err) 252 if sleepMaybeExit(ctx, errCount) { 253 return nil 254 } 255 256 var nr *NonRetryableError 257 if errors.As(err, &nr) { 258 ws.log.Error("handled non-retryable error", "id", ev.ID, "err", err) 259 break 260 } 261 } 262 263 if ws.sendAcks { 264 ws.ack(ctx, conn, &ev) 265 } 266 } 267} 268 269// Indefinitely tries acking the message with the tap server 270func (ws *Websocket) ack(ctx context.Context, conn *websocket.Conn, ev *Event) { 271 for errCount := 0; ; errCount++ { 272 if done(ctx) { 273 return 274 } 275 276 if err := conn.SetWriteDeadline(time.Now().Add(ws.writeTimeout)); err != nil { 277 ws.log.Warn("failed to set write deadline on ack", "err", err) 278 } 279 280 err := conn.WriteJSON(NewACKPayload(ev.ID)) 281 if err == nil { 282 return 283 } 284 285 if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { 286 return // normal remote closure 287 } 288 289 ws.log.Error("failed to send ack", "err", err) 290 291 if sleepMaybeExit(ctx, errCount) { 292 return 293 } 294 } 295} 296 297func done(ctx context.Context) bool { 298 select { 299 case <-ctx.Done(): 300 return true 301 default: 302 return false 303 } 304} 305 306func sleepMaybeExit(ctx context.Context, errCount int) bool { 307 select { 308 case <-ctx.Done(): 309 return true // shutdown received during a backoff sleep means that we're done 310 case <-time.After(backoffDuration(errCount)): 311 return false 312 } 313} 314 315func backoffDuration(errCount int) time.Duration { 316 multiplier := 1 << errCount 317 waitFor := initialBackoff * time.Duration(multiplier) 318 319 return min(waitFor, 10*time.Second) 320}