tangled
alpha
login
or
join now
bnewbold.net
/
scrumble
2
fork
atom
this repo has no description
2
fork
atom
overview
issues
pulls
pipelines
vendor in WIP tapclient from indigo
bnewbold.net
1 month ago
909cff2e
77987282
+857
5 changed files
expand all
collapse all
unified
split
tapclient
doc.go
event.go
event_test.go
websocket.go
websocket_test.go
+37
tapclient/doc.go
···
1
1
+
// Package tap provides a client for consuming atproto events from a tap websocket.
2
2
+
//
3
3
+
// (this is jcalabro code from https://github.com/bluesky-social/indigo/pull/1241)
4
4
+
//
5
5
+
// The client handles connection management, automatic reconnection with backoff,
6
6
+
// and optional message acknowledgements.
7
7
+
//
8
8
+
// Basic usage:
9
9
+
//
10
10
+
// handler := func(ctx context.Context, ev *tap.Event) error {
11
11
+
// switch payload := ev.Payload().(type) {
12
12
+
// case *tap.RecordEvent:
13
13
+
// fmt.Printf("record.Action: %s\n", payload.Action)
14
14
+
// fmt.Printf("record.Collection: %s\n", payload.Collection)
15
15
+
// case *tap.IdentityEvent:
16
16
+
// fmt.Printf("identity.DID: %s\n", payload.DID)
17
17
+
// fmt.Printf("identity.Handle: %s\n", payload.Handle)
18
18
+
// }
19
19
+
// return nil
20
20
+
// }
21
21
+
//
22
22
+
// ws, err := tap.NewWebsocket("wss://example.com/tap", handler,
23
23
+
// tap.WithLogger(slog.Default()),
24
24
+
// tap.WithAcks(),
25
25
+
// )
26
26
+
// if err != nil {
27
27
+
// // handle error...
28
28
+
// }
29
29
+
//
30
30
+
// if err := ws.Run(ctx); err != nil {
31
31
+
// // handle error...
32
32
+
// }
33
33
+
//
34
34
+
// Returning an error from the handler will cause the message to be retried with
35
35
+
// exponential backoff. To skip retries for permanent failures, wrap the error
36
36
+
// with [NewNonRetryableError].
37
37
+
package tapclient
+114
tapclient/event.go
···
1
1
+
package tapclient
2
2
+
3
3
+
import (
4
4
+
"encoding/json"
5
5
+
"fmt"
6
6
+
)
7
7
+
8
8
+
const (
9
9
+
eventTypeACK = "ack"
10
10
+
eventTypeRecord = "record"
11
11
+
eventTypeIdentity = "identity"
12
12
+
)
13
13
+
14
14
+
// Event represents an atproto event from tap. Use a type switch on the Payload() method to access event data.
15
15
+
type Event struct {
16
16
+
ID uint64
17
17
+
Type string
18
18
+
19
19
+
record *RecordEvent
20
20
+
identity *IdentityEvent
21
21
+
}
22
22
+
23
23
+
// RecordEvent represents a record creation, update, or deletion in a repository
24
24
+
type RecordEvent struct {
25
25
+
DID string `json:"did"`
26
26
+
Collection string `json:"collection"`
27
27
+
Rkey string `json:"rkey"`
28
28
+
Action string `json:"action"`
29
29
+
CID string `json:"cid"`
30
30
+
Record json.RawMessage `json:"record"`
31
31
+
Live bool `json:"live"`
32
32
+
}
33
33
+
34
34
+
// IdentityEvent represents an account status change
35
35
+
type IdentityEvent struct {
36
36
+
DID string `json:"did"`
37
37
+
Handle string `json:"handle"`
38
38
+
IsActive bool `json:"isActive"`
39
39
+
Status string `json:"status"`
40
40
+
}
41
41
+
42
42
+
func (e *Event) UnmarshalJSON(data []byte) error {
43
43
+
event := struct {
44
44
+
ID uint64 `json:"id"`
45
45
+
Type string `json:"type"`
46
46
+
Record json.RawMessage `json:"record,omitempty"`
47
47
+
Identity json.RawMessage `json:"identity,omitempty"`
48
48
+
}{}
49
49
+
50
50
+
if err := json.Unmarshal(data, &event); err != nil {
51
51
+
return fmt.Errorf("failed to unmarshal tap event: %w", err)
52
52
+
}
53
53
+
54
54
+
e.ID = event.ID
55
55
+
e.Type = event.Type
56
56
+
57
57
+
switch event.Type {
58
58
+
case eventTypeRecord:
59
59
+
e.record = &RecordEvent{}
60
60
+
if err := json.Unmarshal(event.Record, e.record); err != nil {
61
61
+
return fmt.Errorf("failed to unmarshal tap record event: %w", err)
62
62
+
}
63
63
+
case eventTypeIdentity:
64
64
+
e.identity = &IdentityEvent{}
65
65
+
if err := json.Unmarshal(event.Identity, e.identity); err != nil {
66
66
+
return fmt.Errorf("failed to unmarshal tap identity event: %w", err)
67
67
+
}
68
68
+
default:
69
69
+
return fmt.Errorf("unknown event type %q", event.Type)
70
70
+
}
71
71
+
72
72
+
return nil
73
73
+
}
74
74
+
75
75
+
func (e Event) MarshalJSON() ([]byte, error) {
76
76
+
event := struct {
77
77
+
ID uint64 `json:"id"`
78
78
+
Type string `json:"type"`
79
79
+
Record *RecordEvent `json:"record,omitempty"`
80
80
+
Identity *IdentityEvent `json:"identity,omitempty"`
81
81
+
}{
82
82
+
ID: e.ID,
83
83
+
Type: e.Type,
84
84
+
Record: e.record,
85
85
+
Identity: e.identity,
86
86
+
}
87
87
+
88
88
+
buf, err := json.Marshal(event)
89
89
+
if err != nil {
90
90
+
return nil, fmt.Errorf("failed to marshal tap event: %w", err)
91
91
+
}
92
92
+
93
93
+
return buf, nil
94
94
+
}
95
95
+
96
96
+
// Payload returns the typed event data as either *RecordEvent or *IdentityEvent.
97
97
+
func (e *Event) Payload() any {
98
98
+
switch e.Type {
99
99
+
case eventTypeRecord:
100
100
+
return e.record
101
101
+
case eventTypeIdentity:
102
102
+
return e.identity
103
103
+
}
104
104
+
105
105
+
return nil // unreachable
106
106
+
}
107
107
+
108
108
+
// Constructs a new ACK object to be serialized and sent back to tap
109
109
+
func NewACKPayload(id uint64) *Event {
110
110
+
return &Event{
111
111
+
Type: eventTypeACK,
112
112
+
ID: id,
113
113
+
}
114
114
+
}
+108
tapclient/event_test.go
···
1
1
+
package tapclient
2
2
+
3
3
+
import (
4
4
+
"encoding/json"
5
5
+
"testing"
6
6
+
7
7
+
"github.com/stretchr/testify/require"
8
8
+
)
9
9
+
10
10
+
func TestEventJSON(t *testing.T) {
11
11
+
t.Parallel()
12
12
+
13
13
+
t.Run("marshal/unmarshal record", func(t *testing.T) {
14
14
+
t.Parallel()
15
15
+
require := require.New(t)
16
16
+
17
17
+
original := Event{
18
18
+
ID: 123,
19
19
+
Type: eventTypeRecord,
20
20
+
record: &RecordEvent{
21
21
+
DID: "did:plc:test",
22
22
+
Collection: "app.bsky.feed.post",
23
23
+
Rkey: "abc123",
24
24
+
Action: "create",
25
25
+
CID: "bafytest",
26
26
+
Record: json.RawMessage(`{"text":"hello"}`),
27
27
+
Live: true,
28
28
+
},
29
29
+
}
30
30
+
31
31
+
buf, err := json.Marshal(original)
32
32
+
require.NoError(err)
33
33
+
34
34
+
var decoded Event
35
35
+
require.NoError(json.Unmarshal(buf, &decoded))
36
36
+
require.Equal(original.ID, decoded.ID)
37
37
+
require.Equal(original.Type, decoded.Type)
38
38
+
39
39
+
payload, ok := decoded.Payload().(*RecordEvent)
40
40
+
require.True(ok)
41
41
+
require.Equal(original.record.DID, payload.DID)
42
42
+
require.Equal(original.record.Collection, payload.Collection)
43
43
+
require.Equal(original.record.Rkey, payload.Rkey)
44
44
+
require.Equal(original.record.Action, payload.Action)
45
45
+
require.Equal(original.record.CID, payload.CID)
46
46
+
require.Equal(original.record.Live, payload.Live)
47
47
+
require.JSONEq(string(original.record.Record), string(payload.Record))
48
48
+
})
49
49
+
50
50
+
t.Run("marshal/unmarshal identity", func(t *testing.T) {
51
51
+
t.Parallel()
52
52
+
require := require.New(t)
53
53
+
54
54
+
original := Event{
55
55
+
ID: 456,
56
56
+
Type: eventTypeIdentity,
57
57
+
identity: &IdentityEvent{
58
58
+
DID: "did:plc:user",
59
59
+
Handle: "test.bsky.social",
60
60
+
IsActive: true,
61
61
+
Status: "active",
62
62
+
},
63
63
+
}
64
64
+
65
65
+
buf, err := json.Marshal(original)
66
66
+
require.NoError(err)
67
67
+
68
68
+
var decoded Event
69
69
+
require.NoError(json.Unmarshal(buf, &decoded))
70
70
+
require.Equal(original.ID, decoded.ID)
71
71
+
require.Equal(original.Type, decoded.Type)
72
72
+
73
73
+
payload, ok := decoded.Payload().(*IdentityEvent)
74
74
+
require.True(ok)
75
75
+
require.Equal(original.identity.DID, payload.DID)
76
76
+
require.Equal(original.identity.Handle, payload.Handle)
77
77
+
require.Equal(original.identity.IsActive, payload.IsActive)
78
78
+
require.Equal(original.identity.Status, payload.Status)
79
79
+
})
80
80
+
81
81
+
t.Run("unmarshal from raw json", func(t *testing.T) {
82
82
+
t.Parallel()
83
83
+
require := require.New(t)
84
84
+
85
85
+
recordJSON := `{"id":1,"type":"record","record":{"did":"did:plc:abc","collection":"app.bsky.feed.like","rkey":"xyz","action":"create","cid":"mycid","record":{"subject":"at://did:plc:foo/app.bsky.feed.post/xyz"},"live":false}}`
86
86
+
var recordEvent Event
87
87
+
require.NoError(json.Unmarshal([]byte(recordJSON), &recordEvent))
88
88
+
require.Equal(uint64(1), recordEvent.ID)
89
89
+
require.Equal(eventTypeRecord, recordEvent.Type)
90
90
+
91
91
+
identityJSON := `{"id":2,"type":"identity","identity":{"did":"did:plc:def","handle":"foo.test","isActive":true,"status":"active"}}`
92
92
+
var identEvent Event
93
93
+
require.NoError(json.Unmarshal([]byte(identityJSON), &identEvent))
94
94
+
require.Equal(uint64(2), identEvent.ID)
95
95
+
require.Equal(eventTypeIdentity, identEvent.Type)
96
96
+
})
97
97
+
98
98
+
t.Run("unmarshal unknown type", func(t *testing.T) {
99
99
+
t.Parallel()
100
100
+
require := require.New(t)
101
101
+
102
102
+
badJSON := `{"id":1,"type":"unknown"}`
103
103
+
var ev Event
104
104
+
err := json.Unmarshal([]byte(badJSON), &ev)
105
105
+
require.Error(err)
106
106
+
require.Contains(err.Error(), "unknown event type")
107
107
+
})
108
108
+
}
+320
tapclient/websocket.go
···
1
1
+
package tapclient
2
2
+
3
3
+
import (
4
4
+
"context"
5
5
+
"encoding/json"
6
6
+
"errors"
7
7
+
"fmt"
8
8
+
"io"
9
9
+
"log/slog"
10
10
+
"net/url"
11
11
+
"sync"
12
12
+
"time"
13
13
+
14
14
+
"github.com/gorilla/websocket"
15
15
+
)
16
16
+
17
17
+
var (
18
18
+
initialBackoff = 500 * time.Millisecond
19
19
+
)
20
20
+
21
21
+
// A thin error wrapper that indicates to the tap client consumer loop that a message
22
22
+
// should not be retried (i.e. invalid user input that will surely fail again on retry).
23
23
+
type NonRetryableError struct {
24
24
+
err error
25
25
+
}
26
26
+
27
27
+
func NewNonRetryableError(err error) *NonRetryableError {
28
28
+
return &NonRetryableError{err: err}
29
29
+
}
30
30
+
31
31
+
func (err *NonRetryableError) Error() string {
32
32
+
if err.err != nil {
33
33
+
return err.err.Error()
34
34
+
}
35
35
+
return ""
36
36
+
}
37
37
+
38
38
+
// Websocket implements a tap consumer that reads via a websocket
39
39
+
type Websocket struct {
40
40
+
log *slog.Logger
41
41
+
42
42
+
addr string
43
43
+
sendAcks bool
44
44
+
maxErrs int
45
45
+
46
46
+
connectTimeout time.Duration
47
47
+
readTimeout time.Duration
48
48
+
writeTimeout time.Duration
49
49
+
50
50
+
handler WebsocketHandlerFunc
51
51
+
}
52
52
+
53
53
+
// Defines an option for the tap websocket consumer
54
54
+
type WebsocketOption func(*Websocket)
55
55
+
56
56
+
// Defines the log/slog logger to use throughout the lifecycle of the websocket
57
57
+
// consumer. Pass nil to disable logging.
58
58
+
func WithLogger(logger *slog.Logger) func(*Websocket) {
59
59
+
return func(ws *Websocket) {
60
60
+
ws.log = logger
61
61
+
62
62
+
if ws.log == nil {
63
63
+
// write to io.Discard if a nil logger is passed
64
64
+
ws.log = slog.New(slog.NewTextHandler(io.Discard, nil))
65
65
+
}
66
66
+
}
67
67
+
}
68
68
+
69
69
+
// Sets the connect timeout for connecting to the websocket
70
70
+
func WithConnectTimeout(timeout time.Duration) func(*Websocket) {
71
71
+
return func(ws *Websocket) {
72
72
+
ws.connectTimeout = timeout
73
73
+
}
74
74
+
}
75
75
+
76
76
+
// Sets the read timeout for reading data from the websocket
77
77
+
func WithReadTimeout(timeout time.Duration) func(*Websocket) {
78
78
+
return func(ws *Websocket) {
79
79
+
ws.readTimeout = timeout
80
80
+
}
81
81
+
}
82
82
+
83
83
+
// Sets the write timeout for writing data to the websocket
84
84
+
func WithWriteTimeout(timeout time.Duration) func(*Websocket) {
85
85
+
return func(ws *Websocket) {
86
86
+
ws.writeTimeout = timeout
87
87
+
}
88
88
+
}
89
89
+
90
90
+
// Controls how many times the loop will attempt to reconnect to the websocket in a row before giving up
91
91
+
func WithMaxConsecutiveErrors(numErrs int) func(*Websocket) {
92
92
+
return func(ws *Websocket) {
93
93
+
ws.maxErrs = numErrs
94
94
+
}
95
95
+
}
96
96
+
97
97
+
// Turns on message acknowledgements
98
98
+
func WithAcks() func(*Websocket) {
99
99
+
return func(ws *Websocket) {
100
100
+
ws.sendAcks = true
101
101
+
}
102
102
+
}
103
103
+
104
104
+
// Defines an option for the tap websocket consumer. A nil error indicates that an ACK will be sent to tap
105
105
+
// if WithAcks() is provided.
106
106
+
type WebsocketHandlerFunc func(context.Context, *Event) error
107
107
+
108
108
+
// Initializes a tap websocket consumer
109
109
+
func NewWebsocket(addr string, handler WebsocketHandlerFunc, opts ...WebsocketOption) (*Websocket, error) {
110
110
+
u, err := url.Parse(addr)
111
111
+
if err != nil {
112
112
+
return nil, fmt.Errorf("failed to parse websocket url %q: %w", addr, err)
113
113
+
}
114
114
+
115
115
+
switch u.Scheme {
116
116
+
case "ws", "wss": // ok
117
117
+
default:
118
118
+
return nil, fmt.Errorf("invalid websocket protocol scheme: wanted ws:// or wss://, got %q", u.Scheme)
119
119
+
}
120
120
+
121
121
+
if handler == nil {
122
122
+
return nil, fmt.Errorf("a websocket message handler func is required")
123
123
+
}
124
124
+
125
125
+
ws := &Websocket{
126
126
+
log: slog.Default().WithGroup("tap"),
127
127
+
128
128
+
addr: addr,
129
129
+
sendAcks: false,
130
130
+
maxErrs: 10,
131
131
+
132
132
+
connectTimeout: 30 * time.Second,
133
133
+
readTimeout: 30 * time.Second,
134
134
+
writeTimeout: 30 * time.Second,
135
135
+
136
136
+
handler: handler,
137
137
+
}
138
138
+
139
139
+
for _, opt := range opts {
140
140
+
opt(ws)
141
141
+
}
142
142
+
143
143
+
return ws, nil
144
144
+
}
145
145
+
146
146
+
// Connects to and beings the main tap websocket consumer loop
147
147
+
func (ws *Websocket) Run(ctx context.Context) error {
148
148
+
for errCount := 0; ; {
149
149
+
select {
150
150
+
case <-ctx.Done():
151
151
+
ws.log.Debug("websocket ingester shutting down")
152
152
+
return nil
153
153
+
default:
154
154
+
}
155
155
+
156
156
+
err := ws.runOnce(ctx)
157
157
+
if errors.Is(err, context.Canceled) {
158
158
+
ws.log.Debug("websocket ingester shutting down")
159
159
+
return nil
160
160
+
}
161
161
+
162
162
+
if err == nil {
163
163
+
errCount = 0
164
164
+
ws.log.Debug("websocket connection closed normally, reconnecting")
165
165
+
continue
166
166
+
}
167
167
+
168
168
+
errCount++
169
169
+
ws.log.Error("websocket connection failed", "err", err, "consecutive_errors", errCount)
170
170
+
171
171
+
if errCount >= ws.maxErrs {
172
172
+
return fmt.Errorf("websocket connection failed %d consecutive times: %w", errCount, err)
173
173
+
}
174
174
+
175
175
+
ws.log.Warn("retrying websocket connection", "consecutive_errors", errCount)
176
176
+
if sleepMaybeExit(ctx, errCount) {
177
177
+
return nil
178
178
+
}
179
179
+
}
180
180
+
}
181
181
+
182
182
+
func (ws *Websocket) close(conn *websocket.Conn) {
183
183
+
if err := conn.Close(); err != nil {
184
184
+
ws.log.Error("failed to close websocket connection", "err", err)
185
185
+
}
186
186
+
}
187
187
+
188
188
+
func (ws *Websocket) runOnce(ctx context.Context) error {
189
189
+
dialer := websocket.Dialer{HandshakeTimeout: ws.connectTimeout}
190
190
+
conn, _, err := dialer.DialContext(ctx, ws.addr, nil)
191
191
+
if err != nil {
192
192
+
return fmt.Errorf("failed to connect to websocket at %q: %w", ws.addr, err)
193
193
+
}
194
194
+
195
195
+
var closeOnce sync.Once
196
196
+
closeConn := func() {
197
197
+
closeOnce.Do(func() {
198
198
+
if err := conn.Close(); err != nil {
199
199
+
ws.log.Error("failed to close websocket connection", "err", err)
200
200
+
}
201
201
+
})
202
202
+
}
203
203
+
defer closeConn()
204
204
+
205
205
+
ws.log.Debug("connected to websocket", "addr", ws.addr)
206
206
+
207
207
+
go func() {
208
208
+
<-ctx.Done()
209
209
+
closeConn()
210
210
+
}()
211
211
+
212
212
+
for {
213
213
+
if done(ctx) {
214
214
+
return nil
215
215
+
}
216
216
+
217
217
+
if err := conn.SetReadDeadline(time.Now().Add(ws.readTimeout)); err != nil {
218
218
+
return fmt.Errorf("failed to set websocket read deadline: %w", err)
219
219
+
}
220
220
+
221
221
+
_, buf, err := conn.ReadMessage()
222
222
+
if err != nil {
223
223
+
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
224
224
+
return nil // normal remote closure
225
225
+
}
226
226
+
227
227
+
if ctx.Err() != nil {
228
228
+
return ctx.Err()
229
229
+
}
230
230
+
231
231
+
return fmt.Errorf("failed to read websocket message: %w", err)
232
232
+
}
233
233
+
234
234
+
var ev Event
235
235
+
if err := json.Unmarshal(buf, &ev); err != nil {
236
236
+
ws.log.Warn("failed to unmarshal event json", "err", err)
237
237
+
continue
238
238
+
}
239
239
+
240
240
+
// indefinitely retry messages that failed to process unless a non-retryable error occurrs
241
241
+
for errCount := 0; ; errCount++ {
242
242
+
if done(ctx) {
243
243
+
break
244
244
+
}
245
245
+
246
246
+
err := ws.handler(ctx, &ev)
247
247
+
if err == nil {
248
248
+
break
249
249
+
}
250
250
+
251
251
+
ws.log.Error("failed to process event", "err", err)
252
252
+
if sleepMaybeExit(ctx, errCount) {
253
253
+
return nil
254
254
+
}
255
255
+
256
256
+
var nr *NonRetryableError
257
257
+
if errors.As(err, &nr) {
258
258
+
ws.log.Error("handled non-retryable error", "id", ev.ID, "err", err)
259
259
+
break
260
260
+
}
261
261
+
}
262
262
+
263
263
+
if ws.sendAcks {
264
264
+
ws.ack(ctx, conn, &ev)
265
265
+
}
266
266
+
}
267
267
+
}
268
268
+
269
269
+
// Indefinitely tries acking the message with the tap server
270
270
+
func (ws *Websocket) ack(ctx context.Context, conn *websocket.Conn, ev *Event) {
271
271
+
for errCount := 0; ; errCount++ {
272
272
+
if done(ctx) {
273
273
+
return
274
274
+
}
275
275
+
276
276
+
if err := conn.SetWriteDeadline(time.Now().Add(ws.writeTimeout)); err != nil {
277
277
+
ws.log.Warn("failed to set write deadline on ack", "err", err)
278
278
+
}
279
279
+
280
280
+
err := conn.WriteJSON(NewACKPayload(ev.ID))
281
281
+
if err == nil {
282
282
+
return
283
283
+
}
284
284
+
285
285
+
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
286
286
+
return // normal remote closure
287
287
+
}
288
288
+
289
289
+
ws.log.Error("failed to send ack", "err", err)
290
290
+
291
291
+
if sleepMaybeExit(ctx, errCount) {
292
292
+
return
293
293
+
}
294
294
+
}
295
295
+
}
296
296
+
297
297
+
func done(ctx context.Context) bool {
298
298
+
select {
299
299
+
case <-ctx.Done():
300
300
+
return true
301
301
+
default:
302
302
+
return false
303
303
+
}
304
304
+
}
305
305
+
306
306
+
func sleepMaybeExit(ctx context.Context, errCount int) bool {
307
307
+
select {
308
308
+
case <-ctx.Done():
309
309
+
return true // shutdown received during a backoff sleep means that we're done
310
310
+
case <-time.After(backoffDuration(errCount)):
311
311
+
return false
312
312
+
}
313
313
+
}
314
314
+
315
315
+
func backoffDuration(errCount int) time.Duration {
316
316
+
multiplier := 1 << errCount
317
317
+
waitFor := initialBackoff * time.Duration(multiplier)
318
318
+
319
319
+
return min(waitFor, 10*time.Second)
320
320
+
}
+278
tapclient/websocket_test.go
···
1
1
+
package tapclient
2
2
+
3
3
+
import (
4
4
+
"context"
5
5
+
"encoding/json"
6
6
+
"errors"
7
7
+
"net/http"
8
8
+
"net/http/httptest"
9
9
+
"strings"
10
10
+
"sync"
11
11
+
"testing"
12
12
+
"time"
13
13
+
14
14
+
"github.com/gorilla/websocket"
15
15
+
"github.com/stretchr/testify/require"
16
16
+
)
17
17
+
18
18
+
func init() {
19
19
+
initialBackoff = 0
20
20
+
}
21
21
+
22
22
+
var upgrader = websocket.Upgrader{}
23
23
+
24
24
+
func TestWebsocket(t *testing.T) {
25
25
+
t.Parallel()
26
26
+
ctx := t.Context()
27
27
+
require := require.New(t)
28
28
+
29
29
+
events := []Event{
30
30
+
{ID: 1, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:1", Collection: "app.bsky.feed.post"}},
31
31
+
{ID: 2, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:2", Collection: "app.bsky.feed.like"}},
32
32
+
{ID: 3, Type: eventTypeIdentity, identity: &IdentityEvent{DID: "did:plc:3", Handle: "user3.test"}},
33
33
+
}
34
34
+
35
35
+
var received []*Event
36
36
+
var mu sync.Mutex
37
37
+
var wg sync.WaitGroup
38
38
+
wg.Add(len(events))
39
39
+
40
40
+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
41
41
+
conn, err := upgrader.Upgrade(w, r, nil)
42
42
+
if err != nil {
43
43
+
return
44
44
+
}
45
45
+
defer conn.Close()
46
46
+
47
47
+
for _, ev := range events {
48
48
+
buf, _ := json.Marshal(ev)
49
49
+
conn.WriteMessage(websocket.TextMessage, buf)
50
50
+
time.Sleep(10 * time.Millisecond)
51
51
+
}
52
52
+
53
53
+
time.Sleep(50 * time.Millisecond)
54
54
+
conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
55
55
+
}))
56
56
+
defer server.Close()
57
57
+
58
58
+
wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://")
59
59
+
60
60
+
ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error {
61
61
+
mu.Lock()
62
62
+
received = append(received, ev)
63
63
+
mu.Unlock()
64
64
+
wg.Done()
65
65
+
return nil
66
66
+
}, WithLogger(nil))
67
67
+
require.NoError(err)
68
68
+
69
69
+
go ws.Run(ctx)
70
70
+
wg.Wait()
71
71
+
72
72
+
require.Len(received, 3)
73
73
+
for i, ev := range received {
74
74
+
require.Equal(uint64(i+1), ev.ID)
75
75
+
76
76
+
switch i {
77
77
+
case 0, 1:
78
78
+
switch pl := ev.Payload().(type) {
79
79
+
case *RecordEvent:
80
80
+
require.NotNil(events[i].record)
81
81
+
require.Equal(events[i].record.Collection, pl.Collection)
82
82
+
require.Equal(events[i].Type, eventTypeRecord)
83
83
+
default:
84
84
+
require.FailNow("incorrect payload type, want %T got %T", &RecordEvent{}, ev.Payload())
85
85
+
}
86
86
+
87
87
+
case 2:
88
88
+
switch pl := ev.Payload().(type) {
89
89
+
case *IdentityEvent:
90
90
+
require.NotNil(events[i].identity)
91
91
+
require.Equal(events[i].identity.Handle, pl.Handle)
92
92
+
require.Equal(events[i].Type, eventTypeIdentity)
93
93
+
default:
94
94
+
require.FailNow("incorrect payload type, want %T got %T", &IdentityEvent{}, ev.Payload())
95
95
+
}
96
96
+
}
97
97
+
}
98
98
+
}
99
99
+
100
100
+
func TestWebsocketWithAcks(t *testing.T) {
101
101
+
t.Parallel()
102
102
+
103
103
+
t.Run("ack sent on success", func(t *testing.T) {
104
104
+
t.Parallel()
105
105
+
ctx := t.Context()
106
106
+
require := require.New(t)
107
107
+
108
108
+
recordEvent := Event{
109
109
+
ID: 42,
110
110
+
Type: eventTypeRecord,
111
111
+
record: &RecordEvent{
112
112
+
DID: "did:plc:ack",
113
113
+
Collection: "app.bsky.feed.like",
114
114
+
Rkey: "ack",
115
115
+
Action: "create",
116
116
+
},
117
117
+
}
118
118
+
119
119
+
var receivedAck *Event
120
120
+
var wg sync.WaitGroup
121
121
+
wg.Add(1)
122
122
+
123
123
+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
124
124
+
conn, err := upgrader.Upgrade(w, r, nil)
125
125
+
if err != nil {
126
126
+
return
127
127
+
}
128
128
+
defer conn.Close()
129
129
+
130
130
+
buf, _ := json.Marshal(recordEvent)
131
131
+
conn.WriteMessage(websocket.TextMessage, buf)
132
132
+
133
133
+
_, ackBuf, err := conn.ReadMessage()
134
134
+
if err == nil {
135
135
+
receivedAck = &Event{}
136
136
+
json.Unmarshal(ackBuf, receivedAck)
137
137
+
}
138
138
+
wg.Done()
139
139
+
140
140
+
conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
141
141
+
}))
142
142
+
defer server.Close()
143
143
+
144
144
+
wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://")
145
145
+
146
146
+
ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error {
147
147
+
return nil
148
148
+
}, WithLogger(nil), WithAcks())
149
149
+
require.NoError(err)
150
150
+
151
151
+
go ws.Run(ctx)
152
152
+
wg.Wait()
153
153
+
154
154
+
require.NotNil(receivedAck)
155
155
+
require.Equal(eventTypeACK, receivedAck.Type)
156
156
+
require.Equal(recordEvent.ID, receivedAck.ID)
157
157
+
})
158
158
+
159
159
+
t.Run("ack not sent on error", func(t *testing.T) {
160
160
+
t.Parallel()
161
161
+
ctx := t.Context()
162
162
+
require := require.New(t)
163
163
+
164
164
+
recordEvent := Event{
165
165
+
ID: 99,
166
166
+
Type: eventTypeRecord,
167
167
+
record: &RecordEvent{
168
168
+
DID: "did:plc:noack",
169
169
+
Collection: "app.bsky.feed.post",
170
170
+
Rkey: "noack",
171
171
+
Action: "create",
172
172
+
},
173
173
+
}
174
174
+
175
175
+
var receivedAck bool
176
176
+
var wg sync.WaitGroup
177
177
+
wg.Add(1)
178
178
+
179
179
+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
180
180
+
conn, err := upgrader.Upgrade(w, r, nil)
181
181
+
if err != nil {
182
182
+
return
183
183
+
}
184
184
+
defer conn.Close()
185
185
+
186
186
+
buf, _ := json.Marshal(recordEvent)
187
187
+
conn.WriteMessage(websocket.TextMessage, buf)
188
188
+
189
189
+
conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
190
190
+
_, _, err = conn.ReadMessage()
191
191
+
receivedAck = err == nil
192
192
+
wg.Done()
193
193
+
194
194
+
conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
195
195
+
}))
196
196
+
defer server.Close()
197
197
+
198
198
+
wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://")
199
199
+
200
200
+
ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error {
201
201
+
return errors.New("processing failed")
202
202
+
}, WithLogger(nil), WithAcks())
203
203
+
require.NoError(err)
204
204
+
205
205
+
go ws.Run(ctx)
206
206
+
wg.Wait()
207
207
+
208
208
+
require.False(receivedAck, "expected no ACK when handler returns error")
209
209
+
})
210
210
+
}
211
211
+
212
212
+
func TestWebsocketNonRetryableError(t *testing.T) {
213
213
+
t.Parallel()
214
214
+
ctx := t.Context()
215
215
+
require := require.New(t)
216
216
+
217
217
+
events := []Event{
218
218
+
{ID: 1, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:1", Collection: "app.bsky.feed.post"}},
219
219
+
{ID: 2, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:2", Collection: "app.bsky.feed.post"}},
220
220
+
{ID: 3, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:3", Collection: "app.bsky.feed.post"}},
221
221
+
}
222
222
+
223
223
+
var callCounts sync.Map
224
224
+
var wg sync.WaitGroup
225
225
+
wg.Add(len(events))
226
226
+
227
227
+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
228
228
+
conn, err := upgrader.Upgrade(w, r, nil)
229
229
+
if err != nil {
230
230
+
return
231
231
+
}
232
232
+
defer conn.Close()
233
233
+
234
234
+
for _, ev := range events {
235
235
+
buf, _ := json.Marshal(ev)
236
236
+
conn.WriteMessage(websocket.TextMessage, buf)
237
237
+
time.Sleep(10 * time.Millisecond)
238
238
+
}
239
239
+
240
240
+
time.Sleep(100 * time.Millisecond)
241
241
+
conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
242
242
+
}))
243
243
+
defer server.Close()
244
244
+
245
245
+
wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://")
246
246
+
247
247
+
ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error {
248
248
+
val, _ := callCounts.LoadOrStore(ev.ID, new(int))
249
249
+
count := val.(*int)
250
250
+
*count++
251
251
+
252
252
+
if ev.ID == 2 {
253
253
+
if *count == 1 {
254
254
+
wg.Done()
255
255
+
}
256
256
+
return NewNonRetryableError(errors.New("bad input, do not retry"))
257
257
+
}
258
258
+
259
259
+
wg.Done()
260
260
+
return nil
261
261
+
}, WithLogger(nil))
262
262
+
require.NoError(err)
263
263
+
264
264
+
go ws.Run(ctx)
265
265
+
wg.Wait()
266
266
+
267
267
+
// event 1: should be called exactly once (success)
268
268
+
val1, _ := callCounts.Load(uint64(1))
269
269
+
require.Equal(1, *val1.(*int), "event 1 should be processed once")
270
270
+
271
271
+
// event 2: should be called exactly once (non-retryable error, no retry)
272
272
+
val2, _ := callCounts.Load(uint64(2))
273
273
+
require.Equal(1, *val2.(*int), "event 2 with NonRetryableError should not be retried")
274
274
+
275
275
+
// event 3: should be called exactly once (success, proving we moved on after non-retryable)
276
276
+
val3, _ := callCounts.Load(uint64(3))
277
277
+
require.Equal(1, *val3.(*int), "event 3 should be processed after non-retryable error")
278
278
+
}