this repo has no description
1package tapclient
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "net/http"
8 "net/http/httptest"
9 "strings"
10 "sync"
11 "testing"
12 "time"
13
14 "github.com/gorilla/websocket"
15 "github.com/stretchr/testify/require"
16)
17
18func init() {
19 initialBackoff = 0
20}
21
22var upgrader = websocket.Upgrader{}
23
24func TestWebsocket(t *testing.T) {
25 t.Parallel()
26 ctx := t.Context()
27 require := require.New(t)
28
29 events := []Event{
30 {ID: 1, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:1", Collection: "app.bsky.feed.post"}},
31 {ID: 2, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:2", Collection: "app.bsky.feed.like"}},
32 {ID: 3, Type: eventTypeIdentity, identity: &IdentityEvent{DID: "did:plc:3", Handle: "user3.test"}},
33 }
34
35 var received []*Event
36 var mu sync.Mutex
37 var wg sync.WaitGroup
38 wg.Add(len(events))
39
40 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
41 conn, err := upgrader.Upgrade(w, r, nil)
42 if err != nil {
43 return
44 }
45 defer conn.Close()
46
47 for _, ev := range events {
48 buf, _ := json.Marshal(ev)
49 conn.WriteMessage(websocket.TextMessage, buf)
50 time.Sleep(10 * time.Millisecond)
51 }
52
53 time.Sleep(50 * time.Millisecond)
54 conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
55 }))
56 defer server.Close()
57
58 wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://")
59
60 ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error {
61 mu.Lock()
62 received = append(received, ev)
63 mu.Unlock()
64 wg.Done()
65 return nil
66 }, WithLogger(nil))
67 require.NoError(err)
68
69 go ws.Run(ctx)
70 wg.Wait()
71
72 require.Len(received, 3)
73 for i, ev := range received {
74 require.Equal(uint64(i+1), ev.ID)
75
76 switch i {
77 case 0, 1:
78 switch pl := ev.Payload().(type) {
79 case *RecordEvent:
80 require.NotNil(events[i].record)
81 require.Equal(events[i].record.Collection, pl.Collection)
82 require.Equal(events[i].Type, eventTypeRecord)
83 default:
84 require.FailNow("incorrect payload type, want %T got %T", &RecordEvent{}, ev.Payload())
85 }
86
87 case 2:
88 switch pl := ev.Payload().(type) {
89 case *IdentityEvent:
90 require.NotNil(events[i].identity)
91 require.Equal(events[i].identity.Handle, pl.Handle)
92 require.Equal(events[i].Type, eventTypeIdentity)
93 default:
94 require.FailNow("incorrect payload type, want %T got %T", &IdentityEvent{}, ev.Payload())
95 }
96 }
97 }
98}
99
100func TestWebsocketWithAcks(t *testing.T) {
101 t.Parallel()
102
103 t.Run("ack sent on success", func(t *testing.T) {
104 t.Parallel()
105 ctx := t.Context()
106 require := require.New(t)
107
108 recordEvent := Event{
109 ID: 42,
110 Type: eventTypeRecord,
111 record: &RecordEvent{
112 DID: "did:plc:ack",
113 Collection: "app.bsky.feed.like",
114 Rkey: "ack",
115 Action: "create",
116 },
117 }
118
119 var receivedAck *Event
120 var wg sync.WaitGroup
121 wg.Add(1)
122
123 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
124 conn, err := upgrader.Upgrade(w, r, nil)
125 if err != nil {
126 return
127 }
128 defer conn.Close()
129
130 buf, _ := json.Marshal(recordEvent)
131 conn.WriteMessage(websocket.TextMessage, buf)
132
133 _, ackBuf, err := conn.ReadMessage()
134 if err == nil {
135 receivedAck = &Event{}
136 json.Unmarshal(ackBuf, receivedAck)
137 }
138 wg.Done()
139
140 conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
141 }))
142 defer server.Close()
143
144 wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://")
145
146 ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error {
147 return nil
148 }, WithLogger(nil), WithAcks())
149 require.NoError(err)
150
151 go ws.Run(ctx)
152 wg.Wait()
153
154 require.NotNil(receivedAck)
155 require.Equal(eventTypeACK, receivedAck.Type)
156 require.Equal(recordEvent.ID, receivedAck.ID)
157 })
158
159 t.Run("ack not sent on error", func(t *testing.T) {
160 t.Parallel()
161 ctx := t.Context()
162 require := require.New(t)
163
164 recordEvent := Event{
165 ID: 99,
166 Type: eventTypeRecord,
167 record: &RecordEvent{
168 DID: "did:plc:noack",
169 Collection: "app.bsky.feed.post",
170 Rkey: "noack",
171 Action: "create",
172 },
173 }
174
175 var receivedAck bool
176 var wg sync.WaitGroup
177 wg.Add(1)
178
179 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
180 conn, err := upgrader.Upgrade(w, r, nil)
181 if err != nil {
182 return
183 }
184 defer conn.Close()
185
186 buf, _ := json.Marshal(recordEvent)
187 conn.WriteMessage(websocket.TextMessage, buf)
188
189 conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
190 _, _, err = conn.ReadMessage()
191 receivedAck = err == nil
192 wg.Done()
193
194 conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
195 }))
196 defer server.Close()
197
198 wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://")
199
200 ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error {
201 return errors.New("processing failed")
202 }, WithLogger(nil), WithAcks())
203 require.NoError(err)
204
205 go ws.Run(ctx)
206 wg.Wait()
207
208 require.False(receivedAck, "expected no ACK when handler returns error")
209 })
210}
211
212func TestWebsocketNonRetryableError(t *testing.T) {
213 t.Parallel()
214 ctx := t.Context()
215 require := require.New(t)
216
217 events := []Event{
218 {ID: 1, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:1", Collection: "app.bsky.feed.post"}},
219 {ID: 2, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:2", Collection: "app.bsky.feed.post"}},
220 {ID: 3, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:3", Collection: "app.bsky.feed.post"}},
221 }
222
223 var callCounts sync.Map
224 var wg sync.WaitGroup
225 wg.Add(len(events))
226
227 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
228 conn, err := upgrader.Upgrade(w, r, nil)
229 if err != nil {
230 return
231 }
232 defer conn.Close()
233
234 for _, ev := range events {
235 buf, _ := json.Marshal(ev)
236 conn.WriteMessage(websocket.TextMessage, buf)
237 time.Sleep(10 * time.Millisecond)
238 }
239
240 time.Sleep(100 * time.Millisecond)
241 conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
242 }))
243 defer server.Close()
244
245 wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://")
246
247 ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error {
248 val, _ := callCounts.LoadOrStore(ev.ID, new(int))
249 count := val.(*int)
250 *count++
251
252 if ev.ID == 2 {
253 if *count == 1 {
254 wg.Done()
255 }
256 return NewNonRetryableError(errors.New("bad input, do not retry"))
257 }
258
259 wg.Done()
260 return nil
261 }, WithLogger(nil))
262 require.NoError(err)
263
264 go ws.Run(ctx)
265 wg.Wait()
266
267 // event 1: should be called exactly once (success)
268 val1, _ := callCounts.Load(uint64(1))
269 require.Equal(1, *val1.(*int), "event 1 should be processed once")
270
271 // event 2: should be called exactly once (non-retryable error, no retry)
272 val2, _ := callCounts.Load(uint64(2))
273 require.Equal(1, *val2.(*int), "event 2 with NonRetryableError should not be retried")
274
275 // event 3: should be called exactly once (success, proving we moved on after non-retryable)
276 val3, _ := callCounts.Load(uint64(3))
277 require.Equal(1, *val3.(*int), "event 3 should be processed after non-retryable error")
278}