this repo has no description
1use std::sync::Arc;
2use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
3use std::time::Duration;
4use tokio::sync::RwLock;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum CircuitState {
8 Closed,
9 Open,
10 HalfOpen,
11}
12
13pub struct CircuitBreaker {
14 name: String,
15 failure_threshold: u32,
16 success_threshold: u32,
17 timeout: Duration,
18 state: Arc<RwLock<CircuitState>>,
19 failure_count: AtomicU32,
20 success_count: AtomicU32,
21 last_failure_time: AtomicU64,
22}
23
24impl CircuitBreaker {
25 pub fn new(
26 name: &str,
27 failure_threshold: u32,
28 success_threshold: u32,
29 timeout_secs: u64,
30 ) -> Self {
31 Self {
32 name: name.to_string(),
33 failure_threshold,
34 success_threshold,
35 timeout: Duration::from_secs(timeout_secs),
36 state: Arc::new(RwLock::new(CircuitState::Closed)),
37 failure_count: AtomicU32::new(0),
38 success_count: AtomicU32::new(0),
39 last_failure_time: AtomicU64::new(0),
40 }
41 }
42
43 pub async fn can_execute(&self) -> bool {
44 let state = self.state.read().await;
45
46 match *state {
47 CircuitState::Closed => true,
48 CircuitState::Open => {
49 let last_failure = self.last_failure_time.load(Ordering::SeqCst);
50 let now = std::time::SystemTime::now()
51 .duration_since(std::time::UNIX_EPOCH)
52 .unwrap()
53 .as_secs();
54
55 if now - last_failure >= self.timeout.as_secs() {
56 drop(state);
57 let mut state = self.state.write().await;
58 if *state == CircuitState::Open {
59 *state = CircuitState::HalfOpen;
60 self.success_count.store(0, Ordering::SeqCst);
61 tracing::info!(circuit = %self.name, "Circuit breaker transitioning to half-open");
62 return true;
63 }
64 }
65 false
66 }
67 CircuitState::HalfOpen => true,
68 }
69 }
70
71 pub async fn record_success(&self) {
72 let state = *self.state.read().await;
73
74 match state {
75 CircuitState::Closed => {
76 self.failure_count.store(0, Ordering::SeqCst);
77 }
78 CircuitState::HalfOpen => {
79 let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
80 if count >= self.success_threshold {
81 let mut state = self.state.write().await;
82 *state = CircuitState::Closed;
83 self.failure_count.store(0, Ordering::SeqCst);
84 self.success_count.store(0, Ordering::SeqCst);
85 tracing::info!(circuit = %self.name, "Circuit breaker closed after successful recovery");
86 }
87 }
88 CircuitState::Open => {}
89 }
90 }
91
92 pub async fn record_failure(&self) {
93 let state = *self.state.read().await;
94
95 match state {
96 CircuitState::Closed => {
97 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
98 if count >= self.failure_threshold {
99 let mut state = self.state.write().await;
100 *state = CircuitState::Open;
101 let now = std::time::SystemTime::now()
102 .duration_since(std::time::UNIX_EPOCH)
103 .unwrap()
104 .as_secs();
105 self.last_failure_time.store(now, Ordering::SeqCst);
106 tracing::warn!(
107 circuit = %self.name,
108 failures = count,
109 "Circuit breaker opened after {} failures",
110 count
111 );
112 }
113 }
114 CircuitState::HalfOpen => {
115 let mut state = self.state.write().await;
116 *state = CircuitState::Open;
117 let now = std::time::SystemTime::now()
118 .duration_since(std::time::UNIX_EPOCH)
119 .unwrap()
120 .as_secs();
121 self.last_failure_time.store(now, Ordering::SeqCst);
122 self.success_count.store(0, Ordering::SeqCst);
123 tracing::warn!(circuit = %self.name, "Circuit breaker reopened after failure in half-open state");
124 }
125 CircuitState::Open => {}
126 }
127 }
128
129 pub async fn state(&self) -> CircuitState {
130 *self.state.read().await
131 }
132
133 pub fn name(&self) -> &str {
134 &self.name
135 }
136}
137
138#[derive(Clone)]
139pub struct CircuitBreakers {
140 pub plc_directory: Arc<CircuitBreaker>,
141 pub relay_notification: Arc<CircuitBreaker>,
142}
143
144impl Default for CircuitBreakers {
145 fn default() -> Self {
146 Self::new()
147 }
148}
149
150impl CircuitBreakers {
151 pub fn new() -> Self {
152 Self {
153 plc_directory: Arc::new(CircuitBreaker::new("plc_directory", 5, 3, 60)),
154 relay_notification: Arc::new(CircuitBreaker::new("relay_notification", 10, 5, 30)),
155 }
156 }
157}
158
159#[derive(Debug)]
160pub struct CircuitOpenError {
161 pub circuit_name: String,
162}
163
164impl std::fmt::Display for CircuitOpenError {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 write!(f, "Circuit breaker '{}' is open", self.circuit_name)
167 }
168}
169
170impl std::error::Error for CircuitOpenError {}
171
172pub async fn with_circuit_breaker<T, E, F, Fut>(
173 circuit: &CircuitBreaker,
174 operation: F,
175) -> Result<T, CircuitBreakerError<E>>
176where
177 F: FnOnce() -> Fut,
178 Fut: std::future::Future<Output = Result<T, E>>,
179{
180 if !circuit.can_execute().await {
181 return Err(CircuitBreakerError::CircuitOpen(CircuitOpenError {
182 circuit_name: circuit.name().to_string(),
183 }));
184 }
185
186 match operation().await {
187 Ok(result) => {
188 circuit.record_success().await;
189 Ok(result)
190 }
191 Err(e) => {
192 circuit.record_failure().await;
193 Err(CircuitBreakerError::OperationFailed(e))
194 }
195 }
196}
197
198#[derive(Debug)]
199pub enum CircuitBreakerError<E> {
200 CircuitOpen(CircuitOpenError),
201 OperationFailed(E),
202}
203
204impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 match self {
207 CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
208 CircuitBreakerError::OperationFailed(e) => write!(f, "Operation failed: {}", e),
209 }
210 }
211}
212
213impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> {
214 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
215 match self {
216 CircuitBreakerError::CircuitOpen(e) => Some(e),
217 CircuitBreakerError::OperationFailed(e) => Some(e),
218 }
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[tokio::test]
227 async fn test_circuit_breaker_starts_closed() {
228 let cb = CircuitBreaker::new("test", 3, 2, 10);
229 assert_eq!(cb.state().await, CircuitState::Closed);
230 assert!(cb.can_execute().await);
231 }
232
233 #[tokio::test]
234 async fn test_circuit_breaker_opens_after_failures() {
235 let cb = CircuitBreaker::new("test", 3, 2, 10);
236
237 cb.record_failure().await;
238 assert_eq!(cb.state().await, CircuitState::Closed);
239
240 cb.record_failure().await;
241 assert_eq!(cb.state().await, CircuitState::Closed);
242
243 cb.record_failure().await;
244 assert_eq!(cb.state().await, CircuitState::Open);
245 assert!(!cb.can_execute().await);
246 }
247
248 #[tokio::test]
249 async fn test_circuit_breaker_success_resets_failures() {
250 let cb = CircuitBreaker::new("test", 3, 2, 10);
251
252 cb.record_failure().await;
253 cb.record_failure().await;
254 cb.record_success().await;
255
256 cb.record_failure().await;
257 cb.record_failure().await;
258 assert_eq!(cb.state().await, CircuitState::Closed);
259
260 cb.record_failure().await;
261 assert_eq!(cb.state().await, CircuitState::Open);
262 }
263
264 #[tokio::test]
265 async fn test_circuit_breaker_half_open_closes_after_successes() {
266 let cb = CircuitBreaker::new("test", 3, 2, 0);
267
268 for _ in 0..3 {
269 cb.record_failure().await;
270 }
271 assert_eq!(cb.state().await, CircuitState::Open);
272
273 tokio::time::sleep(Duration::from_millis(100)).await;
274 assert!(cb.can_execute().await);
275 assert_eq!(cb.state().await, CircuitState::HalfOpen);
276
277 cb.record_success().await;
278 assert_eq!(cb.state().await, CircuitState::HalfOpen);
279
280 cb.record_success().await;
281 assert_eq!(cb.state().await, CircuitState::Closed);
282 }
283
284 #[tokio::test]
285 async fn test_circuit_breaker_half_open_reopens_on_failure() {
286 let cb = CircuitBreaker::new("test", 3, 2, 0);
287
288 for _ in 0..3 {
289 cb.record_failure().await;
290 }
291
292 tokio::time::sleep(Duration::from_millis(100)).await;
293 cb.can_execute().await;
294
295 cb.record_failure().await;
296 assert_eq!(cb.state().await, CircuitState::Open);
297 }
298
299 #[tokio::test]
300 async fn test_with_circuit_breaker_helper() {
301 let cb = CircuitBreaker::new("test", 3, 2, 10);
302
303 let result: Result<i32, CircuitBreakerError<std::io::Error>> =
304 with_circuit_breaker(&cb, || async { Ok(42) }).await;
305 assert!(result.is_ok());
306 assert_eq!(result.unwrap(), 42);
307
308 let result: Result<i32, CircuitBreakerError<&str>> =
309 with_circuit_breaker(&cb, || async { Err("error") }).await;
310 assert!(result.is_err());
311 }
312}