The world's most clever kitty cat
at main 440 lines 12 kB view raw
1#![allow(unused)] 2 3use std::collections::HashMap; 4 5use log::debug; 6use serde::{Deserialize, Serialize}; 7use tokio::sync::oneshot; 8 9/// Some = Word, None = End Message 10pub type Token = Option<String>; 11pub type Weight = u16; 12 13#[derive(Default, Debug, Clone, Serialize, Deserialize)] 14pub struct Edges(HashMap<Token, Weight>, u64); 15 16#[derive(Default, Debug, Clone, Serialize, Deserialize)] 17pub struct Brain(HashMap<Token, Edges>); 18 19pub type TypingSender = oneshot::Sender<bool>; 20 21pub fn format_token(tok: &Token) -> String { 22 if let Some(w) = tok { 23 w.clone() 24 } else { 25 "~END".to_string() 26 } 27} 28 29impl Edges { 30 fn increment_token(&mut self, tok: &Token) { 31 if let Some(w) = self.0.get_mut(tok) { 32 *w = w.saturating_add(1); 33 } else { 34 self.0.insert(tok.clone(), 1); 35 } 36 self.1 = self.1.saturating_add(1); 37 } 38 39 fn merge_from(&mut self, other: Self) { 40 self.0.reserve(other.0.len()); 41 for (k, v) in other.0.into_iter() { 42 if let Some(w) = self.0.get_mut(&k) { 43 *w = w.saturating_add(v); 44 } else { 45 self.0.insert(k, v); 46 } 47 self.1 = self.1.saturating_add(v as u64); 48 } 49 } 50 51 fn sample(&self, rand: &mut fastrand::Rng, allow_end: bool) -> Option<&Token> { 52 let total_dist = if !allow_end && let Some(weight) = self.0.get(&None) { 53 self.1 - *weight as u64 54 } else { 55 self.1 56 }; 57 let mut dist_left = rand.f64() * total_dist as f64; 58 59 for (tok, weight) in self.0.iter().filter(|(tok, _)| allow_end || tok.is_some()) { 60 dist_left -= *weight as f64; 61 if dist_left < 0.0 { 62 return Some(tok); 63 } 64 } 65 None 66 } 67 68 pub fn iter_weights(&self) -> impl Iterator<Item = (&Token, Weight, f64)> { 69 self.0 70 .iter() 71 .map(|(k, v)| (k, *v, (*v as f64) / (self.1 as f64))) 72 } 73} 74 75const FORCE_REPLIES: bool = cfg!(test) || (option_env!("BINGUS_FORCE_REPLY").is_some()); 76 77impl Brain { 78 fn normalize_token(word: &str) -> Token { 79 let w = if word.starts_with("http://") || word.starts_with("https://") { 80 word.to_string() 81 } else { 82 word.to_ascii_lowercase() 83 }; 84 Some(w) 85 } 86 87 fn parse(msg: &str) -> impl Iterator<Item = Token> { 88 msg.split_whitespace() 89 .filter_map(|w| { 90 // Filter out pings, they can get annoying 91 if w.starts_with("<@") && w.ends_with(">") { 92 None 93 } else { 94 Some(Self::normalize_token(w)) 95 } 96 }) 97 .chain(std::iter::once(None)) 98 } 99 100 fn should_reply(rand: &mut fastrand::Rng, is_self: bool) -> bool { 101 let chance = if is_self { 45 } else { 80 }; 102 let roll = rand.u8(0..=100); 103 104 (FORCE_REPLIES) || roll <= chance 105 } 106 107 fn extract_final_word(msg: &str) -> Option<String> { 108 msg.split_whitespace() 109 .last() 110 .and_then(Self::normalize_token) 111 } 112 113 fn random_token(&self, rand: &mut fastrand::Rng) -> Option<&Token> { 114 let len = self.0.len(); 115 if len == 0 { 116 None 117 } else { 118 let i = rand.usize(..len); 119 self.0.keys().nth(i) 120 } 121 } 122 123 pub fn ingest(&mut self, msg: &str) -> bool { 124 // Using reduce instead of .any here to prevent short circuting 125 Self::parse(msg) 126 .map_windows(|[from, to]| { 127 if let Some(edge) = self.0.get_mut(from) { 128 edge.increment_token(to); 129 false 130 } else { 131 let new = Edges(HashMap::from_iter([(to.clone(), 1)]), 1); 132 self.0.insert(from.clone(), new); 133 true 134 } 135 }) 136 .reduce(|acc, c| acc || c) 137 .unwrap_or_default() 138 } 139 140 pub fn merge_from(&mut self, other: Self) { 141 for (k, v) in other.0.into_iter() { 142 if let Some(edges) = self.0.get_mut(&k) { 143 edges.merge_from(v); 144 } else { 145 self.0.insert(k, v); 146 } 147 } 148 } 149 150 fn next_from(&self, tok: &Token, rand: &mut fastrand::Rng, allow_end: bool) -> Option<&Token> { 151 // Get the edges for the current token 152 // If we have that token, sample its edges 153 // Otherwise, if we don't know that token, and allow_end is false, try to pick a random token instead 154 self.0 155 .get(tok) 156 .and_then(|edges| edges.sample(rand, allow_end)) 157 .or_else(|| { 158 if allow_end { 159 None 160 } else { 161 self.random_token(rand) 162 } 163 }) 164 } 165 166 pub fn respond( 167 &self, 168 msg: &str, 169 is_self: bool, 170 force_reply: bool, 171 mut typing_oneshot: Option<TypingSender>, 172 ) -> Option<String> { 173 const MAX_TOKENS: usize = 20; 174 175 let mut rng = fastrand::Rng::new(); 176 177 // Roll if we should reply 178 if !force_reply && !Self::should_reply(&mut rng, is_self) { 179 debug!("Failed roll"); 180 return None; 181 } 182 183 // Get the final token 184 let last_token = Self::extract_final_word(msg); 185 186 let mut current_token = if let Some(t) = last_token { 187 // We found a word at the end of the previous message 188 &Some(t) 189 } else { 190 // We couldn't find a word at the end of the last message, pick a random one 191 // If we *still* don't have a token, return early 192 self.random_token(&mut rng)? 193 }; 194 195 let mut chain = Vec::with_capacity(MAX_TOKENS); 196 let sep = String::from(" "); 197 198 while let Some(next @ Some(s)) = self.next_from(current_token, &mut rng, !chain.is_empty()) 199 && chain.len() <= MAX_TOKENS 200 { 201 chain.push(s); 202 if let Some(typ) = typing_oneshot.take() { 203 typ.send(true).ok(); 204 } 205 current_token = next; 206 } 207 208 if let Some(typ) = typing_oneshot.take() { 209 typ.send(false).ok(); 210 } 211 212 if chain.is_empty() { 213 None 214 } else { 215 let s = chain 216 .into_iter() 217 .intersperse(&sep) 218 .cloned() 219 .collect::<String>(); 220 Some(s).filter(|s| !s.trim().is_empty()) 221 } 222 } 223 224 pub fn word_count(&self) -> usize { 225 self.0.len() 226 } 227 228 pub fn get_weights(&self, tok: &str) -> Option<&Edges> { 229 self.0.get(&Self::normalize_token(tok)) 230 } 231 232 fn legacy_token_format(tok: &Token) -> String { 233 tok.as_ref() 234 .map(|s| format!("W-{s}")) 235 .unwrap_or_else(|| String::from("E--")) 236 } 237 238 pub fn as_legacy_hashmap(&self) -> HashMap<String, HashMap<String, Weight>> { 239 self.0 240 .iter() 241 .map(|(k, v)| { 242 let map = 243 v.0.iter() 244 .map(|(t, w)| (Self::legacy_token_format(t), *w)) 245 .collect(); 246 (Self::legacy_token_format(k), map) 247 }) 248 .collect() 249 } 250 251 fn read_legacy_token(s: String) -> Token { 252 match s.as_str() { 253 "E--" => None, 254 word => Some(word.strip_prefix("W-").unwrap_or(word).to_string()), 255 } 256 } 257 258 pub fn from_legacy_hashmap(map: HashMap<String, HashMap<String, Weight>>) -> Self { 259 Self( 260 map.into_iter() 261 .map(|(k, v)| { 262 let sum = v.values().map(|w| *w as u64).sum::<u64>(); 263 let edges = Edges( 264 v.into_iter() 265 .map(|(t, w)| (Self::read_legacy_token(t), w)) 266 .collect(), 267 sum, 268 ); 269 (Self::read_legacy_token(k), edges) 270 }) 271 .collect(), 272 ) 273 } 274} 275 276#[cfg(test)] 277mod tests { 278 279 use super::*; 280 use std::default::Default; 281 282 extern crate test; 283 284 use test::Bencher; 285 286 #[test] 287 fn ingest_parse() { 288 let tokens = Brain::parse("Hello world").collect::<Vec<_>>(); 289 assert_eq!( 290 tokens, 291 vec![Some("hello".to_string()), Some("world".to_string()), None] 292 ); 293 } 294 295 #[test] 296 fn ingest_url() { 297 let tokens = Brain::parse("https://example.com/CAPS-PATH").collect::<Vec<_>>(); 298 assert_eq!( 299 tokens, 300 vec![Some("https://example.com/CAPS-PATH".to_string()), None] 301 ); 302 } 303 304 #[test] 305 fn ingest_ping() { 306 let tokens = Brain::parse("hi <@1234567>").collect::<Vec<_>>(); 307 assert_eq!(tokens, vec![Some("hi".to_string()), None]); 308 } 309 310 #[test] 311 fn basic_chain() { 312 let mut brain = Brain::default(); 313 brain.ingest("hello world"); 314 let hello_edges = brain 315 .0 316 .get(&Some("hello".to_string())) 317 .expect("Hello edges not created"); 318 assert_eq!( 319 hello_edges.0, 320 HashMap::from_iter([(Some("world".to_string()), 1)]) 321 ); 322 let reply = brain.respond("hello", false, false, None); 323 assert_eq!(reply, Some("world".to_string())); 324 } 325 326 #[test] 327 fn at_least_1_token() { 328 let mut brain = Brain::default(); 329 brain.ingest("hello world"); 330 for _ in 0..100 { 331 brain.ingest("hello"); 332 } 333 334 for _ in 0..100 { 335 // I'm too lazy to mock lazyrand LOL!! 336 let reply = brain.respond("hello", false, false, None); 337 assert_eq!(reply, Some("world".to_string())); 338 } 339 } 340 341 #[test] 342 fn none_on_empty() { 343 let mut brain = Brain::default(); 344 345 let reply = brain.respond("hello", false, false, None); 346 assert_eq!(reply, None); 347 } 348 349 #[test] 350 fn random_on_end() { 351 let mut brain = Brain::default(); 352 brain.ingest("world hello"); 353 354 let reply = brain.respond("hello", false, false, None); 355 assert!(reply.is_some()); 356 } 357 358 #[test] 359 fn long_chain() { 360 const LETTERS: &str = "abcdefghijklmnopqrstuvwxyz"; 361 let msg = LETTERS 362 .chars() 363 .map(|c| c.to_string()) 364 .intersperse(" ".to_string()) 365 .collect::<String>(); 366 let mut brain = Brain::default(); 367 brain.ingest(&msg); 368 let reply = brain.respond("a", false, false, None); 369 let expected = LETTERS 370 .chars() 371 .skip(1) 372 .take(21) 373 .map(|c| c.to_string()) 374 .intersperse(" ".to_string()) 375 .collect::<String>(); 376 assert_eq!(reply, Some(expected)); 377 } 378 379 #[test] 380 fn merge_brain() { 381 let mut brain1 = Brain::default(); 382 let mut brain2 = Brain::default(); 383 384 brain1.ingest("hello world"); 385 brain2.ingest("hello world"); 386 brain2.ingest("hello world"); 387 brain2.ingest("other word"); 388 389 brain1.merge_from(brain2); 390 391 let hello_edges = brain1 392 .0 393 .get(&Some("hello".to_string())) 394 .expect("Hello edges not created"); 395 assert_eq!( 396 hello_edges.0, 397 HashMap::from_iter([(Some("world".to_string()), 3)]) 398 ); 399 400 let new_edges = brain1 401 .0 402 .get(&Some("other".to_string())) 403 .expect("New edges not created"); 404 assert_eq!( 405 new_edges.0, 406 HashMap::from_iter([(Some("word".to_string()), 1)]) 407 ); 408 } 409 410 #[bench] 411 fn bench_learn(b: &mut Bencher) { 412 b.iter(|| { 413 let mut brain = Brain::default(); 414 brain.ingest( 415 "your name is bingus the discord bot and this message is a test for benchmarking", 416 ); 417 }); 418 } 419 420 #[bench] 421 fn bench_respond(b: &mut Bencher) { 422 let mut brain = Brain::default(); 423 brain.ingest( 424 "your name is bingus the discord bot and this message is a test for benchmarking", 425 ); 426 b.iter(|| { 427 brain.respond("your", false, true, None); 428 }); 429 } 430 431 include!("lorem.rs"); 432 433 #[bench] 434 fn bench_learn_large(b: &mut Bencher) { 435 b.iter(|| { 436 let mut brain = Brain::default(); 437 brain.ingest(LOREM); 438 }); 439 } 440}