A pit full of rusty nails
at main 69 lines 2.1 kB view raw
1//! [`TokenWeights`] are representations of how common [`Token`]s are, and are paired up with 2//! a [`TokenPair`](crate::token::TokenPair) in a [`NailKov`](crate::NailKov). 3 4use indexmap::IndexMap; 5use rand::Rng; 6use rand_distr::{Distribution, weighted::WeightedAliasIndex}; 7 8use crate::{TokenHasher, error::NailError, token::Token}; 9 10/// A distribution of choices and their likelihood. 11#[derive(Clone, Debug)] 12pub struct TokenWeights { 13 /// Mappings of choice indexes to their likelihood. 14 dist: WeightedAliasIndex<u32>, 15 /// The actual choices 16 choices: Box<[Token]>, 17} 18 19impl Distribution<Token> for TokenWeights { 20 #[inline(always)] 21 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Token { 22 // SAFETY: The sampled index from `dist` will always correspond to a valid 23 // token in the `choices` slice. 24 unsafe { *self.choices.get_unchecked(self.dist.sample(rng)) } 25 } 26} 27 28/// Builder for [`TokenWeights`]. 29#[derive(Clone, Debug)] 30pub struct TokenWeightsBuilder { 31 /// Counts how many times a token is likely to appear. 32 occurrences: IndexMap<Token, u32, TokenHasher>, 33} 34 35impl TokenWeightsBuilder { 36 pub fn new() -> Self { 37 Self { 38 occurrences: IndexMap::with_hasher(Default::default()), 39 } 40 } 41 42 /// Creates a weighted distribution for the likelihood of tokens to appear. 43 pub fn build(self) -> Result<TokenWeights, NailError> { 44 let (choices, counts): (Vec<_>, Vec<_>) = self.occurrences.into_iter().unzip(); 45 46 if choices.is_empty() { 47 return Err(NailError::EmptyInput); 48 } 49 50 Ok(TokenWeights { 51 dist: WeightedAliasIndex::new(counts)?, 52 choices: choices.into(), 53 }) 54 } 55 56 /// Count an occurrence of this token, or add it if it hasn't been seen before. 57 pub fn add(&mut self, token: Token) { 58 self.occurrences 59 .entry(token) 60 .and_modify(|count| *count += 1) 61 .or_insert(1); 62 } 63} 64 65impl Default for TokenWeightsBuilder { 66 fn default() -> Self { 67 Self::new() 68 } 69}