use embassy_futures::select::select; use embassy_net::tcp::{TcpReader, TcpSocket, TcpWriter}; use embassy_sync::{ blocking_mutex::raw::NoopRawMutex, lazy_lock::LazyLock, mutex::Mutex, once_lock::OnceLock, }; use sachy_fmt::unwrap; use snow::{Builder, Keypair, TransportState, params::NoiseParams}; use crate::{ constants::NOISE_PSK, errors::PicoError, rpc::RpcServer, updates::UpdateConnection, utils::try_buffer, }; pub static NOISE_PROTO: &str = "Noise_XXpsk3_25519_ChaChaPoly_BLAKE2s"; static PARAMS: LazyLock = LazyLock::new(|| unwrap!(NOISE_PROTO.parse(), "Unable to parse Noise proto schema")); static LOCAL_PRIVATE_KEY: OnceLock = OnceLock::new(); pub struct NoiseSession { transport: Mutex, } impl NoiseSession { pub async fn initialize<'device>(tcp: &mut TcpSocket<'device>) -> Result { let state = noise_handshake(tcp).await?; Ok(Self { transport: Mutex::new(state), }) } pub async fn run<'device>(self, tcp: &mut TcpSocket<'device>) { let (reader, writer) = tcp.split(); select(self.read_loop(reader), self.write_loop(writer)).await; } async fn read_loop<'device>(&self, mut reader: TcpReader<'device>) { let mut buffer = unwrap!(try_buffer(8192)); let (packet_buf, msg_buf) = buffer.split_at_mut(4096); loop { let Ok(received) = noise_recv(&mut reader, packet_buf).await else { break; }; if let Ok(msg) = self.transport.lock().await.read_message(received, msg_buf) && let Ok(req) = striker_proto::receive_request(&mut msg_buf[..msg]) { let Some((resp, resp_tx)) = RpcServer::handle_request(req) .await .zip(UpdateConnection::can_update()) else { break; }; resp_tx.try_send(resp).ok(); } } } async fn write_loop<'device>(&self, mut writer: TcpWriter<'device>) { let outgoing = UpdateConnection::get_receiver(); let mut buffer = unwrap!(try_buffer(8192)); let (msg_buf, enc_buf) = buffer.split_at_mut(4096); loop { let data = outgoing.receive().await; let packet = unwrap!(striker_proto::send_response(data, msg_buf)); let written = unwrap!( self.transport.lock().await.write_message(packet, enc_buf), "Payload too big" ); if noise_send(&mut writer, &enc_buf[..written]).await.is_err() { break; } if writer.flush().await.is_err() { break; } } } } async fn noise_handshake<'device>( tcp: &mut TcpSocket<'device>, ) -> Result { let builder = Builder::new(PARAMS.get().clone()); let static_key = LOCAL_PRIVATE_KEY .get_or_init(|| unwrap!(builder.generate_keypair(), "Failed to generate key pair")); let mut noise = builder .local_private_key(&static_key.private)? .psk(3, &NOISE_PSK)? .build_responder()?; let (mut reader, mut writer) = tcp.split(); let mut buffer = try_buffer(4096)?; let (payload, packet) = buffer.split_at_mut(2048); noise.read_message(noise_recv(&mut reader, packet).await?, payload)?; let len = noise.write_message(&[], payload)?; noise_send(&mut writer, &payload[..len]).await?; noise.read_message(noise_recv(&mut reader, packet).await?, payload)?; let transport = noise.into_transport_mode()?; Ok(transport) } /// Hyper-basic stream transport receiver. 16-bit BE size followed by payload. async fn noise_recv<'device, 'buffer>( stream: &mut TcpReader<'device>, packet: &'buffer mut [u8], ) -> Result<&'buffer [u8], PicoError> { loop { if let Some(written) = stream .read_with(|buf| { buf.split_at_checked(2).map_or((0, None), |(size, rest)| { let mut msg_len_buf = [0u8; 2]; msg_len_buf.copy_from_slice(size); let buf_size = usize::from(u16::from_be_bytes(msg_len_buf)); packet[..buf_size].copy_from_slice(&rest[..buf_size]); (2 + buf_size, Some(buf_size)) }) }) .await? { return Ok(&packet[..written]); } } } async fn noise_send<'device>( stream: &mut TcpWriter<'device>, payload: &[u8], ) -> Result<(), PicoError> { let len = u16::try_from(payload.len())?; while !stream .write_with(|buf| { buf.split_at_mut_checked(2) .map_or((0, false), |(msg_size, rest)| { msg_size.copy_from_slice(&len.to_be_bytes()); rest[..payload.len()].copy_from_slice(payload); (2 + payload.len(), true) }) }) .await? {} stream.flush().await?; Ok(()) }