package main import ( "encoding/json" "flag" "fmt" "io" "log" "os" "runtime" "sync" "sync/atomic" "time" "github.com/hashicorp/serf/serf" ) type ThroughputResult struct { Implementation string `json:"implementation"` NumNodes int `json:"num_nodes"` DurationNs int64 `json:"duration_ns"` MsgRate int `json:"msg_rate"` BroadcastsSent int64 `json:"broadcasts_sent"` BroadcastsReceived int64 `json:"broadcasts_received"` MsgsPerSec float64 `json:"msgs_per_sec"` CPUCores int `json:"cpu_cores"` } type serfThroughputHandler struct { received atomic.Int64 memberCh chan serf.MemberEvent } func (h *serfThroughputHandler) HandleEvent(e serf.Event) { switch evt := e.(type) { case serf.MemberEvent: select { case h.memberCh <- evt: default: } case serf.UserEvent: if evt.Name == "bench" { h.received.Add(1) } } } func createSerfNode(name string, bindPort, rpcPort int, handler *serfThroughputHandler) (*serf.Serf, error) { cfg := serf.DefaultConfig() cfg.NodeName = name cfg.MemberlistConfig.BindAddr = "127.0.0.1" cfg.MemberlistConfig.BindPort = bindPort cfg.MemberlistConfig.AdvertisePort = bindPort cfg.MemberlistConfig.GossipInterval = 50 * time.Millisecond cfg.MemberlistConfig.ProbeInterval = 200 * time.Millisecond cfg.MemberlistConfig.PushPullInterval = 30 * time.Second cfg.MemberlistConfig.GossipNodes = 3 cfg.LogOutput = io.Discard eventCh := make(chan serf.Event, 256) cfg.EventCh = eventCh s, err := serf.Create(cfg) if err != nil { return nil, err } go func() { for e := range eventCh { handler.HandleEvent(e) } }() return s, nil } func waitForMembers(s *serf.Serf, expected int, timeout time.Duration) bool { deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { if len(s.Members()) >= expected { return true } time.Sleep(50 * time.Millisecond) } return false } func runThroughputBenchmark(numNodes int, duration time.Duration, msgRate int) (*ThroughputResult, error) { nodes := make([]*serf.Serf, numNodes) handlers := make([]*serfThroughputHandler, numNodes) baseBindPort := 28946 for i := 0; i < numNodes; i++ { handlers[i] = &serfThroughputHandler{ memberCh: make(chan serf.MemberEvent, 100), } var err error nodes[i], err = createSerfNode( fmt.Sprintf("node-%d", i), baseBindPort+i, 0, handlers[i], ) if err != nil { for j := 0; j < i; j++ { nodes[j].Shutdown() } return nil, fmt.Errorf("failed to create node %d: %w", i, err) } } for i := 1; i < numNodes; i++ { addr := fmt.Sprintf("127.0.0.1:%d", baseBindPort) _, err := nodes[i].Join([]string{addr}, false) if err != nil { log.Printf("Warning: node %d failed to join: %v", i, err) } } for i := 0; i < numNodes; i++ { if !waitForMembers(nodes[i], numNodes, 10*time.Second) { log.Printf("Warning: Node %d did not see all %d nodes", i, numNodes) } } time.Sleep(500 * time.Millisecond) var totalSent atomic.Int64 stopCh := make(chan struct{}) var wg sync.WaitGroup msgInterval := time.Duration(float64(time.Second) / float64(msgRate)) payload := make([]byte, 64) for i := 0; i < 64; i++ { payload[i] = 'x' } startTime := time.Now() for i, n := range nodes { wg.Add(1) go func(node *serf.Serf, idx int) { defer wg.Done() ticker := time.NewTicker(msgInterval) defer ticker.Stop() for { select { case <-ticker.C: err := node.UserEvent("bench", payload, false) if err == nil { totalSent.Add(1) } case <-stopCh: return } } }(n, i) } time.Sleep(duration) close(stopCh) wg.Wait() elapsed := time.Since(startTime) var totalReceived int64 for _, h := range handlers { totalReceived += h.received.Load() } for _, n := range nodes { n.Shutdown() } msgsPerSec := float64(totalReceived) / elapsed.Seconds() return &ThroughputResult{ Implementation: "serf", NumNodes: numNodes, DurationNs: duration.Nanoseconds(), MsgRate: msgRate, BroadcastsSent: totalSent.Load(), BroadcastsReceived: totalReceived, MsgsPerSec: msgsPerSec, CPUCores: runtime.NumCPU(), }, nil } func main() { numNodes := flag.Int("nodes", 5, "number of nodes") durationSec := flag.Int("duration", 10, "benchmark duration in seconds") msgRate := flag.Int("rate", 100, "messages per second per node") outputJSON := flag.Bool("json", false, "output as JSON") flag.Parse() result, err := runThroughputBenchmark(*numNodes, time.Duration(*durationSec)*time.Second, *msgRate) if err != nil { log.Fatalf("Benchmark failed: %v", err) } if *outputJSON { enc := json.NewEncoder(os.Stdout) enc.SetIndent("", " ") enc.Encode(result) } else { fmt.Printf("=== Serf Throughput Results ===\n") fmt.Printf("Nodes: %d\n", result.NumNodes) fmt.Printf("Duration: %s\n", time.Duration(result.DurationNs)) fmt.Printf("Target Rate: %d msg/s per node\n", result.MsgRate) fmt.Printf("Broadcasts Sent: %d\n", result.BroadcastsSent) fmt.Printf("Broadcasts Recv: %d\n", result.BroadcastsReceived) fmt.Printf("Throughput: %.1f msg/s\n", result.MsgsPerSec) fmt.Printf("CPU Cores: %d\n", result.CPUCores) } }