package websocket import ( "encoding/json" "log" "sync" "time" "github.com/gorilla/websocket" ) const ( writeWait = 10 * time.Second pongWait = 60 * time.Second pingPeriod = (pongWait * 9) / 10 maxMessageSize = 4096 sessionTimeout = 25 * time.Minute // Disconnect after 25 minutes ) // MessageType represents the type of websocket message. type MessageType string const ( // Client -> Server messages MsgTypeHighFive MessageType = "high_five" // Server -> Client messages MsgTypeIdentityJoined MessageType = "identity_joined" MsgTypeIdentityLeft MessageType = "identity_left" MsgTypePopulateRoom MessageType = "populate_room" MsgTypeHighFiveEvent MessageType = "high_five_event" MsgTypeError MessageType = "error" MsgTypeSessionTimeout MessageType = "session_timeout" ) // Message represents a websocket message. type Message struct { Type MessageType `json:"type"` Payload map[string]interface{} `json:"payload,omitempty"` } // ClientInfo contains public information about a connected client. type ClientInfo struct { DID string `json:"did"` Handle string `json:"handle"` } // Client represents a connected websocket client. type Client struct { hub *Hub conn *websocket.Conn send chan []byte DID string Handle string SessionID string IsReconnect bool ConnectedAt time.Time sessionTimer *time.Timer } // Hub maintains the set of active clients and broadcasts messages. type Hub struct { clients map[*Client]bool didClients map[string]*Client // Map DID to client broadcast chan []byte register chan *Client unregister chan *Client mu sync.RWMutex // Callbacks for handling messages onClientJoin func(client *Client) error onClientLeave func(client *Client) error onHighFive func(giver *Client, subjectDID string) error } // NewHub creates a new Hub. func NewHub() *Hub { return &Hub{ clients: make(map[*Client]bool), didClients: make(map[string]*Client), broadcast: make(chan []byte), register: make(chan *Client), unregister: make(chan *Client), } } // SetCallbacks sets the message handler callbacks. func (h *Hub) SetCallbacks( onClientJoin func(client *Client) error, onClientLeave func(client *Client) error, onHighFive func(giver *Client, subjectDID string) error, ) { h.onClientJoin = onClientJoin h.onClientLeave = onClientLeave h.onHighFive = onHighFive } // Run starts the hub's main loop. func (h *Hub) Run() { for { select { case client := <-h.register: h.mu.Lock() h.clients[client] = true if client.DID != "" { h.didClients[client.DID] = client } h.mu.Unlock() log.Printf("Client registered: %s (%s)", client.DID, client.Handle) // Call the join callback if h.onClientJoin != nil { if err := h.onClientJoin(client); err != nil { log.Printf("error in client join callback: %v", err) } } case client := <-h.unregister: h.mu.Lock() if _, ok := h.clients[client]; ok { delete(h.clients, client) if client.DID != "" { delete(h.didClients, client.DID) } close(client.send) } h.mu.Unlock() log.Printf("Client unregistered: %s (%s)", client.DID, client.Handle) // Call the leave callback if h.onClientLeave != nil { if err := h.onClientLeave(client); err != nil { log.Printf("error in client leave callback: %v", err) } } case message := <-h.broadcast: h.mu.RLock() for client := range h.clients { select { case client.send <- message: default: h.mu.RUnlock() h.mu.Lock() close(client.send) delete(h.clients, client) if client.DID != "" { delete(h.didClients, client.DID) } h.mu.Unlock() h.mu.RLock() } } h.mu.RUnlock() } } } // GetClientByDID returns the client for a given DID. func (h *Hub) GetClientByDID(did string) *Client { h.mu.RLock() defer h.mu.RUnlock() return h.didClients[did] } // IsClientConnected checks if a client with the given DID is connected. func (h *Hub) IsClientConnected(did string) bool { h.mu.RLock() defer h.mu.RUnlock() _, ok := h.didClients[did] return ok } // SendToClient sends a message to a specific client by DID. func (h *Hub) SendToClient(did string, msg *Message) error { h.mu.RLock() client, ok := h.didClients[did] h.mu.RUnlock() if !ok { return nil // Client not connected, not an error } data, err := json.Marshal(msg) if err != nil { return err } select { case client.send <- data: return nil default: return nil // Channel full, skip } } // Register registers a new client. func (h *Hub) Register(client *Client) { h.register <- client } // Unregister unregisters a client. func (h *Hub) Unregister(client *Client) { h.unregister <- client } // NewClient creates a new client. func NewClient(hub *Hub, conn *websocket.Conn, did, handle, sessionID string, isReconnect bool) *Client { c := &Client{ hub: hub, conn: conn, send: make(chan []byte, 256), DID: did, Handle: handle, SessionID: sessionID, IsReconnect: isReconnect, ConnectedAt: time.Now(), } // Start session timeout timer c.sessionTimer = time.AfterFunc(sessionTimeout, func() { c.handleSessionTimeout() }) return c } // handleSessionTimeout sends a timeout message and closes the connection. func (c *Client) handleSessionTimeout() { log.Printf("Session timeout for client: %s (%s)", c.DID, c.Handle) // Send session timeout message c.SendMessage(&Message{ Type: MsgTypeSessionTimeout, Payload: map[string]interface{}{ "message": "Your session has expired after 25 minutes. Start a new session to continue giving and receiving high-fives!", }, }) // Give client time to receive the message before closing time.Sleep(100 * time.Millisecond) // Close the connection c.conn.Close() } // StopSessionTimer stops the session timeout timer. func (c *Client) StopSessionTimer() { if c.sessionTimer != nil { c.sessionTimer.Stop() } } // ReadPump pumps messages from the websocket connection to the hub. func (c *Client) ReadPump() { defer func() { c.StopSessionTimer() c.hub.Unregister(c) c.conn.Close() }() c.conn.SetReadLimit(maxMessageSize) c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)) return nil }) for { _, message, err := c.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("websocket error: %v", err) } break } var msg Message if err := json.Unmarshal(message, &msg); err != nil { log.Printf("failed to unmarshal message: %v", err) continue } c.handleMessage(&msg) } } // handleMessage processes incoming messages. func (c *Client) handleMessage(msg *Message) { var err error switch msg.Type { case MsgTypeHighFive: subjectDID, _ := msg.Payload["subject"].(string) if c.hub.onHighFive != nil { err = c.hub.onHighFive(c, subjectDID) } default: log.Printf("unknown message type: %s", msg.Type) } if err != nil { log.Printf("error handling message %s: %v", msg.Type, err) c.SendError(err.Error()) } } // WritePump pumps messages from the hub to the websocket connection. func (c *Client) WritePump() { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() c.conn.Close() }() for { select { case message, ok := <-c.send: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } w, err := c.conn.NextWriter(websocket.TextMessage) if err != nil { return } w.Write(message) if err := w.Close(); err != nil { return } case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } // SendMessage sends a message to this client. func (c *Client) SendMessage(msg *Message) error { data, err := json.Marshal(msg) if err != nil { return err } select { case c.send <- data: return nil default: return nil } } // SendError sends an error message to this client. func (c *Client) SendError(errMsg string) { c.SendMessage(&Message{ Type: MsgTypeError, Payload: map[string]interface{}{ "error": errMsg, }, }) } // BroadcastToAll sends a message to all connected clients. func (h *Hub) BroadcastToAll(msg *Message) { data, err := json.Marshal(msg) if err != nil { log.Printf("failed to marshal broadcast message: %v", err) return } h.mu.RLock() defer h.mu.RUnlock() for client := range h.clients { select { case client.send <- data: default: // Channel full, skip this client } } } // BroadcastToOthers sends a message to all clients except the one with the given DID. func (h *Hub) BroadcastToOthers(excludeDID string, msg *Message) { data, err := json.Marshal(msg) if err != nil { log.Printf("failed to marshal broadcast message: %v", err) return } h.mu.RLock() defer h.mu.RUnlock() for client := range h.clients { if client.DID == excludeDID { continue } select { case client.send <- data: default: // Channel full, skip this client } } } // GetAllClients returns information about all connected clients. func (h *Hub) GetAllClients() []ClientInfo { h.mu.RLock() defer h.mu.RUnlock() clients := make([]ClientInfo, 0, len(h.clients)) for client := range h.clients { clients = append(clients, ClientInfo{ DID: client.DID, Handle: client.Handle, }) } return clients } // GetOtherClients returns information about all connected clients except the given DID. func (h *Hub) GetOtherClients(excludeDID string) []ClientInfo { h.mu.RLock() defer h.mu.RUnlock() clients := make([]ClientInfo, 0, len(h.clients)) for client := range h.clients { if client.DID == excludeDID { continue } clients = append(clients, ClientInfo{ DID: client.DID, Handle: client.Handle, }) } return clients } // GetClientCount returns the number of connected clients. func (h *Hub) GetClientCount() int { h.mu.RLock() defer h.mu.RUnlock() return len(h.clients) } // StartRoomPopulation sends populate_room messages to a new client in batches. func (h *Hub) StartRoomPopulation(newClient *Client) { go func() { // Get all other clients clients := h.GetOtherClients(newClient.DID) // Shuffle for random order for i := len(clients) - 1; i > 0; i-- { j := int(time.Now().UnixNano()) % (i + 1) clients[i], clients[j] = clients[j], clients[i] } // Send in batches of 5 every 3 seconds batchSize := 5 for i := 0; i < len(clients); i += batchSize { end := i + batchSize if end > len(clients) { end = len(clients) } batch := clients[i:end] // Convert to interface slice for payload identities := make([]map[string]string, len(batch)) for j, c := range batch { identities[j] = map[string]string{ "did": c.DID, "handle": c.Handle, } } newClient.SendMessage(&Message{ Type: MsgTypePopulateRoom, Payload: map[string]interface{}{ "identities": identities, }, }) if end < len(clients) { time.Sleep(3 * time.Second) } } }() }