this repo has no description
1use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; 2use std::sync::Arc; 3use std::time::Duration; 4use tokio::sync::RwLock; 5 6#[derive(Debug, Clone, Copy, PartialEq, Eq)] 7pub enum CircuitState { 8 Closed, 9 Open, 10 HalfOpen, 11} 12 13pub struct CircuitBreaker { 14 name: String, 15 failure_threshold: u32, 16 success_threshold: u32, 17 timeout: Duration, 18 state: Arc<RwLock<CircuitState>>, 19 failure_count: AtomicU32, 20 success_count: AtomicU32, 21 last_failure_time: AtomicU64, 22} 23 24impl CircuitBreaker { 25 pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self { 26 Self { 27 name: name.to_string(), 28 failure_threshold, 29 success_threshold, 30 timeout: Duration::from_secs(timeout_secs), 31 state: Arc::new(RwLock::new(CircuitState::Closed)), 32 failure_count: AtomicU32::new(0), 33 success_count: AtomicU32::new(0), 34 last_failure_time: AtomicU64::new(0), 35 } 36 } 37 38 pub async fn can_execute(&self) -> bool { 39 let state = self.state.read().await; 40 match *state { 41 CircuitState::Closed => true, 42 CircuitState::Open => { 43 let last_failure = self.last_failure_time.load(Ordering::SeqCst); 44 let now = std::time::SystemTime::now() 45 .duration_since(std::time::UNIX_EPOCH) 46 .unwrap() 47 .as_secs(); 48 49 if now - last_failure >= self.timeout.as_secs() { 50 drop(state); 51 let mut state = self.state.write().await; 52 if *state == CircuitState::Open { 53 *state = CircuitState::HalfOpen; 54 self.success_count.store(0, Ordering::SeqCst); 55 tracing::info!(circuit = %self.name, "Circuit breaker transitioning to half-open"); 56 return true; 57 } 58 } 59 false 60 } 61 CircuitState::HalfOpen => true, 62 } 63 } 64 65 pub async fn record_success(&self) { 66 let state = *self.state.read().await; 67 68 match state { 69 CircuitState::Closed => { 70 self.failure_count.store(0, Ordering::SeqCst); 71 } 72 CircuitState::HalfOpen => { 73 let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1; 74 if count >= self.success_threshold { 75 let mut state = self.state.write().await; 76 *state = CircuitState::Closed; 77 self.failure_count.store(0, Ordering::SeqCst); 78 self.success_count.store(0, Ordering::SeqCst); 79 tracing::info!(circuit = %self.name, "Circuit breaker closed after successful recovery"); 80 } 81 } 82 CircuitState::Open => {} 83 } 84 } 85 86 pub async fn record_failure(&self) { 87 let state = *self.state.read().await; 88 89 match state { 90 CircuitState::Closed => { 91 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1; 92 if count >= self.failure_threshold { 93 let mut state = self.state.write().await; 94 *state = CircuitState::Open; 95 let now = std::time::SystemTime::now() 96 .duration_since(std::time::UNIX_EPOCH) 97 .unwrap() 98 .as_secs(); 99 self.last_failure_time.store(now, Ordering::SeqCst); 100 tracing::warn!( 101 circuit = %self.name, 102 failures = count, 103 "Circuit breaker opened after {} failures", 104 count 105 ); 106 } 107 } 108 CircuitState::HalfOpen => { 109 let mut state = self.state.write().await; 110 *state = CircuitState::Open; 111 let now = std::time::SystemTime::now() 112 .duration_since(std::time::UNIX_EPOCH) 113 .unwrap() 114 .as_secs(); 115 self.last_failure_time.store(now, Ordering::SeqCst); 116 self.success_count.store(0, Ordering::SeqCst); 117 tracing::warn!(circuit = %self.name, "Circuit breaker reopened after failure in half-open state"); 118 } 119 CircuitState::Open => {} 120 } 121 } 122 123 pub async fn state(&self) -> CircuitState { 124 *self.state.read().await 125 } 126 127 pub fn name(&self) -> &str { 128 &self.name 129 } 130} 131 132#[derive(Clone)] 133pub struct CircuitBreakers { 134 pub plc_directory: Arc<CircuitBreaker>, 135 pub relay_notification: Arc<CircuitBreaker>, 136} 137 138impl Default for CircuitBreakers { 139 fn default() -> Self { 140 Self::new() 141 } 142} 143 144impl CircuitBreakers { 145 pub fn new() -> Self { 146 Self { 147 plc_directory: Arc::new(CircuitBreaker::new("plc_directory", 5, 3, 60)), 148 relay_notification: Arc::new(CircuitBreaker::new("relay_notification", 10, 5, 30)), 149 } 150 } 151} 152 153#[derive(Debug)] 154pub struct CircuitOpenError { 155 pub circuit_name: String, 156} 157 158impl std::fmt::Display for CircuitOpenError { 159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 160 write!(f, "Circuit breaker '{}' is open", self.circuit_name) 161 } 162} 163 164impl std::error::Error for CircuitOpenError {} 165 166pub async fn with_circuit_breaker<T, E, F, Fut>( 167 circuit: &CircuitBreaker, 168 operation: F, 169) -> Result<T, CircuitBreakerError<E>> 170where 171 F: FnOnce() -> Fut, 172 Fut: std::future::Future<Output = Result<T, E>>, 173{ 174 if !circuit.can_execute().await { 175 return Err(CircuitBreakerError::CircuitOpen(CircuitOpenError { 176 circuit_name: circuit.name().to_string(), 177 })); 178 } 179 180 match operation().await { 181 Ok(result) => { 182 circuit.record_success().await; 183 Ok(result) 184 } 185 Err(e) => { 186 circuit.record_failure().await; 187 Err(CircuitBreakerError::OperationFailed(e)) 188 } 189 } 190} 191 192#[derive(Debug)] 193pub enum CircuitBreakerError<E> { 194 CircuitOpen(CircuitOpenError), 195 OperationFailed(E), 196} 197 198impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> { 199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 200 match self { 201 CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e), 202 CircuitBreakerError::OperationFailed(e) => write!(f, "Operation failed: {}", e), 203 } 204 } 205} 206 207impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> { 208 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 209 match self { 210 CircuitBreakerError::CircuitOpen(e) => Some(e), 211 CircuitBreakerError::OperationFailed(e) => Some(e), 212 } 213 } 214} 215 216#[cfg(test)] 217mod tests { 218 use super::*; 219 220 #[tokio::test] 221 async fn test_circuit_breaker_starts_closed() { 222 let cb = CircuitBreaker::new("test", 3, 2, 10); 223 assert_eq!(cb.state().await, CircuitState::Closed); 224 assert!(cb.can_execute().await); 225 } 226 227 #[tokio::test] 228 async fn test_circuit_breaker_opens_after_failures() { 229 let cb = CircuitBreaker::new("test", 3, 2, 10); 230 231 cb.record_failure().await; 232 assert_eq!(cb.state().await, CircuitState::Closed); 233 234 cb.record_failure().await; 235 assert_eq!(cb.state().await, CircuitState::Closed); 236 237 cb.record_failure().await; 238 assert_eq!(cb.state().await, CircuitState::Open); 239 assert!(!cb.can_execute().await); 240 } 241 242 #[tokio::test] 243 async fn test_circuit_breaker_success_resets_failures() { 244 let cb = CircuitBreaker::new("test", 3, 2, 10); 245 246 cb.record_failure().await; 247 cb.record_failure().await; 248 cb.record_success().await; 249 250 cb.record_failure().await; 251 cb.record_failure().await; 252 assert_eq!(cb.state().await, CircuitState::Closed); 253 254 cb.record_failure().await; 255 assert_eq!(cb.state().await, CircuitState::Open); 256 } 257 258 #[tokio::test] 259 async fn test_circuit_breaker_half_open_closes_after_successes() { 260 let cb = CircuitBreaker::new("test", 3, 2, 0); 261 262 for _ in 0..3 { 263 cb.record_failure().await; 264 } 265 assert_eq!(cb.state().await, CircuitState::Open); 266 267 tokio::time::sleep(Duration::from_millis(100)).await; 268 269 assert!(cb.can_execute().await); 270 assert_eq!(cb.state().await, CircuitState::HalfOpen); 271 272 cb.record_success().await; 273 assert_eq!(cb.state().await, CircuitState::HalfOpen); 274 275 cb.record_success().await; 276 assert_eq!(cb.state().await, CircuitState::Closed); 277 } 278 279 #[tokio::test] 280 async fn test_circuit_breaker_half_open_reopens_on_failure() { 281 let cb = CircuitBreaker::new("test", 3, 2, 0); 282 283 for _ in 0..3 { 284 cb.record_failure().await; 285 } 286 287 tokio::time::sleep(Duration::from_millis(100)).await; 288 cb.can_execute().await; 289 290 cb.record_failure().await; 291 assert_eq!(cb.state().await, CircuitState::Open); 292 } 293 294 #[tokio::test] 295 async fn test_with_circuit_breaker_helper() { 296 let cb = CircuitBreaker::new("test", 3, 2, 10); 297 298 let result: Result<i32, CircuitBreakerError<std::io::Error>> = 299 with_circuit_breaker(&cb, || async { Ok(42) }).await; 300 assert!(result.is_ok()); 301 assert_eq!(result.unwrap(), 42); 302 303 let result: Result<i32, CircuitBreakerError<&str>> = 304 with_circuit_breaker(&cb, || async { Err("error") }).await; 305 assert!(result.is_err()); 306 } 307}