···1717pub struct Brain(HashMap<Token, Edges>);
18181919pub type TypingSender = oneshot::Sender<bool>;
2020-pub type TypingReceiver = oneshot::Receiver<bool>;
21202221pub fn format_token(tok: &Token) -> String {
2322 if let Some(w) = tok {
···4948 }
5049 }
51505252- fn sample(&self, rand: &mut fastrand::Rng) -> Option<Token> {
5353- let mut dist_left = rand.f64() * self.1 as f64;
5454- for (tok, weight) in self.0.iter() {
5151+ fn sample(&self, rand: &mut fastrand::Rng, allow_end: bool) -> Option<&Token> {
5252+ let total_dist = if !allow_end && let Some(weight) = self.0.get(&None) {
5353+ self.1 - *weight as u64
5454+ } else {
5555+ self.1
5656+ };
5757+ let mut dist_left = rand.f64() * total_dist as f64;
5858+5959+ for (tok, weight) in self.0.iter().filter(|(tok, _)| allow_end || tok.is_some()) {
5560 dist_left -= *weight as f64;
5661 if dist_left < 0.0 {
5757- return Some(tok.clone());
6262+ return Some(tok);
5863 }
5964 }
6065 None
···161166162167 // Get our final token, or a random one if the message has nothing, or don't reply at all
163168 // if we have no tokens at all.
164164- let mut current_token =
165165- Self::extract_final_token(msg).or_else(|| self.random_token(&mut rng))?;
169169+ let last_token = Self::extract_final_token(msg).or_else(|| self.random_token(&mut rng))?;
170170+ let mut current_token = &last_token;
166171167172 let mut chain = Vec::with_capacity(MAX_TOKENS);
168173 let mut has_triggered_typing = false;
169174170170- while let Some(tok) = current_token
171171- && chain.len() <= MAX_TOKENS
172172- {
173173- if let Some(edges) = self.0.get(&Some(tok)) {
174174- let next = edges.sample(&mut rng).flatten();
175175- if let Some(ref s) = next {
176176- chain.push(s.clone());
177177- if !has_triggered_typing && let Some(typ) = typing_oneshot.take() {
178178- typ.send(true).ok();
175175+ while current_token.is_some() && chain.len() <= MAX_TOKENS {
176176+ if let Some(edges) = self.0.get(current_token) {
177177+ let next = edges.sample(&mut rng, chain.len() > 2);
178178+179179+ if let Some(ref tok) = next {
180180+ if let Some(s) = tok {
181181+ // Is this a non-ending token? If so, push it to our chain!
182182+ chain.push(s.clone());
183183+ if !has_triggered_typing && let Some(typ) = typing_oneshot.take() {
184184+ typ.send(true).ok();
185185+ }
186186+ current_token = tok;
187187+ } else {
188188+ // If we reached an end token, stop chaining
189189+ break;
179190 }
191191+ } else {
192192+ // If we failed to sample any tokens, we can't continue the chain
193193+ break;
180194 }
181181- current_token = next;
182195 } else {
183183- current_token = None;
196196+ // If we don't know the current word, we can't continue the chain
197197+ break;
184198 }
185199 }
186200···243257 );
244258 let reply = brain.respond("hello", false, None);
245259 assert_eq!(reply, Some("world".to_string()));
260260+ }
261261+262262+ #[test]
263263+ fn at_least_2_tokens() {
264264+ let mut brain = Brain::default();
265265+ brain.ingest("hello world");
266266+ brain.ingest("hello");
267267+ brain.ingest("hello");
268268+ brain.ingest("hello");
269269+270270+ for _ in 0..100 {
271271+ // I'm too lazy to mock lazyrand LOL!!
272272+ let reply = brain.respond("hello", false, None);
273273+ assert_eq!(reply, Some("world".to_string()));
274274+ }
246275 }
247276248277 #[test]