websocket-based lrcproto server
at 57784a889eaf98dfc028c0a95bf7888cd32e6042 604 lines 13 kB view raw
1package lrcd 2 3import ( 4 "context" 5 "errors" 6 "github.com/gorilla/websocket" 7 "github.com/rachel-mp4/lrcproto/gen/go" 8 "google.golang.org/protobuf/proto" 9 "log" 10 "net/http" 11 "sync" 12) 13 14type Server struct { 15 eventBus chan clientEvent 16 ctx context.Context 17 cancel context.CancelFunc 18 clients map[*client]bool 19 clientsMu sync.Mutex 20 idmapsMu sync.Mutex 21 clientToID map[*client]*uint32 22 idToClient map[uint32]*client 23 lastID uint32 24 logger *log.Logger 25 debugLogger *log.Logger 26 welcomeEvt []byte 27 pongEvt []byte 28 initChan chan lrcpb.Event_Init 29 pubChan chan PubEvent 30} 31 32type PubEvent struct { 33 ID uint32 34 Body string 35} 36 37type client struct { 38 conn *websocket.Conn 39 dataChan chan []byte 40 ctx context.Context 41 cancel context.CancelFunc 42 muteMap map[*client]bool 43 mutedBy map[*client]bool 44 myIDs []uint32 45 post *string 46 nick *string 47 externID *string 48 color *uint32 49} 50 51type clientEvent struct { 52 client *client 53 event *lrcpb.Event 54} 55 56func NewServer(opts ...Option) (*Server, error) { 57 var options options 58 for _, opt := range opts { 59 err := opt(&options) 60 if err != nil { 61 return nil, err 62 } 63 } 64 65 s := Server{} 66 67 welcomeString := "Welcome to my lrc server!" 68 if options.welcome != nil { 69 welcomeString = *options.welcome 70 } 71 s.setDefaultEvents(welcomeString) 72 73 if options.writer != nil { 74 s.logger = log.New(*options.writer, "[log]", log.Ldate|log.Ltime) 75 if options.verbose { 76 s.debugLogger = log.New(*options.writer, "[debug]", log.Ldate|log.Ltime) 77 } 78 } 79 80 if options.initChan != nil { 81 s.initChan = options.initChan 82 } 83 if options.pubChan != nil { 84 s.pubChan = options.pubChan 85 } 86 87 s.clients = make(map[*client]bool) 88 s.clientsMu = sync.Mutex{} 89 s.idmapsMu = sync.Mutex{} 90 s.clientToID = make(map[*client]*uint32) 91 s.idToClient = make(map[uint32]*client) 92 s.lastID = 0 93 s.eventBus = make(chan clientEvent, 100) 94 return &s, nil 95} 96 97func (s *Server) setDefaultEvents(welcome string) { 98 evt := &lrcpb.Event{Msg: &lrcpb.Event_Get{Get: &lrcpb.Get{Topic: &welcome}}} 99 we, _ := proto.Marshal(evt) 100 s.welcomeEvt = we 101 102 evt = &lrcpb.Event{Msg: &lrcpb.Event_Pong{Pong: &lrcpb.Pong{}}} 103 pe, _ := proto.Marshal(evt) 104 s.pongEvt = pe 105} 106 107func (s *Server) Start() error { 108 if s.ctx != nil { 109 return errors.New("cannot start already started server") 110 } 111 s.ctx, s.cancel = context.WithCancel(context.Background()) 112 go s.broadcaster() 113 s.logDebug("Hello, world!") 114 return nil 115} 116 117func (s *Server) Stop() error { 118 if s.ctx == nil { 119 return nil 120 } 121 select { 122 case <-s.ctx.Done(): 123 return errors.New("cannot stop already stopped server") 124 default: 125 s.cancel() 126 s.logDebug("Goodbye world :c") 127 return nil 128 } 129} 130 131func (s *Server) Connected() int { 132 return len(s.clients) 133} 134 135func (s *Server) StopIfEmpty() bool { 136 if len(s.clients) == 0 { 137 s.Stop() 138 return true 139 } 140 return false 141} 142 143func (s *Server) WSHandler() http.HandlerFunc { 144 return func(w http.ResponseWriter, r *http.Request) { 145 upgrader := &websocket.Upgrader{ 146 Subprotocols: []string{"lrcprotov1"}, 147 CheckOrigin: func(r *http.Request) bool { 148 return true 149 }, 150 } 151 conn, err := upgrader.Upgrade(w, r, nil) 152 if err != nil { 153 log.Println("Upgrade failed:", err) 154 return 155 } 156 defer conn.Close() 157 158 ctx, cancel := context.WithCancel(context.Background()) 159 client := &client{ 160 conn: conn, 161 dataChan: make(chan []byte, 100), 162 ctx: ctx, 163 cancel: cancel, 164 muteMap: make(map[*client]bool, 0), 165 mutedBy: make(map[*client]bool, 0), 166 myIDs: make([]uint32, 0), 167 } 168 169 s.clientsMu.Lock() 170 s.clients[client] = true 171 s.clientsMu.Unlock() 172 173 var wg sync.WaitGroup 174 wg.Add(2) 175 go func() { defer wg.Done(); s.wsWriter(client) }() 176 go func() { defer wg.Done(); s.listenToWS(client) }() 177 s.logDebug("new ws connection!") 178 wg.Wait() 179 180 s.clientsMu.Lock() 181 delete(s.clients, client) 182 close(client.dataChan) 183 s.clientsMu.Unlock() 184 s.idmapsMu.Lock() 185 for _, id := range client.myIDs { 186 delete(s.idToClient, id) 187 } 188 for mutedClient, _ := range client.muteMap { 189 delete(mutedClient.mutedBy, client) 190 } 191 for mutingClient, _ := range client.mutedBy { 192 delete(mutingClient.muteMap, client) 193 } 194 s.idmapsMu.Unlock() 195 conn.Close() 196 s.logDebug("closed ws connection") 197 } 198} 199 200func (s *Server) listenToWS(client *client) { 201 for { 202 select { 203 case <-client.ctx.Done(): 204 return 205 case <-s.ctx.Done(): 206 return 207 default: 208 _, data, err := client.conn.ReadMessage() 209 if err != nil { 210 client.cancel() 211 return 212 } 213 var event lrcpb.Event 214 err = proto.Unmarshal(data, &event) 215 if err != nil { 216 s.logDebug(err.Error()) 217 client.cancel() 218 return 219 } 220 s.eventBus <- clientEvent{client, &event} 221 } 222 } 223} 224 225func (s *Server) wsWriter(client *client) { 226 for { 227 select { 228 case <-client.ctx.Done(): 229 return 230 case <-s.ctx.Done(): 231 return 232 case data, ok := <-client.dataChan: 233 if !ok { 234 client.cancel() 235 return 236 } 237 err := client.conn.WriteMessage(websocket.BinaryMessage, data) 238 if err != nil { 239 s.logDebug(err.Error()) 240 client.cancel() 241 return 242 } 243 } 244 } 245} 246 247// broadcaster takes an event from the events channel, and broadcasts it to all the connected clients individual event channels 248func (s *Server) broadcaster() { 249 for { 250 select { 251 case <-s.ctx.Done(): 252 return 253 case ce := <-s.eventBus: 254 client := ce.client 255 event := ce.event 256 switch msg := event.Msg.(type) { 257 case *lrcpb.Event_Ping: 258 client.dataChan <- s.pongEvt 259 case *lrcpb.Event_Pong: 260 continue 261 case *lrcpb.Event_Init: 262 s.handleInit(msg, client) 263 case *lrcpb.Event_Pub: 264 s.handlePub(msg, client) 265 case *lrcpb.Event_Insert: 266 s.handleInsert(msg, client) 267 case *lrcpb.Event_Delete: 268 s.handleDelete(msg, client) 269 case *lrcpb.Event_Mute: 270 s.handleMute(msg, client) 271 case *lrcpb.Event_Unmute: 272 s.handleUnmute(msg, client) 273 case *lrcpb.Event_Set: 274 s.handleSet(msg, client) 275 case *lrcpb.Event_Get: 276 s.handleGet(msg, client) 277 } 278 } 279 } 280} 281 282func (s *Server) handleInit(msg *lrcpb.Event_Init, client *client) { 283 curID := s.clientToID[client] 284 if curID != nil { 285 return 286 } 287 newID := s.lastID + 1 288 s.lastID = newID 289 s.idmapsMu.Lock() 290 s.clientToID[client] = &newID 291 s.idToClient[newID] = client 292 s.idmapsMu.Unlock() 293 client.myIDs = append(client.myIDs, newID) 294 newpost := "" 295 client.post = &newpost 296 msg.Init.Id = newID 297 nick := client.nick 298 if nick != nil { 299 msg.Init.Nick = *nick 300 } else { 301 msg.Init.Nick = "wanderer" 302 } 303 externID := client.externID 304 if externID != nil { 305 msg.Init.ExternalID = *externID 306 } else { 307 msg.Init.ExternalID = "" 308 } 309 color := client.color 310 if color != nil { 311 msg.Init.Color = *color 312 } else { 313 msg.Init.Color = 0xD90368 314 } 315 msg.Init.Echoed = false 316 317 if s.initChan != nil { 318 select { 319 case s.initChan <- *msg: 320 default: 321 s.log("initchan blocked, closing channel") 322 close(s.initChan) 323 s.initChan = nil 324 } 325 } 326 s.broadcastInit(msg, client) 327} 328 329func (s *Server) broadcastInit(msg *lrcpb.Event_Init, client *client) { 330 stdEvent := &lrcpb.Event{Msg: msg} 331 stdData, _ := proto.Marshal(stdEvent) 332 msg.Init.Echoed = true 333 echoEvent := &lrcpb.Event{Msg: msg} 334 echoData, _ := proto.Marshal(echoEvent) 335 muteEvent := &lrcpb.Event{Msg: &lrcpb.Event_Mute{Mute: &lrcpb.Mute{Id: msg.Init.GetId()}}} 336 muteData, _ := proto.Marshal(muteEvent) 337 s.clientsMu.Lock() 338 defer s.clientsMu.Unlock() 339 for c := range s.clients { 340 var dts []byte 341 if c == client { 342 dts = echoData 343 } else if client.mutedBy[c] { 344 dts = muteData 345 } else { 346 dts = stdData 347 } 348 select { 349 case c.dataChan <- dts: 350 s.logDebug("b init") 351 default: 352 s.log("kicked client") 353 client.cancel() 354 } 355 } 356} 357 358func (s *Server) handlePub(msg *lrcpb.Event_Pub, client *client) { 359 curID := s.clientToID[client] 360 if curID == nil { 361 return 362 } 363 s.idmapsMu.Lock() 364 s.clientToID[client] = nil 365 s.idmapsMu.Unlock() 366 msg.Pub.Id = *curID 367 event := &lrcpb.Event{Msg: msg} 368 if s.pubChan != nil { 369 select { 370 case s.pubChan <- PubEvent{ID: *curID, Body: *client.post}: 371 default: 372 s.log("pubchan blocked, closing channel") 373 close(s.pubChan) 374 s.pubChan = nil 375 } 376 } 377 client.post = nil 378 s.broadcast(event, client) 379} 380 381func (s *Server) handleInsert(msg *lrcpb.Event_Insert, client *client) { 382 curID := s.clientToID[client] 383 if curID == nil { 384 return 385 } 386 newpost, err := insertAtUTF16Index(*client.post, msg.Insert.GetByteIndex(), msg.Insert.GetBody()) 387 if err != nil { 388 return 389 } 390 client.post = &newpost 391 msg.Insert.Id = *curID 392 event := &lrcpb.Event{Msg: msg} 393 s.broadcast(event, client) 394} 395 396func insertAtUTF16Index(base string, index uint32, insert string) (string, error) { 397 runes := []rune(base) 398 399 unitCount := 0 400 var splitAt int 401 for i, r := range runes { 402 units := 1 403 if r > 0xFFFF { 404 units = 2 405 } 406 if uint32(unitCount+units) > index { 407 splitAt = i 408 break 409 } 410 unitCount += units 411 splitAt = i + 1 412 } 413 if index > uint32(unitCount) { 414 return "", errors.New("index out of range") 415 } 416 return string(runes[:splitAt]) + insert + string(runes[splitAt:]), nil 417} 418 419func (s *Server) handleDelete(msg *lrcpb.Event_Delete, client *client) { 420 curID := s.clientToID[client] 421 if curID == nil { 422 return 423 } 424 newPost, err := deleteBtwnUTF16Indices(*client.post, msg.Delete.GetByteStart(), msg.Delete.GetByteEnd()) 425 if err != nil { 426 return 427 } 428 client.post = &newPost 429 msg.Delete.Id = *curID 430 event := &lrcpb.Event{Msg: msg} 431 s.broadcast(event, client) 432} 433 434func deleteBtwnUTF16Indices(base string, start uint32, end uint32) (string, error) { 435 if end <= start { 436 return "", errors.New("end must come after start") 437 } 438 runes := []rune(base) 439 unitCount := 0 440 var startAt, endAt *int 441 for i, r := range runes { 442 units := 1 443 if r > 0xFFFF { 444 units = 2 445 } 446 if startAt == nil && uint32(unitCount+units) > start { 447 startAt = &i 448 } 449 if uint32(unitCount) > end { 450 endAt = &i 451 break 452 } 453 unitCount += units 454 } 455 if end > uint32(unitCount) { 456 return "", errors.New("index out of range") 457 } 458 return string(runes[:*startAt]) + string(runes[*endAt:]), nil 459 460} 461 462func (s *Server) broadcast(event *lrcpb.Event, client *client) { 463 data, _ := proto.Marshal(event) 464 s.clientsMu.Lock() 465 defer s.clientsMu.Unlock() 466 for c := range s.clients { 467 if client.mutedBy[c] { 468 continue 469 } 470 select { 471 case c.dataChan <- data: 472 s.logDebug("b") 473 default: 474 s.log("kicked client") 475 client.cancel() 476 } 477 } 478} 479 480func (s *Server) handleMute(msg *lrcpb.Event_Mute, client *client) { 481 toMute := msg.Mute.GetId() 482 s.idmapsMu.Lock() 483 clientToMute, ok := s.idToClient[toMute] 484 if !ok { 485 return 486 } 487 if clientToMute == client { 488 return 489 } 490 clientToMute.mutedBy[client] = true 491 client.muteMap[clientToMute] = true 492 s.idmapsMu.Unlock() 493 494} 495 496func (s *Server) handleUnmute(msg *lrcpb.Event_Unmute, client *client) { 497 toMute := msg.Unmute.GetId() 498 s.idmapsMu.Lock() 499 clientToMute, ok := s.idToClient[toMute] 500 if !ok { 501 return 502 } 503 if clientToMute == client { 504 return 505 } 506 delete(clientToMute.mutedBy, client) 507 delete(client.muteMap, clientToMute) 508 s.idmapsMu.Unlock() 509 510} 511 512func (s *Server) handleSet(msg *lrcpb.Event_Set, client *client) { 513 nick := msg.Set.Nick 514 if nick != nil { 515 nickname := *nick 516 if len(nickname) <= 16 { 517 client.nick = &nickname 518 } 519 } 520 externalId := msg.Set.ExternalID 521 if externalId != nil { 522 externid := *externalId 523 client.externID = &externid 524 } 525 color := msg.Set.Color 526 if color != nil { 527 c := *color 528 if c <= 0xffffff { 529 client.color = &c 530 } 531 } 532} 533 534func (s *Server) handleGet(msg *lrcpb.Event_Get, client *client) { 535 t := msg.Get.Topic 536 if t != nil { 537 client.dataChan <- s.welcomeEvt 538 } 539 c := msg.Get.Connected 540 if c != nil { 541 conncount := uint32(len(s.clients)) 542 e := &lrcpb.Event{Msg: &lrcpb.Event_Get{Get: &lrcpb.Get{Connected: &conncount}}} 543 data, _ := proto.Marshal(e) 544 client.dataChan <- data 545 } 546} 547 548// func (s *Server) broadcastAll(evt []byte) { 549// s.clientsMu.Lock() 550// defer s.clientsMu.Unlock() 551// for client := range s.clients { 552// select { 553// case client.evtChan <- evt: 554// s.logDebug(fmt.Sprintf("b %x", evt)) 555// default: 556// s.log("kicked client") 557// if client.tcpconn != nil { 558// (*client.tcpconn).Close() 559// } 560// if client.wsconn != nil { 561// (*client.wsconn).Close() 562// } 563// delete(s.clients, client) 564// } 565// } 566// } 567 568// func (s *Server) broadcastInit(evt []byte, c *client, id uint32) { 569// bevt, eevt := events.GenServerEvent(evt, id) 570// s.clientsMu.Lock() 571// defer s.clientsMu.Unlock() 572// for client := range s.clients { 573// evtToSend := bevt 574// if client == c { 575// evtToSend = eevt 576// } 577// select { 578// case client.evtChan <- evtToSend: 579// s.logDebug(fmt.Sprintf("b %x", bevt)) 580// default: 581// s.log("kicked client") 582// if client.tcpconn != nil { 583// (*client.tcpconn).Close() 584// } 585// if client.wsconn != nil { 586// (*client.wsconn).Close() 587// } 588// delete(s.clients, client) 589// } 590// } 591// } 592 593// logDebug debugs unless in production 594func (server *Server) logDebug(s string) { 595 if server.debugLogger != nil { 596 server.debugLogger.Println(s) 597 } 598} 599 600func (server *Server) log(s string) { 601 if server.logger != nil { 602 server.logger.Println(s) 603 } 604}