···11+{-# LANGUAGE BangPatterns #-}
22+{-# LANGUAGE DataKinds #-}
33+{-# LANGUAGE GADTs #-}
44+{-# LANGUAGE ScopedTypeVariables #-}
55+{-# LANGUAGE TypeOperators #-}
66+{-# LANGUAGE TupleSections #-}
77+{-# LANGUAGE TypeFamilies #-}
88+{-# LANGUAGE FlexibleContexts #-}
99+1010+-- This is a simple generative adversarial network to make pictures
1111+-- of numbers similar to those in MNIST.
1212+--
1313+-- It demonstrates a different usage of the library, within a few hours
1414+-- was producing examples like this:
1515+--
1616+-- --.
1717+-- .=-.--..#=###
1818+-- -##==#########.
1919+-- #############-
2020+-- -###-.=..-.-==
2121+-- ###-
2222+-- .###-
2323+-- .####...==-.
2424+-- -####=--.=##=
2525+-- -##=- -##
2626+-- =##
2727+-- -##=
2828+-- -###-
2929+-- .####.
3030+-- .#####.
3131+-- ...---=#####-
3232+-- .=#########. .
3333+-- .#######=. .
3434+-- . =-.
3535+--
3636+-- It's a 5!
3737+--
3838+import Control.Applicative
3939+import Control.Monad
4040+import Control.Monad.Random
4141+import Control.Monad.Trans.Except
4242+4343+import qualified Data.Attoparsec.Text as A
4444+import qualified Data.ByteString as B
4545+import Data.List ( foldl' )
4646+import Data.Semigroup ( (<>) )
4747+import Data.Serialize
4848+import qualified Data.Text as T
4949+import qualified Data.Text.IO as T
5050+import qualified Data.Vector.Storable as V
5151+5252+import qualified Numeric.LinearAlgebra.Static as SA
5353+import Numeric.LinearAlgebra.Data ( toLists )
5454+5555+import Options.Applicative
5656+5757+import Grenade
5858+import Grenade.Utils.OneHot
5959+6060+type Discriminator = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu, FullyConnected 256 80, Logit, FullyConnected 80 1, Logit]
6161+ '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 1, 'D1 1]
6262+6363+type Generator = Network '[ FullyConnected 100 10240, Relu, Reshape, Convolution 10 10 5 5 1 1, Relu, Convolution 10 1 1 1 1 1, Logit, Reshape]
6464+ '[ 'D1 100, 'D1 10240, 'D1 10240, 'D3 32 32 10, 'D3 28 28 10, 'D3 28 28 10, 'D3 28 28 1, 'D3 28 28 1, 'D2 28 28 ]
6565+6666+randomDiscriminator :: MonadRandom m => m Discriminator
6767+randomDiscriminator = randomNetwork
6868+6969+randomGenerator :: MonadRandom m => m Generator
7070+randomGenerator = randomNetwork
7171+7272+trainExample :: LearningParameters -> Discriminator -> Generator -> S ('D2 28 28) -> S ('D1 100) -> ( Discriminator, Generator )
7373+trainExample rate discriminator generator realExample noiseSource
7474+ = let (generatorTape, fakeExample) = runNetwork generator noiseSource
7575+7676+ (discriminatorTapeReal, guessReal) = runNetwork discriminator realExample
7777+ (discriminatorTapeFake, guessFake) = runNetwork discriminator fakeExample
7878+7979+ (discriminator'real, _) = runGradient discriminator discriminatorTapeReal ( guessReal - 1 )
8080+ (discriminator'fake, push) = runGradient discriminator discriminatorTapeFake guessFake
8181+8282+ (generator', _) = runGradient generator generatorTape (-push)
8383+8484+ newDiscriminator = foldl' (applyUpdate rate { learningRegulariser = learningRegulariser rate * 10}) discriminator [ discriminator'real, discriminator'fake ]
8585+ newGenerator = applyUpdate rate generator generator'
8686+ in ( newDiscriminator, newGenerator )
8787+8888+8989+ganTest :: (Discriminator, Generator) -> Int -> FilePath -> LearningParameters -> ExceptT String IO (Discriminator, Generator)
9090+ganTest (discriminator0, generator0) iterations trainFile rate = do
9191+ trainData <- fmap fst <$> readMNIST trainFile
9292+9393+ lift $ foldM (runIteration trainData) ( discriminator0, generator0 ) [1..iterations]
9494+9595+ where
9696+9797+ showShape' :: S ('D2 a b) -> IO ()
9898+ showShape' (S2D mm) = putStrLn $
9999+ let m = SA.extract mm
100100+ ms = toLists m
101101+ render n' | n' <= 0.2 = ' '
102102+ | n' <= 0.4 = '.'
103103+ | n' <= 0.6 = '-'
104104+ | n' <= 0.8 = '='
105105+ | otherwise = '#'
106106+107107+ px = (fmap . fmap) render ms
108108+ in unlines px
109109+110110+ runIteration :: [S ('D2 28 28)] -> (Discriminator, Generator) -> Int -> IO (Discriminator, Generator)
111111+ runIteration trainData ( !discriminator, !generator ) _ = do
112112+ trained' <- foldM ( \(!discriminatorX, !generatorX ) realExample -> do
113113+ fakeExample <- randomOfShape
114114+ return $ trainExample rate discriminatorX generatorX realExample fakeExample
115115+ ) ( discriminator, generator ) trainData
116116+117117+118118+ showShape' . snd . runNetwork (snd trained') =<< randomOfShape
119119+120120+ return trained'
121121+122122+data GanOpts = GanOpts FilePath Int LearningParameters (Maybe FilePath) (Maybe FilePath)
123123+124124+mnist' :: Parser GanOpts
125125+mnist' = GanOpts <$> argument str (metavar "TRAIN")
126126+ <*> option auto (long "iterations" <> short 'i' <> value 15)
127127+ <*> (LearningParameters
128128+ <$> option auto (long "train_rate" <> short 'r' <> value 0.01)
129129+ <*> option auto (long "momentum" <> value 0.9)
130130+ <*> option auto (long "l2" <> value 0.0005)
131131+ )
132132+ <*> optional (strOption (long "load"))
133133+ <*> optional (strOption (long "save"))
134134+135135+136136+main :: IO ()
137137+main = do
138138+ GanOpts mnist iter rate load save <- execParser (info (mnist' <**> helper) idm)
139139+ putStrLn "Training stupidly simply GAN"
140140+ nets0 <- case load of
141141+ Just loadFile -> netLoad loadFile
142142+ Nothing -> (,) <$> randomDiscriminator <*> randomGenerator
143143+144144+ res <- runExceptT $ ganTest nets0 iter mnist rate
145145+ case res of
146146+ Right nets1 -> case save of
147147+ Just saveFile -> B.writeFile saveFile $ runPut (put nets1)
148148+ Nothing -> return ()
149149+150150+ Left err -> putStrLn err
151151+152152+readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))]
153153+readMNIST mnist = ExceptT $ do
154154+ mnistdata <- T.readFile mnist
155155+ return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata)
156156+157157+parseMNIST :: A.Parser (S ('D2 28 28), S ('D1 10))
158158+parseMNIST = do
159159+ Just lab <- oneHot <$> A.decimal
160160+ pixels <- many (A.char ',' >> A.double)
161161+ image <- maybe (fail "Parsed row was of an incorrect size") pure (fromStorable . V.fromList $ pixels)
162162+ return (image, lab)
163163+164164+netLoad :: FilePath -> IO (Discriminator, Generator)
165165+netLoad modelPath = do
166166+ modelData <- B.readFile modelPath
167167+ either fail return $ runGet (get :: Get (Discriminator, Generator)) modelData
+14-1
src/Grenade/Layers/Reshape.hs
···2121--
2222-- Flattens input down to D1 from either 2D or 3D data.
2323--
2424--- Can also be used to turn a 3D image with only one channel into a 2D image.
2424+-- Casts input D1 up to either 2D or 3D data if the shapes are good.
2525+--
2626+-- Can also be used to turn a 3D image with only one channel into a 2D image
2727+-- or vice versa.
2528data Reshape = Reshape
2629 deriving Show
2730···4952 type Tape Reshape ('D2 x y) ('D3 x y z) = ()
5053 runForwards _ (S2D y) = ((), S3D y)
5154 runBackwards _ _ (S3D y) = ((), S2D y)
5555+5656+instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer Reshape ('D1 a) ('D2 x y) where
5757+ type Tape Reshape ('D1 a) ('D2 x y) = ()
5858+ runForwards _ (S1D y) = ((), fromJust' . fromStorable . extract $ y)
5959+ runBackwards _ _ (S2D y) = ((), fromJust' . fromStorable . flatten . extract $ y)
6060+6161+instance (KnownNat a, KnownNat x, KnownNat y, KnownNat (x * z), KnownNat z, a ~ (x * y * z)) => Layer Reshape ('D1 a) ('D3 x y z) where
6262+ type Tape Reshape ('D1 a) ('D3 x y z) = ()
6363+ runForwards _ (S1D y) = ((), fromJust' . fromStorable . extract $ y)
6464+ runBackwards _ _ (S3D y) = ((), fromJust' . fromStorable . flatten . extract $ y)
52655366instance Serialize Reshape where
5467 put _ = return ()