The world's most clever kitty cat
1#![allow(unused)]
2
3use std::collections::HashMap;
4
5use log::debug;
6use serde::{Deserialize, Serialize};
7use tokio::sync::oneshot;
8
9/// Some = Word, None = End Message
10pub type Token = Option<String>;
11pub type Weight = u16;
12
13#[derive(Default, Debug, Clone, Serialize, Deserialize)]
14pub struct Edges(HashMap<Token, Weight>, u64);
15
16#[derive(Default, Debug, Clone, Serialize, Deserialize)]
17pub struct Brain(HashMap<Token, Edges>);
18
19pub type TypingSender = oneshot::Sender<bool>;
20
21pub fn format_token(tok: &Token) -> String {
22 if let Some(w) = tok {
23 w.clone()
24 } else {
25 "~END".to_string()
26 }
27}
28
29impl Edges {
30 fn increment_token(&mut self, tok: &Token) {
31 if let Some(w) = self.0.get_mut(tok) {
32 *w = w.saturating_add(1);
33 } else {
34 self.0.insert(tok.clone(), 1);
35 }
36 self.1 = self.1.saturating_add(1);
37 }
38
39 fn merge_from(&mut self, other: Self) {
40 self.0.reserve(other.0.len());
41 for (k, v) in other.0.into_iter() {
42 if let Some(w) = self.0.get_mut(&k) {
43 *w = w.saturating_add(v);
44 } else {
45 self.0.insert(k, v);
46 }
47 self.1 = self.1.saturating_add(v as u64);
48 }
49 }
50
51 fn sample(&self, rand: &mut fastrand::Rng, allow_end: bool) -> Option<&Token> {
52 let total_dist = if !allow_end && let Some(weight) = self.0.get(&None) {
53 self.1 - *weight as u64
54 } else {
55 self.1
56 };
57 let mut dist_left = rand.f64() * total_dist as f64;
58
59 for (tok, weight) in self.0.iter().filter(|(tok, _)| allow_end || tok.is_some()) {
60 dist_left -= *weight as f64;
61 if dist_left < 0.0 {
62 return Some(tok);
63 }
64 }
65 None
66 }
67
68 pub fn iter_weights(&self) -> impl Iterator<Item = (&Token, Weight, f64)> {
69 self.0
70 .iter()
71 .map(|(k, v)| (k, *v, (*v as f64) / (self.1 as f64)))
72 }
73}
74
75const FORCE_REPLIES: bool = cfg!(test) || (option_env!("BINGUS_FORCE_REPLY").is_some());
76
77impl Brain {
78 fn normalize_token(word: &str) -> Token {
79 let w = if word.starts_with("http://") || word.starts_with("https://") {
80 word.to_string()
81 } else {
82 word.to_ascii_lowercase()
83 };
84 Some(w)
85 }
86
87 fn parse(msg: &str) -> impl Iterator<Item = Token> {
88 msg.split_whitespace()
89 .filter_map(|w| {
90 // Filter out pings, they can get annoying
91 if w.starts_with("<@") && w.ends_with(">") {
92 None
93 } else {
94 Some(Self::normalize_token(w))
95 }
96 })
97 .chain(std::iter::once(None))
98 }
99
100 fn should_reply(rand: &mut fastrand::Rng, is_self: bool) -> bool {
101 let chance = if is_self { 45 } else { 80 };
102 let roll = rand.u8(0..=100);
103
104 (FORCE_REPLIES) || roll <= chance
105 }
106
107 fn extract_final_word(msg: &str) -> Option<String> {
108 msg.split_whitespace()
109 .last()
110 .and_then(Self::normalize_token)
111 }
112
113 fn random_token(&self, rand: &mut fastrand::Rng) -> Option<&Token> {
114 let len = self.0.len();
115 if len == 0 {
116 None
117 } else {
118 let i = rand.usize(..len);
119 self.0.keys().nth(i)
120 }
121 }
122
123 pub fn ingest(&mut self, msg: &str) -> bool {
124 // Using reduce instead of .any here to prevent short circuting
125 Self::parse(msg)
126 .map_windows(|[from, to]| {
127 if let Some(edge) = self.0.get_mut(from) {
128 edge.increment_token(to);
129 false
130 } else {
131 let new = Edges(HashMap::from_iter([(to.clone(), 1)]), 1);
132 self.0.insert(from.clone(), new);
133 true
134 }
135 })
136 .reduce(|acc, c| acc || c)
137 .unwrap_or_default()
138 }
139
140 pub fn merge_from(&mut self, other: Self) {
141 for (k, v) in other.0.into_iter() {
142 if let Some(edges) = self.0.get_mut(&k) {
143 edges.merge_from(v);
144 } else {
145 self.0.insert(k, v);
146 }
147 }
148 }
149
150 fn next_from(&self, tok: &Token, rand: &mut fastrand::Rng, allow_end: bool) -> Option<&Token> {
151 // Get the edges for the current token
152 // If we have that token, sample its edges
153 // Otherwise, if we don't know that token, and allow_end is false, try to pick a random token instead
154 self.0
155 .get(tok)
156 .and_then(|edges| edges.sample(rand, allow_end))
157 .or_else(|| {
158 if allow_end {
159 None
160 } else {
161 self.random_token(rand)
162 }
163 })
164 }
165
166 pub fn respond(
167 &self,
168 msg: &str,
169 is_self: bool,
170 force_reply: bool,
171 mut typing_oneshot: Option<TypingSender>,
172 ) -> Option<String> {
173 const MAX_TOKENS: usize = 20;
174
175 let mut rng = fastrand::Rng::new();
176
177 // Roll if we should reply
178 if !force_reply && !Self::should_reply(&mut rng, is_self) {
179 debug!("Failed roll");
180 return None;
181 }
182
183 // Get the final token
184 let last_token = Self::extract_final_word(msg);
185
186 let mut current_token = if let Some(t) = last_token {
187 // We found a word at the end of the previous message
188 &Some(t)
189 } else {
190 // We couldn't find a word at the end of the last message, pick a random one
191 // If we *still* don't have a token, return early
192 self.random_token(&mut rng)?
193 };
194
195 let mut chain = Vec::with_capacity(MAX_TOKENS);
196 let sep = String::from(" ");
197
198 while let Some(next @ Some(s)) = self.next_from(current_token, &mut rng, !chain.is_empty())
199 && chain.len() <= MAX_TOKENS
200 {
201 chain.push(s);
202 if let Some(typ) = typing_oneshot.take() {
203 typ.send(true).ok();
204 }
205 current_token = next;
206 }
207
208 if let Some(typ) = typing_oneshot.take() {
209 typ.send(false).ok();
210 }
211
212 if chain.is_empty() {
213 None
214 } else {
215 let s = chain
216 .into_iter()
217 .intersperse(&sep)
218 .cloned()
219 .collect::<String>();
220 Some(s).filter(|s| !s.trim().is_empty())
221 }
222 }
223
224 pub fn word_count(&self) -> usize {
225 self.0.len()
226 }
227
228 pub fn get_weights(&self, tok: &str) -> Option<&Edges> {
229 self.0.get(&Self::normalize_token(tok))
230 }
231
232 fn legacy_token_format(tok: &Token) -> String {
233 tok.as_ref()
234 .map(|s| format!("W-{s}"))
235 .unwrap_or_else(|| String::from("E--"))
236 }
237
238 pub fn as_legacy_hashmap(&self) -> HashMap<String, HashMap<String, Weight>> {
239 self.0
240 .iter()
241 .map(|(k, v)| {
242 let map =
243 v.0.iter()
244 .map(|(t, w)| (Self::legacy_token_format(t), *w))
245 .collect();
246 (Self::legacy_token_format(k), map)
247 })
248 .collect()
249 }
250
251 fn read_legacy_token(s: String) -> Token {
252 match s.as_str() {
253 "E--" => None,
254 word => Some(word.strip_prefix("W-").unwrap_or(word).to_string()),
255 }
256 }
257
258 pub fn from_legacy_hashmap(map: HashMap<String, HashMap<String, Weight>>) -> Self {
259 Self(
260 map.into_iter()
261 .map(|(k, v)| {
262 let sum = v.values().map(|w| *w as u64).sum::<u64>();
263 let edges = Edges(
264 v.into_iter()
265 .map(|(t, w)| (Self::read_legacy_token(t), w))
266 .collect(),
267 sum,
268 );
269 (Self::read_legacy_token(k), edges)
270 })
271 .collect(),
272 )
273 }
274}
275
276#[cfg(test)]
277mod tests {
278
279 use super::*;
280 use std::default::Default;
281
282 extern crate test;
283
284 use test::Bencher;
285
286 #[test]
287 fn ingest_parse() {
288 let tokens = Brain::parse("Hello world").collect::<Vec<_>>();
289 assert_eq!(
290 tokens,
291 vec![Some("hello".to_string()), Some("world".to_string()), None]
292 );
293 }
294
295 #[test]
296 fn ingest_url() {
297 let tokens = Brain::parse("https://example.com/CAPS-PATH").collect::<Vec<_>>();
298 assert_eq!(
299 tokens,
300 vec![Some("https://example.com/CAPS-PATH".to_string()), None]
301 );
302 }
303
304 #[test]
305 fn ingest_ping() {
306 let tokens = Brain::parse("hi <@1234567>").collect::<Vec<_>>();
307 assert_eq!(tokens, vec![Some("hi".to_string()), None]);
308 }
309
310 #[test]
311 fn basic_chain() {
312 let mut brain = Brain::default();
313 brain.ingest("hello world");
314 let hello_edges = brain
315 .0
316 .get(&Some("hello".to_string()))
317 .expect("Hello edges not created");
318 assert_eq!(
319 hello_edges.0,
320 HashMap::from_iter([(Some("world".to_string()), 1)])
321 );
322 let reply = brain.respond("hello", false, false, None);
323 assert_eq!(reply, Some("world".to_string()));
324 }
325
326 #[test]
327 fn at_least_1_token() {
328 let mut brain = Brain::default();
329 brain.ingest("hello world");
330 for _ in 0..100 {
331 brain.ingest("hello");
332 }
333
334 for _ in 0..100 {
335 // I'm too lazy to mock lazyrand LOL!!
336 let reply = brain.respond("hello", false, false, None);
337 assert_eq!(reply, Some("world".to_string()));
338 }
339 }
340
341 #[test]
342 fn none_on_empty() {
343 let mut brain = Brain::default();
344
345 let reply = brain.respond("hello", false, false, None);
346 assert_eq!(reply, None);
347 }
348
349 #[test]
350 fn random_on_end() {
351 let mut brain = Brain::default();
352 brain.ingest("world hello");
353
354 let reply = brain.respond("hello", false, false, None);
355 assert!(reply.is_some());
356 }
357
358 #[test]
359 fn long_chain() {
360 const LETTERS: &str = "abcdefghijklmnopqrstuvwxyz";
361 let msg = LETTERS
362 .chars()
363 .map(|c| c.to_string())
364 .intersperse(" ".to_string())
365 .collect::<String>();
366 let mut brain = Brain::default();
367 brain.ingest(&msg);
368 let reply = brain.respond("a", false, false, None);
369 let expected = LETTERS
370 .chars()
371 .skip(1)
372 .take(21)
373 .map(|c| c.to_string())
374 .intersperse(" ".to_string())
375 .collect::<String>();
376 assert_eq!(reply, Some(expected));
377 }
378
379 #[test]
380 fn merge_brain() {
381 let mut brain1 = Brain::default();
382 let mut brain2 = Brain::default();
383
384 brain1.ingest("hello world");
385 brain2.ingest("hello world");
386 brain2.ingest("hello world");
387 brain2.ingest("other word");
388
389 brain1.merge_from(brain2);
390
391 let hello_edges = brain1
392 .0
393 .get(&Some("hello".to_string()))
394 .expect("Hello edges not created");
395 assert_eq!(
396 hello_edges.0,
397 HashMap::from_iter([(Some("world".to_string()), 3)])
398 );
399
400 let new_edges = brain1
401 .0
402 .get(&Some("other".to_string()))
403 .expect("New edges not created");
404 assert_eq!(
405 new_edges.0,
406 HashMap::from_iter([(Some("word".to_string()), 1)])
407 );
408 }
409
410 #[bench]
411 fn bench_learn(b: &mut Bencher) {
412 b.iter(|| {
413 let mut brain = Brain::default();
414 brain.ingest(
415 "your name is bingus the discord bot and this message is a test for benchmarking",
416 );
417 });
418 }
419
420 #[bench]
421 fn bench_respond(b: &mut Bencher) {
422 let mut brain = Brain::default();
423 brain.ingest(
424 "your name is bingus the discord bot and this message is a test for benchmarking",
425 );
426 b.iter(|| {
427 brain.respond("your", false, true, None);
428 });
429 }
430
431 include!("lorem.rs");
432
433 #[bench]
434 fn bench_learn_large(b: &mut Bencher) {
435 b.iter(|| {
436 let mut brain = Brain::default();
437 brain.ingest(LOREM);
438 });
439 }
440}