The world's most clever kitty cat
at main 281 lines 8.6 kB view raw
1#![feature(iter_map_windows)] 2#![feature(iter_intersperse)] 3 4mod brain; 5mod cmd; 6mod on_message; 7mod status; 8 9pub mod prelude { 10 pub use anyhow::Context; 11 use std::result::Result as StdResult; 12 pub type Result<T = (), E = anyhow::Error> = StdResult<T, E>; 13} 14 15use std::{ 16 collections::HashSet, 17 fs::File, 18 path::{Path, PathBuf}, 19 sync::{ 20 Arc, 21 atomic::{AtomicBool, Ordering}, 22 }, 23}; 24 25use brotli::enc::{BrotliEncoderParams, backward_references::BrotliEncoderMode}; 26use log::{debug, error, info, warn}; 27use prelude::*; 28use tokio::{ 29 signal::unix::{SignalKind, signal}, 30 sync::RwLock, 31 time::{self, Duration}, 32}; 33use twilight_gateway::{ 34 CloseFrame, Event, EventTypeFlags, Intents, MessageSender, Shard, ShardId, StreamExt, 35}; 36use twilight_http::Client as HttpClient; 37use twilight_model::{ 38 application::interaction::InteractionData, 39 id::{ 40 Id, 41 marker::{ApplicationMarker, ChannelMarker, UserMarker}, 42 }, 43}; 44 45use crate::{ 46 brain::Brain, 47 cmd::{handle_app_command, register_all_commands}, 48 on_message::handle_discord_message, 49 status::update_status, 50}; 51 52pub type BrainHandle = RwLock<Brain>; 53 54#[derive(Debug)] 55pub struct BotContext { 56 http: HttpClient, 57 self_id: Id<UserMarker>, 58 app_id: Id<ApplicationMarker>, 59 owners: HashSet<Id<UserMarker>>, 60 brain_file_path: PathBuf, 61 reply_channels: HashSet<Id<ChannelMarker>>, 62 brain_handle: BrainHandle, 63 shard_sender: MessageSender, 64 pending_save: AtomicBool, 65} 66 67async fn handle_discord_event(event: Event, ctx: Arc<BotContext>) -> Result { 68 match event { 69 Event::MessageCreate(msg) => handle_discord_message(msg, ctx) 70 .await 71 .context("While handling a new message"), 72 Event::InteractionCreate(mut inter) => { 73 if let Some(InteractionData::ApplicationCommand(data)) = 74 std::mem::take(&mut inter.0.data) 75 { 76 handle_app_command(*data, ctx, inter.0) 77 .await 78 .context("While handling an app command") 79 } else { 80 Ok(()) 81 } 82 } 83 Event::Ready(ev) => { 84 info!("Connected to gateway as {}", ev.user.name); 85 let brain = ctx.brain_handle.read().await; 86 update_status(&brain, &ctx.shard_sender).context("Failed to update status on ready") 87 } 88 _ => Ok(()), 89 } 90} 91 92const BROTLI_BUF_SIZE: usize = 1024 * 1000; 93fn get_brotli_params() -> BrotliEncoderParams { 94 BrotliEncoderParams { 95 quality: 5, 96 mode: BrotliEncoderMode::BROTLI_MODE_TEXT, 97 ..Default::default() 98 } 99} 100 101fn load_brain(path: &Path) -> Result<Option<Brain>> { 102 if path.exists() { 103 let mut file = File::open(path).context("Failed to open brain file")?; 104 let mut brotli_stream = brotli::Decompressor::new(&mut file, BROTLI_BUF_SIZE); 105 rmp_serde::from_read(&mut brotli_stream) 106 .map(Some) 107 .context("Failed to decode brain file") 108 } else { 109 Ok(None) 110 } 111} 112 113async fn save_brain(ctx: Arc<BotContext>) -> Result { 114 let scratch_path = ctx.brain_file_path.with_file_name(format!( 115 "~{}", 116 ctx.brain_file_path.file_name().unwrap().to_str().unwrap() 117 )); 118 let mut file = File::create(&scratch_path).context("Failed to open brain file")?; 119 let mut brotli_writer = 120 brotli::CompressorWriter::with_params(&mut file, BROTLI_BUF_SIZE, &get_brotli_params()); 121 122 let brain = ctx.brain_handle.read().await; 123 rmp_serde::encode::write(&mut brotli_writer, &*brain) 124 .context("Failed to write serialized brain")?; 125 126 std::fs::rename(&scratch_path, &ctx.brain_file_path) 127 .context("Failed to override scratch file")?; 128 129 debug!("Saved brain file"); 130 Ok(()) 131} 132 133#[tokio::main] 134async fn main() -> Result { 135 let mut clog = colog::default_builder(); 136 clog.filter( 137 None, 138 if cfg!(debug_assertions) { 139 log::LevelFilter::Debug 140 } else { 141 log::LevelFilter::Info 142 }, 143 ); 144 clog.try_init().context("Failed to initialize colog")?; 145 146 info!("Start of bingus-bot {}", env!("CARGO_PKG_VERSION")); 147 148 // Config 149 let token_file = std::env::var("TOKEN_FILE").context("Missing TOKEN_FILE env var")?; 150 let reply_channels = std::env::var("REPLY_CHANNELS") 151 .context("Missing REPLY_CHANNELS env var")? 152 .split(",") 153 .filter_map(|s| { 154 if s.trim().is_empty() { 155 None 156 } else { 157 Some(s.trim().parse::<u64>().map(|c| Id::new(c))) 158 } 159 }) 160 .collect::<Result<_, _>>() 161 .context("Invalid channel IDs for REPLY_CHANNELS")?; 162 let brain_file_path = 163 PathBuf::from(std::env::var("BRAIN_FILE").unwrap_or_else(|_| "brain.msgpackz".to_string())); 164 let intents = Intents::GUILD_MESSAGES | Intents::MESSAGE_CONTENT; 165 166 // Read token 167 let token = std::fs::read_to_string(token_file).context("Failed to read bot token")?; 168 let token = token.trim(); 169 170 // Read Brain 171 let brain = if let Some(brain) = load_brain(&brain_file_path)? { 172 info!("Loading brain from {brain_file_path:?}"); 173 brain 174 } else { 175 info!("Creating new brain file at {brain_file_path:?}"); 176 Brain::default() 177 }; 178 let brain_handle = RwLock::new(brain); 179 180 // Init 181 let mut shard = Shard::new(ShardId::ONE, token.to_string(), intents); 182 let http = HttpClient::new(token.to_string()); 183 184 let app = http 185 .current_user_application() 186 .await 187 .context("Failed to get current App")? 188 .model() 189 .await 190 .context("Failed to deserialize")?; 191 192 let app_id = app.id; 193 194 let self_id = app.bot.context("App is not a bot!")?.id; 195 196 let owners = if let Some(user) = app.owner { 197 HashSet::from_iter([user.id]) 198 } else if let Some(team) = app.team { 199 team.members.iter().map(|m| m.user.id).collect() 200 } else { 201 warn!("No Owner?? Bingus is free!!!"); 202 HashSet::new() 203 }; 204 205 let context = Arc::new(BotContext { 206 http, 207 self_id, 208 app_id, 209 owners, 210 reply_channels, 211 brain_file_path, 212 brain_handle, 213 shard_sender: shard.sender(), 214 pending_save: AtomicBool::new(false), 215 }); 216 217 info!("Registering Commands..."); 218 register_all_commands(context.clone()).await?; 219 220 let mut interval = time::interval(Duration::from_secs(60)); 221 interval.tick().await; 222 223 let mut sigterm = signal(SignalKind::terminate()).context("Failed to listen to SIGTERM")?; 224 225 info!("Connecting to gateway..."); 226 227 loop { 228 tokio::select! { 229 230 biased; 231 232 Ok(()) = tokio::signal::ctrl_c() => { 233 info!("SIGINT: Closing connection and saving"); 234 shard.close(CloseFrame::NORMAL); 235 } 236 _ = sigterm.recv() => { 237 info!("SIGTERM: Closing connection and saving"); 238 shard.close(CloseFrame::NORMAL); 239 } 240 _ = interval.tick() => { 241 debug!("Save Interval"); 242 if context.pending_save.load(Ordering::Relaxed) { 243 let ctx = context.clone(); 244 tokio::spawn(async move { 245 if let Err(why) = save_brain(ctx.clone()).await { 246 error!("Failed to save brain file:\n{why:?}"); 247 } 248 ctx.pending_save.store(false, Ordering::Relaxed); 249 }); 250 } 251 }, 252 opt = shard.next_event(EventTypeFlags::all()) => { 253 match opt { 254 Some(Ok(Event::GatewayClose(_))) | None => { 255 info!("Disconnected from Discord"); 256 break; 257 } 258 Some(Ok(event)) => { 259 let ctx = context.clone(); 260 tokio::spawn(async move { 261 if let Err(why) = handle_discord_event(event, ctx).await { 262 error!("Error while processing Discord event:\n{why:?}"); 263 } 264 }); 265 } 266 Some(Err(why)) => { 267 warn!("Failed to receive event:\n{why:?}"); 268 } 269 } 270 } 271 } 272 } 273 274 if context.pending_save.load(Ordering::Relaxed) { 275 save_brain(context) 276 .await 277 .context("Failed to write brain file on exit")?; 278 } 279 280 Ok(()) 281}