An experimental pub/sub client and server project.

Merge pull request #3 from willdot/connLock

Make sure using connections is synchronous

authored by willdot.net and committed by

GitHub 71cb555e a4a90b1b

+664 -445
+2 -1
.gitignore
··· 1 - .DS_STORE 1 + .DS_STORE 2 + example/example
+20
dockerfile.example-server
··· 1 + FROM golang:latest as builder 2 + 3 + WORKDIR /app 4 + 5 + COPY go.mod go.sum ./ 6 + COPY example/server/ ./ 7 + RUN go mod download 8 + 9 + COPY . . 10 + 11 + RUN CGO_ENABLED=0 go build -o message-broker-server . 12 + 13 + FROM alpine:latest 14 + 15 + RUN apk --no-cache add ca-certificates 16 + 17 + WORKDIR /root/ 18 + COPY --from=builder /app/message-broker-server . 19 + 20 + CMD ["./message-broker-server"]
+12 -9
example/main.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "flag" 5 6 "fmt" 6 7 "log/slog" 8 + "time" 7 9 8 - "github.com/willdot/messagebroker" 9 10 "github.com/willdot/messagebroker/pubsub" 10 - "github.com/willdot/messagebroker/server" 11 11 ) 12 12 13 + var consumeOnly *bool 14 + 13 15 func main() { 14 - server, err := server.New(context.Background(), ":3000") 15 - if err != nil { 16 - panic(err) 16 + consumeOnly = flag.Bool("consume-only", false, "just consumes (doesn't start server and doesn't publish)") 17 + flag.Parse() 18 + 19 + if *consumeOnly == false { 20 + go sendMessages() 17 21 } 18 - defer server.Shutdown() 19 - 20 - go sendMessages() 21 22 22 23 sub, err := pubsub.NewSubscriber(":3000") 23 24 if err != nil { ··· 49 50 i := 0 50 51 for { 51 52 i++ 52 - msg := messagebroker.Message{ 53 + msg := pubsub.Message{ 53 54 Topic: "topic a", 54 55 Data: []byte(fmt.Sprintf("message %d", i)), 55 56 } ··· 59 60 slog.Error("failed to publish message", "error", err) 60 61 continue 61 62 } 63 + 64 + time.Sleep(time.Millisecond * 500) 62 65 } 63 66 }
+23
example/server/main.go
··· 1 + package main 2 + 3 + import ( 4 + "log" 5 + "os" 6 + "os/signal" 7 + "syscall" 8 + 9 + "github.com/willdot/messagebroker/server" 10 + ) 11 + 12 + func main() { 13 + srv, err := server.New(":3000") 14 + if err != nil { 15 + log.Fatal(err) 16 + } 17 + defer srv.Shutdown() 18 + 19 + signals := make(chan os.Signal, 1) 20 + signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) 21 + 22 + <-signals 23 + }
+5 -1
go.mod
··· 2 2 3 3 go 1.21.0 4 4 5 - require github.com/stretchr/testify v1.8.4 5 + require ( 6 + github.com/docker/distribution v2.8.3+incompatible 7 + github.com/google/uuid v1.4.0 8 + github.com/stretchr/testify v1.8.4 9 + ) 6 10 7 11 require ( 8 12 github.com/davecgh/go-spew v1.1.1 // indirect
+4
go.sum
··· 1 1 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 2 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 + github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk= 4 + github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= 5 + github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= 6 + github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 3 7 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 8 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 9 github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
+1 -1
message.go pubsub/message.go
··· 1 - package messagebroker 1 + package pubsub 2 2 3 3 // Message represents a message that can be published or consumed 4 4 type Message struct {
+33 -16
pubsub/publisher.go
··· 2 2 3 3 import ( 4 4 "encoding/binary" 5 - "encoding/json" 6 5 "fmt" 7 6 "net" 7 + "sync" 8 8 9 - "github.com/willdot/messagebroker" 10 9 "github.com/willdot/messagebroker/server" 11 10 ) 12 11 13 12 // Publisher allows messages to be published to a server 14 13 type Publisher struct { 15 - conn net.Conn 14 + conn net.Conn 15 + connMu sync.Mutex 16 16 } 17 17 18 18 // NewPublisher connects to the server at the given address and registers as a publisher ··· 39 39 } 40 40 41 41 // Publish will publish the given message to the server 42 - func (p *Publisher) PublishMessage(message messagebroker.Message) error { 43 - b, err := json.Marshal(message) 44 - if err != nil { 45 - return fmt.Errorf("failed to marshal message: %w", err) 42 + func (p *Publisher) PublishMessage(message Message) error { 43 + op := func(conn net.Conn) error { 44 + // send topic first 45 + topic := fmt.Sprintf("topic:%s", message.Topic) 46 + err := binary.Write(p.conn, binary.BigEndian, uint32(len(topic))) 47 + if err != nil { 48 + return fmt.Errorf("failed to write topic size to server") 49 + } 50 + 51 + _, err = p.conn.Write([]byte(topic)) 52 + if err != nil { 53 + return fmt.Errorf("failed to write topic to server") 54 + } 55 + 56 + err = binary.Write(p.conn, binary.BigEndian, uint32(len(message.Data))) 57 + if err != nil { 58 + return fmt.Errorf("failed to write message size to server") 59 + } 60 + 61 + _, err = p.conn.Write(message.Data) 62 + if err != nil { 63 + return fmt.Errorf("failed to publish data to server") 64 + } 65 + return nil 46 66 } 47 67 48 - err = binary.Write(p.conn, binary.BigEndian, uint32(len(b))) 49 - if err != nil { 50 - return fmt.Errorf("failed to write message size to server") 51 - } 68 + return p.connOperation(op) 69 + } 52 70 53 - _, err = p.conn.Write(b) 54 - if err != nil { 55 - return fmt.Errorf("failed to publish data to server") 56 - } 71 + func (p *Publisher) connOperation(op connOpp) error { 72 + p.connMu.Lock() 73 + defer p.connMu.Unlock() 57 74 58 - return nil 75 + return op(p.conn) 59 76 }
+141 -97
pubsub/subscriber.go
··· 4 4 "context" 5 5 "encoding/binary" 6 6 "encoding/json" 7 + "errors" 7 8 "fmt" 8 - "log/slog" 9 9 "net" 10 + "sync" 10 11 "time" 11 12 12 - "github.com/willdot/messagebroker" 13 13 "github.com/willdot/messagebroker/server" 14 14 ) 15 15 16 + type connOpp func(conn net.Conn) error 17 + 16 18 // Subscriber allows subscriptions to a server and the consumption of messages 17 19 type Subscriber struct { 18 - conn net.Conn 20 + conn net.Conn 21 + connMu sync.Mutex 19 22 } 20 23 21 24 // NewSubscriber will connect to the server at the given address ··· 37 40 38 41 // SubscribeToTopics will subscribe to the provided topics 39 42 func (s *Subscriber) SubscribeToTopics(topicNames []string) error { 40 - err := binary.Write(s.conn, binary.BigEndian, server.Subscribe) 41 - if err != nil { 42 - return fmt.Errorf("failed to subscribe: %w", err) 43 - } 43 + op := func(conn net.Conn) error { 44 + err := binary.Write(conn, binary.BigEndian, server.Subscribe) 45 + if err != nil { 46 + return fmt.Errorf("failed to subscribe: %w", err) 47 + } 44 48 45 - b, err := json.Marshal(topicNames) 46 - if err != nil { 47 - return fmt.Errorf("failed to marshal topic names: %w", err) 48 - } 49 + b, err := json.Marshal(topicNames) 50 + if err != nil { 51 + return fmt.Errorf("failed to marshal topic names: %w", err) 52 + } 49 53 50 - err = binary.Write(s.conn, binary.BigEndian, uint32(len(b))) 51 - if err != nil { 52 - return fmt.Errorf("failed to write topic data length: %w", err) 53 - } 54 + err = binary.Write(conn, binary.BigEndian, uint32(len(b))) 55 + if err != nil { 56 + return fmt.Errorf("failed to write topic data length: %w", err) 57 + } 54 58 55 - _, err = s.conn.Write(b) 56 - if err != nil { 57 - return fmt.Errorf("failed to subscribe to topics: %w", err) 58 - } 59 + _, err = conn.Write(b) 60 + if err != nil { 61 + return fmt.Errorf("failed to subscribe to topics: %w", err) 62 + } 59 63 60 - var resp server.Status 61 - err = binary.Read(s.conn, binary.BigEndian, &resp) 62 - if err != nil { 63 - return fmt.Errorf("failed to read confirmation of subscription: %w", err) 64 - } 64 + var resp server.Status 65 + err = binary.Read(conn, binary.BigEndian, &resp) 66 + if err != nil { 67 + return fmt.Errorf("failed to read confirmation of subscription: %w", err) 68 + } 69 + 70 + if resp == server.Subscribed { 71 + return nil 72 + } 65 73 66 - if resp == server.Subscribed { 67 - return nil 68 - } 74 + var dataLen uint32 75 + err = binary.Read(conn, binary.BigEndian, &dataLen) 76 + if err != nil { 77 + return fmt.Errorf("received status %s:", resp) 78 + } 69 79 70 - var dataLen uint32 71 - err = binary.Read(s.conn, binary.BigEndian, &dataLen) 72 - if err != nil { 73 - return fmt.Errorf("received status %s:", resp) 74 - } 80 + buf := make([]byte, dataLen) 81 + _, err = conn.Read(buf) 82 + if err != nil { 83 + return fmt.Errorf("received status %s:", resp) 84 + } 75 85 76 - buf := make([]byte, dataLen) 77 - _, err = s.conn.Read(buf) 78 - if err != nil { 79 - return fmt.Errorf("received status %s:", resp) 86 + return fmt.Errorf("received status %s - %s", resp, buf) 80 87 } 81 88 82 - return fmt.Errorf("received status %s - %s", resp, buf) 89 + return s.connOperation(op) 83 90 } 84 91 85 92 // UnsubscribeToTopics will unsubscribe to the provided topics 86 93 func (s *Subscriber) UnsubscribeToTopics(topicNames []string) error { 87 - err := binary.Write(s.conn, binary.BigEndian, server.Unsubscribe) 88 - if err != nil { 89 - return fmt.Errorf("failed to unsubscribe: %w", err) 90 - } 94 + op := func(conn net.Conn) error { 95 + err := binary.Write(conn, binary.BigEndian, server.Unsubscribe) 96 + if err != nil { 97 + return fmt.Errorf("failed to unsubscribe: %w", err) 98 + } 91 99 92 - b, err := json.Marshal(topicNames) 93 - if err != nil { 94 - return fmt.Errorf("failed to marshal topic names: %w", err) 95 - } 100 + b, err := json.Marshal(topicNames) 101 + if err != nil { 102 + return fmt.Errorf("failed to marshal topic names: %w", err) 103 + } 96 104 97 - err = binary.Write(s.conn, binary.BigEndian, uint32(len(b))) 98 - if err != nil { 99 - return fmt.Errorf("failed to write topic data length: %w", err) 100 - } 105 + err = binary.Write(conn, binary.BigEndian, uint32(len(b))) 106 + if err != nil { 107 + return fmt.Errorf("failed to write topic data length: %w", err) 108 + } 101 109 102 - _, err = s.conn.Write(b) 103 - if err != nil { 104 - return fmt.Errorf("failed to unsubscribe to topics: %w", err) 105 - } 110 + _, err = conn.Write(b) 111 + if err != nil { 112 + return fmt.Errorf("failed to unsubscribe to topics: %w", err) 113 + } 106 114 107 - var resp server.Status 108 - err = binary.Read(s.conn, binary.BigEndian, &resp) 109 - if err != nil { 110 - return fmt.Errorf("failed to read confirmation of unsubscription: %w", err) 111 - } 115 + var resp server.Status 116 + err = binary.Read(conn, binary.BigEndian, &resp) 117 + if err != nil { 118 + return fmt.Errorf("failed to read confirmation of unsubscription: %w", err) 119 + } 112 120 113 - if resp == server.Unsubscribed { 114 - return nil 115 - } 121 + if resp == server.Unsubscribed { 122 + return nil 123 + } 116 124 117 - var dataLen uint32 118 - err = binary.Read(s.conn, binary.BigEndian, &dataLen) 119 - if err != nil { 120 - return fmt.Errorf("received status %s:", resp) 121 - } 125 + var dataLen uint32 126 + err = binary.Read(conn, binary.BigEndian, &dataLen) 127 + if err != nil { 128 + return fmt.Errorf("received status %s:", resp) 129 + } 122 130 123 - buf := make([]byte, dataLen) 124 - _, err = s.conn.Read(buf) 125 - if err != nil { 126 - return fmt.Errorf("received status %s:", resp) 131 + buf := make([]byte, dataLen) 132 + _, err = conn.Read(buf) 133 + if err != nil { 134 + return fmt.Errorf("received status %s:", resp) 135 + } 136 + 137 + return fmt.Errorf("received status %s - %s", resp, buf) 127 138 } 128 139 129 - return fmt.Errorf("received status %s - %s", resp, buf) 140 + return s.connOperation(op) 130 141 } 131 142 132 143 // Consumer allows the consumption of messages. If during the consumer receiving messages from the 133 144 // server an error occurs, it will be stored in Err 134 145 type Consumer struct { 135 - msgs chan messagebroker.Message 146 + msgs chan Message 136 147 // TODO: better error handling? Maybe a channel of errors? 137 148 Err error 138 149 } 139 150 140 151 // Messages returns a channel in which this consumer will put messages onto. It is safe to range over the channel since it will be closed once 141 152 // the consumer has finished either due to an error or from being cancelled. 142 - func (c *Consumer) Messages() <-chan messagebroker.Message { 153 + func (c *Consumer) Messages() <-chan Message { 143 154 return c.msgs 144 155 } 145 156 ··· 147 158 // to read the messages 148 159 func (s *Subscriber) Consume(ctx context.Context) *Consumer { 149 160 consumer := &Consumer{ 150 - msgs: make(chan messagebroker.Message), 161 + msgs: make(chan Message), 151 162 } 152 163 153 164 go s.consume(ctx, consumer) ··· 174 185 } 175 186 } 176 187 177 - func (s *Subscriber) readMessage() (*messagebroker.Message, error) { 178 - err := s.conn.SetReadDeadline(time.Now().Add(time.Second)) 179 - if err != nil { 180 - return nil, err 188 + func (s *Subscriber) readMessage() (*Message, error) { 189 + var msg *Message 190 + op := func(conn net.Conn) error { 191 + err := s.conn.SetReadDeadline(time.Now().Add(time.Second)) 192 + if err != nil { 193 + return err 194 + } 195 + 196 + var topicLen uint64 197 + err = binary.Read(s.conn, binary.BigEndian, &topicLen) 198 + if err != nil { 199 + // TODO: check if this is needed elsewhere. I'm not sure where the read deadline resets.... 200 + if neterr, ok := err.(net.Error); ok && neterr.Timeout() { 201 + return nil 202 + } 203 + return err 204 + } 205 + 206 + topicBuf := make([]byte, topicLen) 207 + _, err = s.conn.Read(topicBuf) 208 + if err != nil { 209 + return err 210 + } 211 + 212 + var dataLen uint64 213 + err = binary.Read(s.conn, binary.BigEndian, &dataLen) 214 + if err != nil { 215 + return err 216 + } 217 + 218 + if dataLen <= 0 { 219 + return nil 220 + } 221 + 222 + dataBuf := make([]byte, dataLen) 223 + _, err = s.conn.Read(dataBuf) 224 + if err != nil { 225 + return err 226 + } 227 + 228 + msg = &Message{ 229 + Data: dataBuf, 230 + Topic: string(topicBuf), 231 + } 232 + 233 + return nil 234 + 181 235 } 182 236 183 - var dataLen uint64 184 - err = binary.Read(s.conn, binary.BigEndian, &dataLen) 237 + err := s.connOperation(op) 185 238 if err != nil { 186 - if neterr, ok := err.(net.Error); ok && neterr.Timeout() { 239 + var neterr net.Error 240 + if errors.As(err, &neterr) && neterr.Timeout() { 187 241 return nil, nil 188 242 } 189 243 return nil, err 190 244 } 191 245 192 - if dataLen <= 0 { 193 - return nil, nil 194 - } 246 + return msg, err 247 + } 195 248 196 - buf := make([]byte, dataLen) 197 - _, err = s.conn.Read(buf) 198 - if err != nil { 199 - return nil, err 200 - } 249 + func (s *Subscriber) connOperation(op connOpp) error { 250 + s.connMu.Lock() 251 + defer s.connMu.Unlock() 201 252 202 - var msg messagebroker.Message 203 - err = json.Unmarshal(buf, &msg) 204 - if err != nil { 205 - slog.Error("failed to unmarshal message", "error", err) 206 - return nil, nil 207 - } 208 - 209 - return &msg, nil 253 + return op(s.conn) 210 254 }
+19 -22
pubsub/subscriber_test.go
··· 8 8 9 9 "github.com/stretchr/testify/assert" 10 10 "github.com/stretchr/testify/require" 11 - "github.com/willdot/messagebroker" 12 11 13 12 "github.com/willdot/messagebroker/server" 14 13 ) 15 14 16 15 const ( 17 - serverAddr = ":3000" 16 + serverAddr = ":9999" 17 + topicA = "topic a" 18 + topicB = "topic b" 18 19 ) 19 20 20 21 func createServer(t *testing.T) { 21 - server, err := server.New(context.Background(), serverAddr) 22 + server, err := server.New(serverAddr) 22 23 require.NoError(t, err) 23 24 24 25 t.Cleanup(func() { ··· 72 73 sub.Close() 73 74 }) 74 75 75 - topics := []string{"topic a", "topic b"} 76 + topics := []string{topicA, topicB} 76 77 77 78 err = sub.SubscribeToTopics(topics) 78 79 require.NoError(t, err) ··· 88 89 sub.Close() 89 90 }) 90 91 91 - topics := []string{"topic a", "topic b"} 92 + topics := []string{topicA, topicB} 92 93 93 94 err = sub.SubscribeToTopics(topics) 94 95 require.NoError(t, err) 95 96 96 - err = sub.UnsubscribeToTopics([]string{"topic a"}) 97 + err = sub.UnsubscribeToTopics([]string{topicA}) 97 98 require.NoError(t, err) 98 99 99 100 ctx, cancel := context.WithCancel(context.Background()) ··· 104 105 consumer := sub.Consume(ctx) 105 106 require.NoError(t, err) 106 107 107 - var receivedMessages []messagebroker.Message 108 + var receivedMessages []Message 108 109 consumerFinCh := make(chan struct{}) 109 110 go func() { 110 111 for msg := range consumer.Messages() { ··· 118 119 // publish a message to both topics and check the subscriber only gets the message from the 1 topic 119 120 // and not the unsubscribed topic 120 121 121 - publisher, err := NewPublisher("localhost:3000") 122 + publisher, err := NewPublisher("localhost:9999") 122 123 require.NoError(t, err) 123 124 t.Cleanup(func() { 124 125 publisher.Close() 125 126 }) 126 127 127 - msg := messagebroker.Message{ 128 - Topic: "topic a", 128 + msg := Message{ 129 + Topic: topicA, 129 130 Data: []byte("hello world"), 130 131 } 131 132 132 133 err = publisher.PublishMessage(msg) 133 134 require.NoError(t, err) 134 135 135 - msg.Topic = "topic b" 136 + msg.Topic = topicB 136 137 err = publisher.PublishMessage(msg) 137 138 require.NoError(t, err) 138 139 139 140 cancel() 140 141 141 - // give the consumer some time to read the messages -- TODO: make better! 142 - time.Sleep(time.Millisecond * 500) 143 - cancel() 144 - 145 142 select { 146 143 case <-consumerFinCh: 147 144 break ··· 150 147 } 151 148 152 149 assert.Len(t, receivedMessages, 1) 153 - assert.Equal(t, "topic b", receivedMessages[0].Topic) 150 + assert.Equal(t, topicB, receivedMessages[0].Topic) 154 151 } 155 152 156 153 func TestPublishAndSubscribe(t *testing.T) { ··· 163 160 sub.Close() 164 161 }) 165 162 166 - topics := []string{"topic a", "topic b"} 163 + topics := []string{topicA, topicB} 167 164 168 165 err = sub.SubscribeToTopics(topics) 169 166 require.NoError(t, err) ··· 176 173 consumer := sub.Consume(ctx) 177 174 require.NoError(t, err) 178 175 179 - var receivedMessages []messagebroker.Message 176 + var receivedMessages []Message 180 177 181 178 consumerFinCh := make(chan struct{}) 182 179 go func() { ··· 188 185 consumerFinCh <- struct{}{} 189 186 }() 190 187 191 - publisher, err := NewPublisher("localhost:3000") 188 + publisher, err := NewPublisher("localhost:9999") 192 189 require.NoError(t, err) 193 190 t.Cleanup(func() { 194 191 publisher.Close() 195 192 }) 196 193 197 194 // send some messages 198 - sentMessages := make([]messagebroker.Message, 0, 10) 195 + sentMessages := make([]Message, 0, 10) 199 196 for i := 0; i < 10; i++ { 200 - msg := messagebroker.Message{ 201 - Topic: "topic a", 197 + msg := Message{ 198 + Topic: topicA, 202 199 Data: []byte(fmt.Sprintf("message %d", i)), 203 200 } 204 201
-99
server/peer.go
··· 1 - package server 2 - 3 - import ( 4 - "encoding/binary" 5 - "fmt" 6 - "log/slog" 7 - "net" 8 - ) 9 - 10 - type peer struct { 11 - conn net.Conn 12 - } 13 - 14 - func newPeer(conn net.Conn) peer { 15 - return peer{ 16 - conn: conn, 17 - } 18 - } 19 - 20 - // Read wraps the peers underlying connections Read function to satisfy io.Reader 21 - func (p *peer) Read(b []byte) (n int, err error) { 22 - return p.conn.Read(b) 23 - } 24 - 25 - // Write wraps the peers underlying connections Write function to satisfy io.Writer 26 - func (p *peer) Write(b []byte) (n int, err error) { 27 - return p.conn.Write(b) 28 - } 29 - 30 - func (p *peer) addr() net.Addr { 31 - return p.conn.LocalAddr() 32 - } 33 - 34 - func (p *peer) readAction() (Action, error) { 35 - var action Action 36 - err := binary.Read(p.conn, binary.BigEndian, &action) 37 - if err != nil { 38 - return 0, fmt.Errorf("failed to read action from peer: %w", err) 39 - } 40 - 41 - return action, nil 42 - } 43 - 44 - func (p *peer) readDataLength() (uint32, error) { 45 - var dataLen uint32 46 - err := binary.Read(p.conn, binary.BigEndian, &dataLen) 47 - if err != nil { 48 - return 0, fmt.Errorf("failed to read data length from peer: %w", err) 49 - } 50 - 51 - return dataLen, nil 52 - } 53 - 54 - // Status represents the status of a request 55 - type Status uint8 56 - 57 - const ( 58 - Subscribed = 1 59 - Unsubscribed = 2 60 - Error = 3 61 - ) 62 - 63 - func (s Status) String() string { 64 - switch s { 65 - case Subscribed: 66 - return "subsribed" 67 - case Unsubscribed: 68 - return "unsubscribed" 69 - case Error: 70 - return "error" 71 - } 72 - 73 - return "" 74 - } 75 - 76 - func (p *peer) writeStatus(status Status, message string) { 77 - err := binary.Write(p.conn, binary.BigEndian, status) 78 - if err != nil { 79 - slog.Error("failed to write status to peers connection", "error", err, "peer", p.addr()) 80 - return 81 - } 82 - 83 - if message == "" { 84 - return 85 - } 86 - 87 - msgBytes := []byte(message) 88 - err = binary.Write(p.conn, binary.BigEndian, uint32(len(msgBytes))) 89 - if err != nil { 90 - slog.Error("failed to write message length to peers connection", "error", err, "peer", p.addr()) 91 - return 92 - } 93 - 94 - _, err = p.conn.Write(msgBytes) 95 - if err != nil { 96 - slog.Error("failed to write message to peers connection", "error", err, "peer", p.addr()) 97 - return 98 - } 99 - }
+36
server/peer/peer.go
··· 1 + package peer 2 + 3 + import ( 4 + "net" 5 + "sync" 6 + ) 7 + 8 + // Peer represents a remote connection to the server such as a publisher or subscriber 9 + type Peer struct { 10 + conn net.Conn 11 + connMu sync.Mutex 12 + } 13 + 14 + // New returns a new peer 15 + func New(conn net.Conn) *Peer { 16 + return &Peer{ 17 + conn: conn, 18 + } 19 + } 20 + 21 + // Addr returns the peers connections address 22 + func (p *Peer) Addr() net.Addr { 23 + return p.conn.RemoteAddr() 24 + } 25 + 26 + // ConnOpp represents a set of actions on a connection that can be used synchrnously 27 + type ConnOpp func(conn net.Conn) error 28 + 29 + // RunConnOperation will run the provided operation. It ensures that it is the only operation that is being 30 + // run on the connection to ensure any other operations don't get mixed up. 31 + func (p *Peer) RunConnOperation(op ConnOpp) error { 32 + p.connMu.Lock() 33 + defer p.connMu.Unlock() 34 + 35 + return op(p.conn) 36 + }
+248 -101
server/server.go
··· 1 1 package server 2 2 3 3 import ( 4 - "context" 4 + "encoding/binary" 5 5 "encoding/json" 6 6 "errors" 7 7 "fmt" 8 8 "log/slog" 9 9 "net" 10 + "strings" 10 11 "sync" 12 + "time" 11 13 12 - "github.com/willdot/messagebroker" 14 + "github.com/willdot/messagebroker/server/peer" 13 15 ) 14 16 15 17 // Action represents the type of action that a peer requests to do ··· 21 23 Publish Action = 3 22 24 ) 23 25 26 + // Status represents the status of a request 27 + type Status uint8 28 + 29 + const ( 30 + Subscribed = 1 31 + Unsubscribed = 2 32 + Error = 3 33 + ) 34 + 35 + func (s Status) String() string { 36 + switch s { 37 + case Subscribed: 38 + return "subsribed" 39 + case Unsubscribed: 40 + return "unsubscribed" 41 + case Error: 42 + return "error" 43 + } 44 + 45 + return "" 46 + } 47 + 24 48 // Server accepts subscribe and publish connections and passes messages around 25 49 type Server struct { 26 - addr string 50 + Addr string 27 51 lis net.Listener 28 52 29 53 mu sync.Mutex ··· 31 55 } 32 56 33 57 // New creates and starts a new server 34 - func New(ctx context.Context, addr string) (*Server, error) { 35 - lis, err := net.Listen("tcp", addr) 58 + func New(Addr string) (*Server, error) { 59 + lis, err := net.Listen("tcp", Addr) 36 60 if err != nil { 37 61 return nil, fmt.Errorf("failed to listen: %w", err) 38 62 } ··· 42 66 topics: map[string]topic{}, 43 67 } 44 68 45 - go srv.start(ctx) 69 + go srv.start() 46 70 47 71 return srv, nil 48 72 } ··· 52 76 return s.lis.Close() 53 77 } 54 78 55 - func (s *Server) start(ctx context.Context) { 79 + func (s *Server) start() { 56 80 for { 57 81 conn, err := s.lis.Accept() 58 82 if err != nil { ··· 69 93 } 70 94 71 95 func (s *Server) handleConn(conn net.Conn) { 72 - peer := newPeer(conn) 73 - action, err := peer.readAction() 96 + peer := peer.New(conn) 97 + 98 + action, err := readAction(peer, 0) 74 99 if err != nil { 75 - slog.Error("failed to read action from peer", "error", err, "peer", peer.addr()) 100 + slog.Error("failed to read action from peer", "error", err, "peer", peer.Addr()) 76 101 return 77 102 } 78 103 ··· 84 109 case Publish: 85 110 s.handlePublish(peer) 86 111 default: 87 - slog.Error("unknown action", "action", action, "peer", peer.addr()) 88 - peer.writeStatus(Error, "unknown action") 112 + slog.Error("unknown action", "action", action, "peer", peer.Addr()) 113 + writeInvalidAction(peer) 89 114 } 90 115 } 91 116 92 - func (s *Server) handleSubscribe(peer peer) { 117 + func (s *Server) handleSubscribe(peer *peer.Peer) { 93 118 // subscribe the peer to the topic 94 119 s.subscribePeerToTopic(peer) 95 120 96 121 // keep handling the peers connection, getting the action from the peer when it wishes to do something else. 97 122 // once the peers connection ends, it will be unsubscribed from all topics and returned 98 123 for { 99 - action, err := peer.readAction() 124 + action, err := readAction(peer, time.Millisecond*100) 100 125 if err != nil { 126 + var neterr net.Error 127 + if errors.As(err, &neterr) && neterr.Timeout() { 128 + time.Sleep(time.Second) 129 + continue 130 + } 101 131 // TODO: see if there's a way to check if the peers connection has been ended etc 102 - slog.Error("failed to read action from subscriber", "error", err, "peer", peer.addr()) 132 + slog.Error("failed to read action from subscriber", "error", err, "peer", peer.Addr()) 103 133 104 - s.unsubscribePeerFromAllTopics(peer) 134 + s.unsubscribePeerFromAllTopics(*peer) 105 135 106 136 return 107 137 } ··· 112 142 case Unsubscribe: 113 143 s.handleUnsubscribe(peer) 114 144 default: 115 - slog.Error("unknown action for subscriber", "action", action, "peer", peer.addr()) 116 - peer.writeStatus(Error, "unknown action") 145 + slog.Error("unknown action for subscriber", "action", action, "peer", peer.Addr()) 146 + writeInvalidAction(peer) 117 147 continue 118 148 } 119 149 } 120 150 } 121 151 122 - func (s *Server) subscribePeerToTopic(peer peer) { 123 - // get the topics the peer wishes to subscribe to 124 - dataLen, err := peer.readDataLength() 125 - if err != nil { 126 - slog.Error(err.Error(), "peer", peer.addr()) 127 - peer.writeStatus(Error, "invalid data length of topics provided") 128 - return 129 - } 130 - if dataLen == 0 { 131 - peer.writeStatus(Error, "data length of topics is 0") 132 - return 133 - } 152 + func (s *Server) subscribePeerToTopic(peer *peer.Peer) { 153 + op := func(conn net.Conn) error { 154 + // get the topics the peer wishes to subscribe to 155 + dataLen, err := dataLength(conn) 156 + if err != nil { 157 + slog.Error(err.Error(), "peer", peer.Addr()) 158 + writeStatus(Error, "invalid data length of topics provided", conn) 159 + return nil 160 + } 161 + if dataLen == 0 { 162 + writeStatus(Error, "data length of topics is 0", conn) 163 + return nil 164 + } 134 165 135 - buf := make([]byte, dataLen) 136 - _, err = peer.Read(buf) 137 - if err != nil { 138 - slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.addr()) 139 - peer.writeStatus(Error, "failed to read topic data") 140 - return 141 - } 166 + buf := make([]byte, dataLen) 167 + _, err = conn.Read(buf) 168 + if err != nil { 169 + slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.Addr()) 170 + writeStatus(Error, "failed to read topic data", conn) 171 + return nil 172 + } 142 173 143 - var topics []string 144 - err = json.Unmarshal(buf, &topics) 145 - if err != nil { 146 - slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.addr()) 147 - peer.writeStatus(Error, "invalid topic data provided") 148 - return 149 - } 174 + var topics []string 175 + err = json.Unmarshal(buf, &topics) 176 + if err != nil { 177 + slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.Addr()) 178 + writeStatus(Error, "invalid topic data provided", conn) 179 + return nil 180 + } 150 181 151 - s.subscribeToTopics(peer, topics) 152 - peer.writeStatus(Subscribed, "") 153 - } 182 + s.subscribeToTopics(peer, topics) 183 + writeStatus(Subscribed, "", conn) 154 184 155 - func (s *Server) handleUnsubscribe(peer peer) { 156 - // get the topics the peer wishes to unsubscribe from 157 - dataLen, err := peer.readDataLength() 158 - if err != nil { 159 - slog.Error(err.Error(), "peer", peer.addr()) 160 - peer.writeStatus(Error, "invalid data length of topics provided") 161 - return 162 - } 163 - if dataLen == 0 { 164 - peer.writeStatus(Error, "data length of topics is 0") 165 - return 185 + return nil 166 186 } 167 187 168 - buf := make([]byte, dataLen) 169 - _, err = peer.Read(buf) 170 - if err != nil { 171 - slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.addr()) 172 - peer.writeStatus(Error, "failed to read topic data") 173 - return 174 - } 175 - 176 - var topics []string 177 - err = json.Unmarshal(buf, &topics) 178 - if err != nil { 179 - slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.addr()) 180 - peer.writeStatus(Error, "invalid topic data provided") 181 - return 182 - } 183 - 184 - s.unsubscribeToTopics(peer, topics) 185 - peer.writeStatus(Unsubscribed, "") 188 + _ = peer.RunConnOperation(op) 186 189 } 187 190 188 - func (s *Server) handlePublish(peer peer) { 189 - for { 190 - dataLen, err := peer.readDataLength() 191 + func (s *Server) handleUnsubscribe(peer *peer.Peer) { 192 + op := func(conn net.Conn) error { 193 + // get the topics the peer wishes to unsubscribe from 194 + dataLen, err := dataLength(conn) 191 195 if err != nil { 192 - slog.Error(err.Error(), "peer", peer.addr()) 193 - peer.writeStatus(Error, "invalid data length of data provided") 194 - return 196 + slog.Error(err.Error(), "peer", peer.Addr()) 197 + writeStatus(Error, "invalid data length of topics provided", conn) 198 + return nil 195 199 } 196 200 if dataLen == 0 { 197 - continue 201 + writeStatus(Error, "data length of topics is 0", conn) 202 + return nil 198 203 } 199 204 200 205 buf := make([]byte, dataLen) 201 - _, err = peer.Read(buf) 206 + _, err = conn.Read(buf) 202 207 if err != nil { 203 - slog.Error("failed to read data from peer", "error", err, "peer", peer.addr()) 204 - peer.writeStatus(Error, "failed to read data") 205 - return 208 + slog.Error("failed to read subscibers topic data", "error", err, "peer", peer.Addr()) 209 + writeStatus(Error, "failed to read topic data", conn) 210 + return nil 206 211 } 207 212 208 - var msg messagebroker.Message 209 - err = json.Unmarshal(buf, &msg) 213 + var topics []string 214 + err = json.Unmarshal(buf, &topics) 210 215 if err != nil { 211 - slog.Error("failed to unmarshal data to message", "error", err, "peer", peer.addr()) 212 - peer.writeStatus(Error, "invalid message") 216 + slog.Error("failed to unmarshal subscibers topic data", "error", err, "peer", peer.Addr()) 217 + writeStatus(Error, "invalid topic data provided", conn) 218 + return nil 219 + } 220 + 221 + s.unsubscribeToTopics(*peer, topics) 222 + writeStatus(Unsubscribed, "", conn) 223 + 224 + return nil 225 + } 226 + 227 + _ = peer.RunConnOperation(op) 228 + } 229 + 230 + type messageToSend struct { 231 + topic string 232 + data []byte 233 + } 234 + 235 + func (s *Server) handlePublish(peer *peer.Peer) { 236 + for { 237 + var message *messageToSend 238 + 239 + op := func(conn net.Conn) error { 240 + dataLen, err := dataLength(conn) 241 + if err != nil { 242 + slog.Error("failed to read data length", "error", err, "peer", peer.Addr()) 243 + writeStatus(Error, "invalid data length of data provided", conn) 244 + return nil 245 + } 246 + if dataLen == 0 { 247 + return nil 248 + } 249 + topicBuf := make([]byte, dataLen) 250 + _, err = conn.Read(topicBuf) 251 + if err != nil { 252 + slog.Error("failed to read topic from peer", "error", err, "peer", peer.Addr()) 253 + writeStatus(Error, "failed to read topic", conn) 254 + return nil 255 + } 256 + 257 + topicStr := string(topicBuf) 258 + if !strings.HasPrefix(topicStr, "topic:") { 259 + slog.Error("topic data does not contain topic prefix", "peer", peer.Addr()) 260 + writeStatus(Error, "topic data does not contain 'topic:' prefix", conn) 261 + return nil 262 + } 263 + topicStr = strings.TrimPrefix(topicStr, "topic:") 264 + 265 + dataLen, err = dataLength(conn) 266 + if err != nil { 267 + slog.Error(err.Error(), "peer", peer.Addr()) 268 + writeStatus(Error, "invalid data length of data provided", conn) 269 + return nil 270 + } 271 + if dataLen == 0 { 272 + return nil 273 + } 274 + 275 + dataBuf := make([]byte, dataLen) 276 + _, err = conn.Read(dataBuf) 277 + if err != nil { 278 + slog.Error("failed to read data from peer", "error", err, "peer", peer.Addr()) 279 + writeStatus(Error, "failed to read data", conn) 280 + return nil 281 + } 282 + 283 + message = &messageToSend{ 284 + topic: topicStr, 285 + data: dataBuf, 286 + } 287 + return nil 288 + } 289 + 290 + _ = peer.RunConnOperation(op) 291 + 292 + if message == nil { 213 293 continue 214 294 } 295 + // TODO: this can be done in a go routine because once we've got the message from the publisher, the publisher 296 + // doesn't need to wait for us to send the message to all peers 215 297 216 - topic := s.getTopic(msg.Topic) 298 + topic := s.getTopic(message.topic) 217 299 if topic != nil { 218 - topic.sendMessageToSubscribers(msg) 300 + topic.sendMessageToSubscribers(message.data) 219 301 } 220 302 } 221 303 } 222 304 223 - func (s *Server) subscribeToTopics(peer peer, topics []string) { 305 + func (s *Server) subscribeToTopics(peer *peer.Peer, topics []string) { 224 306 for _, topic := range topics { 225 307 s.addSubsciberToTopic(topic, peer) 226 308 } 227 309 } 228 310 229 - func (s *Server) addSubsciberToTopic(topicName string, peer peer) { 311 + func (s *Server) addSubsciberToTopic(topicName string, peer *peer.Peer) { 230 312 s.mu.Lock() 231 313 defer s.mu.Unlock() 232 314 ··· 235 317 t = newTopic(topicName) 236 318 } 237 319 238 - t.subscriptions[peer.addr()] = subscriber{ 320 + t.subscriptions[peer.Addr()] = subscriber{ 239 321 peer: peer, 240 322 currentOffset: 0, 241 323 } ··· 243 325 s.topics[topicName] = t 244 326 } 245 327 246 - func (s *Server) unsubscribeToTopics(peer peer, topics []string) { 328 + func (s *Server) unsubscribeToTopics(peer peer.Peer, topics []string) { 247 329 for _, topic := range topics { 248 330 s.removeSubsciberFromTopic(topic, peer) 249 331 } 250 332 } 251 333 252 - func (s *Server) removeSubsciberFromTopic(topicName string, peer peer) { 334 + func (s *Server) removeSubsciberFromTopic(topicName string, peer peer.Peer) { 253 335 s.mu.Lock() 254 336 defer s.mu.Unlock() 255 337 ··· 258 340 return 259 341 } 260 342 261 - delete(t.subscriptions, peer.addr()) 343 + delete(t.subscriptions, peer.Addr()) 262 344 } 263 345 264 - func (s *Server) unsubscribePeerFromAllTopics(peer peer) { 346 + func (s *Server) unsubscribePeerFromAllTopics(peer peer.Peer) { 265 347 s.mu.Lock() 266 348 defer s.mu.Unlock() 267 349 268 350 for _, topic := range s.topics { 269 - delete(topic.subscriptions, peer.addr()) 351 + delete(topic.subscriptions, peer.Addr()) 270 352 } 271 353 } 272 354 ··· 280 362 281 363 return nil 282 364 } 365 + 366 + func readAction(peer *peer.Peer, timeout time.Duration) (Action, error) { 367 + var action Action 368 + op := func(conn net.Conn) error { 369 + if timeout > 0 { 370 + conn.SetReadDeadline(time.Now().Add(timeout)) 371 + } 372 + 373 + err := binary.Read(conn, binary.BigEndian, &action) 374 + if err != nil { 375 + return err 376 + } 377 + return nil 378 + } 379 + 380 + err := peer.RunConnOperation(op) 381 + if err != nil { 382 + return 0, fmt.Errorf("failed to read action from peer: %w", err) 383 + } 384 + 385 + return action, nil 386 + } 387 + 388 + func writeInvalidAction(peer *peer.Peer) { 389 + op := func(conn net.Conn) error { 390 + writeStatus(Error, "unknown action", conn) 391 + return nil 392 + } 393 + 394 + _ = peer.RunConnOperation(op) 395 + } 396 + 397 + func dataLength(conn net.Conn) (uint32, error) { 398 + var dataLen uint32 399 + err := binary.Read(conn, binary.BigEndian, &dataLen) 400 + if err != nil { 401 + return 0, err 402 + } 403 + return dataLen, nil 404 + } 405 + 406 + func writeStatus(status Status, message string, conn net.Conn) { 407 + err := binary.Write(conn, binary.BigEndian, status) 408 + if err != nil { 409 + slog.Error("failed to write status to peers connection", "error", err, "peer", conn.RemoteAddr()) 410 + return 411 + } 412 + 413 + if message == "" { 414 + return 415 + } 416 + 417 + msgBytes := []byte(message) 418 + err = binary.Write(conn, binary.BigEndian, uint32(len(msgBytes))) 419 + if err != nil { 420 + slog.Error("failed to write message length to peers connection", "error", err, "peer", conn.RemoteAddr()) 421 + return 422 + } 423 + 424 + _, err = conn.Write(msgBytes) 425 + if err != nil { 426 + slog.Error("failed to write message to peers connection", "error", err, "peer", conn.RemoteAddr()) 427 + return 428 + } 429 + }
+82 -62
server/server_test.go
··· 1 1 package server 2 2 3 3 import ( 4 - "context" 5 4 "encoding/binary" 6 5 "encoding/json" 7 6 "fmt" ··· 11 10 12 11 "github.com/stretchr/testify/assert" 13 12 "github.com/stretchr/testify/require" 14 - "github.com/willdot/messagebroker" 13 + ) 14 + 15 + const ( 16 + topicA = "topic a" 17 + topicB = "topic b" 18 + topicC = "topic c" 19 + 20 + serverAddr = ":6666" 15 21 ) 16 22 17 23 func createServer(t *testing.T) *Server { 18 - srv, err := New(context.Background(), ":3000") 24 + srv, err := New(serverAddr) 19 25 require.NoError(t, err) 20 26 21 27 t.Cleanup(func() { ··· 36 42 } 37 43 38 44 func createConnectionAndSubscribe(t *testing.T, topics []string) net.Conn { 39 - conn, err := net.Dial("tcp", "localhost:3000") 45 + conn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 40 46 require.NoError(t, err) 41 47 42 48 err = binary.Write(conn, binary.BigEndian, Subscribe) ··· 64 70 func TestSubscribeToTopics(t *testing.T) { 65 71 // create a server with an existing topic so we can test subscribing to a new and 66 72 // existing topic 67 - srv := createServerWithExistingTopic(t, "topic a") 73 + srv := createServerWithExistingTopic(t, topicA) 68 74 69 - _ = createConnectionAndSubscribe(t, []string{"topic a", "topic b"}) 75 + _ = createConnectionAndSubscribe(t, []string{topicA, topicB}) 70 76 71 77 assert.Len(t, srv.topics, 2) 72 - assert.Len(t, srv.topics["topic a"].subscriptions, 1) 73 - assert.Len(t, srv.topics["topic b"].subscriptions, 1) 78 + assert.Len(t, srv.topics[topicA].subscriptions, 1) 79 + assert.Len(t, srv.topics[topicB].subscriptions, 1) 74 80 } 75 81 76 82 func TestUnsubscribesFromTopic(t *testing.T) { 77 - srv := createServerWithExistingTopic(t, "topic a") 83 + srv := createServerWithExistingTopic(t, topicA) 78 84 79 - conn := createConnectionAndSubscribe(t, []string{"topic a", "topic b", "topic c"}) 85 + conn := createConnectionAndSubscribe(t, []string{topicA, topicB, topicC}) 80 86 81 87 assert.Len(t, srv.topics, 3) 82 - assert.Len(t, srv.topics["topic a"].subscriptions, 1) 83 - assert.Len(t, srv.topics["topic b"].subscriptions, 1) 84 - assert.Len(t, srv.topics["topic c"].subscriptions, 1) 88 + assert.Len(t, srv.topics[topicA].subscriptions, 1) 89 + assert.Len(t, srv.topics[topicB].subscriptions, 1) 90 + assert.Len(t, srv.topics[topicC].subscriptions, 1) 85 91 86 92 err := binary.Write(conn, binary.BigEndian, Unsubscribe) 87 93 require.NoError(t, err) 88 94 89 - topics := []string{"topic a", "topic b"} 95 + topics := []string{topicA, topicB} 90 96 rawTopics, err := json.Marshal(topics) 91 97 require.NoError(t, err) 92 98 ··· 104 110 assert.Equal(t, expectedRes, int(resp)) 105 111 106 112 assert.Len(t, srv.topics, 3) 107 - assert.Len(t, srv.topics["topic a"].subscriptions, 0) 108 - assert.Len(t, srv.topics["topic b"].subscriptions, 0) 109 - assert.Len(t, srv.topics["topic c"].subscriptions, 1) 113 + assert.Len(t, srv.topics[topicA].subscriptions, 0) 114 + assert.Len(t, srv.topics[topicB].subscriptions, 0) 115 + assert.Len(t, srv.topics[topicC].subscriptions, 1) 110 116 } 111 117 112 118 func TestSubscriberClosesWithoutUnsubscribing(t *testing.T) { 113 119 srv := createServer(t) 114 120 115 - conn := createConnectionAndSubscribe(t, []string{"topic a", "topic b"}) 121 + conn := createConnectionAndSubscribe(t, []string{topicA, topicB}) 116 122 117 123 assert.Len(t, srv.topics, 2) 118 - assert.Len(t, srv.topics["topic a"].subscriptions, 1) 119 - assert.Len(t, srv.topics["topic b"].subscriptions, 1) 124 + assert.Len(t, srv.topics[topicA].subscriptions, 1) 125 + assert.Len(t, srv.topics[topicB].subscriptions, 1) 120 126 121 127 // close the conn 122 128 err := conn.Close() 123 129 require.NoError(t, err) 124 130 125 - publisherConn, err := net.Dial("tcp", "localhost:3000") 131 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 126 132 require.NoError(t, err) 127 133 128 134 err = binary.Write(publisherConn, binary.BigEndian, Publish) ··· 137 143 require.Equal(t, len(data), n) 138 144 139 145 assert.Len(t, srv.topics, 2) 140 - assert.Len(t, srv.topics["topic a"].subscriptions, 0) 141 - assert.Len(t, srv.topics["topic b"].subscriptions, 0) 146 + assert.Len(t, srv.topics[topicA].subscriptions, 0) 147 + assert.Len(t, srv.topics[topicB].subscriptions, 0) 142 148 } 143 149 144 150 func TestInvalidAction(t *testing.T) { 145 151 _ = createServer(t) 146 152 147 - conn, err := net.Dial("tcp", "localhost:3000") 153 + conn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 148 154 require.NoError(t, err) 149 155 150 156 err = binary.Write(conn, binary.BigEndian, uint8(99)) ··· 170 176 assert.Equal(t, expectedMessage, string(buf)) 171 177 } 172 178 173 - func TestInvalidMessagePublished(t *testing.T) { 179 + func TestInvalidTopicDataPublished(t *testing.T) { 174 180 _ = createServer(t) 175 181 176 - publisherConn, err := net.Dial("tcp", "localhost:3000") 182 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 177 183 require.NoError(t, err) 178 184 179 185 err = binary.Write(publisherConn, binary.BigEndian, Publish) 180 186 require.NoError(t, err) 181 187 182 - // send some data 183 - data := []byte("this isn't wrapped in a message type") 184 - 185 - // send data length first 186 - err = binary.Write(publisherConn, binary.BigEndian, uint32(len(data))) 188 + // send topic 189 + topic := topicA 190 + err = binary.Write(publisherConn, binary.BigEndian, uint32(len(topic))) 187 191 require.NoError(t, err) 188 - n, err := publisherConn.Write(data) 192 + _, err = publisherConn.Write([]byte(topic)) 189 193 require.NoError(t, err) 190 - require.Equal(t, len(data), n) 191 194 192 195 expectedRes := Error 193 196 ··· 196 199 197 200 assert.Equal(t, expectedRes, int(resp)) 198 201 199 - expectedMessage := "invalid message" 202 + expectedMessage := "topic data does not contain 'topic:' prefix" 200 203 201 204 var dataLen uint32 202 205 err = binary.Read(publisherConn, binary.BigEndian, &dataLen) ··· 212 215 func TestSendsDataToTopicSubscribers(t *testing.T) { 213 216 _ = createServer(t) 214 217 215 - subscribers := make([]net.Conn, 0, 5) 216 - for i := 0; i < 5; i++ { 217 - subscriberConn := createConnectionAndSubscribe(t, []string{"topic a", "topic b"}) 218 + subscribers := make([]net.Conn, 0, 10) 219 + for i := 0; i < 10; i++ { 220 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}) 218 221 219 222 subscribers = append(subscribers, subscriberConn) 220 223 } 221 224 222 - publisherConn, err := net.Dial("tcp", "localhost:3000") 225 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 223 226 require.NoError(t, err) 224 227 225 228 err = binary.Write(publisherConn, binary.BigEndian, Publish) 226 229 require.NoError(t, err) 227 230 228 - // send a message 229 - msg := messagebroker.Message{ 230 - Topic: "topic a", 231 - Data: []byte("hello world"), 232 - } 231 + topic := fmt.Sprintf("topic:%s", topicA) 232 + messageData := "hello world" 233 233 234 - rawMsg, err := json.Marshal(msg) 234 + // send topic first 235 + err = binary.Write(publisherConn, binary.BigEndian, uint32(len(topic))) 236 + require.NoError(t, err) 237 + _, err = publisherConn.Write([]byte(topic)) 235 238 require.NoError(t, err) 236 239 237 - // send data length first 238 - err = binary.Write(publisherConn, binary.BigEndian, uint32(len(rawMsg))) 240 + // now send the data 241 + err = binary.Write(publisherConn, binary.BigEndian, uint32(len(messageData))) 239 242 require.NoError(t, err) 240 - n, err := publisherConn.Write(rawMsg) 243 + n, err := publisherConn.Write([]byte(messageData)) 241 244 require.NoError(t, err) 242 - require.Equal(t, len(rawMsg), n) 245 + require.Equal(t, len(messageData), n) 243 246 244 247 // check the subsribers got the data 245 248 for _, conn := range subscribers { 249 + var topicLen uint64 250 + err = binary.Read(conn, binary.BigEndian, &topicLen) 251 + require.NoError(t, err) 252 + 253 + topicBuf := make([]byte, topicLen) 254 + _, err = conn.Read(topicBuf) 255 + require.NoError(t, err) 256 + assert.Equal(t, topicA, string(topicBuf)) 246 257 247 258 var dataLen uint64 248 259 err = binary.Read(conn, binary.BigEndian, &dataLen) ··· 253 264 require.NoError(t, err) 254 265 require.Equal(t, int(dataLen), n) 255 266 256 - assert.Equal(t, rawMsg, buf) 267 + assert.Equal(t, messageData, string(buf)) 257 268 } 258 269 } 259 270 260 271 func TestPublishMultipleTimes(t *testing.T) { 261 272 _ = createServer(t) 262 273 263 - publisherConn, err := net.Dial("tcp", "localhost:3000") 274 + publisherConn, err := net.Dial("tcp", fmt.Sprintf("localhost%s", serverAddr)) 264 275 require.NoError(t, err) 265 276 266 277 err = binary.Write(publisherConn, binary.BigEndian, Publish) ··· 268 279 269 280 messages := make([][]byte, 0, 10) 270 281 for i := 0; i < 10; i++ { 271 - msg := messagebroker.Message{ 272 - Topic: "topic a", 273 - Data: []byte(fmt.Sprintf("message %d", i)), 274 - } 275 - 276 - rawMsg, err := json.Marshal(msg) 277 - require.NoError(t, err) 278 - 279 - messages = append(messages, rawMsg) 282 + messages = append(messages, []byte(fmt.Sprintf("message %d", i))) 280 283 } 281 284 282 285 subscribeFinCh := make(chan struct{}) 283 286 // create a subscriber that will read messages 284 - subscriberConn := createConnectionAndSubscribe(t, []string{"topic a", "topic b"}) 287 + subscriberConn := createConnectionAndSubscribe(t, []string{topicA, topicB}) 285 288 go func() { 286 289 // check subscriber got all messages 287 290 for _, msg := range messages { 291 + var topicLen uint64 292 + err = binary.Read(subscriberConn, binary.BigEndian, &topicLen) 293 + require.NoError(t, err) 294 + 295 + topicBuf := make([]byte, topicLen) 296 + _, err = subscriberConn.Read(topicBuf) 297 + require.NoError(t, err) 298 + assert.Equal(t, topicA, string(topicBuf)) 299 + 288 300 var dataLen uint64 289 301 err = binary.Read(subscriberConn, binary.BigEndian, &dataLen) 290 302 require.NoError(t, err) ··· 300 312 subscribeFinCh <- struct{}{} 301 313 }() 302 314 315 + topic := fmt.Sprintf("topic:%s", topicA) 316 + 303 317 // send multiple messages 304 318 for _, msg := range messages { 305 - // send data length first 319 + // send topic first 320 + err = binary.Write(publisherConn, binary.BigEndian, uint32(len(topic))) 321 + require.NoError(t, err) 322 + _, err = publisherConn.Write([]byte(topic)) 323 + require.NoError(t, err) 324 + 325 + // now send the data 306 326 err = binary.Write(publisherConn, binary.BigEndian, uint32(len(msg))) 307 327 require.NoError(t, err) 308 - n, err := publisherConn.Write(msg) 328 + n, err := publisherConn.Write([]byte(msg)) 309 329 require.NoError(t, err) 310 330 require.Equal(t, len(msg), n) 311 331 }
-26
server/subscriber.go
··· 1 - package server 2 - 3 - import ( 4 - "encoding/binary" 5 - "fmt" 6 - ) 7 - 8 - type subscriber struct { 9 - peer peer 10 - currentOffset int 11 - } 12 - 13 - func (s *subscriber) sendMessage(msg []byte) error { 14 - dataLen := uint64(len(msg)) 15 - 16 - err := binary.Write(&s.peer, binary.BigEndian, dataLen) 17 - if err != nil { 18 - return fmt.Errorf("failed to send data length: %w", err) 19 - } 20 - 21 - _, err = s.peer.Write(msg) 22 - if err != nil { 23 - return fmt.Errorf("failed to write to peer: %w", err) 24 - } 25 - return nil 26 - }
+38 -10
server/topic.go
··· 1 1 package server 2 2 3 3 import ( 4 - "encoding/json" 4 + "encoding/binary" 5 + "fmt" 5 6 "log/slog" 6 7 "net" 7 8 "sync" 8 9 9 - "github.com/willdot/messagebroker" 10 + "github.com/willdot/messagebroker/server/peer" 10 11 ) 11 12 12 13 type topic struct { ··· 15 16 mu sync.Mutex 16 17 } 17 18 19 + type subscriber struct { 20 + peer *peer.Peer 21 + currentOffset int 22 + } 23 + 18 24 func newTopic(name string) topic { 19 25 return topic{ 20 26 name: name, ··· 30 36 delete(t.subscriptions, addr) 31 37 } 32 38 33 - func (t *topic) sendMessageToSubscribers(msg messagebroker.Message) { 39 + func (t *topic) sendMessageToSubscribers(msgData []byte) { 34 40 t.mu.Lock() 35 41 subscribers := t.subscriptions 36 42 t.mu.Unlock() 37 43 38 - msgData, err := json.Marshal(msg) 39 - if err != nil { 40 - slog.Error("failed to marshal message for subscribers", "error", err) 41 - } 42 - 43 44 for addr, subscriber := range subscribers { 44 - err := subscriber.sendMessage(msgData) 45 + err := subscriber.peer.RunConnOperation(sendMessageOp(t.name, msgData)) 45 46 if err != nil { 46 47 slog.Error("failed to send to message", "error", err, "peer", addr) 47 - continue 48 + return 48 49 } 49 50 } 50 51 } 52 + 53 + func sendMessageOp(topic string, data []byte) peer.ConnOpp { 54 + return func(conn net.Conn) error { 55 + topicLen := uint64(len(topic)) 56 + err := binary.Write(conn, binary.BigEndian, topicLen) 57 + if err != nil { 58 + return fmt.Errorf("failed to send topic length: %w", err) 59 + } 60 + _, err = conn.Write([]byte(topic)) 61 + if err != nil { 62 + return fmt.Errorf("failed to send topic: %w", err) 63 + } 64 + 65 + dataLen := uint64(len(data)) 66 + 67 + err = binary.Write(conn, binary.BigEndian, dataLen) 68 + if err != nil { 69 + return fmt.Errorf("failed to send data length: %w", err) 70 + } 71 + 72 + _, err = conn.Write(data) 73 + if err != nil { 74 + return fmt.Errorf("failed to write to peer: %w", err) 75 + } 76 + return nil 77 + } 78 + }