💣 Machine learning which might blow up in your face 💣

Merge pull request #25 from HuwCampbell/topic/generative-adversarial

Topic/generative adversarial

authored by

Huw Campbell and committed by
GitHub
c3f0373f ac9b43ea

+200 -1
+19
grenade.cabal
··· 130 130 , MonadRandom 131 131 , vector 132 132 133 + executable gan-mnist 134 + ghc-options: -Wall -threaded -O2 135 + main-is: main/gan-mnist.hs 136 + build-depends: base 137 + , grenade 138 + , attoparsec 139 + , bytestring 140 + , cereal 141 + , either 142 + , optparse-applicative == 0.13.* 143 + , text == 1.2.* 144 + , mtl >= 2.2.1 && < 2.3 145 + , hmatrix >= 0.18 && < 0.19 146 + , transformers 147 + , semigroups 148 + , singletons 149 + , MonadRandom 150 + , vector 151 + 133 152 executable recurrent 134 153 ghc-options: -Wall -threaded -O2 135 154 main-is: main/recurrent.hs
+167
main/gan-mnist.hs
··· 1 + {-# LANGUAGE BangPatterns #-} 2 + {-# LANGUAGE DataKinds #-} 3 + {-# LANGUAGE GADTs #-} 4 + {-# LANGUAGE ScopedTypeVariables #-} 5 + {-# LANGUAGE TypeOperators #-} 6 + {-# LANGUAGE TupleSections #-} 7 + {-# LANGUAGE TypeFamilies #-} 8 + {-# LANGUAGE FlexibleContexts #-} 9 + 10 + -- This is a simple generative adversarial network to make pictures 11 + -- of numbers similar to those in MNIST. 12 + -- 13 + -- It demonstrates a different usage of the library, within a few hours 14 + -- was producing examples like this: 15 + -- 16 + -- --. 17 + -- .=-.--..#=### 18 + -- -##==#########. 19 + -- #############- 20 + -- -###-.=..-.-== 21 + -- ###- 22 + -- .###- 23 + -- .####...==-. 24 + -- -####=--.=##= 25 + -- -##=- -## 26 + -- =## 27 + -- -##= 28 + -- -###- 29 + -- .####. 30 + -- .#####. 31 + -- ...---=#####- 32 + -- .=#########. . 33 + -- .#######=. . 34 + -- . =-. 35 + -- 36 + -- It's a 5! 37 + -- 38 + import Control.Applicative 39 + import Control.Monad 40 + import Control.Monad.Random 41 + import Control.Monad.Trans.Except 42 + 43 + import qualified Data.Attoparsec.Text as A 44 + import qualified Data.ByteString as B 45 + import Data.List ( foldl' ) 46 + import Data.Semigroup ( (<>) ) 47 + import Data.Serialize 48 + import qualified Data.Text as T 49 + import qualified Data.Text.IO as T 50 + import qualified Data.Vector.Storable as V 51 + 52 + import qualified Numeric.LinearAlgebra.Static as SA 53 + import Numeric.LinearAlgebra.Data ( toLists ) 54 + 55 + import Options.Applicative 56 + 57 + import Grenade 58 + import Grenade.Utils.OneHot 59 + 60 + type Discriminator = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu, FullyConnected 256 80, Logit, FullyConnected 80 1, Logit] 61 + '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 1, 'D1 1] 62 + 63 + type Generator = Network '[ FullyConnected 100 10240, Relu, Reshape, Convolution 10 10 5 5 1 1, Relu, Convolution 10 1 1 1 1 1, Logit, Reshape] 64 + '[ 'D1 100, 'D1 10240, 'D1 10240, 'D3 32 32 10, 'D3 28 28 10, 'D3 28 28 10, 'D3 28 28 1, 'D3 28 28 1, 'D2 28 28 ] 65 + 66 + randomDiscriminator :: MonadRandom m => m Discriminator 67 + randomDiscriminator = randomNetwork 68 + 69 + randomGenerator :: MonadRandom m => m Generator 70 + randomGenerator = randomNetwork 71 + 72 + trainExample :: LearningParameters -> Discriminator -> Generator -> S ('D2 28 28) -> S ('D1 100) -> ( Discriminator, Generator ) 73 + trainExample rate discriminator generator realExample noiseSource 74 + = let (generatorTape, fakeExample) = runNetwork generator noiseSource 75 + 76 + (discriminatorTapeReal, guessReal) = runNetwork discriminator realExample 77 + (discriminatorTapeFake, guessFake) = runNetwork discriminator fakeExample 78 + 79 + (discriminator'real, _) = runGradient discriminator discriminatorTapeReal ( guessReal - 1 ) 80 + (discriminator'fake, push) = runGradient discriminator discriminatorTapeFake guessFake 81 + 82 + (generator', _) = runGradient generator generatorTape (-push) 83 + 84 + newDiscriminator = foldl' (applyUpdate rate { learningRegulariser = learningRegulariser rate * 10}) discriminator [ discriminator'real, discriminator'fake ] 85 + newGenerator = applyUpdate rate generator generator' 86 + in ( newDiscriminator, newGenerator ) 87 + 88 + 89 + ganTest :: (Discriminator, Generator) -> Int -> FilePath -> LearningParameters -> ExceptT String IO (Discriminator, Generator) 90 + ganTest (discriminator0, generator0) iterations trainFile rate = do 91 + trainData <- fmap fst <$> readMNIST trainFile 92 + 93 + lift $ foldM (runIteration trainData) ( discriminator0, generator0 ) [1..iterations] 94 + 95 + where 96 + 97 + showShape' :: S ('D2 a b) -> IO () 98 + showShape' (S2D mm) = putStrLn $ 99 + let m = SA.extract mm 100 + ms = toLists m 101 + render n' | n' <= 0.2 = ' ' 102 + | n' <= 0.4 = '.' 103 + | n' <= 0.6 = '-' 104 + | n' <= 0.8 = '=' 105 + | otherwise = '#' 106 + 107 + px = (fmap . fmap) render ms 108 + in unlines px 109 + 110 + runIteration :: [S ('D2 28 28)] -> (Discriminator, Generator) -> Int -> IO (Discriminator, Generator) 111 + runIteration trainData ( !discriminator, !generator ) _ = do 112 + trained' <- foldM ( \(!discriminatorX, !generatorX ) realExample -> do 113 + fakeExample <- randomOfShape 114 + return $ trainExample rate discriminatorX generatorX realExample fakeExample 115 + ) ( discriminator, generator ) trainData 116 + 117 + 118 + showShape' . snd . runNetwork (snd trained') =<< randomOfShape 119 + 120 + return trained' 121 + 122 + data GanOpts = GanOpts FilePath Int LearningParameters (Maybe FilePath) (Maybe FilePath) 123 + 124 + mnist' :: Parser GanOpts 125 + mnist' = GanOpts <$> argument str (metavar "TRAIN") 126 + <*> option auto (long "iterations" <> short 'i' <> value 15) 127 + <*> (LearningParameters 128 + <$> option auto (long "train_rate" <> short 'r' <> value 0.01) 129 + <*> option auto (long "momentum" <> value 0.9) 130 + <*> option auto (long "l2" <> value 0.0005) 131 + ) 132 + <*> optional (strOption (long "load")) 133 + <*> optional (strOption (long "save")) 134 + 135 + 136 + main :: IO () 137 + main = do 138 + GanOpts mnist iter rate load save <- execParser (info (mnist' <**> helper) idm) 139 + putStrLn "Training stupidly simply GAN" 140 + nets0 <- case load of 141 + Just loadFile -> netLoad loadFile 142 + Nothing -> (,) <$> randomDiscriminator <*> randomGenerator 143 + 144 + res <- runExceptT $ ganTest nets0 iter mnist rate 145 + case res of 146 + Right nets1 -> case save of 147 + Just saveFile -> B.writeFile saveFile $ runPut (put nets1) 148 + Nothing -> return () 149 + 150 + Left err -> putStrLn err 151 + 152 + readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))] 153 + readMNIST mnist = ExceptT $ do 154 + mnistdata <- T.readFile mnist 155 + return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata) 156 + 157 + parseMNIST :: A.Parser (S ('D2 28 28), S ('D1 10)) 158 + parseMNIST = do 159 + Just lab <- oneHot <$> A.decimal 160 + pixels <- many (A.char ',' >> A.double) 161 + image <- maybe (fail "Parsed row was of an incorrect size") pure (fromStorable . V.fromList $ pixels) 162 + return (image, lab) 163 + 164 + netLoad :: FilePath -> IO (Discriminator, Generator) 165 + netLoad modelPath = do 166 + modelData <- B.readFile modelPath 167 + either fail return $ runGet (get :: Get (Discriminator, Generator)) modelData
+14 -1
src/Grenade/Layers/Reshape.hs
··· 21 21 -- 22 22 -- Flattens input down to D1 from either 2D or 3D data. 23 23 -- 24 - -- Can also be used to turn a 3D image with only one channel into a 2D image. 24 + -- Casts input D1 up to either 2D or 3D data if the shapes are good. 25 + -- 26 + -- Can also be used to turn a 3D image with only one channel into a 2D image 27 + -- or vice versa. 25 28 data Reshape = Reshape 26 29 deriving Show 27 30 ··· 49 52 type Tape Reshape ('D2 x y) ('D3 x y z) = () 50 53 runForwards _ (S2D y) = ((), S3D y) 51 54 runBackwards _ _ (S3D y) = ((), S2D y) 55 + 56 + instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer Reshape ('D1 a) ('D2 x y) where 57 + type Tape Reshape ('D1 a) ('D2 x y) = () 58 + runForwards _ (S1D y) = ((), fromJust' . fromStorable . extract $ y) 59 + runBackwards _ _ (S2D y) = ((), fromJust' . fromStorable . flatten . extract $ y) 60 + 61 + instance (KnownNat a, KnownNat x, KnownNat y, KnownNat (x * z), KnownNat z, a ~ (x * y * z)) => Layer Reshape ('D1 a) ('D3 x y z) where 62 + type Tape Reshape ('D1 a) ('D3 x y z) = () 63 + runForwards _ (S1D y) = ((), fromJust' . fromStorable . extract $ y) 64 + runBackwards _ _ (S3D y) = ((), fromJust' . fromStorable . flatten . extract $ y) 52 65 53 66 instance Serialize Reshape where 54 67 put _ = return ()