A pit full of rusty nails
1//! [`TokenWeights`] are representations of how common [`Token`]s are, and are paired up with
2//! a [`TokenPair`](crate::token::TokenPair) in a [`NailKov`](crate::NailKov).
3
4use indexmap::IndexMap;
5use rand::Rng;
6use rand_distr::{Distribution, weighted::WeightedAliasIndex};
7
8use crate::{TokenHasher, error::NailError, token::Token};
9
10/// A distribution of choices and their likelihood.
11#[derive(Clone, Debug)]
12pub struct TokenWeights {
13 /// Mappings of choice indexes to their likelihood.
14 dist: WeightedAliasIndex<u32>,
15 /// The actual choices
16 choices: Box<[Token]>,
17}
18
19impl Distribution<Token> for TokenWeights {
20 #[inline(always)]
21 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Token {
22 // SAFETY: The sampled index from `dist` will always correspond to a valid
23 // token in the `choices` slice.
24 unsafe { *self.choices.get_unchecked(self.dist.sample(rng)) }
25 }
26}
27
28/// Builder for [`TokenWeights`].
29#[derive(Clone, Debug)]
30pub struct TokenWeightsBuilder {
31 /// Counts how many times a token is likely to appear.
32 occurrences: IndexMap<Token, u32, TokenHasher>,
33}
34
35impl TokenWeightsBuilder {
36 pub fn new() -> Self {
37 Self {
38 occurrences: IndexMap::with_hasher(Default::default()),
39 }
40 }
41
42 /// Creates a weighted distribution for the likelihood of tokens to appear.
43 pub fn build(self) -> Result<TokenWeights, NailError> {
44 let (choices, counts): (Vec<_>, Vec<_>) = self.occurrences.into_iter().unzip();
45
46 if choices.is_empty() {
47 return Err(NailError::EmptyInput);
48 }
49
50 Ok(TokenWeights {
51 dist: WeightedAliasIndex::new(counts)?,
52 choices: choices.into(),
53 })
54 }
55
56 /// Count an occurrence of this token, or add it if it hasn't been seen before.
57 pub fn add(&mut self, token: Token) {
58 self.occurrences
59 .entry(token)
60 .and_modify(|count| *count += 1)
61 .or_insert(1);
62 }
63}
64
65impl Default for TokenWeightsBuilder {
66 fn default() -> Self {
67 Self::new()
68 }
69}