this repo has no description

vendor in WIP tapclient from indigo

+857
+37
tapclient/doc.go
··· 1 + // Package tap provides a client for consuming atproto events from a tap websocket. 2 + // 3 + // (this is jcalabro code from https://github.com/bluesky-social/indigo/pull/1241) 4 + // 5 + // The client handles connection management, automatic reconnection with backoff, 6 + // and optional message acknowledgements. 7 + // 8 + // Basic usage: 9 + // 10 + // handler := func(ctx context.Context, ev *tap.Event) error { 11 + // switch payload := ev.Payload().(type) { 12 + // case *tap.RecordEvent: 13 + // fmt.Printf("record.Action: %s\n", payload.Action) 14 + // fmt.Printf("record.Collection: %s\n", payload.Collection) 15 + // case *tap.IdentityEvent: 16 + // fmt.Printf("identity.DID: %s\n", payload.DID) 17 + // fmt.Printf("identity.Handle: %s\n", payload.Handle) 18 + // } 19 + // return nil 20 + // } 21 + // 22 + // ws, err := tap.NewWebsocket("wss://example.com/tap", handler, 23 + // tap.WithLogger(slog.Default()), 24 + // tap.WithAcks(), 25 + // ) 26 + // if err != nil { 27 + // // handle error... 28 + // } 29 + // 30 + // if err := ws.Run(ctx); err != nil { 31 + // // handle error... 32 + // } 33 + // 34 + // Returning an error from the handler will cause the message to be retried with 35 + // exponential backoff. To skip retries for permanent failures, wrap the error 36 + // with [NewNonRetryableError]. 37 + package tapclient
+114
tapclient/event.go
··· 1 + package tapclient 2 + 3 + import ( 4 + "encoding/json" 5 + "fmt" 6 + ) 7 + 8 + const ( 9 + eventTypeACK = "ack" 10 + eventTypeRecord = "record" 11 + eventTypeIdentity = "identity" 12 + ) 13 + 14 + // Event represents an atproto event from tap. Use a type switch on the Payload() method to access event data. 15 + type Event struct { 16 + ID uint64 17 + Type string 18 + 19 + record *RecordEvent 20 + identity *IdentityEvent 21 + } 22 + 23 + // RecordEvent represents a record creation, update, or deletion in a repository 24 + type RecordEvent struct { 25 + DID string `json:"did"` 26 + Collection string `json:"collection"` 27 + Rkey string `json:"rkey"` 28 + Action string `json:"action"` 29 + CID string `json:"cid"` 30 + Record json.RawMessage `json:"record"` 31 + Live bool `json:"live"` 32 + } 33 + 34 + // IdentityEvent represents an account status change 35 + type IdentityEvent struct { 36 + DID string `json:"did"` 37 + Handle string `json:"handle"` 38 + IsActive bool `json:"isActive"` 39 + Status string `json:"status"` 40 + } 41 + 42 + func (e *Event) UnmarshalJSON(data []byte) error { 43 + event := struct { 44 + ID uint64 `json:"id"` 45 + Type string `json:"type"` 46 + Record json.RawMessage `json:"record,omitempty"` 47 + Identity json.RawMessage `json:"identity,omitempty"` 48 + }{} 49 + 50 + if err := json.Unmarshal(data, &event); err != nil { 51 + return fmt.Errorf("failed to unmarshal tap event: %w", err) 52 + } 53 + 54 + e.ID = event.ID 55 + e.Type = event.Type 56 + 57 + switch event.Type { 58 + case eventTypeRecord: 59 + e.record = &RecordEvent{} 60 + if err := json.Unmarshal(event.Record, e.record); err != nil { 61 + return fmt.Errorf("failed to unmarshal tap record event: %w", err) 62 + } 63 + case eventTypeIdentity: 64 + e.identity = &IdentityEvent{} 65 + if err := json.Unmarshal(event.Identity, e.identity); err != nil { 66 + return fmt.Errorf("failed to unmarshal tap identity event: %w", err) 67 + } 68 + default: 69 + return fmt.Errorf("unknown event type %q", event.Type) 70 + } 71 + 72 + return nil 73 + } 74 + 75 + func (e Event) MarshalJSON() ([]byte, error) { 76 + event := struct { 77 + ID uint64 `json:"id"` 78 + Type string `json:"type"` 79 + Record *RecordEvent `json:"record,omitempty"` 80 + Identity *IdentityEvent `json:"identity,omitempty"` 81 + }{ 82 + ID: e.ID, 83 + Type: e.Type, 84 + Record: e.record, 85 + Identity: e.identity, 86 + } 87 + 88 + buf, err := json.Marshal(event) 89 + if err != nil { 90 + return nil, fmt.Errorf("failed to marshal tap event: %w", err) 91 + } 92 + 93 + return buf, nil 94 + } 95 + 96 + // Payload returns the typed event data as either *RecordEvent or *IdentityEvent. 97 + func (e *Event) Payload() any { 98 + switch e.Type { 99 + case eventTypeRecord: 100 + return e.record 101 + case eventTypeIdentity: 102 + return e.identity 103 + } 104 + 105 + return nil // unreachable 106 + } 107 + 108 + // Constructs a new ACK object to be serialized and sent back to tap 109 + func NewACKPayload(id uint64) *Event { 110 + return &Event{ 111 + Type: eventTypeACK, 112 + ID: id, 113 + } 114 + }
+108
tapclient/event_test.go
··· 1 + package tapclient 2 + 3 + import ( 4 + "encoding/json" 5 + "testing" 6 + 7 + "github.com/stretchr/testify/require" 8 + ) 9 + 10 + func TestEventJSON(t *testing.T) { 11 + t.Parallel() 12 + 13 + t.Run("marshal/unmarshal record", func(t *testing.T) { 14 + t.Parallel() 15 + require := require.New(t) 16 + 17 + original := Event{ 18 + ID: 123, 19 + Type: eventTypeRecord, 20 + record: &RecordEvent{ 21 + DID: "did:plc:test", 22 + Collection: "app.bsky.feed.post", 23 + Rkey: "abc123", 24 + Action: "create", 25 + CID: "bafytest", 26 + Record: json.RawMessage(`{"text":"hello"}`), 27 + Live: true, 28 + }, 29 + } 30 + 31 + buf, err := json.Marshal(original) 32 + require.NoError(err) 33 + 34 + var decoded Event 35 + require.NoError(json.Unmarshal(buf, &decoded)) 36 + require.Equal(original.ID, decoded.ID) 37 + require.Equal(original.Type, decoded.Type) 38 + 39 + payload, ok := decoded.Payload().(*RecordEvent) 40 + require.True(ok) 41 + require.Equal(original.record.DID, payload.DID) 42 + require.Equal(original.record.Collection, payload.Collection) 43 + require.Equal(original.record.Rkey, payload.Rkey) 44 + require.Equal(original.record.Action, payload.Action) 45 + require.Equal(original.record.CID, payload.CID) 46 + require.Equal(original.record.Live, payload.Live) 47 + require.JSONEq(string(original.record.Record), string(payload.Record)) 48 + }) 49 + 50 + t.Run("marshal/unmarshal identity", func(t *testing.T) { 51 + t.Parallel() 52 + require := require.New(t) 53 + 54 + original := Event{ 55 + ID: 456, 56 + Type: eventTypeIdentity, 57 + identity: &IdentityEvent{ 58 + DID: "did:plc:user", 59 + Handle: "test.bsky.social", 60 + IsActive: true, 61 + Status: "active", 62 + }, 63 + } 64 + 65 + buf, err := json.Marshal(original) 66 + require.NoError(err) 67 + 68 + var decoded Event 69 + require.NoError(json.Unmarshal(buf, &decoded)) 70 + require.Equal(original.ID, decoded.ID) 71 + require.Equal(original.Type, decoded.Type) 72 + 73 + payload, ok := decoded.Payload().(*IdentityEvent) 74 + require.True(ok) 75 + require.Equal(original.identity.DID, payload.DID) 76 + require.Equal(original.identity.Handle, payload.Handle) 77 + require.Equal(original.identity.IsActive, payload.IsActive) 78 + require.Equal(original.identity.Status, payload.Status) 79 + }) 80 + 81 + t.Run("unmarshal from raw json", func(t *testing.T) { 82 + t.Parallel() 83 + require := require.New(t) 84 + 85 + recordJSON := `{"id":1,"type":"record","record":{"did":"did:plc:abc","collection":"app.bsky.feed.like","rkey":"xyz","action":"create","cid":"mycid","record":{"subject":"at://did:plc:foo/app.bsky.feed.post/xyz"},"live":false}}` 86 + var recordEvent Event 87 + require.NoError(json.Unmarshal([]byte(recordJSON), &recordEvent)) 88 + require.Equal(uint64(1), recordEvent.ID) 89 + require.Equal(eventTypeRecord, recordEvent.Type) 90 + 91 + identityJSON := `{"id":2,"type":"identity","identity":{"did":"did:plc:def","handle":"foo.test","isActive":true,"status":"active"}}` 92 + var identEvent Event 93 + require.NoError(json.Unmarshal([]byte(identityJSON), &identEvent)) 94 + require.Equal(uint64(2), identEvent.ID) 95 + require.Equal(eventTypeIdentity, identEvent.Type) 96 + }) 97 + 98 + t.Run("unmarshal unknown type", func(t *testing.T) { 99 + t.Parallel() 100 + require := require.New(t) 101 + 102 + badJSON := `{"id":1,"type":"unknown"}` 103 + var ev Event 104 + err := json.Unmarshal([]byte(badJSON), &ev) 105 + require.Error(err) 106 + require.Contains(err.Error(), "unknown event type") 107 + }) 108 + }
+320
tapclient/websocket.go
··· 1 + package tapclient 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "errors" 7 + "fmt" 8 + "io" 9 + "log/slog" 10 + "net/url" 11 + "sync" 12 + "time" 13 + 14 + "github.com/gorilla/websocket" 15 + ) 16 + 17 + var ( 18 + initialBackoff = 500 * time.Millisecond 19 + ) 20 + 21 + // A thin error wrapper that indicates to the tap client consumer loop that a message 22 + // should not be retried (i.e. invalid user input that will surely fail again on retry). 23 + type NonRetryableError struct { 24 + err error 25 + } 26 + 27 + func NewNonRetryableError(err error) *NonRetryableError { 28 + return &NonRetryableError{err: err} 29 + } 30 + 31 + func (err *NonRetryableError) Error() string { 32 + if err.err != nil { 33 + return err.err.Error() 34 + } 35 + return "" 36 + } 37 + 38 + // Websocket implements a tap consumer that reads via a websocket 39 + type Websocket struct { 40 + log *slog.Logger 41 + 42 + addr string 43 + sendAcks bool 44 + maxErrs int 45 + 46 + connectTimeout time.Duration 47 + readTimeout time.Duration 48 + writeTimeout time.Duration 49 + 50 + handler WebsocketHandlerFunc 51 + } 52 + 53 + // Defines an option for the tap websocket consumer 54 + type WebsocketOption func(*Websocket) 55 + 56 + // Defines the log/slog logger to use throughout the lifecycle of the websocket 57 + // consumer. Pass nil to disable logging. 58 + func WithLogger(logger *slog.Logger) func(*Websocket) { 59 + return func(ws *Websocket) { 60 + ws.log = logger 61 + 62 + if ws.log == nil { 63 + // write to io.Discard if a nil logger is passed 64 + ws.log = slog.New(slog.NewTextHandler(io.Discard, nil)) 65 + } 66 + } 67 + } 68 + 69 + // Sets the connect timeout for connecting to the websocket 70 + func WithConnectTimeout(timeout time.Duration) func(*Websocket) { 71 + return func(ws *Websocket) { 72 + ws.connectTimeout = timeout 73 + } 74 + } 75 + 76 + // Sets the read timeout for reading data from the websocket 77 + func WithReadTimeout(timeout time.Duration) func(*Websocket) { 78 + return func(ws *Websocket) { 79 + ws.readTimeout = timeout 80 + } 81 + } 82 + 83 + // Sets the write timeout for writing data to the websocket 84 + func WithWriteTimeout(timeout time.Duration) func(*Websocket) { 85 + return func(ws *Websocket) { 86 + ws.writeTimeout = timeout 87 + } 88 + } 89 + 90 + // Controls how many times the loop will attempt to reconnect to the websocket in a row before giving up 91 + func WithMaxConsecutiveErrors(numErrs int) func(*Websocket) { 92 + return func(ws *Websocket) { 93 + ws.maxErrs = numErrs 94 + } 95 + } 96 + 97 + // Turns on message acknowledgements 98 + func WithAcks() func(*Websocket) { 99 + return func(ws *Websocket) { 100 + ws.sendAcks = true 101 + } 102 + } 103 + 104 + // Defines an option for the tap websocket consumer. A nil error indicates that an ACK will be sent to tap 105 + // if WithAcks() is provided. 106 + type WebsocketHandlerFunc func(context.Context, *Event) error 107 + 108 + // Initializes a tap websocket consumer 109 + func NewWebsocket(addr string, handler WebsocketHandlerFunc, opts ...WebsocketOption) (*Websocket, error) { 110 + u, err := url.Parse(addr) 111 + if err != nil { 112 + return nil, fmt.Errorf("failed to parse websocket url %q: %w", addr, err) 113 + } 114 + 115 + switch u.Scheme { 116 + case "ws", "wss": // ok 117 + default: 118 + return nil, fmt.Errorf("invalid websocket protocol scheme: wanted ws:// or wss://, got %q", u.Scheme) 119 + } 120 + 121 + if handler == nil { 122 + return nil, fmt.Errorf("a websocket message handler func is required") 123 + } 124 + 125 + ws := &Websocket{ 126 + log: slog.Default().WithGroup("tap"), 127 + 128 + addr: addr, 129 + sendAcks: false, 130 + maxErrs: 10, 131 + 132 + connectTimeout: 30 * time.Second, 133 + readTimeout: 30 * time.Second, 134 + writeTimeout: 30 * time.Second, 135 + 136 + handler: handler, 137 + } 138 + 139 + for _, opt := range opts { 140 + opt(ws) 141 + } 142 + 143 + return ws, nil 144 + } 145 + 146 + // Connects to and beings the main tap websocket consumer loop 147 + func (ws *Websocket) Run(ctx context.Context) error { 148 + for errCount := 0; ; { 149 + select { 150 + case <-ctx.Done(): 151 + ws.log.Debug("websocket ingester shutting down") 152 + return nil 153 + default: 154 + } 155 + 156 + err := ws.runOnce(ctx) 157 + if errors.Is(err, context.Canceled) { 158 + ws.log.Debug("websocket ingester shutting down") 159 + return nil 160 + } 161 + 162 + if err == nil { 163 + errCount = 0 164 + ws.log.Debug("websocket connection closed normally, reconnecting") 165 + continue 166 + } 167 + 168 + errCount++ 169 + ws.log.Error("websocket connection failed", "err", err, "consecutive_errors", errCount) 170 + 171 + if errCount >= ws.maxErrs { 172 + return fmt.Errorf("websocket connection failed %d consecutive times: %w", errCount, err) 173 + } 174 + 175 + ws.log.Warn("retrying websocket connection", "consecutive_errors", errCount) 176 + if sleepMaybeExit(ctx, errCount) { 177 + return nil 178 + } 179 + } 180 + } 181 + 182 + func (ws *Websocket) close(conn *websocket.Conn) { 183 + if err := conn.Close(); err != nil { 184 + ws.log.Error("failed to close websocket connection", "err", err) 185 + } 186 + } 187 + 188 + func (ws *Websocket) runOnce(ctx context.Context) error { 189 + dialer := websocket.Dialer{HandshakeTimeout: ws.connectTimeout} 190 + conn, _, err := dialer.DialContext(ctx, ws.addr, nil) 191 + if err != nil { 192 + return fmt.Errorf("failed to connect to websocket at %q: %w", ws.addr, err) 193 + } 194 + 195 + var closeOnce sync.Once 196 + closeConn := func() { 197 + closeOnce.Do(func() { 198 + if err := conn.Close(); err != nil { 199 + ws.log.Error("failed to close websocket connection", "err", err) 200 + } 201 + }) 202 + } 203 + defer closeConn() 204 + 205 + ws.log.Debug("connected to websocket", "addr", ws.addr) 206 + 207 + go func() { 208 + <-ctx.Done() 209 + closeConn() 210 + }() 211 + 212 + for { 213 + if done(ctx) { 214 + return nil 215 + } 216 + 217 + if err := conn.SetReadDeadline(time.Now().Add(ws.readTimeout)); err != nil { 218 + return fmt.Errorf("failed to set websocket read deadline: %w", err) 219 + } 220 + 221 + _, buf, err := conn.ReadMessage() 222 + if err != nil { 223 + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { 224 + return nil // normal remote closure 225 + } 226 + 227 + if ctx.Err() != nil { 228 + return ctx.Err() 229 + } 230 + 231 + return fmt.Errorf("failed to read websocket message: %w", err) 232 + } 233 + 234 + var ev Event 235 + if err := json.Unmarshal(buf, &ev); err != nil { 236 + ws.log.Warn("failed to unmarshal event json", "err", err) 237 + continue 238 + } 239 + 240 + // indefinitely retry messages that failed to process unless a non-retryable error occurrs 241 + for errCount := 0; ; errCount++ { 242 + if done(ctx) { 243 + break 244 + } 245 + 246 + err := ws.handler(ctx, &ev) 247 + if err == nil { 248 + break 249 + } 250 + 251 + ws.log.Error("failed to process event", "err", err) 252 + if sleepMaybeExit(ctx, errCount) { 253 + return nil 254 + } 255 + 256 + var nr *NonRetryableError 257 + if errors.As(err, &nr) { 258 + ws.log.Error("handled non-retryable error", "id", ev.ID, "err", err) 259 + break 260 + } 261 + } 262 + 263 + if ws.sendAcks { 264 + ws.ack(ctx, conn, &ev) 265 + } 266 + } 267 + } 268 + 269 + // Indefinitely tries acking the message with the tap server 270 + func (ws *Websocket) ack(ctx context.Context, conn *websocket.Conn, ev *Event) { 271 + for errCount := 0; ; errCount++ { 272 + if done(ctx) { 273 + return 274 + } 275 + 276 + if err := conn.SetWriteDeadline(time.Now().Add(ws.writeTimeout)); err != nil { 277 + ws.log.Warn("failed to set write deadline on ack", "err", err) 278 + } 279 + 280 + err := conn.WriteJSON(NewACKPayload(ev.ID)) 281 + if err == nil { 282 + return 283 + } 284 + 285 + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { 286 + return // normal remote closure 287 + } 288 + 289 + ws.log.Error("failed to send ack", "err", err) 290 + 291 + if sleepMaybeExit(ctx, errCount) { 292 + return 293 + } 294 + } 295 + } 296 + 297 + func done(ctx context.Context) bool { 298 + select { 299 + case <-ctx.Done(): 300 + return true 301 + default: 302 + return false 303 + } 304 + } 305 + 306 + func sleepMaybeExit(ctx context.Context, errCount int) bool { 307 + select { 308 + case <-ctx.Done(): 309 + return true // shutdown received during a backoff sleep means that we're done 310 + case <-time.After(backoffDuration(errCount)): 311 + return false 312 + } 313 + } 314 + 315 + func backoffDuration(errCount int) time.Duration { 316 + multiplier := 1 << errCount 317 + waitFor := initialBackoff * time.Duration(multiplier) 318 + 319 + return min(waitFor, 10*time.Second) 320 + }
+278
tapclient/websocket_test.go
··· 1 + package tapclient 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "errors" 7 + "net/http" 8 + "net/http/httptest" 9 + "strings" 10 + "sync" 11 + "testing" 12 + "time" 13 + 14 + "github.com/gorilla/websocket" 15 + "github.com/stretchr/testify/require" 16 + ) 17 + 18 + func init() { 19 + initialBackoff = 0 20 + } 21 + 22 + var upgrader = websocket.Upgrader{} 23 + 24 + func TestWebsocket(t *testing.T) { 25 + t.Parallel() 26 + ctx := t.Context() 27 + require := require.New(t) 28 + 29 + events := []Event{ 30 + {ID: 1, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:1", Collection: "app.bsky.feed.post"}}, 31 + {ID: 2, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:2", Collection: "app.bsky.feed.like"}}, 32 + {ID: 3, Type: eventTypeIdentity, identity: &IdentityEvent{DID: "did:plc:3", Handle: "user3.test"}}, 33 + } 34 + 35 + var received []*Event 36 + var mu sync.Mutex 37 + var wg sync.WaitGroup 38 + wg.Add(len(events)) 39 + 40 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 41 + conn, err := upgrader.Upgrade(w, r, nil) 42 + if err != nil { 43 + return 44 + } 45 + defer conn.Close() 46 + 47 + for _, ev := range events { 48 + buf, _ := json.Marshal(ev) 49 + conn.WriteMessage(websocket.TextMessage, buf) 50 + time.Sleep(10 * time.Millisecond) 51 + } 52 + 53 + time.Sleep(50 * time.Millisecond) 54 + conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 55 + })) 56 + defer server.Close() 57 + 58 + wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://") 59 + 60 + ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error { 61 + mu.Lock() 62 + received = append(received, ev) 63 + mu.Unlock() 64 + wg.Done() 65 + return nil 66 + }, WithLogger(nil)) 67 + require.NoError(err) 68 + 69 + go ws.Run(ctx) 70 + wg.Wait() 71 + 72 + require.Len(received, 3) 73 + for i, ev := range received { 74 + require.Equal(uint64(i+1), ev.ID) 75 + 76 + switch i { 77 + case 0, 1: 78 + switch pl := ev.Payload().(type) { 79 + case *RecordEvent: 80 + require.NotNil(events[i].record) 81 + require.Equal(events[i].record.Collection, pl.Collection) 82 + require.Equal(events[i].Type, eventTypeRecord) 83 + default: 84 + require.FailNow("incorrect payload type, want %T got %T", &RecordEvent{}, ev.Payload()) 85 + } 86 + 87 + case 2: 88 + switch pl := ev.Payload().(type) { 89 + case *IdentityEvent: 90 + require.NotNil(events[i].identity) 91 + require.Equal(events[i].identity.Handle, pl.Handle) 92 + require.Equal(events[i].Type, eventTypeIdentity) 93 + default: 94 + require.FailNow("incorrect payload type, want %T got %T", &IdentityEvent{}, ev.Payload()) 95 + } 96 + } 97 + } 98 + } 99 + 100 + func TestWebsocketWithAcks(t *testing.T) { 101 + t.Parallel() 102 + 103 + t.Run("ack sent on success", func(t *testing.T) { 104 + t.Parallel() 105 + ctx := t.Context() 106 + require := require.New(t) 107 + 108 + recordEvent := Event{ 109 + ID: 42, 110 + Type: eventTypeRecord, 111 + record: &RecordEvent{ 112 + DID: "did:plc:ack", 113 + Collection: "app.bsky.feed.like", 114 + Rkey: "ack", 115 + Action: "create", 116 + }, 117 + } 118 + 119 + var receivedAck *Event 120 + var wg sync.WaitGroup 121 + wg.Add(1) 122 + 123 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 124 + conn, err := upgrader.Upgrade(w, r, nil) 125 + if err != nil { 126 + return 127 + } 128 + defer conn.Close() 129 + 130 + buf, _ := json.Marshal(recordEvent) 131 + conn.WriteMessage(websocket.TextMessage, buf) 132 + 133 + _, ackBuf, err := conn.ReadMessage() 134 + if err == nil { 135 + receivedAck = &Event{} 136 + json.Unmarshal(ackBuf, receivedAck) 137 + } 138 + wg.Done() 139 + 140 + conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 141 + })) 142 + defer server.Close() 143 + 144 + wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://") 145 + 146 + ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error { 147 + return nil 148 + }, WithLogger(nil), WithAcks()) 149 + require.NoError(err) 150 + 151 + go ws.Run(ctx) 152 + wg.Wait() 153 + 154 + require.NotNil(receivedAck) 155 + require.Equal(eventTypeACK, receivedAck.Type) 156 + require.Equal(recordEvent.ID, receivedAck.ID) 157 + }) 158 + 159 + t.Run("ack not sent on error", func(t *testing.T) { 160 + t.Parallel() 161 + ctx := t.Context() 162 + require := require.New(t) 163 + 164 + recordEvent := Event{ 165 + ID: 99, 166 + Type: eventTypeRecord, 167 + record: &RecordEvent{ 168 + DID: "did:plc:noack", 169 + Collection: "app.bsky.feed.post", 170 + Rkey: "noack", 171 + Action: "create", 172 + }, 173 + } 174 + 175 + var receivedAck bool 176 + var wg sync.WaitGroup 177 + wg.Add(1) 178 + 179 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 180 + conn, err := upgrader.Upgrade(w, r, nil) 181 + if err != nil { 182 + return 183 + } 184 + defer conn.Close() 185 + 186 + buf, _ := json.Marshal(recordEvent) 187 + conn.WriteMessage(websocket.TextMessage, buf) 188 + 189 + conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) 190 + _, _, err = conn.ReadMessage() 191 + receivedAck = err == nil 192 + wg.Done() 193 + 194 + conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 195 + })) 196 + defer server.Close() 197 + 198 + wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://") 199 + 200 + ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error { 201 + return errors.New("processing failed") 202 + }, WithLogger(nil), WithAcks()) 203 + require.NoError(err) 204 + 205 + go ws.Run(ctx) 206 + wg.Wait() 207 + 208 + require.False(receivedAck, "expected no ACK when handler returns error") 209 + }) 210 + } 211 + 212 + func TestWebsocketNonRetryableError(t *testing.T) { 213 + t.Parallel() 214 + ctx := t.Context() 215 + require := require.New(t) 216 + 217 + events := []Event{ 218 + {ID: 1, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:1", Collection: "app.bsky.feed.post"}}, 219 + {ID: 2, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:2", Collection: "app.bsky.feed.post"}}, 220 + {ID: 3, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:3", Collection: "app.bsky.feed.post"}}, 221 + } 222 + 223 + var callCounts sync.Map 224 + var wg sync.WaitGroup 225 + wg.Add(len(events)) 226 + 227 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 228 + conn, err := upgrader.Upgrade(w, r, nil) 229 + if err != nil { 230 + return 231 + } 232 + defer conn.Close() 233 + 234 + for _, ev := range events { 235 + buf, _ := json.Marshal(ev) 236 + conn.WriteMessage(websocket.TextMessage, buf) 237 + time.Sleep(10 * time.Millisecond) 238 + } 239 + 240 + time.Sleep(100 * time.Millisecond) 241 + conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 242 + })) 243 + defer server.Close() 244 + 245 + wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://") 246 + 247 + ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error { 248 + val, _ := callCounts.LoadOrStore(ev.ID, new(int)) 249 + count := val.(*int) 250 + *count++ 251 + 252 + if ev.ID == 2 { 253 + if *count == 1 { 254 + wg.Done() 255 + } 256 + return NewNonRetryableError(errors.New("bad input, do not retry")) 257 + } 258 + 259 + wg.Done() 260 + return nil 261 + }, WithLogger(nil)) 262 + require.NoError(err) 263 + 264 + go ws.Run(ctx) 265 + wg.Wait() 266 + 267 + // event 1: should be called exactly once (success) 268 + val1, _ := callCounts.Load(uint64(1)) 269 + require.Equal(1, *val1.(*int), "event 1 should be processed once") 270 + 271 + // event 2: should be called exactly once (non-retryable error, no retry) 272 + val2, _ := callCounts.Load(uint64(2)) 273 + require.Equal(1, *val2.(*int), "event 2 with NonRetryableError should not be retried") 274 + 275 + // event 3: should be called exactly once (success, proving we moved on after non-retryable) 276 + val3, _ := callCounts.Load(uint64(3)) 277 + require.Equal(1, *val3.(*int), "event 3 should be processed after non-retryable error") 278 + }