···1+{-# LANGUAGE BangPatterns #-}
2+{-# LANGUAGE DataKinds #-}
3+{-# LANGUAGE GADTs #-}
4+{-# LANGUAGE ScopedTypeVariables #-}
5+{-# LANGUAGE TypeOperators #-}
6+{-# LANGUAGE TupleSections #-}
7+{-# LANGUAGE TypeFamilies #-}
8+{-# LANGUAGE FlexibleContexts #-}
9+10+-- This is a simple generative adversarial network to make pictures
11+-- of numbers similar to those in MNIST.
12+--
13+-- It demonstrates a different usage of the library, within a few hours
14+-- was producing examples like this:
15+--
16+-- --.
17+-- .=-.--..#=###
18+-- -##==#########.
19+-- #############-
20+-- -###-.=..-.-==
21+-- ###-
22+-- .###-
23+-- .####...==-.
24+-- -####=--.=##=
25+-- -##=- -##
26+-- =##
27+-- -##=
28+-- -###-
29+-- .####.
30+-- .#####.
31+-- ...---=#####-
32+-- .=#########. .
33+-- .#######=. .
34+-- . =-.
35+--
36+-- It's a 5!
37+--
38+import Control.Applicative
39+import Control.Monad
40+import Control.Monad.Random
41+import Control.Monad.Trans.Except
42+43+import qualified Data.Attoparsec.Text as A
44+import qualified Data.ByteString as B
45+import Data.List ( foldl' )
46+import Data.Semigroup ( (<>) )
47+import Data.Serialize
48+import qualified Data.Text as T
49+import qualified Data.Text.IO as T
50+import qualified Data.Vector.Storable as V
51+52+import qualified Numeric.LinearAlgebra.Static as SA
53+import Numeric.LinearAlgebra.Data ( toLists )
54+55+import Options.Applicative
56+57+import Grenade
58+import Grenade.Utils.OneHot
59+60+type Discriminator = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu, FullyConnected 256 80, Logit, FullyConnected 80 1, Logit]
61+ '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 1, 'D1 1]
62+63+type Generator = Network '[ FullyConnected 100 10240, Relu, Reshape, Convolution 10 10 5 5 1 1, Relu, Convolution 10 1 1 1 1 1, Logit, Reshape]
64+ '[ 'D1 100, 'D1 10240, 'D1 10240, 'D3 32 32 10, 'D3 28 28 10, 'D3 28 28 10, 'D3 28 28 1, 'D3 28 28 1, 'D2 28 28 ]
65+66+randomDiscriminator :: MonadRandom m => m Discriminator
67+randomDiscriminator = randomNetwork
68+69+randomGenerator :: MonadRandom m => m Generator
70+randomGenerator = randomNetwork
71+72+trainExample :: LearningParameters -> Discriminator -> Generator -> S ('D2 28 28) -> S ('D1 100) -> ( Discriminator, Generator )
73+trainExample rate discriminator generator realExample noiseSource
74+ = let (generatorTape, fakeExample) = runNetwork generator noiseSource
75+76+ (discriminatorTapeReal, guessReal) = runNetwork discriminator realExample
77+ (discriminatorTapeFake, guessFake) = runNetwork discriminator fakeExample
78+79+ (discriminator'real, _) = runGradient discriminator discriminatorTapeReal ( guessReal - 1 )
80+ (discriminator'fake, push) = runGradient discriminator discriminatorTapeFake guessFake
81+82+ (generator', _) = runGradient generator generatorTape (-push)
83+84+ newDiscriminator = foldl' (applyUpdate rate { learningRegulariser = learningRegulariser rate * 10}) discriminator [ discriminator'real, discriminator'fake ]
85+ newGenerator = applyUpdate rate generator generator'
86+ in ( newDiscriminator, newGenerator )
87+88+89+ganTest :: (Discriminator, Generator) -> Int -> FilePath -> LearningParameters -> ExceptT String IO (Discriminator, Generator)
90+ganTest (discriminator0, generator0) iterations trainFile rate = do
91+ trainData <- fmap fst <$> readMNIST trainFile
92+93+ lift $ foldM (runIteration trainData) ( discriminator0, generator0 ) [1..iterations]
94+95+ where
96+97+ showShape' :: S ('D2 a b) -> IO ()
98+ showShape' (S2D mm) = putStrLn $
99+ let m = SA.extract mm
100+ ms = toLists m
101+ render n' | n' <= 0.2 = ' '
102+ | n' <= 0.4 = '.'
103+ | n' <= 0.6 = '-'
104+ | n' <= 0.8 = '='
105+ | otherwise = '#'
106+107+ px = (fmap . fmap) render ms
108+ in unlines px
109+110+ runIteration :: [S ('D2 28 28)] -> (Discriminator, Generator) -> Int -> IO (Discriminator, Generator)
111+ runIteration trainData ( !discriminator, !generator ) _ = do
112+ trained' <- foldM ( \(!discriminatorX, !generatorX ) realExample -> do
113+ fakeExample <- randomOfShape
114+ return $ trainExample rate discriminatorX generatorX realExample fakeExample
115+ ) ( discriminator, generator ) trainData
116+117+118+ showShape' . snd . runNetwork (snd trained') =<< randomOfShape
119+120+ return trained'
121+122+data GanOpts = GanOpts FilePath Int LearningParameters (Maybe FilePath) (Maybe FilePath)
123+124+mnist' :: Parser GanOpts
125+mnist' = GanOpts <$> argument str (metavar "TRAIN")
126+ <*> option auto (long "iterations" <> short 'i' <> value 15)
127+ <*> (LearningParameters
128+ <$> option auto (long "train_rate" <> short 'r' <> value 0.01)
129+ <*> option auto (long "momentum" <> value 0.9)
130+ <*> option auto (long "l2" <> value 0.0005)
131+ )
132+ <*> optional (strOption (long "load"))
133+ <*> optional (strOption (long "save"))
134+135+136+main :: IO ()
137+main = do
138+ GanOpts mnist iter rate load save <- execParser (info (mnist' <**> helper) idm)
139+ putStrLn "Training stupidly simply GAN"
140+ nets0 <- case load of
141+ Just loadFile -> netLoad loadFile
142+ Nothing -> (,) <$> randomDiscriminator <*> randomGenerator
143+144+ res <- runExceptT $ ganTest nets0 iter mnist rate
145+ case res of
146+ Right nets1 -> case save of
147+ Just saveFile -> B.writeFile saveFile $ runPut (put nets1)
148+ Nothing -> return ()
149+150+ Left err -> putStrLn err
151+152+readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))]
153+readMNIST mnist = ExceptT $ do
154+ mnistdata <- T.readFile mnist
155+ return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata)
156+157+parseMNIST :: A.Parser (S ('D2 28 28), S ('D1 10))
158+parseMNIST = do
159+ Just lab <- oneHot <$> A.decimal
160+ pixels <- many (A.char ',' >> A.double)
161+ image <- maybe (fail "Parsed row was of an incorrect size") pure (fromStorable . V.fromList $ pixels)
162+ return (image, lab)
163+164+netLoad :: FilePath -> IO (Discriminator, Generator)
165+netLoad modelPath = do
166+ modelData <- B.readFile modelPath
167+ either fail return $ runGet (get :: Get (Discriminator, Generator)) modelData
+14-1
src/Grenade/Layers/Reshape.hs
···21--
22-- Flattens input down to D1 from either 2D or 3D data.
23--
24--- Can also be used to turn a 3D image with only one channel into a 2D image.
00025data Reshape = Reshape
26 deriving Show
27···49 type Tape Reshape ('D2 x y) ('D3 x y z) = ()
50 runForwards _ (S2D y) = ((), S3D y)
51 runBackwards _ _ (S3D y) = ((), S2D y)
00000000005253instance Serialize Reshape where
54 put _ = return ()
···21--
22-- Flattens input down to D1 from either 2D or 3D data.
23--
24+-- Casts input D1 up to either 2D or 3D data if the shapes are good.
25+--
26+-- Can also be used to turn a 3D image with only one channel into a 2D image
27+-- or vice versa.
28data Reshape = Reshape
29 deriving Show
30···52 type Tape Reshape ('D2 x y) ('D3 x y z) = ()
53 runForwards _ (S2D y) = ((), S3D y)
54 runBackwards _ _ (S3D y) = ((), S2D y)
55+56+instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer Reshape ('D1 a) ('D2 x y) where
57+ type Tape Reshape ('D1 a) ('D2 x y) = ()
58+ runForwards _ (S1D y) = ((), fromJust' . fromStorable . extract $ y)
59+ runBackwards _ _ (S2D y) = ((), fromJust' . fromStorable . flatten . extract $ y)
60+61+instance (KnownNat a, KnownNat x, KnownNat y, KnownNat (x * z), KnownNat z, a ~ (x * y * z)) => Layer Reshape ('D1 a) ('D3 x y z) where
62+ type Tape Reshape ('D1 a) ('D3 x y z) = ()
63+ runForwards _ (S1D y) = ((), fromJust' . fromStorable . extract $ y)
64+ runBackwards _ _ (S3D y) = ((), fromJust' . fromStorable . flatten . extract $ y)
6566instance Serialize Reshape where
67 put _ = return ()