A privacy-first, self-hosted, fully open source personal knowledge management software, written in typescript and golang. (PERSONAL FORK)
at upstream/main 825 lines 20 kB view raw
1// SiYuan - Refactor your thinking 2// Copyright (c) 2020-present, b3log.org 3// 4// This program is free software: you can redistribute it and/or modify 5// it under the terms of the GNU Affero General Public License as published by 6// the Free Software Foundation, either version 3 of the License, or 7// (at your option) any later version. 8// 9// This program is distributed in the hope that it will be useful, 10// but WITHOUT ANY WARRANTY; without even the implied warranty of 11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12// GNU Affero General Public License for more details. 13// 14// You should have received a copy of the GNU Affero General Public License 15// along with this program. If not, see <https://www.gnu.org/licenses/>. 16 17package api 18 19import ( 20 "net/http" 21 "strconv" 22 "sync" 23 "time" 24 25 "github.com/88250/gulu" 26 "github.com/asaskevich/EventBus" 27 "github.com/gin-contrib/sse" 28 "github.com/gin-gonic/gin" 29 "github.com/olahol/melody" 30 "github.com/siyuan-note/logging" 31 "github.com/siyuan-note/siyuan/kernel/util" 32) 33 34const ( 35 MessageTypeString MessageType = "string" 36 MessageTypeBinary MessageType = "binary" 37 MessageTypeClose MessageType = "close" 38 39 EvtBroadcastMessage = "broadcast.message" 40) 41 42var ( 43 BroadcastChannels = sync.Map{} // [string (channel-name)] -> *BroadcastChannel 44 UnifiedSSE = &EventSourceServer{ 45 EventBus: EventBus.New(), 46 WaitGroup: &sync.WaitGroup{}, 47 Subscriber: &EventSourceSubscriber{ 48 lock: &sync.Mutex{}, 49 count: 0, 50 }, 51 } 52 messageID = &MessageID{ 53 lock: &sync.Mutex{}, 54 id: 0, 55 } 56) 57 58type MessageType string 59type MessageEventChannel chan *MessageEvent 60 61type MessageID struct { 62 lock *sync.Mutex 63 id uint64 64} 65 66func (m *MessageID) Next() uint64 { 67 m.lock.Lock() 68 defer m.lock.Unlock() 69 70 m.id++ 71 return m.id 72} 73 74type MessageEvent struct { 75 ID string // event ID 76 Type MessageType 77 Name string // channel name 78 Data []byte 79} 80 81type BroadcastSubscriber struct { 82 Count int // SEE subscriber count 83} 84 85type BroadcastChannel struct { 86 Name string // channel name 87 WebSocket *melody.Melody 88 Subscriber *BroadcastSubscriber // SEE subscriber 89} 90 91// SubscriberCount gets the total number of subscribers 92func (b *BroadcastChannel) SubscriberCount() int { 93 return b.WebSocket.Len() + b.Subscriber.Count + UnifiedSSE.Subscriber.Count() 94} 95 96// BroadcastString broadcast string message to all subscribers 97func (b *BroadcastChannel) BroadcastString(message string) (sent bool, err error) { 98 data := []byte(message) 99 sent = UnifiedSSE.SendEvent(&MessageEvent{ 100 Type: MessageTypeString, 101 Name: b.Name, 102 Data: data, 103 }) 104 err = b.WebSocket.Broadcast(data) 105 return 106} 107 108// BroadcastBinary broadcast binary message to all subscribers 109func (b *BroadcastChannel) BroadcastBinary(data []byte) (sent bool, err error) { 110 sent = UnifiedSSE.SendEvent(&MessageEvent{ 111 Type: MessageTypeBinary, 112 Name: b.Name, 113 Data: data, 114 }) 115 err = b.WebSocket.BroadcastBinary(data) 116 return 117} 118 119func (b *BroadcastChannel) HandleRequest(c *gin.Context) { 120 if err := b.WebSocket.HandleRequestWithKeys( 121 c.Writer, 122 c.Request, 123 map[string]interface{}{ 124 "channel": b.Name, 125 }, 126 ); err != nil { 127 logging.LogErrorf("create broadcast channel failed: %s", err) 128 return 129 } 130} 131 132func (b *BroadcastChannel) Subscribed() bool { 133 return b.SubscriberCount() > 0 134} 135 136func (b *BroadcastChannel) Destroy(force bool) bool { 137 if force || !b.Subscribed() { 138 b.WebSocket.Close() 139 UnifiedSSE.SendEvent(&MessageEvent{ 140 Type: MessageTypeClose, 141 Name: b.Name, 142 }) 143 logging.LogInfof("destroy broadcast channel [%s]", b.Name) 144 return true 145 } 146 return false 147} 148 149type EventSourceSubscriber struct { 150 lock *sync.Mutex 151 count int 152} 153 154func (s *EventSourceSubscriber) updateCount(delta int) { 155 s.lock.Lock() 156 defer s.lock.Unlock() 157 158 s.count += delta 159} 160 161func (s *EventSourceSubscriber) Count() int { 162 s.lock.Lock() 163 defer s.lock.Unlock() 164 165 return s.count 166} 167 168type EventSourceServer struct { 169 EventBus EventBus.Bus 170 WaitGroup *sync.WaitGroup 171 Subscriber *EventSourceSubscriber 172} 173 174// SendEvent sends a message to all subscribers 175func (s *EventSourceServer) SendEvent(event *MessageEvent) bool { 176 if event.ID == "" { 177 switch event.Type { 178 case MessageTypeClose: 179 default: 180 event.ID = strconv.FormatUint(messageID.Next(), 10) 181 } 182 } 183 184 s.EventBus.Publish(EvtBroadcastMessage, event) 185 return s.EventBus.HasCallback(EvtBroadcastMessage) 186} 187 188// Subscribe subscribes to specified broadcast channels 189func (s *EventSourceServer) Subscribe(c *gin.Context, retry uint, channels ...string) { 190 wg := sync.WaitGroup{} 191 wg.Add(len(channels)) 192 for _, channel := range channels { 193 go func() { 194 defer wg.Done() 195 196 var broadcastChannel *BroadcastChannel 197 _broadcastChannel, exist := BroadcastChannels.Load(channel) 198 if exist { // channel exists, use it 199 broadcastChannel = _broadcastChannel.(*BroadcastChannel) 200 } else { 201 broadcastChannel = ConstructBroadcastChannel(channel) 202 } 203 broadcastChannel.Subscriber.Count++ 204 }() 205 } 206 wg.Wait() 207 208 channelSet := make(map[string]bool) 209 for _, channel := range channels { 210 channelSet[channel] = true 211 } 212 213 c.Writer.Flush() 214 s.Stream(c, func(event *MessageEvent, ok bool) bool { 215 if ok { 216 if _, exists := channelSet[event.Name]; exists { 217 switch event.Type { 218 case MessageTypeClose: 219 return false 220 case MessageTypeString: 221 s.SSEvent(c, &sse.Event{ 222 Id: event.ID, 223 Event: event.Name, 224 Retry: retry, 225 Data: string(event.Data), 226 }) 227 default: 228 s.SSEvent(c, &sse.Event{ 229 Id: event.ID, 230 Event: event.Name, 231 Retry: retry, 232 Data: event.Data, 233 }) 234 } 235 c.Writer.Flush() 236 return true 237 } 238 return true 239 } 240 return false 241 }) 242 243 wg.Add(len(channels)) 244 for _, channel := range channels { 245 go func() { 246 defer wg.Done() 247 _broadcastChannel, exist := BroadcastChannels.Load(channel) 248 if exist { 249 broadcastChannel := _broadcastChannel.(*BroadcastChannel) 250 broadcastChannel.Subscriber.Count-- 251 if !broadcastChannel.Subscribed() { 252 BroadcastChannels.Delete(channel) 253 broadcastChannel.Destroy(true) 254 } 255 } 256 }() 257 } 258 wg.Wait() 259} 260 261// SubscribeAll subscribes to all broadcast channels 262func (s *EventSourceServer) SubscribeAll(c *gin.Context, retry uint) { 263 s.Subscriber.updateCount(1) 264 265 c.Writer.Flush() 266 s.Stream(c, func(event *MessageEvent, ok bool) bool { 267 if ok { 268 switch event.Type { 269 case MessageTypeClose: 270 return true 271 case MessageTypeString: 272 s.SSEvent(c, &sse.Event{ 273 Id: event.ID, 274 Event: event.Name, 275 Retry: retry, 276 Data: string(event.Data), 277 }) 278 default: 279 s.SSEvent(c, &sse.Event{ 280 Id: event.ID, 281 Event: event.Name, 282 Retry: retry, 283 Data: event.Data, 284 }) 285 } 286 c.Writer.Flush() 287 return true 288 } 289 return false 290 }) 291 292 s.Subscriber.updateCount(-1) 293 PruneBroadcastChannels() 294} 295 296// GetRetry gets the retry interval 297// 298// If the retry interval is not specified, it will return 0 299func (s *EventSourceServer) GetRetry(c *gin.Context) uint { 300 value := c.DefaultQuery("retry", "") 301 retry, err := strconv.ParseUint(value, 10, 0) 302 if err == nil { 303 return uint(retry) 304 } 305 return 0 306} 307 308// Stream streams message to client 309// 310// If the client is gone, it will return true 311func (s *EventSourceServer) Stream(c *gin.Context, step func(event *MessageEvent, ok bool) bool) bool { 312 channel := make(MessageEventChannel) 313 defer close(channel) 314 315 subscriber := func(event *MessageEvent) { 316 channel <- event 317 } 318 s.EventBus.Subscribe(EvtBroadcastMessage, subscriber) 319 defer s.EventBus.Unsubscribe(EvtBroadcastMessage, subscriber) 320 321 clientGone := c.Writer.CloseNotify() 322 for { 323 select { 324 case <-clientGone: 325 logging.LogInfof("event source connection is closed by client") 326 return true 327 case event, ok := <-channel: 328 if step(event, ok) { 329 continue 330 } 331 logging.LogInfof("event source connection is closed by server") 332 return false 333 } 334 } 335} 336 337// SSEvent writes a Server-Sent Event into the body stream. 338func (s *EventSourceServer) SSEvent(c *gin.Context, event *sse.Event) { 339 c.Render(-1, event) 340} 341 342// Subscribed checks whether the SSE server is subscribed 343func (s *EventSourceServer) Subscribed() bool { 344 return s.Subscriber.Count() > 0 345} 346 347type ChannelInfo struct { 348 Name string `json:"name"` 349 Count int `json:"count"` 350} 351 352type PublishMessage struct { 353 Type MessageType `json:"type"` // "string" | "binary" 354 Size int `json:"size"` // message size 355 Filename string `json:"filename"` // empty string for string-message 356} 357 358type PublishResult struct { 359 Code int `json:"code"` // 0: success 360 Msg string `json:"msg"` // error message 361 362 Channel ChannelInfo `json:"channel"` 363 Message PublishMessage `json:"message"` 364} 365 366// broadcast create a broadcast channel WebSocket connection 367// 368// @param 369// 370// { 371// channel: string, // channel name 372// } 373// 374// @example 375// 376// "ws://localhost:6806/ws/broadcast?channel=test" 377func broadcast(c *gin.Context) { 378 var ( 379 channel string = c.Query("channel") 380 broadcastChannel *BroadcastChannel 381 ) 382 383 _broadcastChannel, exist := BroadcastChannels.Load(channel) 384 if exist { // channel exists, use it 385 broadcastChannel = _broadcastChannel.(*BroadcastChannel) 386 if broadcastChannel.WebSocket.IsClosed() { // channel is closed 387 // delete channel before creating a new one 388 DestroyBroadcastChannel(channel, true) 389 } else { // channel is open 390 // connect to the existing channel 391 broadcastChannel.HandleRequest(c) 392 return 393 } 394 } 395 396 // create a new channel 397 broadcastChannel = ConstructBroadcastChannel(channel) 398 broadcastChannel.HandleRequest(c) 399} 400 401// GetBroadcastChannel gets a broadcast channel 402// 403// If the channel does not exist but the SSE server is subscribed, it will create a new broadcast channel. 404// If the SSE server is not subscribed, it will return nil. 405func GetBroadcastChannel(channel string) *BroadcastChannel { 406 _broadcastChannel, exist := BroadcastChannels.Load(channel) 407 if exist { 408 return _broadcastChannel.(*BroadcastChannel) 409 } 410 if UnifiedSSE.Subscribed() { 411 return ConstructBroadcastChannel(channel) 412 } 413 return nil 414} 415 416// ConstructBroadcastChannel creates a broadcast channel 417func ConstructBroadcastChannel(channel string) *BroadcastChannel { 418 websocket := melody.New() 419 websocket.Config.MaxMessageSize = 1024 * 1024 * 128 // 128 MiB 420 421 // broadcast string message to other session 422 websocket.HandleMessage(func(s *melody.Session, msg []byte) { 423 UnifiedSSE.SendEvent(&MessageEvent{ 424 Type: MessageTypeString, 425 Name: channel, 426 Data: msg, 427 }) 428 websocket.BroadcastOthers(msg, s) 429 }) 430 431 // broadcast binary message to other session 432 websocket.HandleMessageBinary(func(s *melody.Session, msg []byte) { 433 UnifiedSSE.SendEvent(&MessageEvent{ 434 Type: MessageTypeBinary, 435 Name: channel, 436 Data: msg, 437 }) 438 websocket.BroadcastBinaryOthers(msg, s) 439 }) 440 441 // client close the connection 442 websocket.HandleClose(func(s *melody.Session, status int, reason string) error { 443 channel := s.Keys["channel"].(string) 444 logging.LogInfof("close broadcast session in channel [%s] with status code %d: %s", channel, status, reason) 445 446 DestroyBroadcastChannel(channel, false) 447 return nil 448 }) 449 450 var broadcastChannel *BroadcastChannel 451 for { 452 // Melody Initialization is an asynchronous process, so we need to wait for it to complete 453 if websocket.IsClosed() { 454 time.Sleep(1 * time.Nanosecond) 455 } else { 456 newBroadcastChannel := &BroadcastChannel{ 457 Name: channel, 458 WebSocket: websocket, 459 Subscriber: &BroadcastSubscriber{ 460 Count: 0, 461 }, 462 } 463 _broadcastChannel, loaded := BroadcastChannels.LoadOrStore(channel, newBroadcastChannel) 464 broadcastChannel = _broadcastChannel.(*BroadcastChannel) 465 if loaded { // channel exists 466 if broadcastChannel.WebSocket.IsClosed() { // channel is closed, replace it 467 BroadcastChannels.Store(channel, newBroadcastChannel) 468 broadcastChannel = newBroadcastChannel 469 } else { // channel is open, destroy the new one 470 newBroadcastChannel.Destroy(true) 471 } 472 } 473 break 474 } 475 } 476 return broadcastChannel 477} 478 479// DestroyBroadcastChannel tries to destroy a broadcast channel 480// 481// Return true if the channel destroy successfully, otherwise false 482func DestroyBroadcastChannel(channel string, force bool) bool { 483 _broadcastChannel, exist := BroadcastChannels.Load(channel) 484 if !exist { 485 return true 486 } 487 488 broadcastChannel := _broadcastChannel.(*BroadcastChannel) 489 if force || !broadcastChannel.Subscribed() { 490 BroadcastChannels.Delete(channel) 491 broadcastChannel.Destroy(true) 492 return true 493 } 494 495 return false 496} 497 498// PruneBroadcastChannels prunes all broadcast channels without subscribers 499func PruneBroadcastChannels() []string { 500 channels := []string{} 501 BroadcastChannels.Range(func(key, value any) bool { 502 channel := key.(string) 503 broadcastChannel := value.(*BroadcastChannel) 504 if !broadcastChannel.Subscribed() { 505 BroadcastChannels.Delete(channel) 506 broadcastChannel.Destroy(true) 507 channels = append(channels, channel) 508 } 509 return true 510 }) 511 return channels 512} 513 514// broadcastSubscribe subscribe to a broadcast channel by SSE 515// 516// If the channel-name does not specified, the client will subscribe to all broadcast channels. 517// 518// @param 519// 520// { 521// retry: string, // retry interval (ms) (optional) 522// channel: string, // channel name (optional, multiple) 523// } 524// 525// @example 526// 527// "http://localhost:6806/es/broadcast/subscribe?retry=1000&channel=test1&channel=test2" 528func broadcastSubscribe(c *gin.Context) { 529 // REF: https://github.com/gin-gonic/examples/blob/master/server-sent-event/main.go 530 c.Writer.Header().Set("Content-Type", "text/event-stream") 531 c.Writer.Header().Set("Cache-Control", "no-cache") 532 c.Writer.Header().Set("Connection", "keep-alive") 533 c.Writer.Header().Set("Transfer-Encoding", "chunked") 534 535 defer UnifiedSSE.WaitGroup.Done() 536 UnifiedSSE.WaitGroup.Add(1) 537 538 retry := UnifiedSSE.GetRetry(c) 539 channels, ok := c.GetQueryArray("channel") 540 if ok { // subscribe specified broadcast channels 541 UnifiedSSE.Subscribe(c, retry, channels...) 542 } else { // subscribe all broadcast channels 543 UnifiedSSE.SubscribeAll(c, retry) 544 } 545} 546 547// broadcastPublish push multiple binary messages to multiple broadcast channels 548// 549// @param 550// 551// MultipartForm: [name] -> [values] 552// - name: string // channel name 553// - values: 554// - string[] // string-messages to the same channel 555// - File[] // binary-messages to the same channel 556// - filename: string // message key 557// 558// @returns 559// 560// { 561// code: int, 562// msg: string, 563// data: { 564// results: { 565// code: int, // 0: success 566// msg: string, // error message 567// channel: { 568// name: string, // channel name 569// count: string, // subscriber count 570// }, 571// message: { 572// type: string, // "string" | "binary" 573// size: int, // message size (Bytes) 574// filename: string, // empty string for string-message 575// }, 576// }[], 577// }, 578// } 579func broadcastPublish(c *gin.Context) { 580 ret := gulu.Ret.NewResult() 581 defer c.JSON(http.StatusOK, ret) 582 583 results := []*PublishResult{} 584 585 // Multipart form 586 form, err := c.MultipartForm() 587 if err != nil { 588 ret.Code = -2 589 ret.Msg = err.Error() 590 return 591 } 592 593 // Broadcast string messages 594 for name, values := range form.Value { 595 channel := ChannelInfo{ 596 Name: name, 597 Count: 0, 598 } 599 600 // Get broadcast channel 601 broadcastChannel := GetBroadcastChannel(channel.Name) 602 if broadcastChannel == nil { 603 channel.Count = 0 604 } else { 605 channel.Count = broadcastChannel.SubscriberCount() 606 } 607 608 // Broadcast each string message to the same channel 609 for _, value := range values { 610 result := &PublishResult{ 611 Code: 0, 612 Msg: "", 613 Channel: channel, 614 Message: PublishMessage{ 615 Type: MessageTypeString, 616 Size: len(value), 617 Filename: "", 618 }, 619 } 620 results = append(results, result) 621 622 if broadcastChannel != nil { 623 _, err := broadcastChannel.BroadcastString(value) 624 if err != nil { 625 logging.LogErrorf("broadcast message failed: %s", err) 626 result.Code = -2 627 result.Msg = err.Error() 628 continue 629 } 630 } 631 } 632 } 633 634 // Broadcast binary message 635 for name, files := range form.File { 636 channel := ChannelInfo{ 637 Name: name, 638 Count: 0, 639 } 640 641 // Get broadcast channel 642 broadcastChannel := GetBroadcastChannel(channel.Name) 643 if broadcastChannel == nil { 644 channel.Count = 0 645 } else { 646 channel.Count = broadcastChannel.SubscriberCount() 647 } 648 649 // Broadcast each binary message to the same channel 650 for _, file := range files { 651 result := &PublishResult{ 652 Code: 0, 653 Msg: "", 654 Channel: channel, 655 Message: PublishMessage{ 656 Type: MessageTypeBinary, 657 Size: int(file.Size), 658 Filename: file.Filename, 659 }, 660 } 661 results = append(results, result) 662 663 if broadcastChannel != nil { 664 value, err := file.Open() 665 if err != nil { 666 logging.LogErrorf("open multipart form file [%s] failed: %s", file.Filename, err) 667 result.Code = -4 668 result.Msg = err.Error() 669 continue 670 } 671 672 content := make([]byte, file.Size) 673 if _, err := value.Read(content); err != nil { 674 logging.LogErrorf("read multipart form file [%s] failed: %s", file.Filename, err) 675 result.Code = -3 676 result.Msg = err.Error() 677 continue 678 } 679 680 if _, err := broadcastChannel.BroadcastBinary(content); err != nil { 681 logging.LogErrorf("broadcast binary message failed: %s", err) 682 result.Code = -2 683 result.Msg = err.Error() 684 continue 685 } 686 } 687 } 688 } 689 690 ret.Data = map[string]interface{}{ 691 "results": results, 692 } 693} 694 695// postMessage send string message to a broadcast channel 696// 697// @param 698// 699// { 700// channel: string // channel name 701// message: string // message payload 702// } 703// 704// @returns 705// 706// { 707// code: int, 708// msg: string, 709// data: { 710// channel: { 711// name: string, //channel name 712// count: string, //listener count 713// }, 714// }, 715// } 716func postMessage(c *gin.Context) { 717 ret := gulu.Ret.NewResult() 718 defer c.JSON(http.StatusOK, ret) 719 720 arg, ok := util.JsonArg(c, ret) 721 if !ok { 722 return 723 } 724 725 message := arg["message"].(string) 726 channel := &ChannelInfo{ 727 Name: arg["channel"].(string), 728 Count: 0, 729 } 730 731 broadcastChannel := GetBroadcastChannel(channel.Name) 732 if broadcastChannel == nil { 733 channel.Count = 0 734 } else { 735 channel.Count = broadcastChannel.SubscriberCount() 736 if _, err := broadcastChannel.BroadcastString(message); err != nil { 737 logging.LogErrorf("broadcast message failed: %s", err) 738 739 ret.Code = -2 740 ret.Msg = err.Error() 741 return 742 } 743 } 744 ret.Data = map[string]interface{}{ 745 "channel": channel, 746 } 747} 748 749// getChannelInfo gets the information of a broadcast channel 750// 751// @param 752// 753// { 754// name: string, // channel name 755// } 756// 757// @returns 758// 759// { 760// code: int, 761// msg: string, 762// data: { 763// channel: { 764// name: string, //channel name 765// count: string, //listener count 766// }, 767// }, 768// } 769func getChannelInfo(c *gin.Context) { 770 ret := gulu.Ret.NewResult() 771 defer c.JSON(http.StatusOK, ret) 772 773 arg, ok := util.JsonArg(c, ret) 774 if !ok { 775 return 776 } 777 778 channel := &ChannelInfo{ 779 Name: arg["name"].(string), 780 Count: 0, 781 } 782 783 if _broadcastChannel, ok := BroadcastChannels.Load(channel.Name); !ok { 784 channel.Count = 0 785 } else { 786 var broadcastChannel = _broadcastChannel.(*BroadcastChannel) 787 channel.Count = broadcastChannel.SubscriberCount() 788 } 789 790 ret.Data = map[string]interface{}{ 791 "channel": channel, 792 } 793} 794 795// getChannels gets the channel name and lintener number of all broadcast chanel 796// 797// @returns 798// 799// { 800// code: int, 801// msg: string, 802// data: { 803// channels: { 804// name: string, //channel name 805// count: string, //listener count 806// }[], 807// }, 808// } 809func getChannels(c *gin.Context) { 810 ret := gulu.Ret.NewResult() 811 defer c.JSON(http.StatusOK, ret) 812 813 channels := []*ChannelInfo{} 814 BroadcastChannels.Range(func(key, value any) bool { 815 broadcastChannel := value.(*BroadcastChannel) 816 channels = append(channels, &ChannelInfo{ 817 Name: key.(string), 818 Count: broadcastChannel.SubscriberCount(), 819 }) 820 return true 821 }) 822 ret.Data = map[string]interface{}{ 823 "channels": channels, 824 } 825}