use core::ops::Range; use alloc::collections::BTreeMap; use crate::dns::traits::DnsSerialize; pub(crate) const MAX_STR_LEN: u8 = !PTR_MASK; pub(crate) const PTR_MASK: u8 = 0b1100_0000; #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum DnsError { LabelTooLong, InvalidTxt, Unsupported, } impl core::fmt::Display for DnsError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { Self::LabelTooLong => f.write_str("Encoding Error: Segment too long"), Self::InvalidTxt => f.write_str("Encoding Error: TXT segment is invalid"), Self::Unsupported => f.write_str("Encoding Error: Unsupported Record Type"), } } } impl core::error::Error for DnsError {} #[derive(Debug)] pub struct Encoder<'a, 'b> { output: &'b mut [u8], position: usize, lookup: BTreeMap<&'a str, u16>, reservation: Option, } impl<'a, 'b> Encoder<'a, 'b> { pub const fn new(buffer: &'b mut [u8]) -> Self { Self { output: buffer, position: 0, lookup: BTreeMap::new(), reservation: None, } } /// Takes a payload and encodes it, consuming the encoder and yielding the resulting /// slice. pub fn encode(mut self, payload: T) -> Result<&'b [u8], E> where E: core::error::Error, T: DnsSerialize<'a, Error = E>, { payload.serialize(&mut self)?; Ok(self.finish()) } pub(crate) fn finish(self) -> &'b [u8] { &self.output[..self.position] } fn increment(&mut self, amount: usize) { self.position += amount; } pub(crate) fn write_label(&mut self, mut label: &'a str) -> Result<(), DnsError> { loop { if let Some(pos) = self.get_label_position(label) { let [b1, b2] = u16::to_be_bytes(pos); self.write(&[b1 | PTR_MASK, b2]); return Ok(()); } let dot = label.find('.'); let end = dot.unwrap_or(label.len()); let segment = &label[..end]; let len = u8::try_from(segment.len()).map_err(|_| DnsError::LabelTooLong)?; if len > MAX_STR_LEN { return Err(DnsError::LabelTooLong); } self.store_label_position(label); self.write(&len.to_be_bytes()); self.write(segment.as_bytes()); match dot { Some(end) => { label = &label[end + 1..]; } None => { self.write(&[0]); return Ok(()); } } } } pub(crate) fn write(&mut self, bytes: &[u8]) { let len = bytes.len(); let end = self.position + len; self.output[self.position..end].copy_from_slice(bytes); self.increment(len); } fn get_label_position(&mut self, label: &str) -> Option { self.lookup.get(label).copied() } fn store_label_position(&mut self, label: &'a str) { self.lookup.insert(label, self.position as u16); } fn reserve_record_length(&mut self) { if self.reservation.is_none() { self.reservation = Some(self.position); self.increment(2); } } fn distance_from_reservation(&mut self) -> Option<(Range, u16)> { self.reservation .take() .map(|start| (start..(start + 2), (self.position - start - 2) as u16)) } fn write_record_length(&mut self) { if let Some((reservation, len)) = self.distance_from_reservation() { self.output[reservation].copy_from_slice(&len.to_be_bytes()); } } pub(crate) fn with_record_length(&mut self, encoding_scope: F) -> Result<(), E> where E: core::error::Error, F: FnOnce(&mut Self) -> Result<(), E>, { self.reserve_record_length(); encoding_scope(self)?; self.write_record_length(); Ok(()) } }