An open source supporter broker powered by high-fives.
high-five.atprotofans.com/
1package websocket
2
3import (
4 "encoding/json"
5 "log"
6 "sync"
7 "time"
8
9 "github.com/gorilla/websocket"
10)
11
12const (
13 writeWait = 10 * time.Second
14 pongWait = 60 * time.Second
15 pingPeriod = (pongWait * 9) / 10
16 maxMessageSize = 4096
17 sessionTimeout = 25 * time.Minute // Disconnect after 25 minutes
18)
19
20// MessageType represents the type of websocket message.
21type MessageType string
22
23const (
24 // Client -> Server messages
25 MsgTypeHighFive MessageType = "high_five"
26
27 // Server -> Client messages
28 MsgTypeIdentityJoined MessageType = "identity_joined"
29 MsgTypeIdentityLeft MessageType = "identity_left"
30 MsgTypePopulateRoom MessageType = "populate_room"
31 MsgTypeHighFiveEvent MessageType = "high_five_event"
32 MsgTypeError MessageType = "error"
33 MsgTypeSessionTimeout MessageType = "session_timeout"
34)
35
36// Message represents a websocket message.
37type Message struct {
38 Type MessageType `json:"type"`
39 Payload map[string]interface{} `json:"payload,omitempty"`
40}
41
42// ClientInfo contains public information about a connected client.
43type ClientInfo struct {
44 DID string `json:"did"`
45 Handle string `json:"handle"`
46}
47
48// Client represents a connected websocket client.
49type Client struct {
50 hub *Hub
51 conn *websocket.Conn
52 send chan []byte
53 DID string
54 Handle string
55 SessionID string
56 IsReconnect bool
57 ConnectedAt time.Time
58 sessionTimer *time.Timer
59}
60
61// Hub maintains the set of active clients and broadcasts messages.
62type Hub struct {
63 clients map[*Client]bool
64 didClients map[string]*Client // Map DID to client
65 broadcast chan []byte
66 register chan *Client
67 unregister chan *Client
68 mu sync.RWMutex
69
70 // Callbacks for handling messages
71 onClientJoin func(client *Client) error
72 onClientLeave func(client *Client) error
73 onHighFive func(giver *Client, subjectDID string) error
74}
75
76// NewHub creates a new Hub.
77func NewHub() *Hub {
78 return &Hub{
79 clients: make(map[*Client]bool),
80 didClients: make(map[string]*Client),
81 broadcast: make(chan []byte),
82 register: make(chan *Client),
83 unregister: make(chan *Client),
84 }
85}
86
87// SetCallbacks sets the message handler callbacks.
88func (h *Hub) SetCallbacks(
89 onClientJoin func(client *Client) error,
90 onClientLeave func(client *Client) error,
91 onHighFive func(giver *Client, subjectDID string) error,
92) {
93 h.onClientJoin = onClientJoin
94 h.onClientLeave = onClientLeave
95 h.onHighFive = onHighFive
96}
97
98// Run starts the hub's main loop.
99func (h *Hub) Run() {
100 for {
101 select {
102 case client := <-h.register:
103 h.mu.Lock()
104 h.clients[client] = true
105 if client.DID != "" {
106 h.didClients[client.DID] = client
107 }
108 h.mu.Unlock()
109 log.Printf("Client registered: %s (%s)", client.DID, client.Handle)
110
111 // Call the join callback
112 if h.onClientJoin != nil {
113 if err := h.onClientJoin(client); err != nil {
114 log.Printf("error in client join callback: %v", err)
115 }
116 }
117
118 case client := <-h.unregister:
119 h.mu.Lock()
120 if _, ok := h.clients[client]; ok {
121 delete(h.clients, client)
122 if client.DID != "" {
123 delete(h.didClients, client.DID)
124 }
125 close(client.send)
126 }
127 h.mu.Unlock()
128 log.Printf("Client unregistered: %s (%s)", client.DID, client.Handle)
129
130 // Call the leave callback
131 if h.onClientLeave != nil {
132 if err := h.onClientLeave(client); err != nil {
133 log.Printf("error in client leave callback: %v", err)
134 }
135 }
136
137 case message := <-h.broadcast:
138 h.mu.RLock()
139 for client := range h.clients {
140 select {
141 case client.send <- message:
142 default:
143 h.mu.RUnlock()
144 h.mu.Lock()
145 close(client.send)
146 delete(h.clients, client)
147 if client.DID != "" {
148 delete(h.didClients, client.DID)
149 }
150 h.mu.Unlock()
151 h.mu.RLock()
152 }
153 }
154 h.mu.RUnlock()
155 }
156 }
157}
158
159// GetClientByDID returns the client for a given DID.
160func (h *Hub) GetClientByDID(did string) *Client {
161 h.mu.RLock()
162 defer h.mu.RUnlock()
163 return h.didClients[did]
164}
165
166// IsClientConnected checks if a client with the given DID is connected.
167func (h *Hub) IsClientConnected(did string) bool {
168 h.mu.RLock()
169 defer h.mu.RUnlock()
170 _, ok := h.didClients[did]
171 return ok
172}
173
174// SendToClient sends a message to a specific client by DID.
175func (h *Hub) SendToClient(did string, msg *Message) error {
176 h.mu.RLock()
177 client, ok := h.didClients[did]
178 h.mu.RUnlock()
179
180 if !ok {
181 return nil // Client not connected, not an error
182 }
183
184 data, err := json.Marshal(msg)
185 if err != nil {
186 return err
187 }
188
189 select {
190 case client.send <- data:
191 return nil
192 default:
193 return nil // Channel full, skip
194 }
195}
196
197// Register registers a new client.
198func (h *Hub) Register(client *Client) {
199 h.register <- client
200}
201
202// Unregister unregisters a client.
203func (h *Hub) Unregister(client *Client) {
204 h.unregister <- client
205}
206
207// NewClient creates a new client.
208func NewClient(hub *Hub, conn *websocket.Conn, did, handle, sessionID string, isReconnect bool) *Client {
209 c := &Client{
210 hub: hub,
211 conn: conn,
212 send: make(chan []byte, 256),
213 DID: did,
214 Handle: handle,
215 SessionID: sessionID,
216 IsReconnect: isReconnect,
217 ConnectedAt: time.Now(),
218 }
219
220 // Start session timeout timer
221 c.sessionTimer = time.AfterFunc(sessionTimeout, func() {
222 c.handleSessionTimeout()
223 })
224
225 return c
226}
227
228// handleSessionTimeout sends a timeout message and closes the connection.
229func (c *Client) handleSessionTimeout() {
230 log.Printf("Session timeout for client: %s (%s)", c.DID, c.Handle)
231
232 // Send session timeout message
233 c.SendMessage(&Message{
234 Type: MsgTypeSessionTimeout,
235 Payload: map[string]interface{}{
236 "message": "Your session has expired after 25 minutes. Start a new session to continue giving and receiving high-fives!",
237 },
238 })
239
240 // Give client time to receive the message before closing
241 time.Sleep(100 * time.Millisecond)
242
243 // Close the connection
244 c.conn.Close()
245}
246
247// StopSessionTimer stops the session timeout timer.
248func (c *Client) StopSessionTimer() {
249 if c.sessionTimer != nil {
250 c.sessionTimer.Stop()
251 }
252}
253
254// ReadPump pumps messages from the websocket connection to the hub.
255func (c *Client) ReadPump() {
256 defer func() {
257 c.StopSessionTimer()
258 c.hub.Unregister(c)
259 c.conn.Close()
260 }()
261
262 c.conn.SetReadLimit(maxMessageSize)
263 c.conn.SetReadDeadline(time.Now().Add(pongWait))
264 c.conn.SetPongHandler(func(string) error {
265 c.conn.SetReadDeadline(time.Now().Add(pongWait))
266 return nil
267 })
268
269 for {
270 _, message, err := c.conn.ReadMessage()
271 if err != nil {
272 if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
273 log.Printf("websocket error: %v", err)
274 }
275 break
276 }
277
278 var msg Message
279 if err := json.Unmarshal(message, &msg); err != nil {
280 log.Printf("failed to unmarshal message: %v", err)
281 continue
282 }
283
284 c.handleMessage(&msg)
285 }
286}
287
288// handleMessage processes incoming messages.
289func (c *Client) handleMessage(msg *Message) {
290 var err error
291
292 switch msg.Type {
293 case MsgTypeHighFive:
294 subjectDID, _ := msg.Payload["subject"].(string)
295 if c.hub.onHighFive != nil {
296 err = c.hub.onHighFive(c, subjectDID)
297 }
298
299 default:
300 log.Printf("unknown message type: %s", msg.Type)
301 }
302
303 if err != nil {
304 log.Printf("error handling message %s: %v", msg.Type, err)
305 c.SendError(err.Error())
306 }
307}
308
309// WritePump pumps messages from the hub to the websocket connection.
310func (c *Client) WritePump() {
311 ticker := time.NewTicker(pingPeriod)
312 defer func() {
313 ticker.Stop()
314 c.conn.Close()
315 }()
316
317 for {
318 select {
319 case message, ok := <-c.send:
320 c.conn.SetWriteDeadline(time.Now().Add(writeWait))
321 if !ok {
322 c.conn.WriteMessage(websocket.CloseMessage, []byte{})
323 return
324 }
325
326 w, err := c.conn.NextWriter(websocket.TextMessage)
327 if err != nil {
328 return
329 }
330 w.Write(message)
331
332 if err := w.Close(); err != nil {
333 return
334 }
335
336 case <-ticker.C:
337 c.conn.SetWriteDeadline(time.Now().Add(writeWait))
338 if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
339 return
340 }
341 }
342 }
343}
344
345// SendMessage sends a message to this client.
346func (c *Client) SendMessage(msg *Message) error {
347 data, err := json.Marshal(msg)
348 if err != nil {
349 return err
350 }
351
352 select {
353 case c.send <- data:
354 return nil
355 default:
356 return nil
357 }
358}
359
360// SendError sends an error message to this client.
361func (c *Client) SendError(errMsg string) {
362 c.SendMessage(&Message{
363 Type: MsgTypeError,
364 Payload: map[string]interface{}{
365 "error": errMsg,
366 },
367 })
368}
369
370// BroadcastToAll sends a message to all connected clients.
371func (h *Hub) BroadcastToAll(msg *Message) {
372 data, err := json.Marshal(msg)
373 if err != nil {
374 log.Printf("failed to marshal broadcast message: %v", err)
375 return
376 }
377
378 h.mu.RLock()
379 defer h.mu.RUnlock()
380
381 for client := range h.clients {
382 select {
383 case client.send <- data:
384 default:
385 // Channel full, skip this client
386 }
387 }
388}
389
390// BroadcastToOthers sends a message to all clients except the one with the given DID.
391func (h *Hub) BroadcastToOthers(excludeDID string, msg *Message) {
392 data, err := json.Marshal(msg)
393 if err != nil {
394 log.Printf("failed to marshal broadcast message: %v", err)
395 return
396 }
397
398 h.mu.RLock()
399 defer h.mu.RUnlock()
400
401 for client := range h.clients {
402 if client.DID == excludeDID {
403 continue
404 }
405 select {
406 case client.send <- data:
407 default:
408 // Channel full, skip this client
409 }
410 }
411}
412
413// GetAllClients returns information about all connected clients.
414func (h *Hub) GetAllClients() []ClientInfo {
415 h.mu.RLock()
416 defer h.mu.RUnlock()
417
418 clients := make([]ClientInfo, 0, len(h.clients))
419 for client := range h.clients {
420 clients = append(clients, ClientInfo{
421 DID: client.DID,
422 Handle: client.Handle,
423 })
424 }
425 return clients
426}
427
428// GetOtherClients returns information about all connected clients except the given DID.
429func (h *Hub) GetOtherClients(excludeDID string) []ClientInfo {
430 h.mu.RLock()
431 defer h.mu.RUnlock()
432
433 clients := make([]ClientInfo, 0, len(h.clients))
434 for client := range h.clients {
435 if client.DID == excludeDID {
436 continue
437 }
438 clients = append(clients, ClientInfo{
439 DID: client.DID,
440 Handle: client.Handle,
441 })
442 }
443 return clients
444}
445
446// GetClientCount returns the number of connected clients.
447func (h *Hub) GetClientCount() int {
448 h.mu.RLock()
449 defer h.mu.RUnlock()
450 return len(h.clients)
451}
452
453// StartRoomPopulation sends populate_room messages to a new client in batches.
454func (h *Hub) StartRoomPopulation(newClient *Client) {
455 go func() {
456 // Get all other clients
457 clients := h.GetOtherClients(newClient.DID)
458
459 // Shuffle for random order
460 for i := len(clients) - 1; i > 0; i-- {
461 j := int(time.Now().UnixNano()) % (i + 1)
462 clients[i], clients[j] = clients[j], clients[i]
463 }
464
465 // Send in batches of 5 every 3 seconds
466 batchSize := 5
467 for i := 0; i < len(clients); i += batchSize {
468 end := i + batchSize
469 if end > len(clients) {
470 end = len(clients)
471 }
472 batch := clients[i:end]
473
474 // Convert to interface slice for payload
475 identities := make([]map[string]string, len(batch))
476 for j, c := range batch {
477 identities[j] = map[string]string{
478 "did": c.DID,
479 "handle": c.Handle,
480 }
481 }
482
483 newClient.SendMessage(&Message{
484 Type: MsgTypePopulateRoom,
485 Payload: map[string]interface{}{
486 "identities": identities,
487 },
488 })
489
490 if end < len(clients) {
491 time.Sleep(3 * time.Second)
492 }
493 }
494 }()
495}