The world's most clever kitty cat

Add brain persistence

bwc9876.dev f4bb01ad 9102acbe

verified
+51 -19
+1 -1
.gitignore
··· 4 4 dist/ 5 5 result 6 6 __pycache__ 7 - 7 + brain.msgpack 8 8 *.env
+28 -8
src/bingus/cogs/markov.py
··· 4 4 import discord 5 5 from discord.ext import commands 6 6 from discord.message import Message 7 + from pathlib import Path 7 8 from ..lib.markov import MarkovChain 8 9 from ..lib.permissions import require_owner 9 10 ··· 14 15 self.reply_channels = [ 15 16 int(x) for x in os.getenv("Markov.REPLY_CHANNELS", "0").split(",") 16 17 ] 17 - self.markov = MarkovChain({}) 18 + self.chain_file = Path(os.getenv("Markov.BRAIN_FILE", "brain.msgpack")) 19 + if self.chain_file.is_file(): 20 + print(f"Attempting load from {self.chain_file}...") 21 + try: 22 + self.markov = MarkovChain.load_from_file(self.chain_file) 23 + print("Load Complete") 24 + except Exception as E: 25 + print(f"Error while loading\n{E}") 26 + else: 27 + self.markov = MarkovChain({}) 18 28 19 29 async def update_words(self): 20 30 amount = len(self.markov.edges.keys()) 31 + try: 32 + self.markov.save_to_file(self.chain_file) 33 + except Exception as E: 34 + print(f"Error while saving\n{E}") 35 + 21 36 await self.bot.change_presence( 22 37 activity=discord.CustomActivity(name=f"I know {amount} words!") 23 38 ) ··· 25 40 @require_owner 26 41 @commands.slash_command() 27 42 async def dump_chain(self, ctx: discord.ApplicationContext): 28 - o = self.markov.dump() 29 - fd = io.BytesIO(o.encode()) 30 - await ctx.respond(ephemeral=True, file=discord.File(fd, filename="brain.json")) 43 + o = self.markov.dumpb() 44 + fd = io.BytesIO(o) 45 + await ctx.respond( 46 + ephemeral=True, file=discord.File(fd, filename="brain.msgpack") 47 + ) 31 48 32 49 @require_owner 33 50 @commands.slash_command() 34 51 async def load_chain( 35 52 self, ctx: discord.ApplicationContext, raw: discord.Option(discord.Attachment) 36 53 ): 37 - j = (await raw.read()).decode("utf-8") 38 - new = MarkovChain.load(j) 54 + new = MarkovChain.loadb(await raw.read()) 39 55 self.markov.merge(new) 40 56 await ctx.respond("Imported", ephemeral=True) 41 57 await self.update_words() ··· 43 59 @require_owner 44 60 @commands.slash_command() 45 61 async def scan_history(self, ctx: discord.ApplicationContext): 46 - await ctx.defer() 62 + await ctx.defer(ephemeral=True) 47 63 async for msg in ctx.history(limit=None): 48 64 if msg.author.id != self.bot.application_id: 49 65 self.markov.learn(msg.content) 50 - await ctx.respond("> Bingus Learned!") 66 + await ctx.respond("> Bingus Learned!", ephemeral=True) 51 67 await self.update_words() 52 68 53 69 @commands.slash_command() ··· 77 93 await ctx.respond(head, file=discord.File(fd, filename="weights.txt")) 78 94 else: 79 95 await ctx.respond(f"{head}:\n{msg}") 96 + 97 + @commands.Cog.listener() 98 + async def on_ready(self): 99 + await self.update_words() 80 100 81 101 @commands.Cog.listener() 82 102 async def on_message(self, msg: Message):
+22 -10
src/bingus/lib/markov.py
··· 1 1 from dataclasses import dataclass 2 2 import random 3 - import json 4 3 from typing import Optional 4 + from pathlib import Path 5 + from msgpack import packb, unpackb 5 6 6 7 7 8 @dataclass ··· 190 191 else: 191 192 return self._chain(None, max_length=max_length) 192 193 193 - def dump(self) -> str: 194 - return json.dumps( 195 - { 196 - token_ser(e): {token_ser(k): v for k, v in w.to_tokens.items()} 197 - for e, w in self.edges.items() 198 - } 199 - ) 194 + def save_to_file(self, path: Path): 195 + if not path.parent.exists(): 196 + path.parent.mkdir(parents=True) 197 + path.write_bytes(self.dumpb()) 198 + 199 + def load_from_file(path: Path): 200 + return MarkovChain.loadb(path.read_bytes()) 201 + 202 + def dumpb(self): 203 + return packb(self.ser()) 204 + 205 + def loadb(dat): 206 + return MarkovChain.deser(unpackb(dat)) 207 + 208 + def ser(self): 209 + return { 210 + token_ser(e): {token_ser(k): v for k, v in w.to_tokens.items()} 211 + for e, w in self.edges.items() 212 + } 200 213 201 - def load(source: str): 202 - dat = json.loads(source) 214 + def deser(dat): 203 215 edges = { 204 216 token_de(e): StateTransitions({token_de(k): v for k, v in w.items()}) 205 217 for e, w in dat.items()