A privacy-first, self-hosted, fully open source personal knowledge management software, written in typescript and golang. (PERSONAL FORK)
1// SiYuan - Refactor your thinking
2// Copyright (c) 2020-present, b3log.org
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU Affero General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8//
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU Affero General Public License for more details.
13//
14// You should have received a copy of the GNU Affero General Public License
15// along with this program. If not, see <https://www.gnu.org/licenses/>.
16
17package api
18
19import (
20 "net/http"
21 "strconv"
22 "sync"
23 "time"
24
25 "github.com/88250/gulu"
26 "github.com/asaskevich/EventBus"
27 "github.com/gin-contrib/sse"
28 "github.com/gin-gonic/gin"
29 "github.com/olahol/melody"
30 "github.com/siyuan-note/logging"
31 "github.com/siyuan-note/siyuan/kernel/util"
32)
33
34const (
35 MessageTypeString MessageType = "string"
36 MessageTypeBinary MessageType = "binary"
37 MessageTypeClose MessageType = "close"
38
39 EvtBroadcastMessage = "broadcast.message"
40)
41
42var (
43 BroadcastChannels = sync.Map{} // [string (channel-name)] -> *BroadcastChannel
44 UnifiedSSE = &EventSourceServer{
45 EventBus: EventBus.New(),
46 WaitGroup: &sync.WaitGroup{},
47 Subscriber: &EventSourceSubscriber{
48 lock: &sync.Mutex{},
49 count: 0,
50 },
51 }
52 messageID = &MessageID{
53 lock: &sync.Mutex{},
54 id: 0,
55 }
56)
57
58type MessageType string
59type MessageEventChannel chan *MessageEvent
60
61type MessageID struct {
62 lock *sync.Mutex
63 id uint64
64}
65
66func (m *MessageID) Next() uint64 {
67 m.lock.Lock()
68 defer m.lock.Unlock()
69
70 m.id++
71 return m.id
72}
73
74type MessageEvent struct {
75 ID string // event ID
76 Type MessageType
77 Name string // channel name
78 Data []byte
79}
80
81type BroadcastSubscriber struct {
82 Count int // SEE subscriber count
83}
84
85type BroadcastChannel struct {
86 Name string // channel name
87 WebSocket *melody.Melody
88 Subscriber *BroadcastSubscriber // SEE subscriber
89}
90
91// SubscriberCount gets the total number of subscribers
92func (b *BroadcastChannel) SubscriberCount() int {
93 return b.WebSocket.Len() + b.Subscriber.Count + UnifiedSSE.Subscriber.Count()
94}
95
96// BroadcastString broadcast string message to all subscribers
97func (b *BroadcastChannel) BroadcastString(message string) (sent bool, err error) {
98 data := []byte(message)
99 sent = UnifiedSSE.SendEvent(&MessageEvent{
100 Type: MessageTypeString,
101 Name: b.Name,
102 Data: data,
103 })
104 err = b.WebSocket.Broadcast(data)
105 return
106}
107
108// BroadcastBinary broadcast binary message to all subscribers
109func (b *BroadcastChannel) BroadcastBinary(data []byte) (sent bool, err error) {
110 sent = UnifiedSSE.SendEvent(&MessageEvent{
111 Type: MessageTypeBinary,
112 Name: b.Name,
113 Data: data,
114 })
115 err = b.WebSocket.BroadcastBinary(data)
116 return
117}
118
119func (b *BroadcastChannel) HandleRequest(c *gin.Context) {
120 if err := b.WebSocket.HandleRequestWithKeys(
121 c.Writer,
122 c.Request,
123 map[string]interface{}{
124 "channel": b.Name,
125 },
126 ); err != nil {
127 logging.LogErrorf("create broadcast channel failed: %s", err)
128 return
129 }
130}
131
132func (b *BroadcastChannel) Subscribed() bool {
133 return b.SubscriberCount() > 0
134}
135
136func (b *BroadcastChannel) Destroy(force bool) bool {
137 if force || !b.Subscribed() {
138 b.WebSocket.Close()
139 UnifiedSSE.SendEvent(&MessageEvent{
140 Type: MessageTypeClose,
141 Name: b.Name,
142 })
143 logging.LogInfof("destroy broadcast channel [%s]", b.Name)
144 return true
145 }
146 return false
147}
148
149type EventSourceSubscriber struct {
150 lock *sync.Mutex
151 count int
152}
153
154func (s *EventSourceSubscriber) updateCount(delta int) {
155 s.lock.Lock()
156 defer s.lock.Unlock()
157
158 s.count += delta
159}
160
161func (s *EventSourceSubscriber) Count() int {
162 s.lock.Lock()
163 defer s.lock.Unlock()
164
165 return s.count
166}
167
168type EventSourceServer struct {
169 EventBus EventBus.Bus
170 WaitGroup *sync.WaitGroup
171 Subscriber *EventSourceSubscriber
172}
173
174// SendEvent sends a message to all subscribers
175func (s *EventSourceServer) SendEvent(event *MessageEvent) bool {
176 if event.ID == "" {
177 switch event.Type {
178 case MessageTypeClose:
179 default:
180 event.ID = strconv.FormatUint(messageID.Next(), 10)
181 }
182 }
183
184 s.EventBus.Publish(EvtBroadcastMessage, event)
185 return s.EventBus.HasCallback(EvtBroadcastMessage)
186}
187
188// Subscribe subscribes to specified broadcast channels
189func (s *EventSourceServer) Subscribe(c *gin.Context, retry uint, channels ...string) {
190 wg := sync.WaitGroup{}
191 wg.Add(len(channels))
192 for _, channel := range channels {
193 go func() {
194 defer wg.Done()
195
196 var broadcastChannel *BroadcastChannel
197 _broadcastChannel, exist := BroadcastChannels.Load(channel)
198 if exist { // channel exists, use it
199 broadcastChannel = _broadcastChannel.(*BroadcastChannel)
200 } else {
201 broadcastChannel = ConstructBroadcastChannel(channel)
202 }
203 broadcastChannel.Subscriber.Count++
204 }()
205 }
206 wg.Wait()
207
208 channelSet := make(map[string]bool)
209 for _, channel := range channels {
210 channelSet[channel] = true
211 }
212
213 c.Writer.Flush()
214 s.Stream(c, func(event *MessageEvent, ok bool) bool {
215 if ok {
216 if _, exists := channelSet[event.Name]; exists {
217 switch event.Type {
218 case MessageTypeClose:
219 return false
220 case MessageTypeString:
221 s.SSEvent(c, &sse.Event{
222 Id: event.ID,
223 Event: event.Name,
224 Retry: retry,
225 Data: string(event.Data),
226 })
227 default:
228 s.SSEvent(c, &sse.Event{
229 Id: event.ID,
230 Event: event.Name,
231 Retry: retry,
232 Data: event.Data,
233 })
234 }
235 c.Writer.Flush()
236 return true
237 }
238 return true
239 }
240 return false
241 })
242
243 wg.Add(len(channels))
244 for _, channel := range channels {
245 go func() {
246 defer wg.Done()
247 _broadcastChannel, exist := BroadcastChannels.Load(channel)
248 if exist {
249 broadcastChannel := _broadcastChannel.(*BroadcastChannel)
250 broadcastChannel.Subscriber.Count--
251 if !broadcastChannel.Subscribed() {
252 BroadcastChannels.Delete(channel)
253 broadcastChannel.Destroy(true)
254 }
255 }
256 }()
257 }
258 wg.Wait()
259}
260
261// SubscribeAll subscribes to all broadcast channels
262func (s *EventSourceServer) SubscribeAll(c *gin.Context, retry uint) {
263 s.Subscriber.updateCount(1)
264
265 c.Writer.Flush()
266 s.Stream(c, func(event *MessageEvent, ok bool) bool {
267 if ok {
268 switch event.Type {
269 case MessageTypeClose:
270 return true
271 case MessageTypeString:
272 s.SSEvent(c, &sse.Event{
273 Id: event.ID,
274 Event: event.Name,
275 Retry: retry,
276 Data: string(event.Data),
277 })
278 default:
279 s.SSEvent(c, &sse.Event{
280 Id: event.ID,
281 Event: event.Name,
282 Retry: retry,
283 Data: event.Data,
284 })
285 }
286 c.Writer.Flush()
287 return true
288 }
289 return false
290 })
291
292 s.Subscriber.updateCount(-1)
293 PruneBroadcastChannels()
294}
295
296// GetRetry gets the retry interval
297//
298// If the retry interval is not specified, it will return 0
299func (s *EventSourceServer) GetRetry(c *gin.Context) uint {
300 value := c.DefaultQuery("retry", "")
301 retry, err := strconv.ParseUint(value, 10, 0)
302 if err == nil {
303 return uint(retry)
304 }
305 return 0
306}
307
308// Stream streams message to client
309//
310// If the client is gone, it will return true
311func (s *EventSourceServer) Stream(c *gin.Context, step func(event *MessageEvent, ok bool) bool) bool {
312 channel := make(MessageEventChannel)
313 defer close(channel)
314
315 subscriber := func(event *MessageEvent) {
316 channel <- event
317 }
318 s.EventBus.Subscribe(EvtBroadcastMessage, subscriber)
319 defer s.EventBus.Unsubscribe(EvtBroadcastMessage, subscriber)
320
321 clientGone := c.Writer.CloseNotify()
322 for {
323 select {
324 case <-clientGone:
325 logging.LogInfof("event source connection is closed by client")
326 return true
327 case event, ok := <-channel:
328 if step(event, ok) {
329 continue
330 }
331 logging.LogInfof("event source connection is closed by server")
332 return false
333 }
334 }
335}
336
337// SSEvent writes a Server-Sent Event into the body stream.
338func (s *EventSourceServer) SSEvent(c *gin.Context, event *sse.Event) {
339 c.Render(-1, event)
340}
341
342// Subscribed checks whether the SSE server is subscribed
343func (s *EventSourceServer) Subscribed() bool {
344 return s.Subscriber.Count() > 0
345}
346
347type ChannelInfo struct {
348 Name string `json:"name"`
349 Count int `json:"count"`
350}
351
352type PublishMessage struct {
353 Type MessageType `json:"type"` // "string" | "binary"
354 Size int `json:"size"` // message size
355 Filename string `json:"filename"` // empty string for string-message
356}
357
358type PublishResult struct {
359 Code int `json:"code"` // 0: success
360 Msg string `json:"msg"` // error message
361
362 Channel ChannelInfo `json:"channel"`
363 Message PublishMessage `json:"message"`
364}
365
366// broadcast create a broadcast channel WebSocket connection
367//
368// @param
369//
370// {
371// channel: string, // channel name
372// }
373//
374// @example
375//
376// "ws://localhost:6806/ws/broadcast?channel=test"
377func broadcast(c *gin.Context) {
378 var (
379 channel string = c.Query("channel")
380 broadcastChannel *BroadcastChannel
381 )
382
383 _broadcastChannel, exist := BroadcastChannels.Load(channel)
384 if exist { // channel exists, use it
385 broadcastChannel = _broadcastChannel.(*BroadcastChannel)
386 if broadcastChannel.WebSocket.IsClosed() { // channel is closed
387 // delete channel before creating a new one
388 DestroyBroadcastChannel(channel, true)
389 } else { // channel is open
390 // connect to the existing channel
391 broadcastChannel.HandleRequest(c)
392 return
393 }
394 }
395
396 // create a new channel
397 broadcastChannel = ConstructBroadcastChannel(channel)
398 broadcastChannel.HandleRequest(c)
399}
400
401// GetBroadcastChannel gets a broadcast channel
402//
403// If the channel does not exist but the SSE server is subscribed, it will create a new broadcast channel.
404// If the SSE server is not subscribed, it will return nil.
405func GetBroadcastChannel(channel string) *BroadcastChannel {
406 _broadcastChannel, exist := BroadcastChannels.Load(channel)
407 if exist {
408 return _broadcastChannel.(*BroadcastChannel)
409 }
410 if UnifiedSSE.Subscribed() {
411 return ConstructBroadcastChannel(channel)
412 }
413 return nil
414}
415
416// ConstructBroadcastChannel creates a broadcast channel
417func ConstructBroadcastChannel(channel string) *BroadcastChannel {
418 websocket := melody.New()
419 websocket.Config.MaxMessageSize = 1024 * 1024 * 128 // 128 MiB
420
421 // broadcast string message to other session
422 websocket.HandleMessage(func(s *melody.Session, msg []byte) {
423 UnifiedSSE.SendEvent(&MessageEvent{
424 Type: MessageTypeString,
425 Name: channel,
426 Data: msg,
427 })
428 websocket.BroadcastOthers(msg, s)
429 })
430
431 // broadcast binary message to other session
432 websocket.HandleMessageBinary(func(s *melody.Session, msg []byte) {
433 UnifiedSSE.SendEvent(&MessageEvent{
434 Type: MessageTypeBinary,
435 Name: channel,
436 Data: msg,
437 })
438 websocket.BroadcastBinaryOthers(msg, s)
439 })
440
441 // client close the connection
442 websocket.HandleClose(func(s *melody.Session, status int, reason string) error {
443 channel := s.Keys["channel"].(string)
444 logging.LogInfof("close broadcast session in channel [%s] with status code %d: %s", channel, status, reason)
445
446 DestroyBroadcastChannel(channel, false)
447 return nil
448 })
449
450 var broadcastChannel *BroadcastChannel
451 for {
452 // Melody Initialization is an asynchronous process, so we need to wait for it to complete
453 if websocket.IsClosed() {
454 time.Sleep(1 * time.Nanosecond)
455 } else {
456 newBroadcastChannel := &BroadcastChannel{
457 Name: channel,
458 WebSocket: websocket,
459 Subscriber: &BroadcastSubscriber{
460 Count: 0,
461 },
462 }
463 _broadcastChannel, loaded := BroadcastChannels.LoadOrStore(channel, newBroadcastChannel)
464 broadcastChannel = _broadcastChannel.(*BroadcastChannel)
465 if loaded { // channel exists
466 if broadcastChannel.WebSocket.IsClosed() { // channel is closed, replace it
467 BroadcastChannels.Store(channel, newBroadcastChannel)
468 broadcastChannel = newBroadcastChannel
469 } else { // channel is open, destroy the new one
470 newBroadcastChannel.Destroy(true)
471 }
472 }
473 break
474 }
475 }
476 return broadcastChannel
477}
478
479// DestroyBroadcastChannel tries to destroy a broadcast channel
480//
481// Return true if the channel destroy successfully, otherwise false
482func DestroyBroadcastChannel(channel string, force bool) bool {
483 _broadcastChannel, exist := BroadcastChannels.Load(channel)
484 if !exist {
485 return true
486 }
487
488 broadcastChannel := _broadcastChannel.(*BroadcastChannel)
489 if force || !broadcastChannel.Subscribed() {
490 BroadcastChannels.Delete(channel)
491 broadcastChannel.Destroy(true)
492 return true
493 }
494
495 return false
496}
497
498// PruneBroadcastChannels prunes all broadcast channels without subscribers
499func PruneBroadcastChannels() []string {
500 channels := []string{}
501 BroadcastChannels.Range(func(key, value any) bool {
502 channel := key.(string)
503 broadcastChannel := value.(*BroadcastChannel)
504 if !broadcastChannel.Subscribed() {
505 BroadcastChannels.Delete(channel)
506 broadcastChannel.Destroy(true)
507 channels = append(channels, channel)
508 }
509 return true
510 })
511 return channels
512}
513
514// broadcastSubscribe subscribe to a broadcast channel by SSE
515//
516// If the channel-name does not specified, the client will subscribe to all broadcast channels.
517//
518// @param
519//
520// {
521// retry: string, // retry interval (ms) (optional)
522// channel: string, // channel name (optional, multiple)
523// }
524//
525// @example
526//
527// "http://localhost:6806/es/broadcast/subscribe?retry=1000&channel=test1&channel=test2"
528func broadcastSubscribe(c *gin.Context) {
529 // REF: https://github.com/gin-gonic/examples/blob/master/server-sent-event/main.go
530 c.Writer.Header().Set("Content-Type", "text/event-stream")
531 c.Writer.Header().Set("Cache-Control", "no-cache")
532 c.Writer.Header().Set("Connection", "keep-alive")
533 c.Writer.Header().Set("Transfer-Encoding", "chunked")
534
535 defer UnifiedSSE.WaitGroup.Done()
536 UnifiedSSE.WaitGroup.Add(1)
537
538 retry := UnifiedSSE.GetRetry(c)
539 channels, ok := c.GetQueryArray("channel")
540 if ok { // subscribe specified broadcast channels
541 UnifiedSSE.Subscribe(c, retry, channels...)
542 } else { // subscribe all broadcast channels
543 UnifiedSSE.SubscribeAll(c, retry)
544 }
545}
546
547// broadcastPublish push multiple binary messages to multiple broadcast channels
548//
549// @param
550//
551// MultipartForm: [name] -> [values]
552// - name: string // channel name
553// - values:
554// - string[] // string-messages to the same channel
555// - File[] // binary-messages to the same channel
556// - filename: string // message key
557//
558// @returns
559//
560// {
561// code: int,
562// msg: string,
563// data: {
564// results: {
565// code: int, // 0: success
566// msg: string, // error message
567// channel: {
568// name: string, // channel name
569// count: string, // subscriber count
570// },
571// message: {
572// type: string, // "string" | "binary"
573// size: int, // message size (Bytes)
574// filename: string, // empty string for string-message
575// },
576// }[],
577// },
578// }
579func broadcastPublish(c *gin.Context) {
580 ret := gulu.Ret.NewResult()
581 defer c.JSON(http.StatusOK, ret)
582
583 results := []*PublishResult{}
584
585 // Multipart form
586 form, err := c.MultipartForm()
587 if err != nil {
588 ret.Code = -2
589 ret.Msg = err.Error()
590 return
591 }
592
593 // Broadcast string messages
594 for name, values := range form.Value {
595 channel := ChannelInfo{
596 Name: name,
597 Count: 0,
598 }
599
600 // Get broadcast channel
601 broadcastChannel := GetBroadcastChannel(channel.Name)
602 if broadcastChannel == nil {
603 channel.Count = 0
604 } else {
605 channel.Count = broadcastChannel.SubscriberCount()
606 }
607
608 // Broadcast each string message to the same channel
609 for _, value := range values {
610 result := &PublishResult{
611 Code: 0,
612 Msg: "",
613 Channel: channel,
614 Message: PublishMessage{
615 Type: MessageTypeString,
616 Size: len(value),
617 Filename: "",
618 },
619 }
620 results = append(results, result)
621
622 if broadcastChannel != nil {
623 _, err := broadcastChannel.BroadcastString(value)
624 if err != nil {
625 logging.LogErrorf("broadcast message failed: %s", err)
626 result.Code = -2
627 result.Msg = err.Error()
628 continue
629 }
630 }
631 }
632 }
633
634 // Broadcast binary message
635 for name, files := range form.File {
636 channel := ChannelInfo{
637 Name: name,
638 Count: 0,
639 }
640
641 // Get broadcast channel
642 broadcastChannel := GetBroadcastChannel(channel.Name)
643 if broadcastChannel == nil {
644 channel.Count = 0
645 } else {
646 channel.Count = broadcastChannel.SubscriberCount()
647 }
648
649 // Broadcast each binary message to the same channel
650 for _, file := range files {
651 result := &PublishResult{
652 Code: 0,
653 Msg: "",
654 Channel: channel,
655 Message: PublishMessage{
656 Type: MessageTypeBinary,
657 Size: int(file.Size),
658 Filename: file.Filename,
659 },
660 }
661 results = append(results, result)
662
663 if broadcastChannel != nil {
664 value, err := file.Open()
665 if err != nil {
666 logging.LogErrorf("open multipart form file [%s] failed: %s", file.Filename, err)
667 result.Code = -4
668 result.Msg = err.Error()
669 continue
670 }
671
672 content := make([]byte, file.Size)
673 if _, err := value.Read(content); err != nil {
674 logging.LogErrorf("read multipart form file [%s] failed: %s", file.Filename, err)
675 result.Code = -3
676 result.Msg = err.Error()
677 continue
678 }
679
680 if _, err := broadcastChannel.BroadcastBinary(content); err != nil {
681 logging.LogErrorf("broadcast binary message failed: %s", err)
682 result.Code = -2
683 result.Msg = err.Error()
684 continue
685 }
686 }
687 }
688 }
689
690 ret.Data = map[string]interface{}{
691 "results": results,
692 }
693}
694
695// postMessage send string message to a broadcast channel
696//
697// @param
698//
699// {
700// channel: string // channel name
701// message: string // message payload
702// }
703//
704// @returns
705//
706// {
707// code: int,
708// msg: string,
709// data: {
710// channel: {
711// name: string, //channel name
712// count: string, //listener count
713// },
714// },
715// }
716func postMessage(c *gin.Context) {
717 ret := gulu.Ret.NewResult()
718 defer c.JSON(http.StatusOK, ret)
719
720 arg, ok := util.JsonArg(c, ret)
721 if !ok {
722 return
723 }
724
725 message := arg["message"].(string)
726 channel := &ChannelInfo{
727 Name: arg["channel"].(string),
728 Count: 0,
729 }
730
731 broadcastChannel := GetBroadcastChannel(channel.Name)
732 if broadcastChannel == nil {
733 channel.Count = 0
734 } else {
735 channel.Count = broadcastChannel.SubscriberCount()
736 if _, err := broadcastChannel.BroadcastString(message); err != nil {
737 logging.LogErrorf("broadcast message failed: %s", err)
738
739 ret.Code = -2
740 ret.Msg = err.Error()
741 return
742 }
743 }
744 ret.Data = map[string]interface{}{
745 "channel": channel,
746 }
747}
748
749// getChannelInfo gets the information of a broadcast channel
750//
751// @param
752//
753// {
754// name: string, // channel name
755// }
756//
757// @returns
758//
759// {
760// code: int,
761// msg: string,
762// data: {
763// channel: {
764// name: string, //channel name
765// count: string, //listener count
766// },
767// },
768// }
769func getChannelInfo(c *gin.Context) {
770 ret := gulu.Ret.NewResult()
771 defer c.JSON(http.StatusOK, ret)
772
773 arg, ok := util.JsonArg(c, ret)
774 if !ok {
775 return
776 }
777
778 channel := &ChannelInfo{
779 Name: arg["name"].(string),
780 Count: 0,
781 }
782
783 if _broadcastChannel, ok := BroadcastChannels.Load(channel.Name); !ok {
784 channel.Count = 0
785 } else {
786 var broadcastChannel = _broadcastChannel.(*BroadcastChannel)
787 channel.Count = broadcastChannel.SubscriberCount()
788 }
789
790 ret.Data = map[string]interface{}{
791 "channel": channel,
792 }
793}
794
795// getChannels gets the channel name and lintener number of all broadcast chanel
796//
797// @returns
798//
799// {
800// code: int,
801// msg: string,
802// data: {
803// channels: {
804// name: string, //channel name
805// count: string, //listener count
806// }[],
807// },
808// }
809func getChannels(c *gin.Context) {
810 ret := gulu.Ret.NewResult()
811 defer c.JSON(http.StatusOK, ret)
812
813 channels := []*ChannelInfo{}
814 BroadcastChannels.Range(func(key, value any) bool {
815 broadcastChannel := value.(*BroadcastChannel)
816 channels = append(channels, &ChannelInfo{
817 Name: key.(string),
818 Count: broadcastChannel.SubscriberCount(),
819 })
820 return true
821 })
822 ret.Data = map[string]interface{}{
823 "channels": channels,
824 }
825}