auto-reconnecting jetstream proxy
1package main
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "net/http"
9 "os"
10 "strings"
11 "sync"
12 "sync/atomic"
13 "time"
14
15 "github.com/gorilla/websocket"
16)
17
18var DEFAULT_POOL = []string{
19 "wss://jetstream1.us-east.bsky.network",
20 "wss://jetstream2.us-east.bsky.network",
21 "wss://jetstream1.us-west.bsky.network",
22 "wss://jetstream2.us-west.bsky.network",
23 "wss://jetstream.fire.hose.cam",
24 "wss://jetstream2.fr.hose.cam",
25 // want yours here? contact me
26}
27
28// Event represents a Jetstream event
29type Event struct {
30 Did string `json:"did"`
31 TimeUS int64 `json:"time_us"`
32 Kind string `json:"kind,omitempty"`
33 Commit *Commit `json:"commit,omitempty"`
34}
35
36// Commit represents a commit event
37type Commit struct {
38 Rev string `json:"rev,omitempty"`
39 Operation string `json:"operation,omitempty"`
40 Collection string `json:"collection,omitempty"`
41 RKey string `json:"rkey,omitempty"`
42 Record json.RawMessage `json:"record,omitempty"`
43 CID string `json:"cid,omitempty"`
44}
45
46// Message wraps a Jetstream event with both parsed and raw forms
47type Message struct {
48 Event *Event
49 Raw []byte
50}
51
52// Broadcaster manages subscribers to Jetstream events
53type Broadcaster struct {
54 listeners []chan *Message
55 mu sync.Mutex
56 connected atomic.Bool
57 lastMessageTime atomic.Int64 // Unix timestamp in seconds
58}
59
60// Subscribe returns a new channel that will receive Jetstream events
61func (b *Broadcaster) Subscribe() chan *Message {
62 b.mu.Lock()
63 defer b.mu.Unlock()
64
65 // firehose can be more-than-1k events per second,
66 // prefer to create a large buffer for the subscribers
67 ch := make(chan *Message, 10000)
68 b.listeners = append(b.listeners, ch)
69 return ch
70}
71
72func (b *Broadcaster) Unsubscribe(ch chan *Message) {
73 b.mu.Lock()
74 defer b.mu.Unlock()
75
76 for i, listener := range b.listeners {
77 if listener == ch {
78 b.listeners = append(b.listeners[:i], b.listeners[i+1:]...)
79 close(ch)
80 break
81 }
82 }
83}
84
85func (b *Broadcaster) Broadcast(rawMessage []byte) {
86 b.lastMessageTime.Store(time.Now().Unix())
87
88 // Parse the event once
89 var event Event
90 if err := json.Unmarshal(rawMessage, &event); err != nil {
91 slog.Debug("Failed to parse event", slog.Any("error", err))
92 // Broadcast anyway with nil event
93 }
94
95 msg := &Message{
96 Event: &event,
97 Raw: rawMessage,
98 }
99
100 b.mu.Lock()
101 defer b.mu.Unlock()
102
103 for _, ch := range b.listeners {
104 select {
105 case ch <- msg:
106 // event sent successfully. we don't want to block
107 default:
108 // channel full, skip to avoid blocking
109 slog.Warn("jetstream broadcast: channel full, dropping event")
110 }
111 }
112}
113
114type latencyResult struct {
115 url string
116 latency time.Duration
117 err error
118}
119
120func measureLatency(url string) (time.Duration, error) {
121 url = strings.Replace(url, "wss://", "https://", 1)
122 // also support non-tls upstreams
123 url = strings.Replace(url, "ws://", "http://", 1)
124
125 client := &http.Client{
126 Timeout: 20 * time.Second,
127 }
128
129 start := time.Now()
130 // jetstream instances return the "Welcome to jetstream!" banner on / which
131 // should be useful enough for latency
132 resp, err := client.Get(url)
133 if err != nil {
134 return 0, err
135 }
136 defer resp.Body.Close()
137
138 return time.Since(start), nil
139}
140
141var upgrader = websocket.Upgrader{
142 CheckOrigin: func(r *http.Request) bool {
143 return true // Allow all origins
144 },
145}
146
147// handleHealth returns 200 if connected to upstream
148func handleHealth(broadcaster *Broadcaster) http.HandlerFunc {
149 return func(w http.ResponseWriter, r *http.Request) {
150 if broadcaster.connected.Load() {
151 w.WriteHeader(http.StatusOK)
152 w.Write([]byte("Welcome to jetstream!"))
153 } else {
154 w.WriteHeader(http.StatusServiceUnavailable)
155 w.Write([]byte("Not connected to upstream"))
156 }
157 }
158}
159
160// matchesCollection checks if an event matches any of the wanted collections
161func matchesCollection(event *Event, wantedCollections []string) bool {
162 // Always pass through account and identity events
163 if event.Kind == "account" || event.Kind == "identity" {
164 return true
165 }
166
167 // If no wanted collections specified, pass everything
168 if len(wantedCollections) == 0 {
169 return true
170 }
171
172 // For commit events, check the collection
173 if event.Commit == nil {
174 return false
175 }
176
177 collection := event.Commit.Collection
178 for _, wanted := range wantedCollections {
179 // Support wildcard matching like "app.bsky.graph.*"
180 if strings.HasSuffix(wanted, ".*") {
181 prefix := strings.TrimSuffix(wanted, ".*")
182 if strings.HasPrefix(collection, prefix+".") || collection == prefix {
183 return true
184 }
185 } else if collection == wanted {
186 return true
187 }
188 }
189
190 return false
191}
192
193// handleSubscribe upgrades HTTP connection to websocket and streams events
194func handleSubscribe(broadcaster *Broadcaster) http.HandlerFunc {
195 return func(w http.ResponseWriter, r *http.Request) {
196 conn, err := upgrader.Upgrade(w, r, nil)
197 if err != nil {
198 slog.Error("Failed to upgrade connection", slog.Any("error", err))
199 return
200 }
201 defer conn.Close()
202
203 // Parse wantedCollections from query params
204 wantedCollections := r.URL.Query()["wantedCollections"]
205 if len(wantedCollections) > 100 {
206 slog.Warn("Client requested too many collections, limiting to 100", slog.Int("requested", len(wantedCollections)))
207 wantedCollections = wantedCollections[:100]
208 }
209
210 // Subscribe to broadcaster
211 ch := broadcaster.Subscribe()
212 defer broadcaster.Unsubscribe(ch)
213
214 if len(wantedCollections) > 0 {
215 slog.Info("Client connected", slog.String("remote", r.RemoteAddr), slog.Any("wantedCollections", wantedCollections))
216 } else {
217 slog.Info("Client connected", slog.String("remote", r.RemoteAddr))
218 }
219
220 // Stream events to client
221 for msg := range ch {
222 // If filtering is enabled, check the event
223 if len(wantedCollections) > 0 && msg.Event != nil {
224 // Check if event matches wanted collections
225 if !matchesCollection(msg.Event, wantedCollections) {
226 continue
227 }
228 }
229
230 err := conn.WriteMessage(websocket.TextMessage, msg.Raw)
231 if err != nil {
232 slog.Debug("Client disconnected", slog.String("remote", r.RemoteAddr), slog.Any("error", err))
233 break
234 }
235 }
236
237 slog.Info("Client disconnected", slog.String("remote", r.RemoteAddr))
238 }
239}
240
241// raceUpstreams connects to all upstreams simultaneously and returns the first one to deliver a message
242func raceUpstreams(pool []string) (string, error) {
243 slog.Info("Racing upstreams to find fastest message delivery")
244
245 type result struct {
246 url string
247 duration time.Duration
248 }
249
250 results := make(chan result, len(pool))
251 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
252 defer cancel()
253
254 for _, url := range pool {
255 go func(u string) {
256 start := time.Now()
257 conn, _, err := websocket.DefaultDialer.DialContext(ctx, u+"/subscribe", nil)
258 if err != nil {
259 slog.Debug("Failed to connect during race", slog.String("url", u), slog.Any("error", err))
260 return
261 }
262 defer conn.Close()
263
264 // Wait for first message
265 _, _, err = conn.ReadMessage()
266 if err != nil {
267 slog.Debug("Failed to read message during race", slog.String("url", u), slog.Any("error", err))
268 return
269 }
270
271 duration := time.Since(start)
272 select {
273 case results <- result{url: u, duration: duration}:
274 case <-ctx.Done():
275 }
276 }(url)
277 }
278
279 select {
280 case res := <-results:
281 slog.Info("Race winner", slog.String("url", res.url), slog.Duration("time_to_first_message", res.duration))
282 return res.url, nil
283 case <-ctx.Done():
284 return "", fmt.Errorf("no upstream delivered a message within timeout")
285 }
286}
287
288// watchdog monitors message activity and triggers reconnection if stalled
289// this is useful if bsky relay is down, which means all the jetstreams will be
290// reachable but not send any messages. we should swap to alternate relay infrastructure
291// in this case (via raceUpstreams)
292func watchdog(broadcaster *Broadcaster, trigger chan struct{}) {
293 ticker := time.NewTicker(5 * time.Second)
294 defer ticker.Stop()
295
296 for range ticker.C {
297 if !broadcaster.connected.Load() {
298 continue
299 }
300
301 lastMsg := broadcaster.lastMessageTime.Load()
302 if lastMsg == 0 {
303 // No messages received yet
304 continue
305 }
306
307 timeSinceLastMsg := time.Since(time.Unix(lastMsg, 0))
308 if timeSinceLastMsg > 20*time.Second {
309 slog.Warn("No messages received", slog.Duration("duration", timeSinceLastMsg))
310 select {
311 case trigger <- struct{}{}:
312 // Trigger sent
313 default:
314 // Trigger already pending
315 }
316 }
317 }
318}
319
320// connectToUpstream maintains a connection to the upstream websocket and broadcasts messages
321func connectToUpstream(pool []string, broadcaster *Broadcaster, watchdogTrigger <-chan struct{}) {
322 backoff := 50 * time.Millisecond
323 maxBackoff := 20 * time.Second
324 var currentUpstream string
325 var raceTriggered bool
326
327 for {
328 if raceTriggered {
329 slog.Info("Watchdog triggered, racing upstreams")
330 // Watchdog triggered - race all upstreams
331 bestUpstream, err := raceUpstreams(pool)
332 if err != nil {
333 slog.Error("Failed to race upstreams", slog.Any("error", err))
334 time.Sleep(backoff)
335 backoff *= 2
336 if backoff > maxBackoff {
337 backoff = maxBackoff
338 }
339 continue
340 }
341 currentUpstream = bestUpstream
342 backoff = 50 * time.Millisecond // Reset backoff
343 raceTriggered = false
344 } else {
345 // Find best upstream (re-evaluate on each connection attempt)
346 bestUpstream, err := findBestUpstream(pool)
347 if err != nil {
348 slog.Error("Failed to find best upstream", slog.Any("error", err))
349 time.Sleep(backoff)
350 backoff *= 2
351 if backoff > maxBackoff {
352 backoff = maxBackoff
353 }
354 continue
355 }
356
357 if bestUpstream != currentUpstream {
358 slog.Info("Switching to new upstream", slog.String("url", bestUpstream))
359 currentUpstream = bestUpstream
360 }
361 }
362
363 slog.Info("Connecting to upstream", slog.String("url", currentUpstream))
364
365 conn, _, err := websocket.DefaultDialer.Dial(currentUpstream+"/subscribe", nil)
366 if err != nil {
367 slog.Error("Failed to connect to upstream", slog.String("url", currentUpstream), slog.Any("error", err))
368 broadcaster.connected.Store(false)
369 time.Sleep(backoff)
370 backoff *= 2
371 if backoff > maxBackoff {
372 backoff = maxBackoff
373 }
374 continue
375 }
376
377 slog.Info("Connected to upstream", slog.String("url", currentUpstream))
378 broadcaster.connected.Store(true)
379 broadcaster.lastMessageTime.Store(time.Now().Unix())
380 backoff = 50 * time.Millisecond // Reset backoff on successful connection
381
382 // Read messages from upstream and broadcast them
383 readDone := make(chan struct{})
384 go func() {
385 defer close(readDone)
386 for {
387 messageType, message, err := conn.ReadMessage()
388 if err != nil {
389 slog.Error("Error reading from upstream", slog.Any("error", err))
390 return
391 }
392
393 // Only broadcast text/binary messages
394 if messageType == websocket.TextMessage || messageType == websocket.BinaryMessage {
395 broadcaster.Broadcast(message)
396 }
397 }
398 }()
399
400 // Wait for either read error or watchdog trigger
401 select {
402 case <-readDone:
403 // Normal disconnection
404 case <-watchdogTrigger:
405 // Watchdog triggered disconnection
406 slog.Info("Watchdog triggered disconnection")
407 raceTriggered = true
408 conn.Close()
409 <-readDone // Wait for read goroutine to finish
410 }
411
412 broadcaster.connected.Store(false)
413
414 if !raceTriggered {
415 // Connection lost, will re-evaluate best upstream on next iteration
416 slog.Info("Connection lost, finding new upstream", slog.Duration("backoff", backoff))
417 time.Sleep(backoff)
418 backoff *= 2
419 if backoff > maxBackoff {
420 backoff = maxBackoff
421 }
422 }
423 }
424}
425
426func findBestUpstream(pool []string) (string, error) {
427
428 // Measure latency concurrently
429 results := make(chan latencyResult, len(pool))
430 var wg sync.WaitGroup
431
432 for _, url := range pool {
433 wg.Add(1)
434 go func(u string) {
435 defer wg.Done()
436 latency, err := measureLatency(u)
437 results <- latencyResult{url: u, latency: latency, err: err}
438 }(url)
439 }
440
441 wg.Wait()
442 close(results)
443
444 // Find the best connection
445 var best latencyResult
446 best.latency = time.Hour // Start with a very high latency
447
448 slog.Debug("Latency results:")
449 var err error
450 for result := range results {
451 if result.err != nil {
452 slog.Debug("connection error", slog.String("url", result.url), slog.Any("error", result.err))
453 continue
454 }
455 slog.Debug("latency measured", slog.String("url", result.url), slog.Duration("latency", result.latency))
456
457 if result.latency < best.latency {
458 best = result
459 }
460 }
461
462 if best.err == nil && best.latency < time.Hour {
463 slog.Debug("Best connection", slog.String("url", best.url), slog.Duration("latency", best.latency))
464 return best.url, nil
465 } else {
466 slog.Debug("No valid connections found")
467 return "", err
468 }
469
470}
471
472func main() {
473 maybePool := os.Getenv("POOL")
474 pool := DEFAULT_POOL
475 if maybePool != "" {
476 pool = strings.Split(maybePool, ",")
477 }
478
479 if os.Getenv("DEBUG") == "1" {
480 slog.SetLogLoggerLevel(slog.LevelDebug)
481 }
482
483 envPort := os.Getenv("PORT")
484 port := envPort
485 if envPort == "" {
486 port = "8096"
487 }
488
489 envHost := os.Getenv("HOST")
490 host := envHost
491 if envHost == "" {
492 // should be running on the same hardware as your service
493 host = "127.0.0.1"
494 }
495
496 bindAddr := fmt.Sprintf("%s:%s", host, port)
497
498 // Create broadcaster and start upstream connection
499 // connectToUpstream will continuously find the best upstream and reconnect on failures
500 broadcaster := &Broadcaster{}
501 watchdogTrigger := make(chan struct{}, 1)
502
503 go watchdog(broadcaster, watchdogTrigger)
504 go connectToUpstream(pool, broadcaster, watchdogTrigger)
505
506 // Setup HTTP server
507 http.HandleFunc("/", handleHealth(broadcaster))
508 http.HandleFunc("/subscribe", handleSubscribe(broadcaster))
509
510 slog.Info("Starting proxy server", slog.String("bind", bindAddr))
511 if err := http.ListenAndServe(bindAddr, nil); err != nil {
512 slog.Error("Server failed", slog.Any("error", err))
513 panic(err)
514 }
515
516 // TODO (future) let zlib compression be env'd
517 // TODO: the proxy subscribes to all lexicons, but then filters out at client level. add env var for lex filtering too
518}