auto-reconnecting jetstream proxy
at mistress 518 lines 14 kB view raw
1package main 2 3import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "log/slog" 8 "net/http" 9 "os" 10 "strings" 11 "sync" 12 "sync/atomic" 13 "time" 14 15 "github.com/gorilla/websocket" 16) 17 18var DEFAULT_POOL = []string{ 19 "wss://jetstream1.us-east.bsky.network", 20 "wss://jetstream2.us-east.bsky.network", 21 "wss://jetstream1.us-west.bsky.network", 22 "wss://jetstream2.us-west.bsky.network", 23 "wss://jetstream.fire.hose.cam", 24 "wss://jetstream2.fr.hose.cam", 25 // want yours here? contact me 26} 27 28// Event represents a Jetstream event 29type Event struct { 30 Did string `json:"did"` 31 TimeUS int64 `json:"time_us"` 32 Kind string `json:"kind,omitempty"` 33 Commit *Commit `json:"commit,omitempty"` 34} 35 36// Commit represents a commit event 37type Commit struct { 38 Rev string `json:"rev,omitempty"` 39 Operation string `json:"operation,omitempty"` 40 Collection string `json:"collection,omitempty"` 41 RKey string `json:"rkey,omitempty"` 42 Record json.RawMessage `json:"record,omitempty"` 43 CID string `json:"cid,omitempty"` 44} 45 46// Message wraps a Jetstream event with both parsed and raw forms 47type Message struct { 48 Event *Event 49 Raw []byte 50} 51 52// Broadcaster manages subscribers to Jetstream events 53type Broadcaster struct { 54 listeners []chan *Message 55 mu sync.Mutex 56 connected atomic.Bool 57 lastMessageTime atomic.Int64 // Unix timestamp in seconds 58} 59 60// Subscribe returns a new channel that will receive Jetstream events 61func (b *Broadcaster) Subscribe() chan *Message { 62 b.mu.Lock() 63 defer b.mu.Unlock() 64 65 // firehose can be more-than-1k events per second, 66 // prefer to create a large buffer for the subscribers 67 ch := make(chan *Message, 10000) 68 b.listeners = append(b.listeners, ch) 69 return ch 70} 71 72func (b *Broadcaster) Unsubscribe(ch chan *Message) { 73 b.mu.Lock() 74 defer b.mu.Unlock() 75 76 for i, listener := range b.listeners { 77 if listener == ch { 78 b.listeners = append(b.listeners[:i], b.listeners[i+1:]...) 79 close(ch) 80 break 81 } 82 } 83} 84 85func (b *Broadcaster) Broadcast(rawMessage []byte) { 86 b.lastMessageTime.Store(time.Now().Unix()) 87 88 // Parse the event once 89 var event Event 90 if err := json.Unmarshal(rawMessage, &event); err != nil { 91 slog.Debug("Failed to parse event", slog.Any("error", err)) 92 // Broadcast anyway with nil event 93 } 94 95 msg := &Message{ 96 Event: &event, 97 Raw: rawMessage, 98 } 99 100 b.mu.Lock() 101 defer b.mu.Unlock() 102 103 for _, ch := range b.listeners { 104 select { 105 case ch <- msg: 106 // event sent successfully. we don't want to block 107 default: 108 // channel full, skip to avoid blocking 109 slog.Warn("jetstream broadcast: channel full, dropping event") 110 } 111 } 112} 113 114type latencyResult struct { 115 url string 116 latency time.Duration 117 err error 118} 119 120func measureLatency(url string) (time.Duration, error) { 121 url = strings.Replace(url, "wss://", "https://", 1) 122 // also support non-tls upstreams 123 url = strings.Replace(url, "ws://", "http://", 1) 124 125 client := &http.Client{ 126 Timeout: 20 * time.Second, 127 } 128 129 start := time.Now() 130 // jetstream instances return the "Welcome to jetstream!" banner on / which 131 // should be useful enough for latency 132 resp, err := client.Get(url) 133 if err != nil { 134 return 0, err 135 } 136 defer resp.Body.Close() 137 138 return time.Since(start), nil 139} 140 141var upgrader = websocket.Upgrader{ 142 CheckOrigin: func(r *http.Request) bool { 143 return true // Allow all origins 144 }, 145} 146 147// handleHealth returns 200 if connected to upstream 148func handleHealth(broadcaster *Broadcaster) http.HandlerFunc { 149 return func(w http.ResponseWriter, r *http.Request) { 150 if broadcaster.connected.Load() { 151 w.WriteHeader(http.StatusOK) 152 w.Write([]byte("Welcome to jetstream!")) 153 } else { 154 w.WriteHeader(http.StatusServiceUnavailable) 155 w.Write([]byte("Not connected to upstream")) 156 } 157 } 158} 159 160// matchesCollection checks if an event matches any of the wanted collections 161func matchesCollection(event *Event, wantedCollections []string) bool { 162 // Always pass through account and identity events 163 if event.Kind == "account" || event.Kind == "identity" { 164 return true 165 } 166 167 // If no wanted collections specified, pass everything 168 if len(wantedCollections) == 0 { 169 return true 170 } 171 172 // For commit events, check the collection 173 if event.Commit == nil { 174 return false 175 } 176 177 collection := event.Commit.Collection 178 for _, wanted := range wantedCollections { 179 // Support wildcard matching like "app.bsky.graph.*" 180 if strings.HasSuffix(wanted, ".*") { 181 prefix := strings.TrimSuffix(wanted, ".*") 182 if strings.HasPrefix(collection, prefix+".") || collection == prefix { 183 return true 184 } 185 } else if collection == wanted { 186 return true 187 } 188 } 189 190 return false 191} 192 193// handleSubscribe upgrades HTTP connection to websocket and streams events 194func handleSubscribe(broadcaster *Broadcaster) http.HandlerFunc { 195 return func(w http.ResponseWriter, r *http.Request) { 196 conn, err := upgrader.Upgrade(w, r, nil) 197 if err != nil { 198 slog.Error("Failed to upgrade connection", slog.Any("error", err)) 199 return 200 } 201 defer conn.Close() 202 203 // Parse wantedCollections from query params 204 wantedCollections := r.URL.Query()["wantedCollections"] 205 if len(wantedCollections) > 100 { 206 slog.Warn("Client requested too many collections, limiting to 100", slog.Int("requested", len(wantedCollections))) 207 wantedCollections = wantedCollections[:100] 208 } 209 210 // Subscribe to broadcaster 211 ch := broadcaster.Subscribe() 212 defer broadcaster.Unsubscribe(ch) 213 214 if len(wantedCollections) > 0 { 215 slog.Info("Client connected", slog.String("remote", r.RemoteAddr), slog.Any("wantedCollections", wantedCollections)) 216 } else { 217 slog.Info("Client connected", slog.String("remote", r.RemoteAddr)) 218 } 219 220 // Stream events to client 221 for msg := range ch { 222 // If filtering is enabled, check the event 223 if len(wantedCollections) > 0 && msg.Event != nil { 224 // Check if event matches wanted collections 225 if !matchesCollection(msg.Event, wantedCollections) { 226 continue 227 } 228 } 229 230 err := conn.WriteMessage(websocket.TextMessage, msg.Raw) 231 if err != nil { 232 slog.Debug("Client disconnected", slog.String("remote", r.RemoteAddr), slog.Any("error", err)) 233 break 234 } 235 } 236 237 slog.Info("Client disconnected", slog.String("remote", r.RemoteAddr)) 238 } 239} 240 241// raceUpstreams connects to all upstreams simultaneously and returns the first one to deliver a message 242func raceUpstreams(pool []string) (string, error) { 243 slog.Info("Racing upstreams to find fastest message delivery") 244 245 type result struct { 246 url string 247 duration time.Duration 248 } 249 250 results := make(chan result, len(pool)) 251 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 252 defer cancel() 253 254 for _, url := range pool { 255 go func(u string) { 256 start := time.Now() 257 conn, _, err := websocket.DefaultDialer.DialContext(ctx, u+"/subscribe", nil) 258 if err != nil { 259 slog.Debug("Failed to connect during race", slog.String("url", u), slog.Any("error", err)) 260 return 261 } 262 defer conn.Close() 263 264 // Wait for first message 265 _, _, err = conn.ReadMessage() 266 if err != nil { 267 slog.Debug("Failed to read message during race", slog.String("url", u), slog.Any("error", err)) 268 return 269 } 270 271 duration := time.Since(start) 272 select { 273 case results <- result{url: u, duration: duration}: 274 case <-ctx.Done(): 275 } 276 }(url) 277 } 278 279 select { 280 case res := <-results: 281 slog.Info("Race winner", slog.String("url", res.url), slog.Duration("time_to_first_message", res.duration)) 282 return res.url, nil 283 case <-ctx.Done(): 284 return "", fmt.Errorf("no upstream delivered a message within timeout") 285 } 286} 287 288// watchdog monitors message activity and triggers reconnection if stalled 289// this is useful if bsky relay is down, which means all the jetstreams will be 290// reachable but not send any messages. we should swap to alternate relay infrastructure 291// in this case (via raceUpstreams) 292func watchdog(broadcaster *Broadcaster, trigger chan struct{}) { 293 ticker := time.NewTicker(5 * time.Second) 294 defer ticker.Stop() 295 296 for range ticker.C { 297 if !broadcaster.connected.Load() { 298 continue 299 } 300 301 lastMsg := broadcaster.lastMessageTime.Load() 302 if lastMsg == 0 { 303 // No messages received yet 304 continue 305 } 306 307 timeSinceLastMsg := time.Since(time.Unix(lastMsg, 0)) 308 if timeSinceLastMsg > 20*time.Second { 309 slog.Warn("No messages received", slog.Duration("duration", timeSinceLastMsg)) 310 select { 311 case trigger <- struct{}{}: 312 // Trigger sent 313 default: 314 // Trigger already pending 315 } 316 } 317 } 318} 319 320// connectToUpstream maintains a connection to the upstream websocket and broadcasts messages 321func connectToUpstream(pool []string, broadcaster *Broadcaster, watchdogTrigger <-chan struct{}) { 322 backoff := 50 * time.Millisecond 323 maxBackoff := 20 * time.Second 324 var currentUpstream string 325 var raceTriggered bool 326 327 for { 328 if raceTriggered { 329 slog.Info("Watchdog triggered, racing upstreams") 330 // Watchdog triggered - race all upstreams 331 bestUpstream, err := raceUpstreams(pool) 332 if err != nil { 333 slog.Error("Failed to race upstreams", slog.Any("error", err)) 334 time.Sleep(backoff) 335 backoff *= 2 336 if backoff > maxBackoff { 337 backoff = maxBackoff 338 } 339 continue 340 } 341 currentUpstream = bestUpstream 342 backoff = 50 * time.Millisecond // Reset backoff 343 raceTriggered = false 344 } else { 345 // Find best upstream (re-evaluate on each connection attempt) 346 bestUpstream, err := findBestUpstream(pool) 347 if err != nil { 348 slog.Error("Failed to find best upstream", slog.Any("error", err)) 349 time.Sleep(backoff) 350 backoff *= 2 351 if backoff > maxBackoff { 352 backoff = maxBackoff 353 } 354 continue 355 } 356 357 if bestUpstream != currentUpstream { 358 slog.Info("Switching to new upstream", slog.String("url", bestUpstream)) 359 currentUpstream = bestUpstream 360 } 361 } 362 363 slog.Info("Connecting to upstream", slog.String("url", currentUpstream)) 364 365 conn, _, err := websocket.DefaultDialer.Dial(currentUpstream+"/subscribe", nil) 366 if err != nil { 367 slog.Error("Failed to connect to upstream", slog.String("url", currentUpstream), slog.Any("error", err)) 368 broadcaster.connected.Store(false) 369 time.Sleep(backoff) 370 backoff *= 2 371 if backoff > maxBackoff { 372 backoff = maxBackoff 373 } 374 continue 375 } 376 377 slog.Info("Connected to upstream", slog.String("url", currentUpstream)) 378 broadcaster.connected.Store(true) 379 broadcaster.lastMessageTime.Store(time.Now().Unix()) 380 backoff = 50 * time.Millisecond // Reset backoff on successful connection 381 382 // Read messages from upstream and broadcast them 383 readDone := make(chan struct{}) 384 go func() { 385 defer close(readDone) 386 for { 387 messageType, message, err := conn.ReadMessage() 388 if err != nil { 389 slog.Error("Error reading from upstream", slog.Any("error", err)) 390 return 391 } 392 393 // Only broadcast text/binary messages 394 if messageType == websocket.TextMessage || messageType == websocket.BinaryMessage { 395 broadcaster.Broadcast(message) 396 } 397 } 398 }() 399 400 // Wait for either read error or watchdog trigger 401 select { 402 case <-readDone: 403 // Normal disconnection 404 case <-watchdogTrigger: 405 // Watchdog triggered disconnection 406 slog.Info("Watchdog triggered disconnection") 407 raceTriggered = true 408 conn.Close() 409 <-readDone // Wait for read goroutine to finish 410 } 411 412 broadcaster.connected.Store(false) 413 414 if !raceTriggered { 415 // Connection lost, will re-evaluate best upstream on next iteration 416 slog.Info("Connection lost, finding new upstream", slog.Duration("backoff", backoff)) 417 time.Sleep(backoff) 418 backoff *= 2 419 if backoff > maxBackoff { 420 backoff = maxBackoff 421 } 422 } 423 } 424} 425 426func findBestUpstream(pool []string) (string, error) { 427 428 // Measure latency concurrently 429 results := make(chan latencyResult, len(pool)) 430 var wg sync.WaitGroup 431 432 for _, url := range pool { 433 wg.Add(1) 434 go func(u string) { 435 defer wg.Done() 436 latency, err := measureLatency(u) 437 results <- latencyResult{url: u, latency: latency, err: err} 438 }(url) 439 } 440 441 wg.Wait() 442 close(results) 443 444 // Find the best connection 445 var best latencyResult 446 best.latency = time.Hour // Start with a very high latency 447 448 slog.Debug("Latency results:") 449 var err error 450 for result := range results { 451 if result.err != nil { 452 slog.Debug("connection error", slog.String("url", result.url), slog.Any("error", result.err)) 453 continue 454 } 455 slog.Debug("latency measured", slog.String("url", result.url), slog.Duration("latency", result.latency)) 456 457 if result.latency < best.latency { 458 best = result 459 } 460 } 461 462 if best.err == nil && best.latency < time.Hour { 463 slog.Debug("Best connection", slog.String("url", best.url), slog.Duration("latency", best.latency)) 464 return best.url, nil 465 } else { 466 slog.Debug("No valid connections found") 467 return "", err 468 } 469 470} 471 472func main() { 473 maybePool := os.Getenv("POOL") 474 pool := DEFAULT_POOL 475 if maybePool != "" { 476 pool = strings.Split(maybePool, ",") 477 } 478 479 if os.Getenv("DEBUG") == "1" { 480 slog.SetLogLoggerLevel(slog.LevelDebug) 481 } 482 483 envPort := os.Getenv("PORT") 484 port := envPort 485 if envPort == "" { 486 port = "8096" 487 } 488 489 envHost := os.Getenv("HOST") 490 host := envHost 491 if envHost == "" { 492 // should be running on the same hardware as your service 493 host = "127.0.0.1" 494 } 495 496 bindAddr := fmt.Sprintf("%s:%s", host, port) 497 498 // Create broadcaster and start upstream connection 499 // connectToUpstream will continuously find the best upstream and reconnect on failures 500 broadcaster := &Broadcaster{} 501 watchdogTrigger := make(chan struct{}, 1) 502 503 go watchdog(broadcaster, watchdogTrigger) 504 go connectToUpstream(pool, broadcaster, watchdogTrigger) 505 506 // Setup HTTP server 507 http.HandleFunc("/", handleHealth(broadcaster)) 508 http.HandleFunc("/subscribe", handleSubscribe(broadcaster)) 509 510 slog.Info("Starting proxy server", slog.String("bind", bindAddr)) 511 if err := http.ListenAndServe(bindAddr, nil); err != nil { 512 slog.Error("Server failed", slog.Any("error", err)) 513 panic(err) 514 } 515 516 // TODO (future) let zlib compression be env'd 517 // TODO: the proxy subscribes to all lexicons, but then filters out at client level. add env var for lex filtering too 518}