An open source supporter broker powered by high-fives. high-five.atprotofans.com/
at main 495 lines 11 kB view raw
1package websocket 2 3import ( 4 "encoding/json" 5 "log" 6 "sync" 7 "time" 8 9 "github.com/gorilla/websocket" 10) 11 12const ( 13 writeWait = 10 * time.Second 14 pongWait = 60 * time.Second 15 pingPeriod = (pongWait * 9) / 10 16 maxMessageSize = 4096 17 sessionTimeout = 25 * time.Minute // Disconnect after 25 minutes 18) 19 20// MessageType represents the type of websocket message. 21type MessageType string 22 23const ( 24 // Client -> Server messages 25 MsgTypeHighFive MessageType = "high_five" 26 27 // Server -> Client messages 28 MsgTypeIdentityJoined MessageType = "identity_joined" 29 MsgTypeIdentityLeft MessageType = "identity_left" 30 MsgTypePopulateRoom MessageType = "populate_room" 31 MsgTypeHighFiveEvent MessageType = "high_five_event" 32 MsgTypeError MessageType = "error" 33 MsgTypeSessionTimeout MessageType = "session_timeout" 34) 35 36// Message represents a websocket message. 37type Message struct { 38 Type MessageType `json:"type"` 39 Payload map[string]interface{} `json:"payload,omitempty"` 40} 41 42// ClientInfo contains public information about a connected client. 43type ClientInfo struct { 44 DID string `json:"did"` 45 Handle string `json:"handle"` 46} 47 48// Client represents a connected websocket client. 49type Client struct { 50 hub *Hub 51 conn *websocket.Conn 52 send chan []byte 53 DID string 54 Handle string 55 SessionID string 56 IsReconnect bool 57 ConnectedAt time.Time 58 sessionTimer *time.Timer 59} 60 61// Hub maintains the set of active clients and broadcasts messages. 62type Hub struct { 63 clients map[*Client]bool 64 didClients map[string]*Client // Map DID to client 65 broadcast chan []byte 66 register chan *Client 67 unregister chan *Client 68 mu sync.RWMutex 69 70 // Callbacks for handling messages 71 onClientJoin func(client *Client) error 72 onClientLeave func(client *Client) error 73 onHighFive func(giver *Client, subjectDID string) error 74} 75 76// NewHub creates a new Hub. 77func NewHub() *Hub { 78 return &Hub{ 79 clients: make(map[*Client]bool), 80 didClients: make(map[string]*Client), 81 broadcast: make(chan []byte), 82 register: make(chan *Client), 83 unregister: make(chan *Client), 84 } 85} 86 87// SetCallbacks sets the message handler callbacks. 88func (h *Hub) SetCallbacks( 89 onClientJoin func(client *Client) error, 90 onClientLeave func(client *Client) error, 91 onHighFive func(giver *Client, subjectDID string) error, 92) { 93 h.onClientJoin = onClientJoin 94 h.onClientLeave = onClientLeave 95 h.onHighFive = onHighFive 96} 97 98// Run starts the hub's main loop. 99func (h *Hub) Run() { 100 for { 101 select { 102 case client := <-h.register: 103 h.mu.Lock() 104 h.clients[client] = true 105 if client.DID != "" { 106 h.didClients[client.DID] = client 107 } 108 h.mu.Unlock() 109 log.Printf("Client registered: %s (%s)", client.DID, client.Handle) 110 111 // Call the join callback 112 if h.onClientJoin != nil { 113 if err := h.onClientJoin(client); err != nil { 114 log.Printf("error in client join callback: %v", err) 115 } 116 } 117 118 case client := <-h.unregister: 119 h.mu.Lock() 120 if _, ok := h.clients[client]; ok { 121 delete(h.clients, client) 122 if client.DID != "" { 123 delete(h.didClients, client.DID) 124 } 125 close(client.send) 126 } 127 h.mu.Unlock() 128 log.Printf("Client unregistered: %s (%s)", client.DID, client.Handle) 129 130 // Call the leave callback 131 if h.onClientLeave != nil { 132 if err := h.onClientLeave(client); err != nil { 133 log.Printf("error in client leave callback: %v", err) 134 } 135 } 136 137 case message := <-h.broadcast: 138 h.mu.RLock() 139 for client := range h.clients { 140 select { 141 case client.send <- message: 142 default: 143 h.mu.RUnlock() 144 h.mu.Lock() 145 close(client.send) 146 delete(h.clients, client) 147 if client.DID != "" { 148 delete(h.didClients, client.DID) 149 } 150 h.mu.Unlock() 151 h.mu.RLock() 152 } 153 } 154 h.mu.RUnlock() 155 } 156 } 157} 158 159// GetClientByDID returns the client for a given DID. 160func (h *Hub) GetClientByDID(did string) *Client { 161 h.mu.RLock() 162 defer h.mu.RUnlock() 163 return h.didClients[did] 164} 165 166// IsClientConnected checks if a client with the given DID is connected. 167func (h *Hub) IsClientConnected(did string) bool { 168 h.mu.RLock() 169 defer h.mu.RUnlock() 170 _, ok := h.didClients[did] 171 return ok 172} 173 174// SendToClient sends a message to a specific client by DID. 175func (h *Hub) SendToClient(did string, msg *Message) error { 176 h.mu.RLock() 177 client, ok := h.didClients[did] 178 h.mu.RUnlock() 179 180 if !ok { 181 return nil // Client not connected, not an error 182 } 183 184 data, err := json.Marshal(msg) 185 if err != nil { 186 return err 187 } 188 189 select { 190 case client.send <- data: 191 return nil 192 default: 193 return nil // Channel full, skip 194 } 195} 196 197// Register registers a new client. 198func (h *Hub) Register(client *Client) { 199 h.register <- client 200} 201 202// Unregister unregisters a client. 203func (h *Hub) Unregister(client *Client) { 204 h.unregister <- client 205} 206 207// NewClient creates a new client. 208func NewClient(hub *Hub, conn *websocket.Conn, did, handle, sessionID string, isReconnect bool) *Client { 209 c := &Client{ 210 hub: hub, 211 conn: conn, 212 send: make(chan []byte, 256), 213 DID: did, 214 Handle: handle, 215 SessionID: sessionID, 216 IsReconnect: isReconnect, 217 ConnectedAt: time.Now(), 218 } 219 220 // Start session timeout timer 221 c.sessionTimer = time.AfterFunc(sessionTimeout, func() { 222 c.handleSessionTimeout() 223 }) 224 225 return c 226} 227 228// handleSessionTimeout sends a timeout message and closes the connection. 229func (c *Client) handleSessionTimeout() { 230 log.Printf("Session timeout for client: %s (%s)", c.DID, c.Handle) 231 232 // Send session timeout message 233 c.SendMessage(&Message{ 234 Type: MsgTypeSessionTimeout, 235 Payload: map[string]interface{}{ 236 "message": "Your session has expired after 25 minutes. Start a new session to continue giving and receiving high-fives!", 237 }, 238 }) 239 240 // Give client time to receive the message before closing 241 time.Sleep(100 * time.Millisecond) 242 243 // Close the connection 244 c.conn.Close() 245} 246 247// StopSessionTimer stops the session timeout timer. 248func (c *Client) StopSessionTimer() { 249 if c.sessionTimer != nil { 250 c.sessionTimer.Stop() 251 } 252} 253 254// ReadPump pumps messages from the websocket connection to the hub. 255func (c *Client) ReadPump() { 256 defer func() { 257 c.StopSessionTimer() 258 c.hub.Unregister(c) 259 c.conn.Close() 260 }() 261 262 c.conn.SetReadLimit(maxMessageSize) 263 c.conn.SetReadDeadline(time.Now().Add(pongWait)) 264 c.conn.SetPongHandler(func(string) error { 265 c.conn.SetReadDeadline(time.Now().Add(pongWait)) 266 return nil 267 }) 268 269 for { 270 _, message, err := c.conn.ReadMessage() 271 if err != nil { 272 if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { 273 log.Printf("websocket error: %v", err) 274 } 275 break 276 } 277 278 var msg Message 279 if err := json.Unmarshal(message, &msg); err != nil { 280 log.Printf("failed to unmarshal message: %v", err) 281 continue 282 } 283 284 c.handleMessage(&msg) 285 } 286} 287 288// handleMessage processes incoming messages. 289func (c *Client) handleMessage(msg *Message) { 290 var err error 291 292 switch msg.Type { 293 case MsgTypeHighFive: 294 subjectDID, _ := msg.Payload["subject"].(string) 295 if c.hub.onHighFive != nil { 296 err = c.hub.onHighFive(c, subjectDID) 297 } 298 299 default: 300 log.Printf("unknown message type: %s", msg.Type) 301 } 302 303 if err != nil { 304 log.Printf("error handling message %s: %v", msg.Type, err) 305 c.SendError(err.Error()) 306 } 307} 308 309// WritePump pumps messages from the hub to the websocket connection. 310func (c *Client) WritePump() { 311 ticker := time.NewTicker(pingPeriod) 312 defer func() { 313 ticker.Stop() 314 c.conn.Close() 315 }() 316 317 for { 318 select { 319 case message, ok := <-c.send: 320 c.conn.SetWriteDeadline(time.Now().Add(writeWait)) 321 if !ok { 322 c.conn.WriteMessage(websocket.CloseMessage, []byte{}) 323 return 324 } 325 326 w, err := c.conn.NextWriter(websocket.TextMessage) 327 if err != nil { 328 return 329 } 330 w.Write(message) 331 332 if err := w.Close(); err != nil { 333 return 334 } 335 336 case <-ticker.C: 337 c.conn.SetWriteDeadline(time.Now().Add(writeWait)) 338 if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { 339 return 340 } 341 } 342 } 343} 344 345// SendMessage sends a message to this client. 346func (c *Client) SendMessage(msg *Message) error { 347 data, err := json.Marshal(msg) 348 if err != nil { 349 return err 350 } 351 352 select { 353 case c.send <- data: 354 return nil 355 default: 356 return nil 357 } 358} 359 360// SendError sends an error message to this client. 361func (c *Client) SendError(errMsg string) { 362 c.SendMessage(&Message{ 363 Type: MsgTypeError, 364 Payload: map[string]interface{}{ 365 "error": errMsg, 366 }, 367 }) 368} 369 370// BroadcastToAll sends a message to all connected clients. 371func (h *Hub) BroadcastToAll(msg *Message) { 372 data, err := json.Marshal(msg) 373 if err != nil { 374 log.Printf("failed to marshal broadcast message: %v", err) 375 return 376 } 377 378 h.mu.RLock() 379 defer h.mu.RUnlock() 380 381 for client := range h.clients { 382 select { 383 case client.send <- data: 384 default: 385 // Channel full, skip this client 386 } 387 } 388} 389 390// BroadcastToOthers sends a message to all clients except the one with the given DID. 391func (h *Hub) BroadcastToOthers(excludeDID string, msg *Message) { 392 data, err := json.Marshal(msg) 393 if err != nil { 394 log.Printf("failed to marshal broadcast message: %v", err) 395 return 396 } 397 398 h.mu.RLock() 399 defer h.mu.RUnlock() 400 401 for client := range h.clients { 402 if client.DID == excludeDID { 403 continue 404 } 405 select { 406 case client.send <- data: 407 default: 408 // Channel full, skip this client 409 } 410 } 411} 412 413// GetAllClients returns information about all connected clients. 414func (h *Hub) GetAllClients() []ClientInfo { 415 h.mu.RLock() 416 defer h.mu.RUnlock() 417 418 clients := make([]ClientInfo, 0, len(h.clients)) 419 for client := range h.clients { 420 clients = append(clients, ClientInfo{ 421 DID: client.DID, 422 Handle: client.Handle, 423 }) 424 } 425 return clients 426} 427 428// GetOtherClients returns information about all connected clients except the given DID. 429func (h *Hub) GetOtherClients(excludeDID string) []ClientInfo { 430 h.mu.RLock() 431 defer h.mu.RUnlock() 432 433 clients := make([]ClientInfo, 0, len(h.clients)) 434 for client := range h.clients { 435 if client.DID == excludeDID { 436 continue 437 } 438 clients = append(clients, ClientInfo{ 439 DID: client.DID, 440 Handle: client.Handle, 441 }) 442 } 443 return clients 444} 445 446// GetClientCount returns the number of connected clients. 447func (h *Hub) GetClientCount() int { 448 h.mu.RLock() 449 defer h.mu.RUnlock() 450 return len(h.clients) 451} 452 453// StartRoomPopulation sends populate_room messages to a new client in batches. 454func (h *Hub) StartRoomPopulation(newClient *Client) { 455 go func() { 456 // Get all other clients 457 clients := h.GetOtherClients(newClient.DID) 458 459 // Shuffle for random order 460 for i := len(clients) - 1; i > 0; i-- { 461 j := int(time.Now().UnixNano()) % (i + 1) 462 clients[i], clients[j] = clients[j], clients[i] 463 } 464 465 // Send in batches of 5 every 3 seconds 466 batchSize := 5 467 for i := 0; i < len(clients); i += batchSize { 468 end := i + batchSize 469 if end > len(clients) { 470 end = len(clients) 471 } 472 batch := clients[i:end] 473 474 // Convert to interface slice for payload 475 identities := make([]map[string]string, len(batch)) 476 for j, c := range batch { 477 identities[j] = map[string]string{ 478 "did": c.DID, 479 "handle": c.Handle, 480 } 481 } 482 483 newClient.SendMessage(&Message{ 484 Type: MsgTypePopulateRoom, 485 Payload: map[string]interface{}{ 486 "identities": identities, 487 }, 488 }) 489 490 if end < len(clients) { 491 time.Sleep(3 * time.Second) 492 } 493 } 494 }() 495}