this repo has no description
1package knotclient 2 3import ( 4 "context" 5 "encoding/json" 6 "log/slog" 7 "math/rand" 8 "net/url" 9 "sync" 10 "time" 11 12 "tangled.sh/tangled.sh/core/log" 13 14 "github.com/gorilla/websocket" 15) 16 17type ProcessFunc func(source EventSource, message Message) error 18 19type Message struct { 20 Rkey string 21 Nsid string 22 // do not full deserialize this portion of the message, processFunc can do that 23 EventJson json.RawMessage 24} 25 26type ConsumerConfig struct { 27 Sources map[EventSource]struct{} 28 ProcessFunc ProcessFunc 29 RetryInterval time.Duration 30 MaxRetryInterval time.Duration 31 ConnectionTimeout time.Duration 32 WorkerCount int 33 QueueSize int 34 Logger *slog.Logger 35 Dev bool 36} 37 38type EventSource struct { 39 Knot string 40} 41 42func NewEventSource(knot string) EventSource { 43 return EventSource{ 44 Knot: knot, 45 } 46} 47 48type EventConsumer struct { 49 cfg ConsumerConfig 50 wg sync.WaitGroup 51 dialer *websocket.Dialer 52 connMap sync.Map 53 jobQueue chan job 54 logger *slog.Logger 55 randSource *rand.Rand 56 57 // rw lock over edits to consumer config 58 mu sync.RWMutex 59} 60 61func (e *EventConsumer) buildUrl(s EventSource, cursor string) (*url.URL, error) { 62 scheme := "wss" 63 if e.cfg.Dev { 64 scheme = "ws" 65 } 66 67 u, err := url.Parse(scheme + "://" + s.Knot + "/events") 68 if err != nil { 69 return nil, err 70 } 71 72 if cursor != "" { 73 query := url.Values{} 74 query.Add("cursor", cursor) 75 u.RawQuery = query.Encode() 76 } 77 return u, nil 78} 79 80type job struct { 81 source EventSource 82 message []byte 83} 84 85func NewEventConsumer(cfg ConsumerConfig) *EventConsumer { 86 if cfg.RetryInterval == 0 { 87 cfg.RetryInterval = 15 * time.Minute 88 } 89 if cfg.ConnectionTimeout == 0 { 90 cfg.ConnectionTimeout = 10 * time.Second 91 } 92 if cfg.WorkerCount <= 0 { 93 cfg.WorkerCount = 5 94 } 95 if cfg.MaxRetryInterval == 0 { 96 cfg.MaxRetryInterval = 1 * time.Hour 97 } 98 if cfg.Logger == nil { 99 cfg.Logger = log.New("eventconsumer") 100 } 101 if cfg.QueueSize == 0 { 102 cfg.QueueSize = 100 103 } 104 return &EventConsumer{ 105 cfg: cfg, 106 dialer: websocket.DefaultDialer, 107 jobQueue: make(chan job, cfg.QueueSize), // buffered job queue 108 logger: cfg.Logger, 109 randSource: rand.New(rand.NewSource(time.Now().UnixNano())), 110 } 111} 112 113func (c *EventConsumer) Start(ctx context.Context) { 114 // start workers 115 for range c.cfg.WorkerCount { 116 c.wg.Add(1) 117 go c.worker(ctx) 118 } 119 120 // start streaming 121 for source := range c.cfg.Sources { 122 c.wg.Add(1) 123 go c.startConnectionLoop(ctx, source) 124 } 125} 126 127func (c *EventConsumer) Stop() { 128 c.connMap.Range(func(_, val any) bool { 129 if conn, ok := val.(*websocket.Conn); ok { 130 conn.Close() 131 } 132 return true 133 }) 134 c.wg.Wait() 135 close(c.jobQueue) 136} 137 138func (c *EventConsumer) AddSource(ctx context.Context, s EventSource) { 139 c.mu.Lock() 140 c.cfg.Sources[s] = struct{}{} 141 c.wg.Add(1) 142 go c.startConnectionLoop(ctx, s) 143 c.mu.Unlock() 144} 145 146func (c *EventConsumer) worker(ctx context.Context) { 147 defer c.wg.Done() 148 for { 149 select { 150 case <-ctx.Done(): 151 return 152 case j, ok := <-c.jobQueue: 153 if !ok { 154 return 155 } 156 157 var msg Message 158 err := json.Unmarshal(j.message, &msg) 159 if err != nil { 160 c.logger.Error("error deserializing message", "source", j.source.Knot, "err", err) 161 return 162 } 163 if err := c.cfg.ProcessFunc(j.source, msg); err != nil { 164 c.logger.Error("error processing message", "source", j.source, "err", err) 165 } 166 } 167 } 168} 169 170func (c *EventConsumer) startConnectionLoop(ctx context.Context, source EventSource) { 171 defer c.wg.Done() 172 retryInterval := c.cfg.RetryInterval 173 for { 174 select { 175 case <-ctx.Done(): 176 return 177 default: 178 err := c.runConnection(ctx, source) 179 if err != nil { 180 c.logger.Error("connection failed", "source", source, "err", err) 181 } 182 183 // apply jitter 184 jitter := time.Duration(c.randSource.Int63n(int64(retryInterval) / 5)) 185 delay := retryInterval + jitter 186 187 if retryInterval < c.cfg.MaxRetryInterval { 188 retryInterval *= 2 189 if retryInterval > c.cfg.MaxRetryInterval { 190 retryInterval = c.cfg.MaxRetryInterval 191 } 192 } 193 c.logger.Info("retrying connection", "source", source, "delay", delay) 194 select { 195 case <-time.After(delay): 196 case <-ctx.Done(): 197 return 198 } 199 } 200 } 201} 202 203func (c *EventConsumer) runConnection(ctx context.Context, source EventSource) error { 204 connCtx, cancel := context.WithTimeout(ctx, c.cfg.ConnectionTimeout) 205 defer cancel() 206 207 u, err := url.Parse(source) 208 209 u, err := c.buildUrl(source, cursor) 210 if err != nil { 211 return err 212 } 213 214 c.logger.Info("connecting", "url", u.String()) 215 conn, _, err := c.dialer.DialContext(connCtx, u.String(), nil) 216 if err != nil { 217 return err 218 } 219 defer conn.Close() 220 c.connMap.Store(source, conn) 221 defer c.connMap.Delete(source) 222 223 c.logger.Info("connected", "source", source) 224 225 for { 226 select { 227 case <-ctx.Done(): 228 return nil 229 default: 230 msgType, msg, err := conn.ReadMessage() 231 if err != nil { 232 return err 233 } 234 if msgType != websocket.TextMessage { 235 continue 236 } 237 select { 238 case c.jobQueue <- job{source: source, message: msg}: 239 case <-ctx.Done(): 240 return nil 241 } 242 } 243 } 244}