this repo has no description
1use std::sync::Arc; 2use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; 3use std::time::Duration; 4use tokio::sync::RwLock; 5 6#[derive(Debug, Clone, Copy, PartialEq, Eq)] 7pub enum CircuitState { 8 Closed, 9 Open, 10 HalfOpen, 11} 12 13pub struct CircuitBreaker { 14 name: String, 15 failure_threshold: u32, 16 success_threshold: u32, 17 timeout: Duration, 18 state: Arc<RwLock<CircuitState>>, 19 failure_count: AtomicU32, 20 success_count: AtomicU32, 21 last_failure_time: AtomicU64, 22} 23 24impl CircuitBreaker { 25 pub fn new( 26 name: &str, 27 failure_threshold: u32, 28 success_threshold: u32, 29 timeout_secs: u64, 30 ) -> Self { 31 Self { 32 name: name.to_string(), 33 failure_threshold, 34 success_threshold, 35 timeout: Duration::from_secs(timeout_secs), 36 state: Arc::new(RwLock::new(CircuitState::Closed)), 37 failure_count: AtomicU32::new(0), 38 success_count: AtomicU32::new(0), 39 last_failure_time: AtomicU64::new(0), 40 } 41 } 42 43 pub async fn can_execute(&self) -> bool { 44 let state = self.state.read().await; 45 46 match *state { 47 CircuitState::Closed => true, 48 CircuitState::Open => { 49 let last_failure = self.last_failure_time.load(Ordering::SeqCst); 50 let now = std::time::SystemTime::now() 51 .duration_since(std::time::UNIX_EPOCH) 52 .unwrap() 53 .as_secs(); 54 55 if now - last_failure >= self.timeout.as_secs() { 56 drop(state); 57 let mut state = self.state.write().await; 58 if *state == CircuitState::Open { 59 *state = CircuitState::HalfOpen; 60 self.success_count.store(0, Ordering::SeqCst); 61 tracing::info!(circuit = %self.name, "Circuit breaker transitioning to half-open"); 62 return true; 63 } 64 } 65 false 66 } 67 CircuitState::HalfOpen => true, 68 } 69 } 70 71 pub async fn record_success(&self) { 72 let state = *self.state.read().await; 73 74 match state { 75 CircuitState::Closed => { 76 self.failure_count.store(0, Ordering::SeqCst); 77 } 78 CircuitState::HalfOpen => { 79 let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1; 80 if count >= self.success_threshold { 81 let mut state = self.state.write().await; 82 *state = CircuitState::Closed; 83 self.failure_count.store(0, Ordering::SeqCst); 84 self.success_count.store(0, Ordering::SeqCst); 85 tracing::info!(circuit = %self.name, "Circuit breaker closed after successful recovery"); 86 } 87 } 88 CircuitState::Open => {} 89 } 90 } 91 92 pub async fn record_failure(&self) { 93 let state = *self.state.read().await; 94 95 match state { 96 CircuitState::Closed => { 97 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1; 98 if count >= self.failure_threshold { 99 let mut state = self.state.write().await; 100 *state = CircuitState::Open; 101 let now = std::time::SystemTime::now() 102 .duration_since(std::time::UNIX_EPOCH) 103 .unwrap() 104 .as_secs(); 105 self.last_failure_time.store(now, Ordering::SeqCst); 106 tracing::warn!( 107 circuit = %self.name, 108 failures = count, 109 "Circuit breaker opened after {} failures", 110 count 111 ); 112 } 113 } 114 CircuitState::HalfOpen => { 115 let mut state = self.state.write().await; 116 *state = CircuitState::Open; 117 let now = std::time::SystemTime::now() 118 .duration_since(std::time::UNIX_EPOCH) 119 .unwrap() 120 .as_secs(); 121 self.last_failure_time.store(now, Ordering::SeqCst); 122 self.success_count.store(0, Ordering::SeqCst); 123 tracing::warn!(circuit = %self.name, "Circuit breaker reopened after failure in half-open state"); 124 } 125 CircuitState::Open => {} 126 } 127 } 128 129 pub async fn state(&self) -> CircuitState { 130 *self.state.read().await 131 } 132 133 pub fn name(&self) -> &str { 134 &self.name 135 } 136} 137 138#[derive(Clone)] 139pub struct CircuitBreakers { 140 pub plc_directory: Arc<CircuitBreaker>, 141 pub relay_notification: Arc<CircuitBreaker>, 142} 143 144impl Default for CircuitBreakers { 145 fn default() -> Self { 146 Self::new() 147 } 148} 149 150impl CircuitBreakers { 151 pub fn new() -> Self { 152 Self { 153 plc_directory: Arc::new(CircuitBreaker::new("plc_directory", 5, 3, 60)), 154 relay_notification: Arc::new(CircuitBreaker::new("relay_notification", 10, 5, 30)), 155 } 156 } 157} 158 159#[derive(Debug)] 160pub struct CircuitOpenError { 161 pub circuit_name: String, 162} 163 164impl std::fmt::Display for CircuitOpenError { 165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 166 write!(f, "Circuit breaker '{}' is open", self.circuit_name) 167 } 168} 169 170impl std::error::Error for CircuitOpenError {} 171 172pub async fn with_circuit_breaker<T, E, F, Fut>( 173 circuit: &CircuitBreaker, 174 operation: F, 175) -> Result<T, CircuitBreakerError<E>> 176where 177 F: FnOnce() -> Fut, 178 Fut: std::future::Future<Output = Result<T, E>>, 179{ 180 if !circuit.can_execute().await { 181 return Err(CircuitBreakerError::CircuitOpen(CircuitOpenError { 182 circuit_name: circuit.name().to_string(), 183 })); 184 } 185 186 match operation().await { 187 Ok(result) => { 188 circuit.record_success().await; 189 Ok(result) 190 } 191 Err(e) => { 192 circuit.record_failure().await; 193 Err(CircuitBreakerError::OperationFailed(e)) 194 } 195 } 196} 197 198#[derive(Debug)] 199pub enum CircuitBreakerError<E> { 200 CircuitOpen(CircuitOpenError), 201 OperationFailed(E), 202} 203 204impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> { 205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 206 match self { 207 CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e), 208 CircuitBreakerError::OperationFailed(e) => write!(f, "Operation failed: {}", e), 209 } 210 } 211} 212 213impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> { 214 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 215 match self { 216 CircuitBreakerError::CircuitOpen(e) => Some(e), 217 CircuitBreakerError::OperationFailed(e) => Some(e), 218 } 219 } 220} 221 222#[cfg(test)] 223mod tests { 224 use super::*; 225 226 #[tokio::test] 227 async fn test_circuit_breaker_starts_closed() { 228 let cb = CircuitBreaker::new("test", 3, 2, 10); 229 assert_eq!(cb.state().await, CircuitState::Closed); 230 assert!(cb.can_execute().await); 231 } 232 233 #[tokio::test] 234 async fn test_circuit_breaker_opens_after_failures() { 235 let cb = CircuitBreaker::new("test", 3, 2, 10); 236 237 cb.record_failure().await; 238 assert_eq!(cb.state().await, CircuitState::Closed); 239 240 cb.record_failure().await; 241 assert_eq!(cb.state().await, CircuitState::Closed); 242 243 cb.record_failure().await; 244 assert_eq!(cb.state().await, CircuitState::Open); 245 assert!(!cb.can_execute().await); 246 } 247 248 #[tokio::test] 249 async fn test_circuit_breaker_success_resets_failures() { 250 let cb = CircuitBreaker::new("test", 3, 2, 10); 251 252 cb.record_failure().await; 253 cb.record_failure().await; 254 cb.record_success().await; 255 256 cb.record_failure().await; 257 cb.record_failure().await; 258 assert_eq!(cb.state().await, CircuitState::Closed); 259 260 cb.record_failure().await; 261 assert_eq!(cb.state().await, CircuitState::Open); 262 } 263 264 #[tokio::test] 265 async fn test_circuit_breaker_half_open_closes_after_successes() { 266 let cb = CircuitBreaker::new("test", 3, 2, 0); 267 268 for _ in 0..3 { 269 cb.record_failure().await; 270 } 271 assert_eq!(cb.state().await, CircuitState::Open); 272 273 tokio::time::sleep(Duration::from_millis(100)).await; 274 assert!(cb.can_execute().await); 275 assert_eq!(cb.state().await, CircuitState::HalfOpen); 276 277 cb.record_success().await; 278 assert_eq!(cb.state().await, CircuitState::HalfOpen); 279 280 cb.record_success().await; 281 assert_eq!(cb.state().await, CircuitState::Closed); 282 } 283 284 #[tokio::test] 285 async fn test_circuit_breaker_half_open_reopens_on_failure() { 286 let cb = CircuitBreaker::new("test", 3, 2, 0); 287 288 for _ in 0..3 { 289 cb.record_failure().await; 290 } 291 292 tokio::time::sleep(Duration::from_millis(100)).await; 293 cb.can_execute().await; 294 295 cb.record_failure().await; 296 assert_eq!(cb.state().await, CircuitState::Open); 297 } 298 299 #[tokio::test] 300 async fn test_with_circuit_breaker_helper() { 301 let cb = CircuitBreaker::new("test", 3, 2, 10); 302 303 let result: Result<i32, CircuitBreakerError<std::io::Error>> = 304 with_circuit_breaker(&cb, || async { Ok(42) }).await; 305 assert!(result.is_ok()); 306 assert_eq!(result.unwrap(), 42); 307 308 let result: Result<i32, CircuitBreakerError<&str>> = 309 with_circuit_breaker(&cb, || async { Err("error") }).await; 310 assert!(result.is_err()); 311 } 312}