websocket-based lrcproto server
1package lrcd
2
3import (
4 "context"
5 "errors"
6 "github.com/gorilla/websocket"
7 "github.com/rachel-mp4/lrcproto/gen/go"
8 "google.golang.org/protobuf/proto"
9 "log"
10 "net/http"
11 "sync"
12)
13
14type Server struct {
15 eventBus chan clientEvent
16 ctx context.Context
17 cancel context.CancelFunc
18 clients map[*client]bool
19 clientsMu sync.Mutex
20 idmapsMu sync.Mutex
21 clientToID map[*client]*uint32
22 idToClient map[uint32]*client
23 lastID uint32
24 logger *log.Logger
25 debugLogger *log.Logger
26 welcomeEvt []byte
27 pongEvt []byte
28 initChan chan lrcpb.Event_Init
29 pubChan chan PubEvent
30}
31
32type PubEvent struct {
33 ID uint32
34 Body string
35}
36
37type client struct {
38 conn *websocket.Conn
39 dataChan chan []byte
40 ctx context.Context
41 cancel context.CancelFunc
42 muteMap map[*client]bool
43 mutedBy map[*client]bool
44 myIDs []uint32
45 post *string
46 nick *string
47 externID *string
48 color *uint32
49}
50
51type clientEvent struct {
52 client *client
53 event *lrcpb.Event
54}
55
56func NewServer(opts ...Option) (*Server, error) {
57 var options options
58 for _, opt := range opts {
59 err := opt(&options)
60 if err != nil {
61 return nil, err
62 }
63 }
64
65 s := Server{}
66
67 welcomeString := "Welcome to my lrc server!"
68 if options.welcome != nil {
69 welcomeString = *options.welcome
70 }
71 s.setDefaultEvents(welcomeString)
72
73 if options.writer != nil {
74 s.logger = log.New(*options.writer, "[log]", log.Ldate|log.Ltime)
75 if options.verbose {
76 s.debugLogger = log.New(*options.writer, "[debug]", log.Ldate|log.Ltime)
77 }
78 }
79
80 if options.initChan != nil {
81 s.initChan = options.initChan
82 }
83 if options.pubChan != nil {
84 s.pubChan = options.pubChan
85 }
86
87 s.clients = make(map[*client]bool)
88 s.clientsMu = sync.Mutex{}
89 s.idmapsMu = sync.Mutex{}
90 s.clientToID = make(map[*client]*uint32)
91 s.idToClient = make(map[uint32]*client)
92 s.lastID = 0
93 s.eventBus = make(chan clientEvent, 100)
94 return &s, nil
95}
96
97func (s *Server) setDefaultEvents(welcome string) {
98 evt := &lrcpb.Event{Msg: &lrcpb.Event_Get{Get: &lrcpb.Get{Topic: &welcome}}}
99 we, _ := proto.Marshal(evt)
100 s.welcomeEvt = we
101
102 evt = &lrcpb.Event{Msg: &lrcpb.Event_Pong{Pong: &lrcpb.Pong{}}}
103 pe, _ := proto.Marshal(evt)
104 s.pongEvt = pe
105}
106
107func (s *Server) Start() error {
108 if s.ctx != nil {
109 return errors.New("cannot start already started server")
110 }
111 s.ctx, s.cancel = context.WithCancel(context.Background())
112 go s.broadcaster()
113 s.logDebug("Hello, world!")
114 return nil
115}
116
117func (s *Server) Stop() error {
118 if s.ctx == nil {
119 return nil
120 }
121 select {
122 case <-s.ctx.Done():
123 return errors.New("cannot stop already stopped server")
124 default:
125 s.cancel()
126 s.logDebug("Goodbye world :c")
127 return nil
128 }
129}
130
131func (s *Server) Connected() int {
132 return len(s.clients)
133}
134
135func (s *Server) StopIfEmpty() bool {
136 if len(s.clients) == 0 {
137 s.Stop()
138 return true
139 }
140 return false
141}
142
143func (s *Server) WSHandler() http.HandlerFunc {
144 return func(w http.ResponseWriter, r *http.Request) {
145 upgrader := &websocket.Upgrader{
146 Subprotocols: []string{"lrcprotov1"},
147 CheckOrigin: func(r *http.Request) bool {
148 return true
149 },
150 }
151 conn, err := upgrader.Upgrade(w, r, nil)
152 if err != nil {
153 log.Println("Upgrade failed:", err)
154 return
155 }
156 defer conn.Close()
157
158 ctx, cancel := context.WithCancel(context.Background())
159 client := &client{
160 conn: conn,
161 dataChan: make(chan []byte, 100),
162 ctx: ctx,
163 cancel: cancel,
164 muteMap: make(map[*client]bool, 0),
165 mutedBy: make(map[*client]bool, 0),
166 myIDs: make([]uint32, 0),
167 }
168
169 s.clientsMu.Lock()
170 s.clients[client] = true
171 s.clientsMu.Unlock()
172
173 var wg sync.WaitGroup
174 wg.Add(2)
175 go func() { defer wg.Done(); s.wsWriter(client) }()
176 go func() { defer wg.Done(); s.listenToWS(client) }()
177 s.logDebug("new ws connection!")
178 wg.Wait()
179
180 s.clientsMu.Lock()
181 delete(s.clients, client)
182 close(client.dataChan)
183 s.clientsMu.Unlock()
184 s.idmapsMu.Lock()
185 for _, id := range client.myIDs {
186 delete(s.idToClient, id)
187 }
188 for mutedClient, _ := range client.muteMap {
189 delete(mutedClient.mutedBy, client)
190 }
191 for mutingClient, _ := range client.mutedBy {
192 delete(mutingClient.muteMap, client)
193 }
194 s.idmapsMu.Unlock()
195 conn.Close()
196 s.logDebug("closed ws connection")
197 }
198}
199
200func (s *Server) listenToWS(client *client) {
201 for {
202 select {
203 case <-client.ctx.Done():
204 return
205 case <-s.ctx.Done():
206 return
207 default:
208 _, data, err := client.conn.ReadMessage()
209 if err != nil {
210 client.cancel()
211 return
212 }
213 var event lrcpb.Event
214 err = proto.Unmarshal(data, &event)
215 if err != nil {
216 s.logDebug(err.Error())
217 client.cancel()
218 return
219 }
220 s.eventBus <- clientEvent{client, &event}
221 }
222 }
223}
224
225func (s *Server) wsWriter(client *client) {
226 for {
227 select {
228 case <-client.ctx.Done():
229 return
230 case <-s.ctx.Done():
231 return
232 case data, ok := <-client.dataChan:
233 if !ok {
234 client.cancel()
235 return
236 }
237 err := client.conn.WriteMessage(websocket.BinaryMessage, data)
238 if err != nil {
239 s.logDebug(err.Error())
240 client.cancel()
241 return
242 }
243 }
244 }
245}
246
247// broadcaster takes an event from the events channel, and broadcasts it to all the connected clients individual event channels
248func (s *Server) broadcaster() {
249 for {
250 select {
251 case <-s.ctx.Done():
252 return
253 case ce := <-s.eventBus:
254 client := ce.client
255 event := ce.event
256 switch msg := event.Msg.(type) {
257 case *lrcpb.Event_Ping:
258 client.dataChan <- s.pongEvt
259 case *lrcpb.Event_Pong:
260 continue
261 case *lrcpb.Event_Init:
262 s.handleInit(msg, client)
263 case *lrcpb.Event_Pub:
264 s.handlePub(msg, client)
265 case *lrcpb.Event_Insert:
266 s.handleInsert(msg, client)
267 case *lrcpb.Event_Delete:
268 s.handleDelete(msg, client)
269 case *lrcpb.Event_Mute:
270 s.handleMute(msg, client)
271 case *lrcpb.Event_Unmute:
272 s.handleUnmute(msg, client)
273 case *lrcpb.Event_Set:
274 s.handleSet(msg, client)
275 case *lrcpb.Event_Get:
276 s.handleGet(msg, client)
277 }
278 }
279 }
280}
281
282func (s *Server) handleInit(msg *lrcpb.Event_Init, client *client) {
283 curID := s.clientToID[client]
284 if curID != nil {
285 return
286 }
287 newID := s.lastID + 1
288 s.lastID = newID
289 s.idmapsMu.Lock()
290 s.clientToID[client] = &newID
291 s.idToClient[newID] = client
292 s.idmapsMu.Unlock()
293 client.myIDs = append(client.myIDs, newID)
294 newpost := ""
295 client.post = &newpost
296 msg.Init.Id = newID
297 nick := client.nick
298 if nick != nil {
299 msg.Init.Nick = *nick
300 } else {
301 msg.Init.Nick = "wanderer"
302 }
303 externID := client.externID
304 if externID != nil {
305 msg.Init.ExternalID = *externID
306 } else {
307 msg.Init.ExternalID = ""
308 }
309 color := client.color
310 if color != nil {
311 msg.Init.Color = *color
312 } else {
313 msg.Init.Color = 0xD90368
314 }
315 msg.Init.Echoed = false
316
317 if s.initChan != nil {
318 select {
319 case s.initChan <- *msg:
320 default:
321 s.log("initchan blocked, closing channel")
322 close(s.initChan)
323 s.initChan = nil
324 }
325 }
326 s.broadcastInit(msg, client)
327}
328
329func (s *Server) broadcastInit(msg *lrcpb.Event_Init, client *client) {
330 stdEvent := &lrcpb.Event{Msg: msg}
331 stdData, _ := proto.Marshal(stdEvent)
332 msg.Init.Echoed = true
333 echoEvent := &lrcpb.Event{Msg: msg}
334 echoData, _ := proto.Marshal(echoEvent)
335 muteEvent := &lrcpb.Event{Msg: &lrcpb.Event_Mute{Mute: &lrcpb.Mute{Id: msg.Init.GetId()}}}
336 muteData, _ := proto.Marshal(muteEvent)
337 s.clientsMu.Lock()
338 defer s.clientsMu.Unlock()
339 for c := range s.clients {
340 var dts []byte
341 if c == client {
342 dts = echoData
343 } else if client.mutedBy[c] {
344 dts = muteData
345 } else {
346 dts = stdData
347 }
348 select {
349 case c.dataChan <- dts:
350 s.logDebug("b init")
351 default:
352 s.log("kicked client")
353 client.cancel()
354 }
355 }
356}
357
358func (s *Server) handlePub(msg *lrcpb.Event_Pub, client *client) {
359 curID := s.clientToID[client]
360 if curID == nil {
361 return
362 }
363 s.idmapsMu.Lock()
364 s.clientToID[client] = nil
365 s.idmapsMu.Unlock()
366 msg.Pub.Id = *curID
367 event := &lrcpb.Event{Msg: msg}
368 if s.pubChan != nil {
369 select {
370 case s.pubChan <- PubEvent{ID: *curID, Body: *client.post}:
371 default:
372 s.log("pubchan blocked, closing channel")
373 close(s.pubChan)
374 s.pubChan = nil
375 }
376 }
377 client.post = nil
378 s.broadcast(event, client)
379}
380
381func (s *Server) handleInsert(msg *lrcpb.Event_Insert, client *client) {
382 curID := s.clientToID[client]
383 if curID == nil {
384 return
385 }
386 newpost, err := insertAtUTF16Index(*client.post, msg.Insert.GetByteIndex(), msg.Insert.GetBody())
387 if err != nil {
388 return
389 }
390 client.post = &newpost
391 msg.Insert.Id = *curID
392 event := &lrcpb.Event{Msg: msg}
393 s.broadcast(event, client)
394}
395
396func insertAtUTF16Index(base string, index uint32, insert string) (string, error) {
397 runes := []rune(base)
398
399 unitCount := 0
400 var splitAt int
401 for i, r := range runes {
402 units := 1
403 if r > 0xFFFF {
404 units = 2
405 }
406 if uint32(unitCount+units) > index {
407 splitAt = i
408 break
409 }
410 unitCount += units
411 splitAt = i + 1
412 }
413 if index > uint32(unitCount) {
414 return "", errors.New("index out of range")
415 }
416 return string(runes[:splitAt]) + insert + string(runes[splitAt:]), nil
417}
418
419func (s *Server) handleDelete(msg *lrcpb.Event_Delete, client *client) {
420 curID := s.clientToID[client]
421 if curID == nil {
422 return
423 }
424 newPost, err := deleteBtwnUTF16Indices(*client.post, msg.Delete.GetByteStart(), msg.Delete.GetByteEnd())
425 if err != nil {
426 return
427 }
428 client.post = &newPost
429 msg.Delete.Id = *curID
430 event := &lrcpb.Event{Msg: msg}
431 s.broadcast(event, client)
432}
433
434func deleteBtwnUTF16Indices(base string, start uint32, end uint32) (string, error) {
435 if end <= start {
436 return "", errors.New("end must come after start")
437 }
438 runes := []rune(base)
439 unitCount := 0
440 var startAt, endAt *int
441 for i, r := range runes {
442 units := 1
443 if r > 0xFFFF {
444 units = 2
445 }
446 if startAt == nil && uint32(unitCount+units) > start {
447 startAt = &i
448 }
449 if uint32(unitCount) > end {
450 endAt = &i
451 break
452 }
453 unitCount += units
454 }
455 if end > uint32(unitCount) {
456 return "", errors.New("index out of range")
457 }
458 return string(runes[:*startAt]) + string(runes[*endAt:]), nil
459
460}
461
462func (s *Server) broadcast(event *lrcpb.Event, client *client) {
463 data, _ := proto.Marshal(event)
464 s.clientsMu.Lock()
465 defer s.clientsMu.Unlock()
466 for c := range s.clients {
467 if client.mutedBy[c] {
468 continue
469 }
470 select {
471 case c.dataChan <- data:
472 s.logDebug("b")
473 default:
474 s.log("kicked client")
475 client.cancel()
476 }
477 }
478}
479
480func (s *Server) handleMute(msg *lrcpb.Event_Mute, client *client) {
481 toMute := msg.Mute.GetId()
482 s.idmapsMu.Lock()
483 clientToMute, ok := s.idToClient[toMute]
484 if !ok {
485 return
486 }
487 if clientToMute == client {
488 return
489 }
490 clientToMute.mutedBy[client] = true
491 client.muteMap[clientToMute] = true
492 s.idmapsMu.Unlock()
493
494}
495
496func (s *Server) handleUnmute(msg *lrcpb.Event_Unmute, client *client) {
497 toMute := msg.Unmute.GetId()
498 s.idmapsMu.Lock()
499 clientToMute, ok := s.idToClient[toMute]
500 if !ok {
501 return
502 }
503 if clientToMute == client {
504 return
505 }
506 delete(clientToMute.mutedBy, client)
507 delete(client.muteMap, clientToMute)
508 s.idmapsMu.Unlock()
509
510}
511
512func (s *Server) handleSet(msg *lrcpb.Event_Set, client *client) {
513 nick := msg.Set.Nick
514 if nick != nil {
515 nickname := *nick
516 if len(nickname) <= 16 {
517 client.nick = &nickname
518 }
519 }
520 externalId := msg.Set.ExternalID
521 if externalId != nil {
522 externid := *externalId
523 client.externID = &externid
524 }
525 color := msg.Set.Color
526 if color != nil {
527 c := *color
528 if c <= 0xffffff {
529 client.color = &c
530 }
531 }
532}
533
534func (s *Server) handleGet(msg *lrcpb.Event_Get, client *client) {
535 t := msg.Get.Topic
536 if t != nil {
537 client.dataChan <- s.welcomeEvt
538 }
539 c := msg.Get.Connected
540 if c != nil {
541 conncount := uint32(len(s.clients))
542 e := &lrcpb.Event{Msg: &lrcpb.Event_Get{Get: &lrcpb.Get{Connected: &conncount}}}
543 data, _ := proto.Marshal(e)
544 client.dataChan <- data
545 }
546}
547
548// func (s *Server) broadcastAll(evt []byte) {
549// s.clientsMu.Lock()
550// defer s.clientsMu.Unlock()
551// for client := range s.clients {
552// select {
553// case client.evtChan <- evt:
554// s.logDebug(fmt.Sprintf("b %x", evt))
555// default:
556// s.log("kicked client")
557// if client.tcpconn != nil {
558// (*client.tcpconn).Close()
559// }
560// if client.wsconn != nil {
561// (*client.wsconn).Close()
562// }
563// delete(s.clients, client)
564// }
565// }
566// }
567
568// func (s *Server) broadcastInit(evt []byte, c *client, id uint32) {
569// bevt, eevt := events.GenServerEvent(evt, id)
570// s.clientsMu.Lock()
571// defer s.clientsMu.Unlock()
572// for client := range s.clients {
573// evtToSend := bevt
574// if client == c {
575// evtToSend = eevt
576// }
577// select {
578// case client.evtChan <- evtToSend:
579// s.logDebug(fmt.Sprintf("b %x", bevt))
580// default:
581// s.log("kicked client")
582// if client.tcpconn != nil {
583// (*client.tcpconn).Close()
584// }
585// if client.wsconn != nil {
586// (*client.wsconn).Close()
587// }
588// delete(s.clients, client)
589// }
590// }
591// }
592
593// logDebug debugs unless in production
594func (server *Server) logDebug(s string) {
595 if server.debugLogger != nil {
596 server.debugLogger.Println(s)
597 }
598}
599
600func (server *Server) log(s string) {
601 if server.logger != nil {
602 server.logger.Println(s)
603 }
604}