this repo has no description
1use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
2use std::sync::Arc;
3use std::time::Duration;
4use tokio::sync::RwLock;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum CircuitState {
8 Closed,
9 Open,
10 HalfOpen,
11}
12
13pub struct CircuitBreaker {
14 name: String,
15 failure_threshold: u32,
16 success_threshold: u32,
17 timeout: Duration,
18 state: Arc<RwLock<CircuitState>>,
19 failure_count: AtomicU32,
20 success_count: AtomicU32,
21 last_failure_time: AtomicU64,
22}
23
24impl CircuitBreaker {
25 pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self {
26 Self {
27 name: name.to_string(),
28 failure_threshold,
29 success_threshold,
30 timeout: Duration::from_secs(timeout_secs),
31 state: Arc::new(RwLock::new(CircuitState::Closed)),
32 failure_count: AtomicU32::new(0),
33 success_count: AtomicU32::new(0),
34 last_failure_time: AtomicU64::new(0),
35 }
36 }
37
38 pub async fn can_execute(&self) -> bool {
39 let state = self.state.read().await;
40
41 match *state {
42 CircuitState::Closed => true,
43 CircuitState::Open => {
44 let last_failure = self.last_failure_time.load(Ordering::SeqCst);
45 let now = std::time::SystemTime::now()
46 .duration_since(std::time::UNIX_EPOCH)
47 .unwrap()
48 .as_secs();
49
50 if now - last_failure >= self.timeout.as_secs() {
51 drop(state);
52 let mut state = self.state.write().await;
53 if *state == CircuitState::Open {
54 *state = CircuitState::HalfOpen;
55 self.success_count.store(0, Ordering::SeqCst);
56 tracing::info!(circuit = %self.name, "Circuit breaker transitioning to half-open");
57 return true;
58 }
59 }
60 false
61 }
62 CircuitState::HalfOpen => true,
63 }
64 }
65
66 pub async fn record_success(&self) {
67 let state = *self.state.read().await;
68
69 match state {
70 CircuitState::Closed => {
71 self.failure_count.store(0, Ordering::SeqCst);
72 }
73 CircuitState::HalfOpen => {
74 let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
75 if count >= self.success_threshold {
76 let mut state = self.state.write().await;
77 *state = CircuitState::Closed;
78 self.failure_count.store(0, Ordering::SeqCst);
79 self.success_count.store(0, Ordering::SeqCst);
80 tracing::info!(circuit = %self.name, "Circuit breaker closed after successful recovery");
81 }
82 }
83 CircuitState::Open => {}
84 }
85 }
86
87 pub async fn record_failure(&self) {
88 let state = *self.state.read().await;
89
90 match state {
91 CircuitState::Closed => {
92 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
93 if count >= self.failure_threshold {
94 let mut state = self.state.write().await;
95 *state = CircuitState::Open;
96 let now = std::time::SystemTime::now()
97 .duration_since(std::time::UNIX_EPOCH)
98 .unwrap()
99 .as_secs();
100 self.last_failure_time.store(now, Ordering::SeqCst);
101 tracing::warn!(
102 circuit = %self.name,
103 failures = count,
104 "Circuit breaker opened after {} failures",
105 count
106 );
107 }
108 }
109 CircuitState::HalfOpen => {
110 let mut state = self.state.write().await;
111 *state = CircuitState::Open;
112 let now = std::time::SystemTime::now()
113 .duration_since(std::time::UNIX_EPOCH)
114 .unwrap()
115 .as_secs();
116 self.last_failure_time.store(now, Ordering::SeqCst);
117 self.success_count.store(0, Ordering::SeqCst);
118 tracing::warn!(circuit = %self.name, "Circuit breaker reopened after failure in half-open state");
119 }
120 CircuitState::Open => {}
121 }
122 }
123
124 pub async fn state(&self) -> CircuitState {
125 *self.state.read().await
126 }
127
128 pub fn name(&self) -> &str {
129 &self.name
130 }
131}
132
133#[derive(Clone)]
134pub struct CircuitBreakers {
135 pub plc_directory: Arc<CircuitBreaker>,
136 pub relay_notification: Arc<CircuitBreaker>,
137}
138
139impl Default for CircuitBreakers {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145impl CircuitBreakers {
146 pub fn new() -> Self {
147 Self {
148 plc_directory: Arc::new(CircuitBreaker::new("plc_directory", 5, 3, 60)),
149 relay_notification: Arc::new(CircuitBreaker::new("relay_notification", 10, 5, 30)),
150 }
151 }
152}
153
154#[derive(Debug)]
155pub struct CircuitOpenError {
156 pub circuit_name: String,
157}
158
159impl std::fmt::Display for CircuitOpenError {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 write!(f, "Circuit breaker '{}' is open", self.circuit_name)
162 }
163}
164
165impl std::error::Error for CircuitOpenError {}
166
167pub async fn with_circuit_breaker<T, E, F, Fut>(
168 circuit: &CircuitBreaker,
169 operation: F,
170) -> Result<T, CircuitBreakerError<E>>
171where
172 F: FnOnce() -> Fut,
173 Fut: std::future::Future<Output = Result<T, E>>,
174{
175 if !circuit.can_execute().await {
176 return Err(CircuitBreakerError::CircuitOpen(CircuitOpenError {
177 circuit_name: circuit.name().to_string(),
178 }));
179 }
180
181 match operation().await {
182 Ok(result) => {
183 circuit.record_success().await;
184 Ok(result)
185 }
186 Err(e) => {
187 circuit.record_failure().await;
188 Err(CircuitBreakerError::OperationFailed(e))
189 }
190 }
191}
192
193#[derive(Debug)]
194pub enum CircuitBreakerError<E> {
195 CircuitOpen(CircuitOpenError),
196 OperationFailed(E),
197}
198
199impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 match self {
202 CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
203 CircuitBreakerError::OperationFailed(e) => write!(f, "Operation failed: {}", e),
204 }
205 }
206}
207
208impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> {
209 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
210 match self {
211 CircuitBreakerError::CircuitOpen(e) => Some(e),
212 CircuitBreakerError::OperationFailed(e) => Some(e),
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[tokio::test]
222 async fn test_circuit_breaker_starts_closed() {
223 let cb = CircuitBreaker::new("test", 3, 2, 10);
224 assert_eq!(cb.state().await, CircuitState::Closed);
225 assert!(cb.can_execute().await);
226 }
227
228 #[tokio::test]
229 async fn test_circuit_breaker_opens_after_failures() {
230 let cb = CircuitBreaker::new("test", 3, 2, 10);
231
232 cb.record_failure().await;
233 assert_eq!(cb.state().await, CircuitState::Closed);
234
235 cb.record_failure().await;
236 assert_eq!(cb.state().await, CircuitState::Closed);
237
238 cb.record_failure().await;
239 assert_eq!(cb.state().await, CircuitState::Open);
240 assert!(!cb.can_execute().await);
241 }
242
243 #[tokio::test]
244 async fn test_circuit_breaker_success_resets_failures() {
245 let cb = CircuitBreaker::new("test", 3, 2, 10);
246
247 cb.record_failure().await;
248 cb.record_failure().await;
249 cb.record_success().await;
250
251 cb.record_failure().await;
252 cb.record_failure().await;
253 assert_eq!(cb.state().await, CircuitState::Closed);
254
255 cb.record_failure().await;
256 assert_eq!(cb.state().await, CircuitState::Open);
257 }
258
259 #[tokio::test]
260 async fn test_circuit_breaker_half_open_closes_after_successes() {
261 let cb = CircuitBreaker::new("test", 3, 2, 0);
262
263 for _ in 0..3 {
264 cb.record_failure().await;
265 }
266 assert_eq!(cb.state().await, CircuitState::Open);
267
268 tokio::time::sleep(Duration::from_millis(100)).await;
269 assert!(cb.can_execute().await);
270 assert_eq!(cb.state().await, CircuitState::HalfOpen);
271
272 cb.record_success().await;
273 assert_eq!(cb.state().await, CircuitState::HalfOpen);
274
275 cb.record_success().await;
276 assert_eq!(cb.state().await, CircuitState::Closed);
277 }
278
279 #[tokio::test]
280 async fn test_circuit_breaker_half_open_reopens_on_failure() {
281 let cb = CircuitBreaker::new("test", 3, 2, 0);
282
283 for _ in 0..3 {
284 cb.record_failure().await;
285 }
286
287 tokio::time::sleep(Duration::from_millis(100)).await;
288 cb.can_execute().await;
289
290 cb.record_failure().await;
291 assert_eq!(cb.state().await, CircuitState::Open);
292 }
293
294 #[tokio::test]
295 async fn test_with_circuit_breaker_helper() {
296 let cb = CircuitBreaker::new("test", 3, 2, 10);
297
298 let result: Result<i32, CircuitBreakerError<std::io::Error>> =
299 with_circuit_breaker(&cb, || async { Ok(42) }).await;
300 assert!(result.is_ok());
301 assert_eq!(result.unwrap(), 42);
302
303 let result: Result<i32, CircuitBreakerError<&str>> =
304 with_circuit_breaker(&cb, || async { Err("error") }).await;
305 assert!(result.is_err());
306 }
307}