this repo has no description
at main 278 lines 7.2 kB view raw
1package tapclient 2 3import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "net/http" 8 "net/http/httptest" 9 "strings" 10 "sync" 11 "testing" 12 "time" 13 14 "github.com/gorilla/websocket" 15 "github.com/stretchr/testify/require" 16) 17 18func init() { 19 initialBackoff = 0 20} 21 22var upgrader = websocket.Upgrader{} 23 24func TestWebsocket(t *testing.T) { 25 t.Parallel() 26 ctx := t.Context() 27 require := require.New(t) 28 29 events := []Event{ 30 {ID: 1, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:1", Collection: "app.bsky.feed.post"}}, 31 {ID: 2, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:2", Collection: "app.bsky.feed.like"}}, 32 {ID: 3, Type: eventTypeIdentity, identity: &IdentityEvent{DID: "did:plc:3", Handle: "user3.test"}}, 33 } 34 35 var received []*Event 36 var mu sync.Mutex 37 var wg sync.WaitGroup 38 wg.Add(len(events)) 39 40 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 41 conn, err := upgrader.Upgrade(w, r, nil) 42 if err != nil { 43 return 44 } 45 defer conn.Close() 46 47 for _, ev := range events { 48 buf, _ := json.Marshal(ev) 49 conn.WriteMessage(websocket.TextMessage, buf) 50 time.Sleep(10 * time.Millisecond) 51 } 52 53 time.Sleep(50 * time.Millisecond) 54 conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 55 })) 56 defer server.Close() 57 58 wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://") 59 60 ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error { 61 mu.Lock() 62 received = append(received, ev) 63 mu.Unlock() 64 wg.Done() 65 return nil 66 }, WithLogger(nil)) 67 require.NoError(err) 68 69 go ws.Run(ctx) 70 wg.Wait() 71 72 require.Len(received, 3) 73 for i, ev := range received { 74 require.Equal(uint64(i+1), ev.ID) 75 76 switch i { 77 case 0, 1: 78 switch pl := ev.Payload().(type) { 79 case *RecordEvent: 80 require.NotNil(events[i].record) 81 require.Equal(events[i].record.Collection, pl.Collection) 82 require.Equal(events[i].Type, eventTypeRecord) 83 default: 84 require.FailNow("incorrect payload type, want %T got %T", &RecordEvent{}, ev.Payload()) 85 } 86 87 case 2: 88 switch pl := ev.Payload().(type) { 89 case *IdentityEvent: 90 require.NotNil(events[i].identity) 91 require.Equal(events[i].identity.Handle, pl.Handle) 92 require.Equal(events[i].Type, eventTypeIdentity) 93 default: 94 require.FailNow("incorrect payload type, want %T got %T", &IdentityEvent{}, ev.Payload()) 95 } 96 } 97 } 98} 99 100func TestWebsocketWithAcks(t *testing.T) { 101 t.Parallel() 102 103 t.Run("ack sent on success", func(t *testing.T) { 104 t.Parallel() 105 ctx := t.Context() 106 require := require.New(t) 107 108 recordEvent := Event{ 109 ID: 42, 110 Type: eventTypeRecord, 111 record: &RecordEvent{ 112 DID: "did:plc:ack", 113 Collection: "app.bsky.feed.like", 114 Rkey: "ack", 115 Action: "create", 116 }, 117 } 118 119 var receivedAck *Event 120 var wg sync.WaitGroup 121 wg.Add(1) 122 123 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 124 conn, err := upgrader.Upgrade(w, r, nil) 125 if err != nil { 126 return 127 } 128 defer conn.Close() 129 130 buf, _ := json.Marshal(recordEvent) 131 conn.WriteMessage(websocket.TextMessage, buf) 132 133 _, ackBuf, err := conn.ReadMessage() 134 if err == nil { 135 receivedAck = &Event{} 136 json.Unmarshal(ackBuf, receivedAck) 137 } 138 wg.Done() 139 140 conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 141 })) 142 defer server.Close() 143 144 wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://") 145 146 ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error { 147 return nil 148 }, WithLogger(nil), WithAcks()) 149 require.NoError(err) 150 151 go ws.Run(ctx) 152 wg.Wait() 153 154 require.NotNil(receivedAck) 155 require.Equal(eventTypeACK, receivedAck.Type) 156 require.Equal(recordEvent.ID, receivedAck.ID) 157 }) 158 159 t.Run("ack not sent on error", func(t *testing.T) { 160 t.Parallel() 161 ctx := t.Context() 162 require := require.New(t) 163 164 recordEvent := Event{ 165 ID: 99, 166 Type: eventTypeRecord, 167 record: &RecordEvent{ 168 DID: "did:plc:noack", 169 Collection: "app.bsky.feed.post", 170 Rkey: "noack", 171 Action: "create", 172 }, 173 } 174 175 var receivedAck bool 176 var wg sync.WaitGroup 177 wg.Add(1) 178 179 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 180 conn, err := upgrader.Upgrade(w, r, nil) 181 if err != nil { 182 return 183 } 184 defer conn.Close() 185 186 buf, _ := json.Marshal(recordEvent) 187 conn.WriteMessage(websocket.TextMessage, buf) 188 189 conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) 190 _, _, err = conn.ReadMessage() 191 receivedAck = err == nil 192 wg.Done() 193 194 conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 195 })) 196 defer server.Close() 197 198 wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://") 199 200 ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error { 201 return errors.New("processing failed") 202 }, WithLogger(nil), WithAcks()) 203 require.NoError(err) 204 205 go ws.Run(ctx) 206 wg.Wait() 207 208 require.False(receivedAck, "expected no ACK when handler returns error") 209 }) 210} 211 212func TestWebsocketNonRetryableError(t *testing.T) { 213 t.Parallel() 214 ctx := t.Context() 215 require := require.New(t) 216 217 events := []Event{ 218 {ID: 1, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:1", Collection: "app.bsky.feed.post"}}, 219 {ID: 2, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:2", Collection: "app.bsky.feed.post"}}, 220 {ID: 3, Type: eventTypeRecord, record: &RecordEvent{DID: "did:plc:3", Collection: "app.bsky.feed.post"}}, 221 } 222 223 var callCounts sync.Map 224 var wg sync.WaitGroup 225 wg.Add(len(events)) 226 227 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 228 conn, err := upgrader.Upgrade(w, r, nil) 229 if err != nil { 230 return 231 } 232 defer conn.Close() 233 234 for _, ev := range events { 235 buf, _ := json.Marshal(ev) 236 conn.WriteMessage(websocket.TextMessage, buf) 237 time.Sleep(10 * time.Millisecond) 238 } 239 240 time.Sleep(100 * time.Millisecond) 241 conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 242 })) 243 defer server.Close() 244 245 wsURL := "ws://" + strings.TrimPrefix(server.URL, "http://") 246 247 ws, err := NewWebsocket(wsURL, func(ctx context.Context, ev *Event) error { 248 val, _ := callCounts.LoadOrStore(ev.ID, new(int)) 249 count := val.(*int) 250 *count++ 251 252 if ev.ID == 2 { 253 if *count == 1 { 254 wg.Done() 255 } 256 return NewNonRetryableError(errors.New("bad input, do not retry")) 257 } 258 259 wg.Done() 260 return nil 261 }, WithLogger(nil)) 262 require.NoError(err) 263 264 go ws.Run(ctx) 265 wg.Wait() 266 267 // event 1: should be called exactly once (success) 268 val1, _ := callCounts.Load(uint64(1)) 269 require.Equal(1, *val1.(*int), "event 1 should be processed once") 270 271 // event 2: should be called exactly once (non-retryable error, no retry) 272 val2, _ := callCounts.Load(uint64(2)) 273 require.Equal(1, *val2.(*int), "event 2 with NonRetryableError should not be retried") 274 275 // event 3: should be called exactly once (success, proving we moved on after non-retryable) 276 val3, _ := callCounts.Load(uint64(3)) 277 require.Equal(1, *val3.(*int), "event 3 should be processed after non-retryable error") 278}