💣 Machine learning which might blow up in your face 💣

Merge pull request #23 from HuwCampbell/topic/recurrent-niceties

Topic/recurrent niceties

authored by

Huw Campbell and committed by
GitHub
7632c5eb 7e47a6ad

+327 -131
+2
grenade.cabal
··· 71 71 72 72 Grenade.Recurrent 73 73 74 + Grenade.Recurrent.Core 75 + Grenade.Recurrent.Core.Layer 74 76 Grenade.Recurrent.Core.Network 75 77 Grenade.Recurrent.Core.Runner 76 78
-1
main/recurrent.hs
··· 23 23 import Grenade 24 24 import Grenade.Recurrent 25 25 26 - {-# OPTIONS_GHC -fno-redundant-imports #-} 27 26 -- The defininition for our simple recurrent network. 28 27 -- This file just trains a network to generate a repeating sequence 29 28 -- of 0 0 1.
-1
src/Grenade/Core/Network.hs
··· 157 157 applyUpdate _ NNil GNil 158 158 = NNil 159 159 160 - 161 160 -- | A network can easily be created by hand with (:~>), but an easy way to 162 161 -- initialise a random network is with the randomNetwork. 163 162 class CreatableNetwork (xs :: [*]) (ss :: [Shape]) where
+2
src/Grenade/Core/Runner.hs
··· 42 42 (grads, _) = runGradient network tapes (output - target) 43 43 in grads 44 44 45 + 45 46 -- | Update a network with new weights after training with an instance. 46 47 train :: SingI (Last shapes) 47 48 => LearningParameters ··· 52 53 train rate network input output = 53 54 let grads = backPropagate network input output 54 55 in applyUpdate rate network grads 56 + 55 57 56 58 -- | Run the network with input and return the given output. 57 59 runNet :: Network layers shapes -> S (Head shapes) -> S (Last shapes)
+1
src/Grenade/Recurrent.hs
··· 2 2 module X 3 3 ) where 4 4 5 + import Grenade.Recurrent.Core.Layer as X 5 6 import Grenade.Recurrent.Core.Network as X 6 7 import Grenade.Recurrent.Core.Runner as X 7 8 import Grenade.Recurrent.Layers.BasicRecurrent as X
+6
src/Grenade/Recurrent/Core.hs
··· 1 + module Grenade.Recurrent.Core ( 2 + module X 3 + ) where 4 + 5 + import Grenade.Recurrent.Core.Layer as X 6 + import Grenade.Recurrent.Core.Network as X
+32
src/Grenade/Recurrent/Core/Layer.hs
··· 1 + {-# LANGUAGE DataKinds #-} 2 + {-# LANGUAGE TypeFamilies #-} 3 + {-# LANGUAGE MultiParamTypeClasses #-} 4 + {-# LANGUAGE FlexibleContexts #-} 5 + {-# LANGUAGE FlexibleInstances #-} 6 + module Grenade.Recurrent.Core.Layer ( 7 + RecurrentLayer (..) 8 + , RecurrentUpdateLayer (..) 9 + ) where 10 + 11 + import Data.Singletons ( SingI ) 12 + 13 + import Grenade.Core 14 + 15 + -- | Class for a recurrent layer. 16 + -- It's quite similar to a normal layer but for the input and output 17 + -- of an extra recurrent data shape. 18 + class UpdateLayer x => RecurrentUpdateLayer x where 19 + -- | Shape of data that is passed between each subsequent run of the layer 20 + type RecurrentShape x :: Shape 21 + 22 + class (RecurrentUpdateLayer x, SingI (RecurrentShape x)) => RecurrentLayer x (i :: Shape) (o :: Shape) where 23 + -- | Wengert Tape 24 + type RecTape x i o :: * 25 + -- | Used in training and scoring. Take the input from the previous 26 + -- layer, and give the output from this layer. 27 + runRecurrentForwards :: x -> S (RecurrentShape x) -> S i -> (RecTape x i o, S (RecurrentShape x), S o) 28 + -- | Back propagate a step. Takes the current layer, the input that the 29 + -- layer gave from the input and the back propagated derivatives from 30 + -- the layer above. 31 + -- Returns the gradient layer and the derivatives to push back further. 32 + runRecurrentBackwards :: x -> RecTape x i o -> S (RecurrentShape x) -> S o -> (Gradient x, S (RecurrentShape x), S i)
+214 -32
src/Grenade/Recurrent/Core/Network.hs
··· 1 + {-# LANGUAGE CPP #-} 1 2 {-# LANGUAGE DataKinds #-} 2 3 {-# LANGUAGE GADTs #-} 3 4 {-# LANGUAGE TypeOperators #-} ··· 6 7 {-# LANGUAGE FlexibleContexts #-} 7 8 {-# LANGUAGE FlexibleInstances #-} 8 9 {-# LANGUAGE EmptyDataDecls #-} 10 + {-# LANGUAGE RankNTypes #-} 11 + {-# LANGUAGE BangPatterns #-} 12 + {-# LANGUAGE ScopedTypeVariables #-} 13 + 14 + #if __GLASGOW_HASKELL__ < 800 15 + {-# OPTIONS_GHC -fno-warn-incomplete-patterns #-} 16 + #endif 17 + 9 18 module Grenade.Recurrent.Core.Network ( 10 19 Recurrent 11 20 , FeedForward 12 - , RecurrentLayer (..) 13 - , RecurrentUpdateLayer (..) 21 + 14 22 , RecurrentNetwork (..) 15 23 , RecurrentInputs (..) 16 - , CreatableRecurrent (..) 24 + , RecurrentTapes (..) 25 + , RecurrentGradients (..) 26 + 27 + , randomRecurrent 28 + , runRecurrentNetwork 29 + , runRecurrentGradient 30 + , applyRecurrentUpdate 17 31 ) where 18 32 19 33 20 34 import Control.Monad.Random ( MonadRandom ) 21 35 import Data.Singletons ( SingI ) 36 + import Data.Singletons.Prelude ( Head, Last ) 22 37 import Data.Serialize 23 38 import qualified Data.Vector.Storable as V 24 39 25 40 import Grenade.Core 41 + import Grenade.Recurrent.Core.Layer 26 42 27 43 import qualified Numeric.LinearAlgebra as LA 28 44 import qualified Numeric.LinearAlgebra.Static as LAS 29 45 30 - 31 46 -- | Witness type to say indicate we're building up with a normal feed 32 47 -- forward layer. 33 48 data FeedForward :: * -> * 34 49 -- | Witness type to say indicate we're building up with a recurrent layer. 35 50 data Recurrent :: * -> * 36 51 37 - -- | Class for a recurrent layer. 38 - -- It's quite similar to a normal layer but for the input and output 39 - -- of an extra recurrent data shape. 40 - class UpdateLayer x => RecurrentUpdateLayer x where 41 - -- | Shape of data that is passed between each subsequent run of the layer 42 - type RecurrentShape x :: Shape 43 - 44 - class (RecurrentUpdateLayer x, SingI (RecurrentShape x)) => RecurrentLayer x (i :: Shape) (o :: Shape) where 45 - -- | Wengert Tape 46 - type RecTape x i o :: * 47 - -- | Used in training and scoring. Take the input from the previous 48 - -- layer, and give the output from this layer. 49 - runRecurrentForwards :: x -> S (RecurrentShape x) -> S i -> (RecTape x i o, S (RecurrentShape x), S o) 50 - -- | Back propagate a step. Takes the current layer, the input that the 51 - -- layer gave from the input and the back propagated derivatives from 52 - -- the layer above. 53 - -- Returns the gradient layer and the derivatives to push back further. 54 - runRecurrentBackwards :: x -> RecTape x i o -> S (RecurrentShape x) -> S o -> (Gradient x, S (RecurrentShape x), S i) 55 - 52 + -- | Type of a recurrent neural network. 53 + -- 54 + -- The [*] type specifies the types of the layers. 55 + -- 56 + -- The [Shape] type specifies the shapes of data passed between the layers. 57 + -- 58 + -- The definition is similar to a Network, but every layer in the 59 + -- type is tagged by whether it's a FeedForward Layer of a Recurrent layer. 60 + -- 61 + -- Often, to make the definitions more concise, one will use a type alias 62 + -- for these empty data types. 56 63 data RecurrentNetwork :: [*] -> [Shape] -> * where 57 64 RNil :: SingI i 58 65 => RecurrentNetwork '[] '[i] ··· 69 76 infixr 5 :~~> 70 77 infixr 5 :~@> 71 78 72 - instance Show (RecurrentNetwork '[] '[i]) where 73 - show RNil = "NNil" 74 - instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (FeedForward x ': xs) (i ': rs)) where 75 - show (x :~~> xs) = show x ++ "\n~~>\n" ++ show xs 76 - instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (Recurrent x ': xs) (i ': rs)) where 77 - show (x :~@> xs) = show x ++ "\n~~>\n" ++ show xs 79 + -- | Gradient of a network. 80 + -- 81 + -- Parameterised on the layers of the network. 82 + data RecurrentGradients :: [*] -> * where 83 + RGNil :: RecurrentGradients '[] 84 + 85 + (://>) :: UpdateLayer x 86 + => [Gradient x] 87 + -> RecurrentGradients xs 88 + -> RecurrentGradients (phantom x ': xs) 78 89 79 90 -- | Recurrent inputs (sideways shapes on an imaginary unrolled graph) 80 91 -- Parameterised on the layers of a Network. 81 92 data RecurrentInputs :: [*] -> * where 82 93 RINil :: RecurrentInputs '[] 94 + 83 95 (:~~+>) :: UpdateLayer x 84 - => () -> !(RecurrentInputs xs) -> RecurrentInputs (FeedForward x ': xs) 96 + => () -> !(RecurrentInputs xs) -> RecurrentInputs (FeedForward x ': xs) 97 + 85 98 (:~@+>) :: (SingI (RecurrentShape x), RecurrentUpdateLayer x) 86 99 => !(S (RecurrentShape x)) -> !(RecurrentInputs xs) -> RecurrentInputs (Recurrent x ': xs) 87 - infixr 5 :~~+> 88 - infixr 5 :~@+> 100 + 101 + -- | All the information required to backpropogate 102 + -- through time safely. 103 + -- 104 + -- We index on the time step length as well, to ensure 105 + -- that that all Tape lengths are the same. 106 + data RecurrentTapes :: [*] -> [Shape] -> * where 107 + TRNil :: SingI i 108 + => RecurrentTapes '[] '[i] 109 + 110 + (:\~>) :: [Tape x i h] 111 + -> !(RecurrentTapes xs (h ': hs)) 112 + -> RecurrentTapes (FeedForward x ': xs) (i ': h ': hs) 113 + 114 + 115 + (:\@>) :: [RecTape x i h] 116 + -> !(RecurrentTapes xs (h ': hs)) 117 + -> RecurrentTapes (Recurrent x ': xs) (i ': h ': hs) 118 + 119 + 120 + runRecurrentNetwork :: forall shapes layers. 121 + RecurrentNetwork layers shapes 122 + -> RecurrentInputs layers 123 + -> [S (Head shapes)] 124 + -> (RecurrentTapes layers shapes, RecurrentInputs layers, [S (Last shapes)]) 125 + runRecurrentNetwork = 126 + go 127 + where 128 + go :: forall js sublayers. (Last js ~ Last shapes) 129 + => RecurrentNetwork sublayers js 130 + -> RecurrentInputs sublayers 131 + -> [S (Head js)] 132 + -> (RecurrentTapes sublayers js, RecurrentInputs sublayers, [S (Last js)]) 133 + -- This is a simple non-recurrent layer, just map it forwards 134 + go (layer :~~> n) (() :~~+> nIn) !xs 135 + = let tys = runForwards layer <$> xs 136 + feedForwardTapes = fst <$> tys 137 + forwards = snd <$> tys 138 + -- recursively run the rest of the network, and get the gradients from above. 139 + (newFN, ig, answer) = go n nIn forwards 140 + in (feedForwardTapes :\~> newFN, () :~~+> ig, answer) 141 + 142 + -- This is a recurrent layer, so we need to do a scan, first input to last, providing 143 + -- the recurrent shape output to the next layer. 144 + go (layer :~@> n) (recIn :~@+> nIn) !xs 145 + = let (recOut, tys) = goR layer recIn xs 146 + recurrentTapes = fst <$> tys 147 + forwards = snd <$> tys 148 + 149 + (newFN, ig, answer) = go n nIn forwards 150 + in (recurrentTapes :\@> newFN, recOut :~@+> ig, answer) 151 + 152 + -- Handle the output layer, bouncing the derivatives back down. 153 + -- We may not have a target for each example, so when we don't use 0 gradient. 154 + go RNil RINil !x 155 + = (TRNil, RINil, x) 156 + 157 + -- Helper function for recurrent layers 158 + -- Scans over the recurrent direction of the graph. 159 + goR !layer !recShape (x:xs) = 160 + let (tape, lerec, lepush) = runRecurrentForwards layer recShape x 161 + (rems, push) = goR layer lerec xs 162 + in (rems, (tape, lepush) : push) 163 + goR _ rin [] = (rin, []) 164 + 165 + runRecurrentGradient :: forall layers shapes. 166 + RecurrentNetwork layers shapes 167 + -> RecurrentTapes layers shapes 168 + -> RecurrentInputs layers 169 + -> [S (Last shapes)] 170 + -> (RecurrentGradients layers, RecurrentInputs layers, [S (Head shapes)]) 171 + runRecurrentGradient net tapes r o = 172 + go net tapes r 173 + where 174 + -- We have to be careful regarding the direction of the lists 175 + -- Inputs come in forwards, but our return value is backwards 176 + -- through time. 177 + go :: forall js ss. (Last js ~ Last shapes) 178 + => RecurrentNetwork ss js 179 + -> RecurrentTapes ss js 180 + -> RecurrentInputs ss 181 + -> (RecurrentGradients ss, RecurrentInputs ss, [S (Head js)]) 182 + -- This is a simple non-recurrent layer 183 + -- Run the rest of the network, then fmap the tapes and gradients 184 + go (layer :~~> n) (feedForwardTapes :\~> nTapes) (() :~~+> nRecs) = 185 + let (gradients, rins, feed) = go n nTapes nRecs 186 + backs = uncurry (runBackwards layer) <$> zip (reverse feedForwardTapes) feed 187 + in ((fst <$> backs) ://> gradients, () :~~+> rins, snd <$> backs) 188 + 189 + -- This is a recurrent layer 190 + -- Run the rest of the network, scan over the tapes in reverse 191 + go (layer :~@> n) (recurrentTapes :\@> nTapes) (recGrad :~@+> nRecs) = 192 + let (gradients, rins, feed) = go n nTapes nRecs 193 + backExamples = zip (reverse recurrentTapes) feed 194 + (rg, backs) = goX layer recGrad backExamples 195 + in ((fst <$> backs) ://> gradients, rg :~@+> rins, snd <$> backs) 196 + 197 + -- End of the road, so we reflect the given gradients backwards. 198 + -- Crucially, we reverse the list, so it's backwards in time as 199 + -- well. 200 + go RNil TRNil RINil 201 + = (RGNil, RINil, reverse o) 202 + 203 + -- Helper function for recurrent layers 204 + -- Scans over the recurrent direction of the graph. 205 + goX :: RecurrentLayer x i o => x -> S (RecurrentShape x) -> [(RecTape x i o, S o)] -> (S (RecurrentShape x), [(Gradient x, S i)]) 206 + goX layer !lastback ((recTape, backgrad):xs) = 207 + let (layergrad, recgrad, ingrad) = runRecurrentBackwards layer recTape lastback backgrad 208 + (pushedback, ll) = goX layer recgrad xs 209 + in (pushedback, (layergrad, ingrad) : ll) 210 + goX _ !lastback [] = (lastback, []) 211 + 212 + -- | Apply a batch of gradients to the network 213 + -- Uses runUpdates which can be specialised for 214 + -- a layer. 215 + applyRecurrentUpdate :: LearningParameters 216 + -> RecurrentNetwork layers shapes 217 + -> RecurrentGradients layers 218 + -> RecurrentNetwork layers shapes 219 + applyRecurrentUpdate rate (layer :~~> rest) (gradient ://> grest) 220 + = runUpdates rate layer gradient :~~> applyRecurrentUpdate rate rest grest 221 + 222 + applyRecurrentUpdate rate (layer :~@> rest) (gradient ://> grest) 223 + = runUpdates rate layer gradient :~@> applyRecurrentUpdate rate rest grest 224 + 225 + applyRecurrentUpdate _ RNil RGNil 226 + = RNil 227 + 228 + 229 + instance Show (RecurrentNetwork '[] '[i]) where 230 + show RNil = "NNil" 231 + instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (FeedForward x ': xs) (i ': rs)) where 232 + show (x :~~> xs) = show x ++ "\n~~>\n" ++ show xs 233 + instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (Recurrent x ': xs) (i ': rs)) where 234 + show (x :~@> xs) = show x ++ "\n~~>\n" ++ show xs 235 + 89 236 90 237 -- | A network can easily be created by hand with (:~~>) and (:~@>), but an easy way to initialise a random 91 238 -- recurrent network and a set of random inputs for it is with the randomRecurrent. ··· 144 291 Just i <- fromStorable . V.fromList <$> getListOf get 145 292 rest <- get 146 293 return ( i :~@+> rest) 294 + 295 + 296 + -- Num instance for `RecurrentInputs layers` 297 + -- Not sure if this is really needed, as I only need a `fromInteger 0` at 298 + -- the moment for training, to create a null gradient on the recurrent 299 + -- edge. 300 + -- 301 + -- It does raise an interesting question though? Is a 0 gradient actually 302 + -- the best? 303 + -- 304 + -- I could imaging that weakly push back towards the optimum input could 305 + -- help make a more stable generator. 306 + instance (Num (RecurrentInputs '[])) where 307 + (+) _ _ = RINil 308 + (-) _ _ = RINil 309 + (*) _ _ = RINil 310 + abs _ = RINil 311 + signum _ = RINil 312 + fromInteger _ = RINil 313 + 314 + instance (UpdateLayer x, Num (RecurrentInputs ys)) => (Num (RecurrentInputs (FeedForward x ': ys))) where 315 + (+) (() :~~+> x) (() :~~+> y) = () :~~+> (x + y) 316 + (-) (() :~~+> x) (() :~~+> y) = () :~~+> (x - y) 317 + (*) (() :~~+> x) (() :~~+> y) = () :~~+> (x * y) 318 + abs (() :~~+> x) = () :~~+> abs x 319 + signum (() :~~+> x) = () :~~+> signum x 320 + fromInteger x = () :~~+> fromInteger x 321 + 322 + instance (SingI (RecurrentShape x), RecurrentUpdateLayer x, Num (RecurrentInputs ys)) => (Num (RecurrentInputs (Recurrent x ': ys))) where 323 + (+) (x :~@+> x') (y :~@+> y') = (x + y) :~@+> (x' + y') 324 + (-) (x :~@+> x') (y :~@+> y') = (x - y) :~@+> (x' - y') 325 + (*) (x :~@+> x') (y :~@+> y') = (x * y) :~@+> (x' * y') 326 + abs (x :~@+> x') = abs x :~@+> abs x' 327 + signum (x :~@+> x') = signum x :~@+> signum x' 328 + fromInteger x = fromInteger x :~@+> fromInteger x
+40 -90
src/Grenade/Recurrent/Core/Runner.hs
··· 16 16 module Grenade.Recurrent.Core.Runner ( 17 17 trainRecurrent 18 18 , runRecurrent 19 + , backPropagateRecurrent 19 20 ) where 20 21 21 22 import Data.Singletons.Prelude 22 23 import Grenade.Core 23 24 25 + import Grenade.Recurrent.Core.Layer 24 26 import Grenade.Recurrent.Core.Network 25 27 26 28 -- | Drive and network and collect its back propogated gradients. 27 - -- 28 - -- TODO: split this nicely into backpropagate and update. 29 - -- 30 - -- QUESTION: Should we return a list of gradients or the sum of 31 - -- the gradients? It's different taking into account 32 - -- momentum and L2. 33 - trainRecurrent :: forall shapes layers. SingI (Last shapes) 34 - => LearningParameters 35 - -> RecurrentNetwork layers shapes 36 - -> RecurrentInputs layers 37 - -> [(S (Head shapes), Maybe (S (Last shapes)))] 38 - -> (RecurrentNetwork layers shapes, RecurrentInputs layers) 39 - trainRecurrent rate network recinputs examples = 40 - updateBack $ go inputs network recinputs 41 - where 42 - inputs = fst <$> examples 43 - targets = snd <$> examples 44 - updateBack (a,recgrad,_) = (a,updateRecInputs rate recinputs recgrad) 29 + backPropagateRecurrent :: forall shapes layers. (SingI (Last shapes), Num (RecurrentInputs layers)) 30 + => RecurrentNetwork layers shapes 31 + -> RecurrentInputs layers 32 + -> [(S (Head shapes), Maybe (S (Last shapes)))] 33 + -> (RecurrentGradients layers, RecurrentInputs layers) 34 + backPropagateRecurrent network recinputs examples = 35 + let (tapes, _, guesses) = runRecurrentNetwork network recinputs inputs 45 36 46 - go :: forall js sublayers. (Last js ~ Last shapes) 47 - => [S (Head js)] -- ^ input vector 48 - -> RecurrentNetwork sublayers js -- ^ network to train 49 - -> RecurrentInputs sublayers 50 - -> (RecurrentNetwork sublayers js, RecurrentInputs sublayers, [S (Head js)]) 37 + backPropagations = zipWith makeError guesses targets 51 38 52 - -- This is a simple non-recurrent layer, just map it forwards 53 - -- Note we're doing training here, we could just return a list of gradients 54 - -- (and probably will in future). 55 - go !xs (layer :~~> n) (() :~~+> nIn) 56 - = let tys = runForwards layer <$> xs 57 - tapes = fst <$> tys 58 - ys = snd <$> tys 59 - -- recursively run the rest of the network, and get the gradients from above. 60 - (newFN, ig, grads) = go ys n nIn 61 - -- calculate the gradient for this layer to pass down, 62 - back = uncurry (runBackwards layer) <$> zip (reverse tapes) grads 63 - -- the new trained layer. 64 - newlayer = runUpdates rate layer (fst <$> back) 39 + (gradients, input', _) = runRecurrentGradient network tapes 0 backPropagations 65 40 66 - in (newlayer :~~> newFN, () :~~+> ig, snd <$> back) 41 + in (gradients, input') 67 42 68 - -- This is a recurrent layer, so we need to do a scan, first input to last, providing 69 - -- the recurrent shape output to the next layer. 70 - go !xs (layer :~@> n) (g :~@+> nIn) 71 - = let tys = scanlFrom layer g xs 72 - tapes = fst <$> tys 73 - ys = snd <$> tys 43 + where 74 44 75 - (newFN, ig, grads) = go ys n nIn 45 + inputs = fst <$> examples 46 + targets = snd <$> examples 76 47 77 - backExamples = zip (reverse tapes) grads 48 + makeError :: S (Last shapes) -> Maybe (S (Last shapes)) -> S (Last shapes) 49 + makeError _ Nothing = 0 50 + makeError y (Just t) = y - t 78 51 79 - (rg, back) = myscanbackward layer backExamples 80 - -- the new trained layer. 81 - newlayer = runUpdates rate layer (fst <$> back) 82 - in (newlayer :~@> newFN, rg :~@+> ig, snd <$> back) 52 + 53 + trainRecurrent :: forall shapes layers. (SingI (Last shapes), Num (RecurrentInputs layers)) 54 + => LearningParameters 55 + -> RecurrentNetwork layers shapes 56 + -> RecurrentInputs layers 57 + -> [(S (Head shapes), Maybe (S (Last shapes)))] 58 + -> (RecurrentNetwork layers shapes, RecurrentInputs layers) 59 + trainRecurrent rate network recinputs examples = 60 + let (gradients, recinputs') = backPropagateRecurrent network recinputs examples 83 61 84 - -- Handle the output layer, bouncing the derivatives back down. 85 - -- We may not have a target for each example, so when we don't use 0 gradient. 86 - go !xs RNil RINil 87 - = (RNil, RINil, reverse (zipWith makeError xs targets)) 88 - where 89 - makeError :: S (Last shapes) -> Maybe (S (Last shapes)) -> S (Last shapes) 90 - makeError _ Nothing = 0 91 - makeError y (Just t) = y - t 62 + newInputs = updateRecInputs rate recinputs recinputs' 92 63 93 - updateRecInputs :: forall sublayers. 94 - LearningParameters 95 - -> RecurrentInputs sublayers 96 - -> RecurrentInputs sublayers 97 - -> RecurrentInputs sublayers 64 + newNetwork = applyRecurrentUpdate rate network gradients 98 65 99 - updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys) 100 - = () :~~+> updateRecInputs l xs ys 66 + in (newNetwork, newInputs) 101 67 102 - updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys) 103 - = (realToFrac (learningRate * learningRegulariser) * x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys 68 + updateRecInputs :: LearningParameters 69 + -> RecurrentInputs sublayers 70 + -> RecurrentInputs sublayers 71 + -> RecurrentInputs sublayers 104 72 105 - updateRecInputs _ RINil RINil 106 - = RINil 73 + updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys) 74 + = () :~~+> updateRecInputs l xs ys 107 75 108 - scanlFrom :: forall x i o. RecurrentLayer x i o 109 - => x -- ^ the layer 110 - -> S (RecurrentShape x) -- ^ place to start 111 - -> [S i] -- ^ list of inputs to scan through 112 - -> [(RecTape x i o, S o)] -- ^ list of scan inputs and outputs 113 - scanlFrom !layer !recShape (x:xs) = 114 - let (tape, lerec, lepush) = runRecurrentForwards layer recShape x 115 - in (tape, lepush) : scanlFrom layer lerec xs 116 - scanlFrom _ _ [] = [] 76 + updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys) 77 + = (realToFrac (1 - learningRate * learningRegulariser) * x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys 117 78 118 - myscanbackward :: forall x i o. RecurrentLayer x i o 119 - => x -- ^ the layer 120 - -> [(RecTape x i o, S o)] -- ^ the list of inputs and output to scan over 121 - -> (S (RecurrentShape x), [(Gradient x, S i)]) -- ^ list of gradients to fold and inputs to backprop 122 - myscanbackward layer = 123 - goX 0 124 - where 125 - goX :: S (RecurrentShape x) -> [(RecTape x i o, S o)] -> (S (RecurrentShape x), [(Gradient x, S i)]) 126 - goX !lastback ((recTape, backgrad):xs) = 127 - let (layergrad, recgrad, ingrad) = runRecurrentBackwards layer recTape lastback backgrad 128 - (pushedback, ll) = goX recgrad xs 129 - in (pushedback, (layergrad, ingrad) : ll) 130 - goX !lastback [] = (lastback, []) 79 + updateRecInputs _ RINil RINil 80 + = RINil 131 81 132 82 -- | Just forwards propagation with no training. 133 83 runRecurrent :: RecurrentNetwork layers shapes
+4 -1
src/Grenade/Recurrent/Layers/BasicRecurrent.hs
··· 1 + {-# LANGUAGE CPP #-} 1 2 {-# LANGUAGE DataKinds #-} 2 3 {-# LANGUAGE GADTs #-} 3 4 {-# LANGUAGE RecordWildCards #-} ··· 8 9 {-# LANGUAGE UndecidableInstances #-} 9 10 10 11 -- GHC 7.10 doesn't see recurrent run functions as total. 12 + #if __GLASGOW_HASKELL__ < 800 11 13 {-# OPTIONS_GHC -fno-warn-incomplete-patterns #-} 14 + #endif 12 15 module Grenade.Recurrent.Layers.BasicRecurrent ( 13 16 BasicRecurrent (..) 14 17 , randomBasicRecurrent ··· 25 28 import GHC.TypeLits 26 29 27 30 import Grenade.Core 28 - import Grenade.Recurrent.Core.Network 31 + import Grenade.Recurrent.Core 29 32 30 33 data BasicRecurrent :: Nat -- Input layer size 31 34 -> Nat -- Output layer size
+5 -2
src/Grenade/Recurrent/Layers/LSTM.hs
··· 1 1 {-# LANGUAGE BangPatterns #-} 2 + {-# LANGUAGE CPP #-} 2 3 {-# LANGUAGE DataKinds #-} 3 4 {-# LANGUAGE GADTs #-} 4 5 {-# LANGUAGE RankNTypes #-} ··· 11 12 {-# LANGUAGE ScopedTypeVariables #-} 12 13 13 14 -- GHC 7.10 doesn't see recurrent run functions as total. 15 + #if __GLASGOW_HASKELL__ < 800 14 16 {-# OPTIONS_GHC -fno-warn-incomplete-patterns #-} 17 + #endif 18 + 15 19 module Grenade.Recurrent.Layers.LSTM ( 16 20 LSTM (..) 17 21 , LSTMWeights (..) ··· 29 33 import Numeric.LinearAlgebra.Static 30 34 31 35 import Grenade.Core 32 - 36 + import Grenade.Recurrent.Core 33 37 import Grenade.Layers.Internal.Update 34 38 35 - import Grenade.Recurrent.Core.Network 36 39 37 40 -- | Long Short Term Memory Recurrent unit 38 41 --
+5
test/Test/Grenade/Layers/Nonlinear.hs
··· 1 1 {-# LANGUAGE BangPatterns #-} 2 + {-# LANGUAGE CPP #-} 2 3 {-# LANGUAGE TemplateHaskell #-} 3 4 {-# LANGUAGE DataKinds #-} 4 5 {-# LANGUAGE KindSignatures #-} ··· 9 10 module Test.Grenade.Layers.Nonlinear where 10 11 11 12 import Data.Singletons 13 + 14 + #if __GLASGOW_HASKELL__ < 800 15 + import Data.Proxy 16 + #endif 12 17 13 18 import Grenade 14 19 import GHC.TypeLits
+6
test/Test/Grenade/Layers/PadCrop.hs
··· 1 1 {-# LANGUAGE BangPatterns #-} 2 + {-# LANGUAGE CPP #-} 2 3 {-# LANGUAGE TemplateHaskell #-} 3 4 {-# LANGUAGE DataKinds #-} 4 5 {-# LANGUAGE KindSignatures #-} 5 6 {-# LANGUAGE GADTs #-} 6 7 {-# LANGUAGE ScopedTypeVariables #-} 7 8 {-# OPTIONS_GHC -fno-warn-missing-signatures #-} 9 + 10 + #if __GLASGOW_HASKELL__ < 800 11 + {-# OPTIONS_GHC -fno-warn-incomplete-patterns #-} 12 + #endif 13 + 8 14 module Test.Grenade.Layers.PadCrop where 9 15 10 16 import Grenade
+10 -4
test/Test/Jack/TypeLits.hs
··· 22 22 Just n <- someNatVal <$> choose (1, 10) 23 23 return n 24 24 25 - genShape :: Jack (SomeSing Shape) 25 + #if __GLASGOW_HASKELL__ < 800 26 + type Shape' = ('KProxy :: KProxy Shape) 27 + #else 28 + type Shape' = Shape 29 + #endif 30 + 31 + genShape :: Jack (SomeSing Shape') 26 32 genShape 27 33 = oneOf [ 28 34 genD1 ··· 30 36 , genD3 31 37 ] 32 38 33 - genD1 :: Jack (SomeSing Shape) 39 + genD1 :: Jack (SomeSing Shape') 34 40 genD1 = do 35 41 n <- genNat 36 42 return $ case n of 37 43 SomeNat (_ :: Proxy x) -> SomeSing (sing :: Sing ('D1 x)) 38 44 39 - genD2 :: Jack (SomeSing Shape) 45 + genD2 :: Jack (SomeSing Shape') 40 46 genD2 = do 41 47 n <- genNat 42 48 m <- genNat 43 49 return $ case (n, m) of 44 50 (SomeNat (_ :: Proxy x), SomeNat (_ :: Proxy y)) -> SomeSing (sing :: Sing ('D2 x y)) 45 51 46 - genD3 :: Jack (SomeSing Shape) 52 + genD3 :: Jack (SomeSing Shape') 47 53 genD3 = do 48 54 n <- genNat 49 55 m <- genNat