The world's most clever kitty cat

Require at least 2 tokens before allowing end

bwc9876.dev c158cdd6 4267b47d

verified
+47 -18
+47 -18
src/brain.rs
··· 17 17 pub struct Brain(HashMap<Token, Edges>); 18 18 19 19 pub type TypingSender = oneshot::Sender<bool>; 20 - pub type TypingReceiver = oneshot::Receiver<bool>; 21 20 22 21 pub fn format_token(tok: &Token) -> String { 23 22 if let Some(w) = tok { ··· 49 48 } 50 49 } 51 50 52 - fn sample(&self, rand: &mut fastrand::Rng) -> Option<Token> { 53 - let mut dist_left = rand.f64() * self.1 as f64; 54 - for (tok, weight) in self.0.iter() { 51 + fn sample(&self, rand: &mut fastrand::Rng, allow_end: bool) -> Option<&Token> { 52 + let total_dist = if !allow_end && let Some(weight) = self.0.get(&None) { 53 + self.1 - *weight as u64 54 + } else { 55 + self.1 56 + }; 57 + let mut dist_left = rand.f64() * total_dist as f64; 58 + 59 + for (tok, weight) in self.0.iter().filter(|(tok, _)| allow_end || tok.is_some()) { 55 60 dist_left -= *weight as f64; 56 61 if dist_left < 0.0 { 57 - return Some(tok.clone()); 62 + return Some(tok); 58 63 } 59 64 } 60 65 None ··· 161 166 162 167 // Get our final token, or a random one if the message has nothing, or don't reply at all 163 168 // if we have no tokens at all. 164 - let mut current_token = 165 - Self::extract_final_token(msg).or_else(|| self.random_token(&mut rng))?; 169 + let last_token = Self::extract_final_token(msg).or_else(|| self.random_token(&mut rng))?; 170 + let mut current_token = &last_token; 166 171 167 172 let mut chain = Vec::with_capacity(MAX_TOKENS); 168 173 let mut has_triggered_typing = false; 169 174 170 - while let Some(tok) = current_token 171 - && chain.len() <= MAX_TOKENS 172 - { 173 - if let Some(edges) = self.0.get(&Some(tok)) { 174 - let next = edges.sample(&mut rng).flatten(); 175 - if let Some(ref s) = next { 176 - chain.push(s.clone()); 177 - if !has_triggered_typing && let Some(typ) = typing_oneshot.take() { 178 - typ.send(true).ok(); 175 + while current_token.is_some() && chain.len() <= MAX_TOKENS { 176 + if let Some(edges) = self.0.get(current_token) { 177 + let next = edges.sample(&mut rng, chain.len() > 2); 178 + 179 + if let Some(ref tok) = next { 180 + if let Some(s) = tok { 181 + // Is this a non-ending token? If so, push it to our chain! 182 + chain.push(s.clone()); 183 + if !has_triggered_typing && let Some(typ) = typing_oneshot.take() { 184 + typ.send(true).ok(); 185 + } 186 + current_token = tok; 187 + } else { 188 + // If we reached an end token, stop chaining 189 + break; 179 190 } 191 + } else { 192 + // If we failed to sample any tokens, we can't continue the chain 193 + break; 180 194 } 181 - current_token = next; 182 195 } else { 183 - current_token = None; 196 + // If we don't know the current word, we can't continue the chain 197 + break; 184 198 } 185 199 } 186 200 ··· 243 257 ); 244 258 let reply = brain.respond("hello", false, None); 245 259 assert_eq!(reply, Some("world".to_string())); 260 + } 261 + 262 + #[test] 263 + fn at_least_2_tokens() { 264 + let mut brain = Brain::default(); 265 + brain.ingest("hello world"); 266 + brain.ingest("hello"); 267 + brain.ingest("hello"); 268 + brain.ingest("hello"); 269 + 270 + for _ in 0..100 { 271 + // I'm too lazy to mock lazyrand LOL!! 272 + let reply = brain.respond("hello", false, None); 273 + assert_eq!(reply, Some("world".to_string())); 274 + } 246 275 } 247 276 248 277 #[test]