An experimental pub/sub client and server project.
1package server
2
3import (
4 "encoding/binary"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "net"
11 "strings"
12 "sync"
13 "syscall"
14 "time"
15
16 "github.com/willdot/messagebroker/internal"
17)
18
19// Action represents the type of action that a peer requests to do
20type Action uint16
21
22const (
23 Subscribe Action = 1
24 Unsubscribe Action = 2
25 Publish Action = 3
26 Ack Action = 4
27 Nack Action = 5
28)
29
30// Status represents the status of a request
31type Status uint16
32
33const (
34 Subscribed Status = 1
35 Unsubscribed Status = 2
36 Error Status = 3
37)
38
39func (s Status) String() string {
40 switch s {
41 case Subscribed:
42 return "subscribed"
43 case Unsubscribed:
44 return "unsubscribed"
45 case Error:
46 return "error"
47 }
48
49 return ""
50}
51
52// StartAtType represents where the subcriber wishes to start subscribing to a topic from
53type StartAtType uint16
54
55const (
56 Beginning StartAtType = 0
57 Current StartAtType = 1
58 From StartAtType = 2
59)
60
61// Server accepts subscribe and publish connections and passes messages around
62type Server struct {
63 Addr string
64 lis net.Listener
65
66 mu sync.Mutex
67 topics map[string]*topic
68
69 ackDelay time.Duration
70 ackTimeout time.Duration
71}
72
73// New creates and starts a new server
74func New(Addr string, ackDelay, ackTimeout time.Duration) (*Server, error) {
75 lis, err := net.Listen("tcp", Addr)
76 if err != nil {
77 return nil, fmt.Errorf("failed to listen: %w", err)
78 }
79
80 srv := &Server{
81 lis: lis,
82 topics: map[string]*topic{},
83 ackDelay: ackDelay,
84 ackTimeout: ackTimeout,
85 }
86
87 go srv.start()
88
89 return srv, nil
90}
91
92// Shutdown will cleanly shutdown the server
93func (s *Server) Shutdown() error {
94 return s.lis.Close()
95}
96
97func (s *Server) start() {
98 for {
99 conn, err := s.lis.Accept()
100 if err != nil {
101 if errors.Is(err, net.ErrClosed) {
102 slog.Info("listener closed")
103 return
104 }
105 slog.Error("listener failed to accept", "error", err)
106 continue
107 }
108
109 go s.handleConn(conn)
110 }
111}
112
113func (s *Server) handleConn(conn net.Conn) {
114 peer := NewPeer(conn)
115
116 slog.Info("handling connection", "peer", peer.Addr())
117 defer slog.Info("ending connection", "peer", peer.Addr())
118
119 action, err := readAction(peer, 0)
120 if err != nil {
121 if !errors.Is(err, io.EOF) {
122 slog.Error("failed to read action from peer", "error", err, "peer", peer.Addr())
123 }
124 return
125 }
126
127 switch action {
128 case Subscribe:
129 s.handleSubscribe(peer)
130 case Unsubscribe:
131 s.handleUnsubscribe(peer)
132 case Publish:
133 s.handlePublish(peer)
134 default:
135 slog.Error("unknown action", "action", action, "peer", peer.Addr())
136 writeInvalidAction(peer)
137 }
138}
139
140func (s *Server) handleSubscribe(peer *Peer) {
141 slog.Info("handling subscriber", "peer", peer.Addr())
142 // subscribe the peer to the topic
143 s.subscribePeerToTopic(peer)
144
145 s.waitForPeerAction(peer)
146}
147
148func (s *Server) waitForPeerAction(peer *Peer) {
149 // keep handling the peers connection, getting the action from the peer when it wishes to do something else.
150 // once the peers connection ends, it will be unsubscribed from all topics and returned
151 for {
152 action, err := readAction(peer, time.Millisecond*100)
153 if err != nil {
154 // if the error is a timeout, it means the peer hasn't sent an action indicating it wishes to do something so sleep
155 // for a little bit to allow for other actions to happen on the connection
156 var neterr net.Error
157 if errors.As(err, &neterr) && neterr.Timeout() {
158 time.Sleep(time.Millisecond * 500)
159 continue
160 }
161
162 if !errors.Is(err, io.EOF) {
163 slog.Error("failed to read action from subscriber", "error", err, "peer", peer.Addr())
164 }
165
166 s.unsubscribePeerFromAllTopics(peer)
167
168 return
169 }
170
171 switch action {
172 case Subscribe:
173 s.subscribePeerToTopic(peer)
174 case Unsubscribe:
175 s.handleUnsubscribe(peer)
176 default:
177 slog.Error("unknown action for subscriber", "action", action, "peer", peer.Addr())
178 writeInvalidAction(peer)
179 continue
180 }
181 }
182}
183
184func (s *Server) subscribePeerToTopic(peer *Peer) {
185 op := func(conn net.Conn) error {
186 // get the topics the peer wishes to subscribe to
187 dataLen, err := dataLengthUint32(conn)
188 if err != nil {
189 slog.Error(err.Error(), "peer", peer.Addr())
190 writeStatus(Error, "invalid data length of topics provided", conn)
191 return nil
192 }
193 if dataLen == 0 {
194 writeStatus(Error, "data length of topics is 0", conn)
195 return nil
196 }
197
198 buf := make([]byte, dataLen)
199 _, err = conn.Read(buf)
200 if err != nil {
201 slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.Addr())
202 writeStatus(Error, "failed to read topic data", conn)
203 return nil
204 }
205
206 var topics []string
207 err = json.Unmarshal(buf, &topics)
208 if err != nil {
209 slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.Addr())
210 writeStatus(Error, "invalid topic data provided", conn)
211 return nil
212 }
213
214 var startAtType StartAtType
215 err = binary.Read(conn, binary.BigEndian, &startAtType)
216 if err != nil {
217 slog.Error(err.Error(), "peer", peer.Addr())
218 writeStatus(Error, "invalid start at type provided", conn)
219 return nil
220 }
221 var startAt int
222 switch startAtType {
223 case From:
224 var s uint16
225 err = binary.Read(conn, binary.BigEndian, &s)
226 if err != nil {
227 slog.Error(err.Error(), "peer", peer.Addr())
228 writeStatus(Error, "invalid start at value provided", conn)
229 return nil
230 }
231 startAt = int(s)
232 case Beginning:
233 startAt = 0
234 case Current:
235 startAt = -1
236 default:
237 slog.Error("invalid start up type provided", "start up type", startAtType)
238 writeStatus(Error, "invalid start up type provided", conn)
239 return nil
240 }
241
242 s.subscribeToTopics(peer, topics, startAt)
243 writeStatus(Subscribed, "", conn)
244
245 return nil
246 }
247
248 _ = peer.RunConnOperation(op)
249}
250
251func (s *Server) handleUnsubscribe(peer *Peer) {
252 slog.Info("handling unsubscriber", "peer", peer.Addr())
253 op := func(conn net.Conn) error {
254 // get the topics the peer wishes to unsubscribe from
255 dataLen, err := dataLengthUint32(conn)
256 if err != nil {
257 slog.Error(err.Error(), "peer", peer.Addr())
258 writeStatus(Error, "invalid data length of topics provided", conn)
259 return nil
260 }
261 if dataLen == 0 {
262 writeStatus(Error, "data length of topics is 0", conn)
263 return nil
264 }
265
266 buf := make([]byte, dataLen)
267 _, err = conn.Read(buf)
268 if err != nil {
269 slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.Addr())
270 writeStatus(Error, "failed to read topic data", conn)
271 return nil
272 }
273
274 var topics []string
275 err = json.Unmarshal(buf, &topics)
276 if err != nil {
277 slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.Addr())
278 writeStatus(Error, "invalid topic data provided", conn)
279 return nil
280 }
281
282 s.unsubscribeToTopics(peer, topics)
283 writeStatus(Unsubscribed, "", conn)
284
285 return nil
286 }
287
288 _ = peer.RunConnOperation(op)
289}
290
291func (s *Server) handlePublish(peer *Peer) {
292 slog.Info("handling publisher", "peer", peer.Addr())
293 for {
294 op := func(conn net.Conn) error {
295 topicDataLen, err := dataLengthUint16(conn)
296 if err != nil {
297 if errors.Is(err, io.EOF) {
298 return nil
299 }
300 slog.Error("failed to read data length", "error", err, "peer", peer.Addr())
301 writeStatus(Error, "invalid data length of data provided", conn)
302 return nil
303 }
304 if topicDataLen == 0 {
305 return nil
306 }
307 topicBuf := make([]byte, topicDataLen)
308 _, err = conn.Read(topicBuf)
309 if err != nil {
310 slog.Error("failed to read topic from peer", "error", err, "peer", peer.Addr())
311 writeStatus(Error, "failed to read topic", conn)
312 return nil
313 }
314
315 topicStr := string(topicBuf)
316 if !strings.HasPrefix(topicStr, "topic:") {
317 slog.Error("topic data does not contain topic prefix", "peer", peer.Addr())
318 writeStatus(Error, "topic data does not contain 'topic:' prefix", conn)
319 return nil
320 }
321 topicStr = strings.TrimPrefix(topicStr, "topic:")
322
323 msgDataLen, err := dataLengthUint32(conn)
324 if err != nil {
325 slog.Error(err.Error(), "peer", peer.Addr())
326 writeStatus(Error, "invalid data length of data provided", conn)
327 return nil
328 }
329 if msgDataLen == 0 {
330 return nil
331 }
332
333 dataBuf := make([]byte, msgDataLen)
334 _, err = conn.Read(dataBuf)
335 if err != nil {
336 slog.Error("failed to read data from peer", "error", err, "peer", peer.Addr())
337 writeStatus(Error, "failed to read data", conn)
338 return nil
339 }
340
341 topic := s.getTopic(topicStr)
342 if topic == nil {
343 topic = newTopic(topicStr)
344 s.topics[topicStr] = topic
345 }
346
347 message := internal.NewMessage(dataBuf)
348
349 err = topic.sendMessageToSubscribers(message)
350 if err != nil {
351 slog.Error("failed to send message to subscribers", "error", err, "peer", peer.Addr())
352 writeStatus(Error, "failed to send message to subscribers", conn)
353 return nil
354 }
355
356 return nil
357 }
358
359 _ = peer.RunConnOperation(op)
360 }
361}
362
363func (s *Server) subscribeToTopics(peer *Peer, topics []string, startAt int) {
364 slog.Info("subscribing peer to topics", "topics", topics, "peer", peer.Addr())
365 for _, topic := range topics {
366 s.addSubsciberToTopic(topic, peer, startAt)
367 }
368}
369
370func (s *Server) addSubsciberToTopic(topicName string, peer *Peer, startAt int) {
371 s.mu.Lock()
372 defer s.mu.Unlock()
373
374 t, ok := s.topics[topicName]
375 if !ok {
376 t = newTopic(topicName)
377 }
378
379 t.mu.Lock()
380 t.subscriptions[peer.Addr()] = newSubscriber(peer, t, s.ackDelay, s.ackTimeout, startAt)
381 t.mu.Unlock()
382
383 s.topics[topicName] = t
384}
385
386func (s *Server) unsubscribeToTopics(peer *Peer, topics []string) {
387 slog.Info("unsubscribing peer from topics", "topics", topics, "peer", peer.Addr())
388 for _, topic := range topics {
389 s.removeSubsciberFromTopic(topic, peer)
390 }
391}
392
393func (s *Server) removeSubsciberFromTopic(topicName string, peer *Peer) {
394 s.mu.Lock()
395 defer s.mu.Unlock()
396
397 t, ok := s.topics[topicName]
398 if !ok {
399 return
400 }
401
402 sub := t.findSubscription(peer.Addr())
403 if sub == nil {
404 return
405 }
406
407 sub.unsubscribe()
408 t.removeSubscription(peer.Addr())
409}
410
411func (s *Server) unsubscribePeerFromAllTopics(peer *Peer) {
412 s.mu.Lock()
413 defer s.mu.Unlock()
414
415 for _, t := range s.topics {
416 sub := t.findSubscription(peer.Addr())
417 if sub == nil {
418 return
419 }
420
421 sub.unsubscribe()
422 t.removeSubscription(peer.Addr())
423 }
424}
425
426func (s *Server) getTopic(topicName string) *topic {
427 s.mu.Lock()
428 defer s.mu.Unlock()
429
430 if topic, ok := s.topics[topicName]; ok {
431 return topic
432 }
433
434 return nil
435}
436
437func readAction(peer *Peer, timeout time.Duration) (Action, error) {
438 var action Action
439 op := func(conn net.Conn) error {
440 if timeout > 0 {
441 if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
442 slog.Error("failed to set connection read deadline", "error", err, "peer", peer.Addr())
443 }
444 defer func() {
445 if err := conn.SetReadDeadline(time.Time{}); err != nil {
446 slog.Error("failed to reset connection read deadline", "error", err, "peer", peer.Addr())
447 }
448 }()
449 }
450
451 err := binary.Read(conn, binary.BigEndian, &action)
452 if err != nil {
453 return err
454 }
455 return nil
456 }
457
458 err := peer.RunConnOperation(op)
459 if err != nil {
460 return 0, fmt.Errorf("failed to read action from peer: %w", err)
461 }
462
463 return action, nil
464}
465
466func writeInvalidAction(peer *Peer) {
467 op := func(conn net.Conn) error {
468 writeStatus(Error, "unknown action", conn)
469 return nil
470 }
471
472 _ = peer.RunConnOperation(op)
473}
474
475func dataLengthUint32(conn net.Conn) (uint32, error) {
476 var dataLen uint32
477 err := binary.Read(conn, binary.BigEndian, &dataLen)
478 if err != nil {
479 return 0, err
480 }
481 return dataLen, nil
482}
483
484func dataLengthUint16(conn net.Conn) (uint16, error) {
485 var dataLen uint16
486 err := binary.Read(conn, binary.BigEndian, &dataLen)
487 if err != nil {
488 return 0, err
489 }
490 return dataLen, nil
491}
492
493func writeStatus(status Status, message string, conn net.Conn) {
494 statusB := make([]byte, 2)
495 binary.BigEndian.PutUint16(statusB, uint16(status))
496
497 headers := statusB
498
499 if len(message) > 0 {
500 sizeB := make([]byte, 2)
501 binary.BigEndian.PutUint16(sizeB, uint16(len(message)))
502 headers = append(headers, sizeB...)
503 }
504
505 msgBytes := []byte(message)
506 _, err := conn.Write(append(headers, msgBytes...))
507 if err != nil {
508 if !errors.Is(err, syscall.EPIPE) {
509 slog.Error("failed to write status to peers connection", "error", err, "peer", conn.RemoteAddr())
510 }
511 return
512 }
513}