this repo has no description
1package main
2
3import (
4 "encoding/json"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "os"
10 "runtime"
11 "sync"
12 "sync/atomic"
13 "time"
14
15 "github.com/hashicorp/serf/serf"
16)
17
18type ThroughputResult struct {
19 Implementation string `json:"implementation"`
20 NumNodes int `json:"num_nodes"`
21 DurationNs int64 `json:"duration_ns"`
22 MsgRate int `json:"msg_rate"`
23 BroadcastsSent int64 `json:"broadcasts_sent"`
24 BroadcastsReceived int64 `json:"broadcasts_received"`
25 MsgsPerSec float64 `json:"msgs_per_sec"`
26 CPUCores int `json:"cpu_cores"`
27}
28
29type serfThroughputHandler struct {
30 received atomic.Int64
31 memberCh chan serf.MemberEvent
32}
33
34func (h *serfThroughputHandler) HandleEvent(e serf.Event) {
35 switch evt := e.(type) {
36 case serf.MemberEvent:
37 select {
38 case h.memberCh <- evt:
39 default:
40 }
41 case serf.UserEvent:
42 if evt.Name == "bench" {
43 h.received.Add(1)
44 }
45 }
46}
47
48func createSerfNode(name string, bindPort, rpcPort int, handler *serfThroughputHandler) (*serf.Serf, error) {
49 cfg := serf.DefaultConfig()
50 cfg.NodeName = name
51 cfg.MemberlistConfig.BindAddr = "127.0.0.1"
52 cfg.MemberlistConfig.BindPort = bindPort
53 cfg.MemberlistConfig.AdvertisePort = bindPort
54 cfg.MemberlistConfig.GossipInterval = 50 * time.Millisecond
55 cfg.MemberlistConfig.ProbeInterval = 200 * time.Millisecond
56 cfg.MemberlistConfig.PushPullInterval = 30 * time.Second
57 cfg.MemberlistConfig.GossipNodes = 3
58 cfg.LogOutput = io.Discard
59
60 eventCh := make(chan serf.Event, 256)
61 cfg.EventCh = eventCh
62
63 s, err := serf.Create(cfg)
64 if err != nil {
65 return nil, err
66 }
67
68 go func() {
69 for e := range eventCh {
70 handler.HandleEvent(e)
71 }
72 }()
73
74 return s, nil
75}
76
77func waitForMembers(s *serf.Serf, expected int, timeout time.Duration) bool {
78 deadline := time.Now().Add(timeout)
79 for time.Now().Before(deadline) {
80 if len(s.Members()) >= expected {
81 return true
82 }
83 time.Sleep(50 * time.Millisecond)
84 }
85 return false
86}
87
88func runThroughputBenchmark(numNodes int, duration time.Duration, msgRate int) (*ThroughputResult, error) {
89 nodes := make([]*serf.Serf, numNodes)
90 handlers := make([]*serfThroughputHandler, numNodes)
91
92 baseBindPort := 28946
93
94 for i := 0; i < numNodes; i++ {
95 handlers[i] = &serfThroughputHandler{
96 memberCh: make(chan serf.MemberEvent, 100),
97 }
98
99 var err error
100 nodes[i], err = createSerfNode(
101 fmt.Sprintf("node-%d", i),
102 baseBindPort+i,
103 0,
104 handlers[i],
105 )
106 if err != nil {
107 for j := 0; j < i; j++ {
108 nodes[j].Shutdown()
109 }
110 return nil, fmt.Errorf("failed to create node %d: %w", i, err)
111 }
112 }
113
114 for i := 1; i < numNodes; i++ {
115 addr := fmt.Sprintf("127.0.0.1:%d", baseBindPort)
116 _, err := nodes[i].Join([]string{addr}, false)
117 if err != nil {
118 log.Printf("Warning: node %d failed to join: %v", i, err)
119 }
120 }
121
122 for i := 0; i < numNodes; i++ {
123 if !waitForMembers(nodes[i], numNodes, 10*time.Second) {
124 log.Printf("Warning: Node %d did not see all %d nodes", i, numNodes)
125 }
126 }
127
128 time.Sleep(500 * time.Millisecond)
129
130 var totalSent atomic.Int64
131 stopCh := make(chan struct{})
132 var wg sync.WaitGroup
133
134 msgInterval := time.Duration(float64(time.Second) / float64(msgRate))
135 payload := make([]byte, 64)
136 for i := 0; i < 64; i++ {
137 payload[i] = 'x'
138 }
139
140 startTime := time.Now()
141
142 for i, n := range nodes {
143 wg.Add(1)
144 go func(node *serf.Serf, idx int) {
145 defer wg.Done()
146 ticker := time.NewTicker(msgInterval)
147 defer ticker.Stop()
148 for {
149 select {
150 case <-ticker.C:
151 err := node.UserEvent("bench", payload, false)
152 if err == nil {
153 totalSent.Add(1)
154 }
155 case <-stopCh:
156 return
157 }
158 }
159 }(n, i)
160 }
161
162 time.Sleep(duration)
163 close(stopCh)
164 wg.Wait()
165
166 elapsed := time.Since(startTime)
167
168 var totalReceived int64
169 for _, h := range handlers {
170 totalReceived += h.received.Load()
171 }
172
173 for _, n := range nodes {
174 n.Shutdown()
175 }
176
177 msgsPerSec := float64(totalReceived) / elapsed.Seconds()
178
179 return &ThroughputResult{
180 Implementation: "serf",
181 NumNodes: numNodes,
182 DurationNs: duration.Nanoseconds(),
183 MsgRate: msgRate,
184 BroadcastsSent: totalSent.Load(),
185 BroadcastsReceived: totalReceived,
186 MsgsPerSec: msgsPerSec,
187 CPUCores: runtime.NumCPU(),
188 }, nil
189}
190
191func main() {
192 numNodes := flag.Int("nodes", 5, "number of nodes")
193 durationSec := flag.Int("duration", 10, "benchmark duration in seconds")
194 msgRate := flag.Int("rate", 100, "messages per second per node")
195 outputJSON := flag.Bool("json", false, "output as JSON")
196 flag.Parse()
197
198 result, err := runThroughputBenchmark(*numNodes, time.Duration(*durationSec)*time.Second, *msgRate)
199 if err != nil {
200 log.Fatalf("Benchmark failed: %v", err)
201 }
202
203 if *outputJSON {
204 enc := json.NewEncoder(os.Stdout)
205 enc.SetIndent("", " ")
206 enc.Encode(result)
207 } else {
208 fmt.Printf("=== Serf Throughput Results ===\n")
209 fmt.Printf("Nodes: %d\n", result.NumNodes)
210 fmt.Printf("Duration: %s\n", time.Duration(result.DurationNs))
211 fmt.Printf("Target Rate: %d msg/s per node\n", result.MsgRate)
212 fmt.Printf("Broadcasts Sent: %d\n", result.BroadcastsSent)
213 fmt.Printf("Broadcasts Recv: %d\n", result.BroadcastsReceived)
214 fmt.Printf("Throughput: %.1f msg/s\n", result.MsgsPerSec)
215 fmt.Printf("CPU Cores: %d\n", result.CPUCores)
216 }
217}