this repo has no description
1use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; 2use std::sync::Arc; 3use std::time::Duration; 4use tokio::sync::RwLock; 5#[derive(Debug, Clone, Copy, PartialEq, Eq)] 6pub enum CircuitState { 7 Closed, 8 Open, 9 HalfOpen, 10} 11pub struct CircuitBreaker { 12 name: String, 13 failure_threshold: u32, 14 success_threshold: u32, 15 timeout: Duration, 16 state: Arc<RwLock<CircuitState>>, 17 failure_count: AtomicU32, 18 success_count: AtomicU32, 19 last_failure_time: AtomicU64, 20} 21impl CircuitBreaker { 22 pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self { 23 Self { 24 name: name.to_string(), 25 failure_threshold, 26 success_threshold, 27 timeout: Duration::from_secs(timeout_secs), 28 state: Arc::new(RwLock::new(CircuitState::Closed)), 29 failure_count: AtomicU32::new(0), 30 success_count: AtomicU32::new(0), 31 last_failure_time: AtomicU64::new(0), 32 } 33 } 34 pub async fn can_execute(&self) -> bool { 35 let state = self.state.read().await; 36 match *state { 37 CircuitState::Closed => true, 38 CircuitState::Open => { 39 let last_failure = self.last_failure_time.load(Ordering::SeqCst); 40 let now = std::time::SystemTime::now() 41 .duration_since(std::time::UNIX_EPOCH) 42 .unwrap() 43 .as_secs(); 44 if now - last_failure >= self.timeout.as_secs() { 45 drop(state); 46 let mut state = self.state.write().await; 47 if *state == CircuitState::Open { 48 *state = CircuitState::HalfOpen; 49 self.success_count.store(0, Ordering::SeqCst); 50 tracing::info!(circuit = %self.name, "Circuit breaker transitioning to half-open"); 51 return true; 52 } 53 } 54 false 55 } 56 CircuitState::HalfOpen => true, 57 } 58 } 59 pub async fn record_success(&self) { 60 let state = *self.state.read().await; 61 match state { 62 CircuitState::Closed => { 63 self.failure_count.store(0, Ordering::SeqCst); 64 } 65 CircuitState::HalfOpen => { 66 let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1; 67 if count >= self.success_threshold { 68 let mut state = self.state.write().await; 69 *state = CircuitState::Closed; 70 self.failure_count.store(0, Ordering::SeqCst); 71 self.success_count.store(0, Ordering::SeqCst); 72 tracing::info!(circuit = %self.name, "Circuit breaker closed after successful recovery"); 73 } 74 } 75 CircuitState::Open => {} 76 } 77 } 78 pub async fn record_failure(&self) { 79 let state = *self.state.read().await; 80 match state { 81 CircuitState::Closed => { 82 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1; 83 if count >= self.failure_threshold { 84 let mut state = self.state.write().await; 85 *state = CircuitState::Open; 86 let now = std::time::SystemTime::now() 87 .duration_since(std::time::UNIX_EPOCH) 88 .unwrap() 89 .as_secs(); 90 self.last_failure_time.store(now, Ordering::SeqCst); 91 tracing::warn!( 92 circuit = %self.name, 93 failures = count, 94 "Circuit breaker opened after {} failures", 95 count 96 ); 97 } 98 } 99 CircuitState::HalfOpen => { 100 let mut state = self.state.write().await; 101 *state = CircuitState::Open; 102 let now = std::time::SystemTime::now() 103 .duration_since(std::time::UNIX_EPOCH) 104 .unwrap() 105 .as_secs(); 106 self.last_failure_time.store(now, Ordering::SeqCst); 107 self.success_count.store(0, Ordering::SeqCst); 108 tracing::warn!(circuit = %self.name, "Circuit breaker reopened after failure in half-open state"); 109 } 110 CircuitState::Open => {} 111 } 112 } 113 pub async fn state(&self) -> CircuitState { 114 *self.state.read().await 115 } 116 pub fn name(&self) -> &str { 117 &self.name 118 } 119} 120#[derive(Clone)] 121pub struct CircuitBreakers { 122 pub plc_directory: Arc<CircuitBreaker>, 123 pub relay_notification: Arc<CircuitBreaker>, 124} 125impl Default for CircuitBreakers { 126 fn default() -> Self { 127 Self::new() 128 } 129} 130impl CircuitBreakers { 131 pub fn new() -> Self { 132 Self { 133 plc_directory: Arc::new(CircuitBreaker::new("plc_directory", 5, 3, 60)), 134 relay_notification: Arc::new(CircuitBreaker::new("relay_notification", 10, 5, 30)), 135 } 136 } 137} 138#[derive(Debug)] 139pub struct CircuitOpenError { 140 pub circuit_name: String, 141} 142impl std::fmt::Display for CircuitOpenError { 143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 144 write!(f, "Circuit breaker '{}' is open", self.circuit_name) 145 } 146} 147impl std::error::Error for CircuitOpenError {} 148pub async fn with_circuit_breaker<T, E, F, Fut>( 149 circuit: &CircuitBreaker, 150 operation: F, 151) -> Result<T, CircuitBreakerError<E>> 152where 153 F: FnOnce() -> Fut, 154 Fut: std::future::Future<Output = Result<T, E>>, 155{ 156 if !circuit.can_execute().await { 157 return Err(CircuitBreakerError::CircuitOpen(CircuitOpenError { 158 circuit_name: circuit.name().to_string(), 159 })); 160 } 161 match operation().await { 162 Ok(result) => { 163 circuit.record_success().await; 164 Ok(result) 165 } 166 Err(e) => { 167 circuit.record_failure().await; 168 Err(CircuitBreakerError::OperationFailed(e)) 169 } 170 } 171} 172#[derive(Debug)] 173pub enum CircuitBreakerError<E> { 174 CircuitOpen(CircuitOpenError), 175 OperationFailed(E), 176} 177impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> { 178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 179 match self { 180 CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e), 181 CircuitBreakerError::OperationFailed(e) => write!(f, "Operation failed: {}", e), 182 } 183 } 184} 185impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> { 186 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 187 match self { 188 CircuitBreakerError::CircuitOpen(e) => Some(e), 189 CircuitBreakerError::OperationFailed(e) => Some(e), 190 } 191 } 192} 193#[cfg(test)] 194mod tests { 195 use super::*; 196 #[tokio::test] 197 async fn test_circuit_breaker_starts_closed() { 198 let cb = CircuitBreaker::new("test", 3, 2, 10); 199 assert_eq!(cb.state().await, CircuitState::Closed); 200 assert!(cb.can_execute().await); 201 } 202 #[tokio::test] 203 async fn test_circuit_breaker_opens_after_failures() { 204 let cb = CircuitBreaker::new("test", 3, 2, 10); 205 cb.record_failure().await; 206 assert_eq!(cb.state().await, CircuitState::Closed); 207 cb.record_failure().await; 208 assert_eq!(cb.state().await, CircuitState::Closed); 209 cb.record_failure().await; 210 assert_eq!(cb.state().await, CircuitState::Open); 211 assert!(!cb.can_execute().await); 212 } 213 #[tokio::test] 214 async fn test_circuit_breaker_success_resets_failures() { 215 let cb = CircuitBreaker::new("test", 3, 2, 10); 216 cb.record_failure().await; 217 cb.record_failure().await; 218 cb.record_success().await; 219 cb.record_failure().await; 220 cb.record_failure().await; 221 assert_eq!(cb.state().await, CircuitState::Closed); 222 cb.record_failure().await; 223 assert_eq!(cb.state().await, CircuitState::Open); 224 } 225 #[tokio::test] 226 async fn test_circuit_breaker_half_open_closes_after_successes() { 227 let cb = CircuitBreaker::new("test", 3, 2, 0); 228 for _ in 0..3 { 229 cb.record_failure().await; 230 } 231 assert_eq!(cb.state().await, CircuitState::Open); 232 tokio::time::sleep(Duration::from_millis(100)).await; 233 assert!(cb.can_execute().await); 234 assert_eq!(cb.state().await, CircuitState::HalfOpen); 235 cb.record_success().await; 236 assert_eq!(cb.state().await, CircuitState::HalfOpen); 237 cb.record_success().await; 238 assert_eq!(cb.state().await, CircuitState::Closed); 239 } 240 #[tokio::test] 241 async fn test_circuit_breaker_half_open_reopens_on_failure() { 242 let cb = CircuitBreaker::new("test", 3, 2, 0); 243 for _ in 0..3 { 244 cb.record_failure().await; 245 } 246 tokio::time::sleep(Duration::from_millis(100)).await; 247 cb.can_execute().await; 248 cb.record_failure().await; 249 assert_eq!(cb.state().await, CircuitState::Open); 250 } 251 #[tokio::test] 252 async fn test_with_circuit_breaker_helper() { 253 let cb = CircuitBreaker::new("test", 3, 2, 10); 254 let result: Result<i32, CircuitBreakerError<std::io::Error>> = 255 with_circuit_breaker(&cb, || async { Ok(42) }).await; 256 assert!(result.is_ok()); 257 assert_eq!(result.unwrap(), 42); 258 let result: Result<i32, CircuitBreakerError<&str>> = 259 with_circuit_breaker(&cb, || async { Err("error") }).await; 260 assert!(result.is_err()); 261 } 262}