···2323import Grenade
2424import Grenade.Recurrent
25252626-{-# OPTIONS_GHC -fno-redundant-imports #-}
2726-- The defininition for our simple recurrent network.
2827-- This file just trains a network to generate a repeating sequence
2928-- of 0 0 1.
-1
src/Grenade/Core/Network.hs
···157157applyUpdate _ NNil GNil
158158 = NNil
159159160160-161160-- | A network can easily be created by hand with (:~>), but an easy way to
162161-- initialise a random network is with the randomNetwork.
163162class CreatableNetwork (xs :: [*]) (ss :: [Shape]) where
+2
src/Grenade/Core/Runner.hs
···4242 (grads, _) = runGradient network tapes (output - target)
4343 in grads
44444545+4546-- | Update a network with new weights after training with an instance.
4647train :: SingI (Last shapes)
4748 => LearningParameters
···5253train rate network input output =
5354 let grads = backPropagate network input output
5455 in applyUpdate rate network grads
5656+55575658-- | Run the network with input and return the given output.
5759runNet :: Network layers shapes -> S (Head shapes) -> S (Last shapes)
+1
src/Grenade/Recurrent.hs
···22 module X
33 ) where
4455+import Grenade.Recurrent.Core.Layer as X
56import Grenade.Recurrent.Core.Network as X
67import Grenade.Recurrent.Core.Runner as X
78import Grenade.Recurrent.Layers.BasicRecurrent as X
+6
src/Grenade/Recurrent/Core.hs
···11+module Grenade.Recurrent.Core (
22+ module X
33+ ) where
44+55+import Grenade.Recurrent.Core.Layer as X
66+import Grenade.Recurrent.Core.Network as X
+32
src/Grenade/Recurrent/Core/Layer.hs
···11+{-# LANGUAGE DataKinds #-}
22+{-# LANGUAGE TypeFamilies #-}
33+{-# LANGUAGE MultiParamTypeClasses #-}
44+{-# LANGUAGE FlexibleContexts #-}
55+{-# LANGUAGE FlexibleInstances #-}
66+module Grenade.Recurrent.Core.Layer (
77+ RecurrentLayer (..)
88+ , RecurrentUpdateLayer (..)
99+ ) where
1010+1111+import Data.Singletons ( SingI )
1212+1313+import Grenade.Core
1414+1515+-- | Class for a recurrent layer.
1616+-- It's quite similar to a normal layer but for the input and output
1717+-- of an extra recurrent data shape.
1818+class UpdateLayer x => RecurrentUpdateLayer x where
1919+ -- | Shape of data that is passed between each subsequent run of the layer
2020+ type RecurrentShape x :: Shape
2121+2222+class (RecurrentUpdateLayer x, SingI (RecurrentShape x)) => RecurrentLayer x (i :: Shape) (o :: Shape) where
2323+ -- | Wengert Tape
2424+ type RecTape x i o :: *
2525+ -- | Used in training and scoring. Take the input from the previous
2626+ -- layer, and give the output from this layer.
2727+ runRecurrentForwards :: x -> S (RecurrentShape x) -> S i -> (RecTape x i o, S (RecurrentShape x), S o)
2828+ -- | Back propagate a step. Takes the current layer, the input that the
2929+ -- layer gave from the input and the back propagated derivatives from
3030+ -- the layer above.
3131+ -- Returns the gradient layer and the derivatives to push back further.
3232+ runRecurrentBackwards :: x -> RecTape x i o -> S (RecurrentShape x) -> S o -> (Gradient x, S (RecurrentShape x), S i)
+214-32
src/Grenade/Recurrent/Core/Network.hs
···11+{-# LANGUAGE CPP #-}
12{-# LANGUAGE DataKinds #-}
23{-# LANGUAGE GADTs #-}
34{-# LANGUAGE TypeOperators #-}
···67{-# LANGUAGE FlexibleContexts #-}
78{-# LANGUAGE FlexibleInstances #-}
89{-# LANGUAGE EmptyDataDecls #-}
1010+{-# LANGUAGE RankNTypes #-}
1111+{-# LANGUAGE BangPatterns #-}
1212+{-# LANGUAGE ScopedTypeVariables #-}
1313+1414+#if __GLASGOW_HASKELL__ < 800
1515+{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
1616+#endif
1717+918module Grenade.Recurrent.Core.Network (
1019 Recurrent
1120 , FeedForward
1212- , RecurrentLayer (..)
1313- , RecurrentUpdateLayer (..)
2121+1422 , RecurrentNetwork (..)
1523 , RecurrentInputs (..)
1616- , CreatableRecurrent (..)
2424+ , RecurrentTapes (..)
2525+ , RecurrentGradients (..)
2626+2727+ , randomRecurrent
2828+ , runRecurrentNetwork
2929+ , runRecurrentGradient
3030+ , applyRecurrentUpdate
1731 ) where
183219332034import Control.Monad.Random ( MonadRandom )
2135import Data.Singletons ( SingI )
3636+import Data.Singletons.Prelude ( Head, Last )
2237import Data.Serialize
2338import qualified Data.Vector.Storable as V
24392540import Grenade.Core
4141+import Grenade.Recurrent.Core.Layer
26422743import qualified Numeric.LinearAlgebra as LA
2844import qualified Numeric.LinearAlgebra.Static as LAS
29453030-3146-- | Witness type to say indicate we're building up with a normal feed
3247-- forward layer.
3348data FeedForward :: * -> *
3449-- | Witness type to say indicate we're building up with a recurrent layer.
3550data Recurrent :: * -> *
36513737--- | Class for a recurrent layer.
3838--- It's quite similar to a normal layer but for the input and output
3939--- of an extra recurrent data shape.
4040-class UpdateLayer x => RecurrentUpdateLayer x where
4141- -- | Shape of data that is passed between each subsequent run of the layer
4242- type RecurrentShape x :: Shape
4343-4444-class (RecurrentUpdateLayer x, SingI (RecurrentShape x)) => RecurrentLayer x (i :: Shape) (o :: Shape) where
4545- -- | Wengert Tape
4646- type RecTape x i o :: *
4747- -- | Used in training and scoring. Take the input from the previous
4848- -- layer, and give the output from this layer.
4949- runRecurrentForwards :: x -> S (RecurrentShape x) -> S i -> (RecTape x i o, S (RecurrentShape x), S o)
5050- -- | Back propagate a step. Takes the current layer, the input that the
5151- -- layer gave from the input and the back propagated derivatives from
5252- -- the layer above.
5353- -- Returns the gradient layer and the derivatives to push back further.
5454- runRecurrentBackwards :: x -> RecTape x i o -> S (RecurrentShape x) -> S o -> (Gradient x, S (RecurrentShape x), S i)
5555-5252+-- | Type of a recurrent neural network.
5353+--
5454+-- The [*] type specifies the types of the layers.
5555+--
5656+-- The [Shape] type specifies the shapes of data passed between the layers.
5757+--
5858+-- The definition is similar to a Network, but every layer in the
5959+-- type is tagged by whether it's a FeedForward Layer of a Recurrent layer.
6060+--
6161+-- Often, to make the definitions more concise, one will use a type alias
6262+-- for these empty data types.
5663data RecurrentNetwork :: [*] -> [Shape] -> * where
5764 RNil :: SingI i
5865 => RecurrentNetwork '[] '[i]
···6976infixr 5 :~~>
7077infixr 5 :~@>
71787272-instance Show (RecurrentNetwork '[] '[i]) where
7373- show RNil = "NNil"
7474-instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (FeedForward x ': xs) (i ': rs)) where
7575- show (x :~~> xs) = show x ++ "\n~~>\n" ++ show xs
7676-instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (Recurrent x ': xs) (i ': rs)) where
7777- show (x :~@> xs) = show x ++ "\n~~>\n" ++ show xs
7979+-- | Gradient of a network.
8080+--
8181+-- Parameterised on the layers of the network.
8282+data RecurrentGradients :: [*] -> * where
8383+ RGNil :: RecurrentGradients '[]
8484+8585+ (://>) :: UpdateLayer x
8686+ => [Gradient x]
8787+ -> RecurrentGradients xs
8888+ -> RecurrentGradients (phantom x ': xs)
78897990-- | Recurrent inputs (sideways shapes on an imaginary unrolled graph)
8091-- Parameterised on the layers of a Network.
8192data RecurrentInputs :: [*] -> * where
8293 RINil :: RecurrentInputs '[]
9494+8395 (:~~+>) :: UpdateLayer x
8484- => () -> !(RecurrentInputs xs) -> RecurrentInputs (FeedForward x ': xs)
9696+ => () -> !(RecurrentInputs xs) -> RecurrentInputs (FeedForward x ': xs)
9797+8598 (:~@+>) :: (SingI (RecurrentShape x), RecurrentUpdateLayer x)
8699 => !(S (RecurrentShape x)) -> !(RecurrentInputs xs) -> RecurrentInputs (Recurrent x ': xs)
8787-infixr 5 :~~+>
8888-infixr 5 :~@+>
100100+101101+-- | All the information required to backpropogate
102102+-- through time safely.
103103+--
104104+-- We index on the time step length as well, to ensure
105105+-- that that all Tape lengths are the same.
106106+data RecurrentTapes :: [*] -> [Shape] -> * where
107107+ TRNil :: SingI i
108108+ => RecurrentTapes '[] '[i]
109109+110110+ (:\~>) :: [Tape x i h]
111111+ -> !(RecurrentTapes xs (h ': hs))
112112+ -> RecurrentTapes (FeedForward x ': xs) (i ': h ': hs)
113113+114114+115115+ (:\@>) :: [RecTape x i h]
116116+ -> !(RecurrentTapes xs (h ': hs))
117117+ -> RecurrentTapes (Recurrent x ': xs) (i ': h ': hs)
118118+119119+120120+runRecurrentNetwork :: forall shapes layers.
121121+ RecurrentNetwork layers shapes
122122+ -> RecurrentInputs layers
123123+ -> [S (Head shapes)]
124124+ -> (RecurrentTapes layers shapes, RecurrentInputs layers, [S (Last shapes)])
125125+runRecurrentNetwork =
126126+ go
127127+ where
128128+ go :: forall js sublayers. (Last js ~ Last shapes)
129129+ => RecurrentNetwork sublayers js
130130+ -> RecurrentInputs sublayers
131131+ -> [S (Head js)]
132132+ -> (RecurrentTapes sublayers js, RecurrentInputs sublayers, [S (Last js)])
133133+ -- This is a simple non-recurrent layer, just map it forwards
134134+ go (layer :~~> n) (() :~~+> nIn) !xs
135135+ = let tys = runForwards layer <$> xs
136136+ feedForwardTapes = fst <$> tys
137137+ forwards = snd <$> tys
138138+ -- recursively run the rest of the network, and get the gradients from above.
139139+ (newFN, ig, answer) = go n nIn forwards
140140+ in (feedForwardTapes :\~> newFN, () :~~+> ig, answer)
141141+142142+ -- This is a recurrent layer, so we need to do a scan, first input to last, providing
143143+ -- the recurrent shape output to the next layer.
144144+ go (layer :~@> n) (recIn :~@+> nIn) !xs
145145+ = let (recOut, tys) = goR layer recIn xs
146146+ recurrentTapes = fst <$> tys
147147+ forwards = snd <$> tys
148148+149149+ (newFN, ig, answer) = go n nIn forwards
150150+ in (recurrentTapes :\@> newFN, recOut :~@+> ig, answer)
151151+152152+ -- Handle the output layer, bouncing the derivatives back down.
153153+ -- We may not have a target for each example, so when we don't use 0 gradient.
154154+ go RNil RINil !x
155155+ = (TRNil, RINil, x)
156156+157157+ -- Helper function for recurrent layers
158158+ -- Scans over the recurrent direction of the graph.
159159+ goR !layer !recShape (x:xs) =
160160+ let (tape, lerec, lepush) = runRecurrentForwards layer recShape x
161161+ (rems, push) = goR layer lerec xs
162162+ in (rems, (tape, lepush) : push)
163163+ goR _ rin [] = (rin, [])
164164+165165+runRecurrentGradient :: forall layers shapes.
166166+ RecurrentNetwork layers shapes
167167+ -> RecurrentTapes layers shapes
168168+ -> RecurrentInputs layers
169169+ -> [S (Last shapes)]
170170+ -> (RecurrentGradients layers, RecurrentInputs layers, [S (Head shapes)])
171171+runRecurrentGradient net tapes r o =
172172+ go net tapes r
173173+ where
174174+ -- We have to be careful regarding the direction of the lists
175175+ -- Inputs come in forwards, but our return value is backwards
176176+ -- through time.
177177+ go :: forall js ss. (Last js ~ Last shapes)
178178+ => RecurrentNetwork ss js
179179+ -> RecurrentTapes ss js
180180+ -> RecurrentInputs ss
181181+ -> (RecurrentGradients ss, RecurrentInputs ss, [S (Head js)])
182182+ -- This is a simple non-recurrent layer
183183+ -- Run the rest of the network, then fmap the tapes and gradients
184184+ go (layer :~~> n) (feedForwardTapes :\~> nTapes) (() :~~+> nRecs) =
185185+ let (gradients, rins, feed) = go n nTapes nRecs
186186+ backs = uncurry (runBackwards layer) <$> zip (reverse feedForwardTapes) feed
187187+ in ((fst <$> backs) ://> gradients, () :~~+> rins, snd <$> backs)
188188+189189+ -- This is a recurrent layer
190190+ -- Run the rest of the network, scan over the tapes in reverse
191191+ go (layer :~@> n) (recurrentTapes :\@> nTapes) (recGrad :~@+> nRecs) =
192192+ let (gradients, rins, feed) = go n nTapes nRecs
193193+ backExamples = zip (reverse recurrentTapes) feed
194194+ (rg, backs) = goX layer recGrad backExamples
195195+ in ((fst <$> backs) ://> gradients, rg :~@+> rins, snd <$> backs)
196196+197197+ -- End of the road, so we reflect the given gradients backwards.
198198+ -- Crucially, we reverse the list, so it's backwards in time as
199199+ -- well.
200200+ go RNil TRNil RINil
201201+ = (RGNil, RINil, reverse o)
202202+203203+ -- Helper function for recurrent layers
204204+ -- Scans over the recurrent direction of the graph.
205205+ goX :: RecurrentLayer x i o => x -> S (RecurrentShape x) -> [(RecTape x i o, S o)] -> (S (RecurrentShape x), [(Gradient x, S i)])
206206+ goX layer !lastback ((recTape, backgrad):xs) =
207207+ let (layergrad, recgrad, ingrad) = runRecurrentBackwards layer recTape lastback backgrad
208208+ (pushedback, ll) = goX layer recgrad xs
209209+ in (pushedback, (layergrad, ingrad) : ll)
210210+ goX _ !lastback [] = (lastback, [])
211211+212212+-- | Apply a batch of gradients to the network
213213+-- Uses runUpdates which can be specialised for
214214+-- a layer.
215215+applyRecurrentUpdate :: LearningParameters
216216+ -> RecurrentNetwork layers shapes
217217+ -> RecurrentGradients layers
218218+ -> RecurrentNetwork layers shapes
219219+applyRecurrentUpdate rate (layer :~~> rest) (gradient ://> grest)
220220+ = runUpdates rate layer gradient :~~> applyRecurrentUpdate rate rest grest
221221+222222+applyRecurrentUpdate rate (layer :~@> rest) (gradient ://> grest)
223223+ = runUpdates rate layer gradient :~@> applyRecurrentUpdate rate rest grest
224224+225225+applyRecurrentUpdate _ RNil RGNil
226226+ = RNil
227227+228228+229229+instance Show (RecurrentNetwork '[] '[i]) where
230230+ show RNil = "NNil"
231231+instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (FeedForward x ': xs) (i ': rs)) where
232232+ show (x :~~> xs) = show x ++ "\n~~>\n" ++ show xs
233233+instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (Recurrent x ': xs) (i ': rs)) where
234234+ show (x :~@> xs) = show x ++ "\n~~>\n" ++ show xs
235235+8923690237-- | A network can easily be created by hand with (:~~>) and (:~@>), but an easy way to initialise a random
91238-- recurrent network and a set of random inputs for it is with the randomRecurrent.
···144291 Just i <- fromStorable . V.fromList <$> getListOf get
145292 rest <- get
146293 return ( i :~@+> rest)
294294+295295+296296+-- Num instance for `RecurrentInputs layers`
297297+-- Not sure if this is really needed, as I only need a `fromInteger 0` at
298298+-- the moment for training, to create a null gradient on the recurrent
299299+-- edge.
300300+--
301301+-- It does raise an interesting question though? Is a 0 gradient actually
302302+-- the best?
303303+--
304304+-- I could imaging that weakly push back towards the optimum input could
305305+-- help make a more stable generator.
306306+instance (Num (RecurrentInputs '[])) where
307307+ (+) _ _ = RINil
308308+ (-) _ _ = RINil
309309+ (*) _ _ = RINil
310310+ abs _ = RINil
311311+ signum _ = RINil
312312+ fromInteger _ = RINil
313313+314314+instance (UpdateLayer x, Num (RecurrentInputs ys)) => (Num (RecurrentInputs (FeedForward x ': ys))) where
315315+ (+) (() :~~+> x) (() :~~+> y) = () :~~+> (x + y)
316316+ (-) (() :~~+> x) (() :~~+> y) = () :~~+> (x - y)
317317+ (*) (() :~~+> x) (() :~~+> y) = () :~~+> (x * y)
318318+ abs (() :~~+> x) = () :~~+> abs x
319319+ signum (() :~~+> x) = () :~~+> signum x
320320+ fromInteger x = () :~~+> fromInteger x
321321+322322+instance (SingI (RecurrentShape x), RecurrentUpdateLayer x, Num (RecurrentInputs ys)) => (Num (RecurrentInputs (Recurrent x ': ys))) where
323323+ (+) (x :~@+> x') (y :~@+> y') = (x + y) :~@+> (x' + y')
324324+ (-) (x :~@+> x') (y :~@+> y') = (x - y) :~@+> (x' - y')
325325+ (*) (x :~@+> x') (y :~@+> y') = (x * y) :~@+> (x' * y')
326326+ abs (x :~@+> x') = abs x :~@+> abs x'
327327+ signum (x :~@+> x') = signum x :~@+> signum x'
328328+ fromInteger x = fromInteger x :~@+> fromInteger x
+40-90
src/Grenade/Recurrent/Core/Runner.hs
···1616module Grenade.Recurrent.Core.Runner (
1717 trainRecurrent
1818 , runRecurrent
1919+ , backPropagateRecurrent
1920 ) where
20212122import Data.Singletons.Prelude
2223import Grenade.Core
23242525+import Grenade.Recurrent.Core.Layer
2426import Grenade.Recurrent.Core.Network
25272628-- | Drive and network and collect its back propogated gradients.
2727---
2828--- TODO: split this nicely into backpropagate and update.
2929---
3030--- QUESTION: Should we return a list of gradients or the sum of
3131--- the gradients? It's different taking into account
3232--- momentum and L2.
3333-trainRecurrent :: forall shapes layers. SingI (Last shapes)
3434- => LearningParameters
3535- -> RecurrentNetwork layers shapes
3636- -> RecurrentInputs layers
3737- -> [(S (Head shapes), Maybe (S (Last shapes)))]
3838- -> (RecurrentNetwork layers shapes, RecurrentInputs layers)
3939-trainRecurrent rate network recinputs examples =
4040- updateBack $ go inputs network recinputs
4141- where
4242- inputs = fst <$> examples
4343- targets = snd <$> examples
4444- updateBack (a,recgrad,_) = (a,updateRecInputs rate recinputs recgrad)
2929+backPropagateRecurrent :: forall shapes layers. (SingI (Last shapes), Num (RecurrentInputs layers))
3030+ => RecurrentNetwork layers shapes
3131+ -> RecurrentInputs layers
3232+ -> [(S (Head shapes), Maybe (S (Last shapes)))]
3333+ -> (RecurrentGradients layers, RecurrentInputs layers)
3434+backPropagateRecurrent network recinputs examples =
3535+ let (tapes, _, guesses) = runRecurrentNetwork network recinputs inputs
45364646- go :: forall js sublayers. (Last js ~ Last shapes)
4747- => [S (Head js)] -- ^ input vector
4848- -> RecurrentNetwork sublayers js -- ^ network to train
4949- -> RecurrentInputs sublayers
5050- -> (RecurrentNetwork sublayers js, RecurrentInputs sublayers, [S (Head js)])
3737+ backPropagations = zipWith makeError guesses targets
51385252- -- This is a simple non-recurrent layer, just map it forwards
5353- -- Note we're doing training here, we could just return a list of gradients
5454- -- (and probably will in future).
5555- go !xs (layer :~~> n) (() :~~+> nIn)
5656- = let tys = runForwards layer <$> xs
5757- tapes = fst <$> tys
5858- ys = snd <$> tys
5959- -- recursively run the rest of the network, and get the gradients from above.
6060- (newFN, ig, grads) = go ys n nIn
6161- -- calculate the gradient for this layer to pass down,
6262- back = uncurry (runBackwards layer) <$> zip (reverse tapes) grads
6363- -- the new trained layer.
6464- newlayer = runUpdates rate layer (fst <$> back)
3939+ (gradients, input', _) = runRecurrentGradient network tapes 0 backPropagations
65406666- in (newlayer :~~> newFN, () :~~+> ig, snd <$> back)
4141+ in (gradients, input')
67426868- -- This is a recurrent layer, so we need to do a scan, first input to last, providing
6969- -- the recurrent shape output to the next layer.
7070- go !xs (layer :~@> n) (g :~@+> nIn)
7171- = let tys = scanlFrom layer g xs
7272- tapes = fst <$> tys
7373- ys = snd <$> tys
4343+ where
74447575- (newFN, ig, grads) = go ys n nIn
4545+ inputs = fst <$> examples
4646+ targets = snd <$> examples
76477777- backExamples = zip (reverse tapes) grads
4848+ makeError :: S (Last shapes) -> Maybe (S (Last shapes)) -> S (Last shapes)
4949+ makeError _ Nothing = 0
5050+ makeError y (Just t) = y - t
78517979- (rg, back) = myscanbackward layer backExamples
8080- -- the new trained layer.
8181- newlayer = runUpdates rate layer (fst <$> back)
8282- in (newlayer :~@> newFN, rg :~@+> ig, snd <$> back)
5252+5353+trainRecurrent :: forall shapes layers. (SingI (Last shapes), Num (RecurrentInputs layers))
5454+ => LearningParameters
5555+ -> RecurrentNetwork layers shapes
5656+ -> RecurrentInputs layers
5757+ -> [(S (Head shapes), Maybe (S (Last shapes)))]
5858+ -> (RecurrentNetwork layers shapes, RecurrentInputs layers)
5959+trainRecurrent rate network recinputs examples =
6060+ let (gradients, recinputs') = backPropagateRecurrent network recinputs examples
83618484- -- Handle the output layer, bouncing the derivatives back down.
8585- -- We may not have a target for each example, so when we don't use 0 gradient.
8686- go !xs RNil RINil
8787- = (RNil, RINil, reverse (zipWith makeError xs targets))
8888- where
8989- makeError :: S (Last shapes) -> Maybe (S (Last shapes)) -> S (Last shapes)
9090- makeError _ Nothing = 0
9191- makeError y (Just t) = y - t
6262+ newInputs = updateRecInputs rate recinputs recinputs'
92639393- updateRecInputs :: forall sublayers.
9494- LearningParameters
9595- -> RecurrentInputs sublayers
9696- -> RecurrentInputs sublayers
9797- -> RecurrentInputs sublayers
6464+ newNetwork = applyRecurrentUpdate rate network gradients
98659999- updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys)
100100- = () :~~+> updateRecInputs l xs ys
6666+ in (newNetwork, newInputs)
10167102102- updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys)
103103- = (realToFrac (learningRate * learningRegulariser) * x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys
6868+updateRecInputs :: LearningParameters
6969+ -> RecurrentInputs sublayers
7070+ -> RecurrentInputs sublayers
7171+ -> RecurrentInputs sublayers
10472105105- updateRecInputs _ RINil RINil
106106- = RINil
7373+updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys)
7474+ = () :~~+> updateRecInputs l xs ys
10775108108-scanlFrom :: forall x i o. RecurrentLayer x i o
109109- => x -- ^ the layer
110110- -> S (RecurrentShape x) -- ^ place to start
111111- -> [S i] -- ^ list of inputs to scan through
112112- -> [(RecTape x i o, S o)] -- ^ list of scan inputs and outputs
113113-scanlFrom !layer !recShape (x:xs) =
114114- let (tape, lerec, lepush) = runRecurrentForwards layer recShape x
115115- in (tape, lepush) : scanlFrom layer lerec xs
116116-scanlFrom _ _ [] = []
7676+updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys)
7777+ = (realToFrac (1 - learningRate * learningRegulariser) * x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys
11778118118-myscanbackward :: forall x i o. RecurrentLayer x i o
119119- => x -- ^ the layer
120120- -> [(RecTape x i o, S o)] -- ^ the list of inputs and output to scan over
121121- -> (S (RecurrentShape x), [(Gradient x, S i)]) -- ^ list of gradients to fold and inputs to backprop
122122-myscanbackward layer =
123123- goX 0
124124- where
125125- goX :: S (RecurrentShape x) -> [(RecTape x i o, S o)] -> (S (RecurrentShape x), [(Gradient x, S i)])
126126- goX !lastback ((recTape, backgrad):xs) =
127127- let (layergrad, recgrad, ingrad) = runRecurrentBackwards layer recTape lastback backgrad
128128- (pushedback, ll) = goX recgrad xs
129129- in (pushedback, (layergrad, ingrad) : ll)
130130- goX !lastback [] = (lastback, [])
7979+updateRecInputs _ RINil RINil
8080+ = RINil
1318113282-- | Just forwards propagation with no training.
13383runRecurrent :: RecurrentNetwork layers shapes
+4-1
src/Grenade/Recurrent/Layers/BasicRecurrent.hs
···11+{-# LANGUAGE CPP #-}
12{-# LANGUAGE DataKinds #-}
23{-# LANGUAGE GADTs #-}
34{-# LANGUAGE RecordWildCards #-}
···89{-# LANGUAGE UndecidableInstances #-}
9101011-- GHC 7.10 doesn't see recurrent run functions as total.
1212+#if __GLASGOW_HASKELL__ < 800
1113{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
1414+#endif
1215module Grenade.Recurrent.Layers.BasicRecurrent (
1316 BasicRecurrent (..)
1417 , randomBasicRecurrent
···2528import GHC.TypeLits
26292730import Grenade.Core
2828-import Grenade.Recurrent.Core.Network
3131+import Grenade.Recurrent.Core
29323033data BasicRecurrent :: Nat -- Input layer size
3134 -> Nat -- Output layer size
+5-2
src/Grenade/Recurrent/Layers/LSTM.hs
···11{-# LANGUAGE BangPatterns #-}
22+{-# LANGUAGE CPP #-}
23{-# LANGUAGE DataKinds #-}
34{-# LANGUAGE GADTs #-}
45{-# LANGUAGE RankNTypes #-}
···1112{-# LANGUAGE ScopedTypeVariables #-}
12131314-- GHC 7.10 doesn't see recurrent run functions as total.
1515+#if __GLASGOW_HASKELL__ < 800
1416{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
1717+#endif
1818+1519module Grenade.Recurrent.Layers.LSTM (
1620 LSTM (..)
1721 , LSTMWeights (..)
···2933import Numeric.LinearAlgebra.Static
30343135import Grenade.Core
3232-3636+import Grenade.Recurrent.Core
3337import Grenade.Layers.Internal.Update
34383535-import Grenade.Recurrent.Core.Network
36393740-- | Long Short Term Memory Recurrent unit
3841--
+5
test/Test/Grenade/Layers/Nonlinear.hs
···11{-# LANGUAGE BangPatterns #-}
22+{-# LANGUAGE CPP #-}
23{-# LANGUAGE TemplateHaskell #-}
34{-# LANGUAGE DataKinds #-}
45{-# LANGUAGE KindSignatures #-}
···910module Test.Grenade.Layers.Nonlinear where
10111112import Data.Singletons
1313+1414+#if __GLASGOW_HASKELL__ < 800
1515+import Data.Proxy
1616+#endif
12171318import Grenade
1419import GHC.TypeLits
+6
test/Test/Grenade/Layers/PadCrop.hs
···11{-# LANGUAGE BangPatterns #-}
22+{-# LANGUAGE CPP #-}
23{-# LANGUAGE TemplateHaskell #-}
34{-# LANGUAGE DataKinds #-}
45{-# LANGUAGE KindSignatures #-}
56{-# LANGUAGE GADTs #-}
67{-# LANGUAGE ScopedTypeVariables #-}
78{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
99+1010+#if __GLASGOW_HASKELL__ < 800
1111+{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
1212+#endif
1313+814module Test.Grenade.Layers.PadCrop where
9151016import Grenade
+10-4
test/Test/Jack/TypeLits.hs
···2222 Just n <- someNatVal <$> choose (1, 10)
2323 return n
24242525-genShape :: Jack (SomeSing Shape)
2525+#if __GLASGOW_HASKELL__ < 800
2626+type Shape' = ('KProxy :: KProxy Shape)
2727+#else
2828+type Shape' = Shape
2929+#endif
3030+3131+genShape :: Jack (SomeSing Shape')
2632genShape
2733 = oneOf [
2834 genD1
···3036 , genD3
3137 ]
32383333-genD1 :: Jack (SomeSing Shape)
3939+genD1 :: Jack (SomeSing Shape')
3440genD1 = do
3541 n <- genNat
3642 return $ case n of
3743 SomeNat (_ :: Proxy x) -> SomeSing (sing :: Sing ('D1 x))
38443939-genD2 :: Jack (SomeSing Shape)
4545+genD2 :: Jack (SomeSing Shape')
4046genD2 = do
4147 n <- genNat
4248 m <- genNat
4349 return $ case (n, m) of
4450 (SomeNat (_ :: Proxy x), SomeNat (_ :: Proxy y)) -> SomeSing (sing :: Sing ('D2 x y))
45514646-genD3 :: Jack (SomeSing Shape)
5252+genD3 :: Jack (SomeSing Shape')
4753genD3 = do
4854 n <- genNat
4955 m <- genNat