this repo has no description
1use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
2use std::sync::Arc;
3use std::time::Duration;
4use tokio::sync::RwLock;
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum CircuitState {
7 Closed,
8 Open,
9 HalfOpen,
10}
11pub struct CircuitBreaker {
12 name: String,
13 failure_threshold: u32,
14 success_threshold: u32,
15 timeout: Duration,
16 state: Arc<RwLock<CircuitState>>,
17 failure_count: AtomicU32,
18 success_count: AtomicU32,
19 last_failure_time: AtomicU64,
20}
21impl CircuitBreaker {
22 pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self {
23 Self {
24 name: name.to_string(),
25 failure_threshold,
26 success_threshold,
27 timeout: Duration::from_secs(timeout_secs),
28 state: Arc::new(RwLock::new(CircuitState::Closed)),
29 failure_count: AtomicU32::new(0),
30 success_count: AtomicU32::new(0),
31 last_failure_time: AtomicU64::new(0),
32 }
33 }
34 pub async fn can_execute(&self) -> bool {
35 let state = self.state.read().await;
36 match *state {
37 CircuitState::Closed => true,
38 CircuitState::Open => {
39 let last_failure = self.last_failure_time.load(Ordering::SeqCst);
40 let now = std::time::SystemTime::now()
41 .duration_since(std::time::UNIX_EPOCH)
42 .unwrap()
43 .as_secs();
44 if now - last_failure >= self.timeout.as_secs() {
45 drop(state);
46 let mut state = self.state.write().await;
47 if *state == CircuitState::Open {
48 *state = CircuitState::HalfOpen;
49 self.success_count.store(0, Ordering::SeqCst);
50 tracing::info!(circuit = %self.name, "Circuit breaker transitioning to half-open");
51 return true;
52 }
53 }
54 false
55 }
56 CircuitState::HalfOpen => true,
57 }
58 }
59 pub async fn record_success(&self) {
60 let state = *self.state.read().await;
61 match state {
62 CircuitState::Closed => {
63 self.failure_count.store(0, Ordering::SeqCst);
64 }
65 CircuitState::HalfOpen => {
66 let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
67 if count >= self.success_threshold {
68 let mut state = self.state.write().await;
69 *state = CircuitState::Closed;
70 self.failure_count.store(0, Ordering::SeqCst);
71 self.success_count.store(0, Ordering::SeqCst);
72 tracing::info!(circuit = %self.name, "Circuit breaker closed after successful recovery");
73 }
74 }
75 CircuitState::Open => {}
76 }
77 }
78 pub async fn record_failure(&self) {
79 let state = *self.state.read().await;
80 match state {
81 CircuitState::Closed => {
82 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
83 if count >= self.failure_threshold {
84 let mut state = self.state.write().await;
85 *state = CircuitState::Open;
86 let now = std::time::SystemTime::now()
87 .duration_since(std::time::UNIX_EPOCH)
88 .unwrap()
89 .as_secs();
90 self.last_failure_time.store(now, Ordering::SeqCst);
91 tracing::warn!(
92 circuit = %self.name,
93 failures = count,
94 "Circuit breaker opened after {} failures",
95 count
96 );
97 }
98 }
99 CircuitState::HalfOpen => {
100 let mut state = self.state.write().await;
101 *state = CircuitState::Open;
102 let now = std::time::SystemTime::now()
103 .duration_since(std::time::UNIX_EPOCH)
104 .unwrap()
105 .as_secs();
106 self.last_failure_time.store(now, Ordering::SeqCst);
107 self.success_count.store(0, Ordering::SeqCst);
108 tracing::warn!(circuit = %self.name, "Circuit breaker reopened after failure in half-open state");
109 }
110 CircuitState::Open => {}
111 }
112 }
113 pub async fn state(&self) -> CircuitState {
114 *self.state.read().await
115 }
116 pub fn name(&self) -> &str {
117 &self.name
118 }
119}
120#[derive(Clone)]
121pub struct CircuitBreakers {
122 pub plc_directory: Arc<CircuitBreaker>,
123 pub relay_notification: Arc<CircuitBreaker>,
124}
125impl Default for CircuitBreakers {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130impl CircuitBreakers {
131 pub fn new() -> Self {
132 Self {
133 plc_directory: Arc::new(CircuitBreaker::new("plc_directory", 5, 3, 60)),
134 relay_notification: Arc::new(CircuitBreaker::new("relay_notification", 10, 5, 30)),
135 }
136 }
137}
138#[derive(Debug)]
139pub struct CircuitOpenError {
140 pub circuit_name: String,
141}
142impl std::fmt::Display for CircuitOpenError {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 write!(f, "Circuit breaker '{}' is open", self.circuit_name)
145 }
146}
147impl std::error::Error for CircuitOpenError {}
148pub async fn with_circuit_breaker<T, E, F, Fut>(
149 circuit: &CircuitBreaker,
150 operation: F,
151) -> Result<T, CircuitBreakerError<E>>
152where
153 F: FnOnce() -> Fut,
154 Fut: std::future::Future<Output = Result<T, E>>,
155{
156 if !circuit.can_execute().await {
157 return Err(CircuitBreakerError::CircuitOpen(CircuitOpenError {
158 circuit_name: circuit.name().to_string(),
159 }));
160 }
161 match operation().await {
162 Ok(result) => {
163 circuit.record_success().await;
164 Ok(result)
165 }
166 Err(e) => {
167 circuit.record_failure().await;
168 Err(CircuitBreakerError::OperationFailed(e))
169 }
170 }
171}
172#[derive(Debug)]
173pub enum CircuitBreakerError<E> {
174 CircuitOpen(CircuitOpenError),
175 OperationFailed(E),
176}
177impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 match self {
180 CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
181 CircuitBreakerError::OperationFailed(e) => write!(f, "Operation failed: {}", e),
182 }
183 }
184}
185impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> {
186 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
187 match self {
188 CircuitBreakerError::CircuitOpen(e) => Some(e),
189 CircuitBreakerError::OperationFailed(e) => Some(e),
190 }
191 }
192}
193#[cfg(test)]
194mod tests {
195 use super::*;
196 #[tokio::test]
197 async fn test_circuit_breaker_starts_closed() {
198 let cb = CircuitBreaker::new("test", 3, 2, 10);
199 assert_eq!(cb.state().await, CircuitState::Closed);
200 assert!(cb.can_execute().await);
201 }
202 #[tokio::test]
203 async fn test_circuit_breaker_opens_after_failures() {
204 let cb = CircuitBreaker::new("test", 3, 2, 10);
205 cb.record_failure().await;
206 assert_eq!(cb.state().await, CircuitState::Closed);
207 cb.record_failure().await;
208 assert_eq!(cb.state().await, CircuitState::Closed);
209 cb.record_failure().await;
210 assert_eq!(cb.state().await, CircuitState::Open);
211 assert!(!cb.can_execute().await);
212 }
213 #[tokio::test]
214 async fn test_circuit_breaker_success_resets_failures() {
215 let cb = CircuitBreaker::new("test", 3, 2, 10);
216 cb.record_failure().await;
217 cb.record_failure().await;
218 cb.record_success().await;
219 cb.record_failure().await;
220 cb.record_failure().await;
221 assert_eq!(cb.state().await, CircuitState::Closed);
222 cb.record_failure().await;
223 assert_eq!(cb.state().await, CircuitState::Open);
224 }
225 #[tokio::test]
226 async fn test_circuit_breaker_half_open_closes_after_successes() {
227 let cb = CircuitBreaker::new("test", 3, 2, 0);
228 for _ in 0..3 {
229 cb.record_failure().await;
230 }
231 assert_eq!(cb.state().await, CircuitState::Open);
232 tokio::time::sleep(Duration::from_millis(100)).await;
233 assert!(cb.can_execute().await);
234 assert_eq!(cb.state().await, CircuitState::HalfOpen);
235 cb.record_success().await;
236 assert_eq!(cb.state().await, CircuitState::HalfOpen);
237 cb.record_success().await;
238 assert_eq!(cb.state().await, CircuitState::Closed);
239 }
240 #[tokio::test]
241 async fn test_circuit_breaker_half_open_reopens_on_failure() {
242 let cb = CircuitBreaker::new("test", 3, 2, 0);
243 for _ in 0..3 {
244 cb.record_failure().await;
245 }
246 tokio::time::sleep(Duration::from_millis(100)).await;
247 cb.can_execute().await;
248 cb.record_failure().await;
249 assert_eq!(cb.state().await, CircuitState::Open);
250 }
251 #[tokio::test]
252 async fn test_with_circuit_breaker_helper() {
253 let cb = CircuitBreaker::new("test", 3, 2, 10);
254 let result: Result<i32, CircuitBreakerError<std::io::Error>> =
255 with_circuit_breaker(&cb, || async { Ok(42) }).await;
256 assert!(result.is_ok());
257 assert_eq!(result.unwrap(), 42);
258 let result: Result<i32, CircuitBreakerError<&str>> =
259 with_circuit_breaker(&cb, || async { Err("error") }).await;
260 assert!(result.is_err());
261 }
262}