A pit full of rusty nails
1//! Crate defining a Markov Chain implementation, and a string interner for use
2//! with the markov chain.
3//!
4
5mod distribution;
6mod error;
7mod token;
8
9use core::hash::BuildHasherDefault;
10
11use crossbeam_utils::CachePadded;
12use error::NailError;
13use estr::IdentityHasher;
14use indexmap::IndexMap;
15use itertools::Itertools;
16use rand::{Rng, seq::IteratorRandom};
17use rand_distr::Distribution;
18
19use distribution::{TokenWeights, TokenWeightsBuilder};
20use token::{Token, TokenPair};
21use unicode_segmentation::UnicodeSegmentation;
22
23/// `nailkov` relies on `estr`'s precomputed hashes, so we avoid
24/// hashing ourselves and can just use the precomputed hashes instead.
25type TokenHasher = BuildHasherDefault<IdentityHasher>;
26
27#[derive(Clone, Debug)]
28pub struct NailKov {
29 chain: CachePadded<IndexMap<TokenPair, TokenWeights, TokenHasher>>,
30}
31
32pub struct NailKovIter<'a, R: Rng> {
33 rng: &'a mut R,
34 markov: &'a NailKov,
35 prev: TokenPair,
36}
37
38impl<R: Rng> Iterator for NailKovIter<'_, R> {
39 type Item = Token;
40
41 #[inline]
42 fn next(&mut self) -> Option<Self::Item> {
43 let dist = self.markov.chain.get(&self.prev)?;
44
45 let next_token = dist.sample(&mut self.rng);
46
47 self.prev = TokenPair::new(self.prev.right, next_token);
48
49 Some(next_token)
50 }
51}
52
53impl NailKov {
54 #[inline]
55 pub fn generate_tokens<'a, R: Rng>(&'a self, rng: &'a mut R) -> NailKovIter<'a, R> {
56 NailKovIter {
57 // A markov chain that was successfully built is never empty, so
58 // it will always return with a value, making unwrapping it safe to do.
59 prev: self.chain.keys().choose(rng).copied().unwrap(),
60 markov: self,
61 rng,
62 }
63 }
64}
65
66impl NailKov {
67 pub fn from_input(input: &str) -> Result<NailKov, NailError> {
68 NailBuilder::new(TokenHasher::new()).with_input(input)
69 }
70}
71
72struct NailBuilder {
73 chain: IndexMap<TokenPair, TokenWeightsBuilder, TokenHasher>,
74}
75
76impl NailBuilder {
77 fn new(hasher: TokenHasher) -> Self {
78 Self {
79 chain: IndexMap::with_hasher(hasher),
80 }
81 }
82
83 fn with_input(self, input: &str) -> Result<NailKov, NailError> {
84 self.feed_str(input)?.build()
85 }
86
87 fn build(self) -> Result<NailKov, NailError> {
88 if self.chain.is_empty() {
89 return Err(NailError::EmptyInput);
90 }
91
92 let chain: IndexMap<TokenPair, TokenWeights, TokenHasher> = self
93 .chain
94 .into_iter()
95 .flat_map(|(pair, dist)| {
96 dist.build()
97 .inspect_err(|err| tracing::error!("Weight error {pair:?}: {err}"))
98 .map(|build| (pair, build))
99 })
100 .collect();
101
102 if chain.is_empty() {
103 return Err(NailError::EmptyInput);
104 }
105
106 Ok(NailKov {
107 chain: CachePadded::new(chain),
108 })
109 }
110
111 /// Add the occurrence of `next` following `prev`.
112 fn add_token_pair(&mut self, prev: TokenPair, next: Token) {
113 match self.chain.get_mut(&prev) {
114 Some(builder) => {
115 builder.add(next);
116 }
117 None => {
118 let mut builder = TokenWeightsBuilder::new();
119 builder.add(next);
120 self.chain.insert(prev, builder);
121 }
122 }
123 }
124
125 fn feed_str(self, content: &str) -> Result<Self, NailError> {
126 self.feed_tokens(content.split_word_bounds().map(Token::from))
127 }
128
129 fn feed_tokens(mut self, tokens: impl Iterator<Item = Token>) -> Result<Self, NailError> {
130 let windows = tokens.tuple_windows();
131
132 if windows.size_hint().1.is_none() {
133 return Err(NailError::EmptyInput);
134 }
135
136 for (left, right, next) in windows {
137 self.add_token_pair(TokenPair::new(left, right), next);
138 }
139
140 Ok(self)
141 }
142}