this repo has no description
1use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; 2use std::sync::Arc; 3use std::time::Duration; 4use tokio::sync::RwLock; 5 6#[derive(Debug, Clone, Copy, PartialEq, Eq)] 7pub enum CircuitState { 8 Closed, 9 Open, 10 HalfOpen, 11} 12 13pub struct CircuitBreaker { 14 name: String, 15 failure_threshold: u32, 16 success_threshold: u32, 17 timeout: Duration, 18 state: Arc<RwLock<CircuitState>>, 19 failure_count: AtomicU32, 20 success_count: AtomicU32, 21 last_failure_time: AtomicU64, 22} 23 24impl CircuitBreaker { 25 pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self { 26 Self { 27 name: name.to_string(), 28 failure_threshold, 29 success_threshold, 30 timeout: Duration::from_secs(timeout_secs), 31 state: Arc::new(RwLock::new(CircuitState::Closed)), 32 failure_count: AtomicU32::new(0), 33 success_count: AtomicU32::new(0), 34 last_failure_time: AtomicU64::new(0), 35 } 36 } 37 38 pub async fn can_execute(&self) -> bool { 39 let state = self.state.read().await; 40 41 match *state { 42 CircuitState::Closed => true, 43 CircuitState::Open => { 44 let last_failure = self.last_failure_time.load(Ordering::SeqCst); 45 let now = std::time::SystemTime::now() 46 .duration_since(std::time::UNIX_EPOCH) 47 .unwrap() 48 .as_secs(); 49 50 if now - last_failure >= self.timeout.as_secs() { 51 drop(state); 52 let mut state = self.state.write().await; 53 if *state == CircuitState::Open { 54 *state = CircuitState::HalfOpen; 55 self.success_count.store(0, Ordering::SeqCst); 56 tracing::info!(circuit = %self.name, "Circuit breaker transitioning to half-open"); 57 return true; 58 } 59 } 60 false 61 } 62 CircuitState::HalfOpen => true, 63 } 64 } 65 66 pub async fn record_success(&self) { 67 let state = *self.state.read().await; 68 69 match state { 70 CircuitState::Closed => { 71 self.failure_count.store(0, Ordering::SeqCst); 72 } 73 CircuitState::HalfOpen => { 74 let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1; 75 if count >= self.success_threshold { 76 let mut state = self.state.write().await; 77 *state = CircuitState::Closed; 78 self.failure_count.store(0, Ordering::SeqCst); 79 self.success_count.store(0, Ordering::SeqCst); 80 tracing::info!(circuit = %self.name, "Circuit breaker closed after successful recovery"); 81 } 82 } 83 CircuitState::Open => {} 84 } 85 } 86 87 pub async fn record_failure(&self) { 88 let state = *self.state.read().await; 89 90 match state { 91 CircuitState::Closed => { 92 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1; 93 if count >= self.failure_threshold { 94 let mut state = self.state.write().await; 95 *state = CircuitState::Open; 96 let now = std::time::SystemTime::now() 97 .duration_since(std::time::UNIX_EPOCH) 98 .unwrap() 99 .as_secs(); 100 self.last_failure_time.store(now, Ordering::SeqCst); 101 tracing::warn!( 102 circuit = %self.name, 103 failures = count, 104 "Circuit breaker opened after {} failures", 105 count 106 ); 107 } 108 } 109 CircuitState::HalfOpen => { 110 let mut state = self.state.write().await; 111 *state = CircuitState::Open; 112 let now = std::time::SystemTime::now() 113 .duration_since(std::time::UNIX_EPOCH) 114 .unwrap() 115 .as_secs(); 116 self.last_failure_time.store(now, Ordering::SeqCst); 117 self.success_count.store(0, Ordering::SeqCst); 118 tracing::warn!(circuit = %self.name, "Circuit breaker reopened after failure in half-open state"); 119 } 120 CircuitState::Open => {} 121 } 122 } 123 124 pub async fn state(&self) -> CircuitState { 125 *self.state.read().await 126 } 127 128 pub fn name(&self) -> &str { 129 &self.name 130 } 131} 132 133#[derive(Clone)] 134pub struct CircuitBreakers { 135 pub plc_directory: Arc<CircuitBreaker>, 136 pub relay_notification: Arc<CircuitBreaker>, 137} 138 139impl Default for CircuitBreakers { 140 fn default() -> Self { 141 Self::new() 142 } 143} 144 145impl CircuitBreakers { 146 pub fn new() -> Self { 147 Self { 148 plc_directory: Arc::new(CircuitBreaker::new("plc_directory", 5, 3, 60)), 149 relay_notification: Arc::new(CircuitBreaker::new("relay_notification", 10, 5, 30)), 150 } 151 } 152} 153 154#[derive(Debug)] 155pub struct CircuitOpenError { 156 pub circuit_name: String, 157} 158 159impl std::fmt::Display for CircuitOpenError { 160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 161 write!(f, "Circuit breaker '{}' is open", self.circuit_name) 162 } 163} 164 165impl std::error::Error for CircuitOpenError {} 166 167pub async fn with_circuit_breaker<T, E, F, Fut>( 168 circuit: &CircuitBreaker, 169 operation: F, 170) -> Result<T, CircuitBreakerError<E>> 171where 172 F: FnOnce() -> Fut, 173 Fut: std::future::Future<Output = Result<T, E>>, 174{ 175 if !circuit.can_execute().await { 176 return Err(CircuitBreakerError::CircuitOpen(CircuitOpenError { 177 circuit_name: circuit.name().to_string(), 178 })); 179 } 180 181 match operation().await { 182 Ok(result) => { 183 circuit.record_success().await; 184 Ok(result) 185 } 186 Err(e) => { 187 circuit.record_failure().await; 188 Err(CircuitBreakerError::OperationFailed(e)) 189 } 190 } 191} 192 193#[derive(Debug)] 194pub enum CircuitBreakerError<E> { 195 CircuitOpen(CircuitOpenError), 196 OperationFailed(E), 197} 198 199impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> { 200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 201 match self { 202 CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e), 203 CircuitBreakerError::OperationFailed(e) => write!(f, "Operation failed: {}", e), 204 } 205 } 206} 207 208impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> { 209 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 210 match self { 211 CircuitBreakerError::CircuitOpen(e) => Some(e), 212 CircuitBreakerError::OperationFailed(e) => Some(e), 213 } 214 } 215} 216 217#[cfg(test)] 218mod tests { 219 use super::*; 220 221 #[tokio::test] 222 async fn test_circuit_breaker_starts_closed() { 223 let cb = CircuitBreaker::new("test", 3, 2, 10); 224 assert_eq!(cb.state().await, CircuitState::Closed); 225 assert!(cb.can_execute().await); 226 } 227 228 #[tokio::test] 229 async fn test_circuit_breaker_opens_after_failures() { 230 let cb = CircuitBreaker::new("test", 3, 2, 10); 231 232 cb.record_failure().await; 233 assert_eq!(cb.state().await, CircuitState::Closed); 234 235 cb.record_failure().await; 236 assert_eq!(cb.state().await, CircuitState::Closed); 237 238 cb.record_failure().await; 239 assert_eq!(cb.state().await, CircuitState::Open); 240 assert!(!cb.can_execute().await); 241 } 242 243 #[tokio::test] 244 async fn test_circuit_breaker_success_resets_failures() { 245 let cb = CircuitBreaker::new("test", 3, 2, 10); 246 247 cb.record_failure().await; 248 cb.record_failure().await; 249 cb.record_success().await; 250 251 cb.record_failure().await; 252 cb.record_failure().await; 253 assert_eq!(cb.state().await, CircuitState::Closed); 254 255 cb.record_failure().await; 256 assert_eq!(cb.state().await, CircuitState::Open); 257 } 258 259 #[tokio::test] 260 async fn test_circuit_breaker_half_open_closes_after_successes() { 261 let cb = CircuitBreaker::new("test", 3, 2, 0); 262 263 for _ in 0..3 { 264 cb.record_failure().await; 265 } 266 assert_eq!(cb.state().await, CircuitState::Open); 267 268 tokio::time::sleep(Duration::from_millis(100)).await; 269 assert!(cb.can_execute().await); 270 assert_eq!(cb.state().await, CircuitState::HalfOpen); 271 272 cb.record_success().await; 273 assert_eq!(cb.state().await, CircuitState::HalfOpen); 274 275 cb.record_success().await; 276 assert_eq!(cb.state().await, CircuitState::Closed); 277 } 278 279 #[tokio::test] 280 async fn test_circuit_breaker_half_open_reopens_on_failure() { 281 let cb = CircuitBreaker::new("test", 3, 2, 0); 282 283 for _ in 0..3 { 284 cb.record_failure().await; 285 } 286 287 tokio::time::sleep(Duration::from_millis(100)).await; 288 cb.can_execute().await; 289 290 cb.record_failure().await; 291 assert_eq!(cb.state().await, CircuitState::Open); 292 } 293 294 #[tokio::test] 295 async fn test_with_circuit_breaker_helper() { 296 let cb = CircuitBreaker::new("test", 3, 2, 10); 297 298 let result: Result<i32, CircuitBreakerError<std::io::Error>> = 299 with_circuit_breaker(&cb, || async { Ok(42) }).await; 300 assert!(result.is_ok()); 301 assert_eq!(result.unwrap(), 42); 302 303 let result: Result<i32, CircuitBreakerError<&str>> = 304 with_circuit_breaker(&cb, || async { Err("error") }).await; 305 assert!(result.is_err()); 306 } 307}