this repo has no description
1use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
2use std::sync::Arc;
3use std::time::Duration;
4use tokio::sync::RwLock;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum CircuitState {
8 Closed,
9 Open,
10 HalfOpen,
11}
12
13pub struct CircuitBreaker {
14 name: String,
15 failure_threshold: u32,
16 success_threshold: u32,
17 timeout: Duration,
18 state: Arc<RwLock<CircuitState>>,
19 failure_count: AtomicU32,
20 success_count: AtomicU32,
21 last_failure_time: AtomicU64,
22}
23
24impl CircuitBreaker {
25 pub fn new(name: &str, failure_threshold: u32, success_threshold: u32, timeout_secs: u64) -> Self {
26 Self {
27 name: name.to_string(),
28 failure_threshold,
29 success_threshold,
30 timeout: Duration::from_secs(timeout_secs),
31 state: Arc::new(RwLock::new(CircuitState::Closed)),
32 failure_count: AtomicU32::new(0),
33 success_count: AtomicU32::new(0),
34 last_failure_time: AtomicU64::new(0),
35 }
36 }
37
38 pub async fn can_execute(&self) -> bool {
39 let state = self.state.read().await;
40 match *state {
41 CircuitState::Closed => true,
42 CircuitState::Open => {
43 let last_failure = self.last_failure_time.load(Ordering::SeqCst);
44 let now = std::time::SystemTime::now()
45 .duration_since(std::time::UNIX_EPOCH)
46 .unwrap()
47 .as_secs();
48
49 if now - last_failure >= self.timeout.as_secs() {
50 drop(state);
51 let mut state = self.state.write().await;
52 if *state == CircuitState::Open {
53 *state = CircuitState::HalfOpen;
54 self.success_count.store(0, Ordering::SeqCst);
55 tracing::info!(circuit = %self.name, "Circuit breaker transitioning to half-open");
56 return true;
57 }
58 }
59 false
60 }
61 CircuitState::HalfOpen => true,
62 }
63 }
64
65 pub async fn record_success(&self) {
66 let state = *self.state.read().await;
67
68 match state {
69 CircuitState::Closed => {
70 self.failure_count.store(0, Ordering::SeqCst);
71 }
72 CircuitState::HalfOpen => {
73 let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
74 if count >= self.success_threshold {
75 let mut state = self.state.write().await;
76 *state = CircuitState::Closed;
77 self.failure_count.store(0, Ordering::SeqCst);
78 self.success_count.store(0, Ordering::SeqCst);
79 tracing::info!(circuit = %self.name, "Circuit breaker closed after successful recovery");
80 }
81 }
82 CircuitState::Open => {}
83 }
84 }
85
86 pub async fn record_failure(&self) {
87 let state = *self.state.read().await;
88
89 match state {
90 CircuitState::Closed => {
91 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
92 if count >= self.failure_threshold {
93 let mut state = self.state.write().await;
94 *state = CircuitState::Open;
95 let now = std::time::SystemTime::now()
96 .duration_since(std::time::UNIX_EPOCH)
97 .unwrap()
98 .as_secs();
99 self.last_failure_time.store(now, Ordering::SeqCst);
100 tracing::warn!(
101 circuit = %self.name,
102 failures = count,
103 "Circuit breaker opened after {} failures",
104 count
105 );
106 }
107 }
108 CircuitState::HalfOpen => {
109 let mut state = self.state.write().await;
110 *state = CircuitState::Open;
111 let now = std::time::SystemTime::now()
112 .duration_since(std::time::UNIX_EPOCH)
113 .unwrap()
114 .as_secs();
115 self.last_failure_time.store(now, Ordering::SeqCst);
116 self.success_count.store(0, Ordering::SeqCst);
117 tracing::warn!(circuit = %self.name, "Circuit breaker reopened after failure in half-open state");
118 }
119 CircuitState::Open => {}
120 }
121 }
122
123 pub async fn state(&self) -> CircuitState {
124 *self.state.read().await
125 }
126
127 pub fn name(&self) -> &str {
128 &self.name
129 }
130}
131
132#[derive(Clone)]
133pub struct CircuitBreakers {
134 pub plc_directory: Arc<CircuitBreaker>,
135 pub relay_notification: Arc<CircuitBreaker>,
136}
137
138impl Default for CircuitBreakers {
139 fn default() -> Self {
140 Self::new()
141 }
142}
143
144impl CircuitBreakers {
145 pub fn new() -> Self {
146 Self {
147 plc_directory: Arc::new(CircuitBreaker::new("plc_directory", 5, 3, 60)),
148 relay_notification: Arc::new(CircuitBreaker::new("relay_notification", 10, 5, 30)),
149 }
150 }
151}
152
153#[derive(Debug)]
154pub struct CircuitOpenError {
155 pub circuit_name: String,
156}
157
158impl std::fmt::Display for CircuitOpenError {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 write!(f, "Circuit breaker '{}' is open", self.circuit_name)
161 }
162}
163
164impl std::error::Error for CircuitOpenError {}
165
166pub async fn with_circuit_breaker<T, E, F, Fut>(
167 circuit: &CircuitBreaker,
168 operation: F,
169) -> Result<T, CircuitBreakerError<E>>
170where
171 F: FnOnce() -> Fut,
172 Fut: std::future::Future<Output = Result<T, E>>,
173{
174 if !circuit.can_execute().await {
175 return Err(CircuitBreakerError::CircuitOpen(CircuitOpenError {
176 circuit_name: circuit.name().to_string(),
177 }));
178 }
179
180 match operation().await {
181 Ok(result) => {
182 circuit.record_success().await;
183 Ok(result)
184 }
185 Err(e) => {
186 circuit.record_failure().await;
187 Err(CircuitBreakerError::OperationFailed(e))
188 }
189 }
190}
191
192#[derive(Debug)]
193pub enum CircuitBreakerError<E> {
194 CircuitOpen(CircuitOpenError),
195 OperationFailed(E),
196}
197
198impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 match self {
201 CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
202 CircuitBreakerError::OperationFailed(e) => write!(f, "Operation failed: {}", e),
203 }
204 }
205}
206
207impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> {
208 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
209 match self {
210 CircuitBreakerError::CircuitOpen(e) => Some(e),
211 CircuitBreakerError::OperationFailed(e) => Some(e),
212 }
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[tokio::test]
221 async fn test_circuit_breaker_starts_closed() {
222 let cb = CircuitBreaker::new("test", 3, 2, 10);
223 assert_eq!(cb.state().await, CircuitState::Closed);
224 assert!(cb.can_execute().await);
225 }
226
227 #[tokio::test]
228 async fn test_circuit_breaker_opens_after_failures() {
229 let cb = CircuitBreaker::new("test", 3, 2, 10);
230
231 cb.record_failure().await;
232 assert_eq!(cb.state().await, CircuitState::Closed);
233
234 cb.record_failure().await;
235 assert_eq!(cb.state().await, CircuitState::Closed);
236
237 cb.record_failure().await;
238 assert_eq!(cb.state().await, CircuitState::Open);
239 assert!(!cb.can_execute().await);
240 }
241
242 #[tokio::test]
243 async fn test_circuit_breaker_success_resets_failures() {
244 let cb = CircuitBreaker::new("test", 3, 2, 10);
245
246 cb.record_failure().await;
247 cb.record_failure().await;
248 cb.record_success().await;
249
250 cb.record_failure().await;
251 cb.record_failure().await;
252 assert_eq!(cb.state().await, CircuitState::Closed);
253
254 cb.record_failure().await;
255 assert_eq!(cb.state().await, CircuitState::Open);
256 }
257
258 #[tokio::test]
259 async fn test_circuit_breaker_half_open_closes_after_successes() {
260 let cb = CircuitBreaker::new("test", 3, 2, 0);
261
262 for _ in 0..3 {
263 cb.record_failure().await;
264 }
265 assert_eq!(cb.state().await, CircuitState::Open);
266
267 tokio::time::sleep(Duration::from_millis(100)).await;
268
269 assert!(cb.can_execute().await);
270 assert_eq!(cb.state().await, CircuitState::HalfOpen);
271
272 cb.record_success().await;
273 assert_eq!(cb.state().await, CircuitState::HalfOpen);
274
275 cb.record_success().await;
276 assert_eq!(cb.state().await, CircuitState::Closed);
277 }
278
279 #[tokio::test]
280 async fn test_circuit_breaker_half_open_reopens_on_failure() {
281 let cb = CircuitBreaker::new("test", 3, 2, 0);
282
283 for _ in 0..3 {
284 cb.record_failure().await;
285 }
286
287 tokio::time::sleep(Duration::from_millis(100)).await;
288 cb.can_execute().await;
289
290 cb.record_failure().await;
291 assert_eq!(cb.state().await, CircuitState::Open);
292 }
293
294 #[tokio::test]
295 async fn test_with_circuit_breaker_helper() {
296 let cb = CircuitBreaker::new("test", 3, 2, 10);
297
298 let result: Result<i32, CircuitBreakerError<std::io::Error>> =
299 with_circuit_breaker(&cb, || async { Ok(42) }).await;
300 assert!(result.is_ok());
301 assert_eq!(result.unwrap(), 42);
302
303 let result: Result<i32, CircuitBreakerError<&str>> =
304 with_circuit_breaker(&cb, || async { Err("error") }).await;
305 assert!(result.is_err());
306 }
307}