···17 Grenade provides an API for composing layers of a neural network
18 into a sequence parallel graph in a type safe manner; running
19 networks with reverse automatic differentiation to calculate their
20- gradients; and applying gradient decent for learning.
21 .
22 Documentation and examples are available on github
23 <https://github.com/HuwCampbell/grenade>.
···26 README.md
27 cbits/im2col.h
28 cbits/im2col.c
29- cbits/gradient_decent.h
30- cbits/gradient_decent.c
31 cbits/pad.h
32 cbits/pad.c
33···108 Grenade.Utils.OneHot
109110 includes: cbits/im2col.h
111- cbits/gradient_decent.h
112 cbits/pad.h
113 c-sources: cbits/im2col.c
114- cbits/gradient_decent.c
115 cbits/pad.c
116117 cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1
···17 Grenade provides an API for composing layers of a neural network
18 into a sequence parallel graph in a type safe manner; running
19 networks with reverse automatic differentiation to calculate their
20+ gradients; and applying gradient descent for learning.
21 .
22 Documentation and examples are available on github
23 <https://github.com/HuwCampbell/grenade>.
···26 README.md
27 cbits/im2col.h
28 cbits/im2col.c
29+ cbits/gradient_descent.h
30+ cbits/gradient_descent.c
31 cbits/pad.h
32 cbits/pad.c
33···108 Grenade.Utils.OneHot
109110 includes: cbits/im2col.h
111+ cbits/gradient_descent.h
112 cbits/pad.h
113 c-sources: cbits/im2col.c
114+ cbits/gradient_descent.c
115 cbits/pad.c
116117 cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1
+1-1
src/Grenade/Core/Network.hs
···139 = (GNil, o)
140141142--- | Apply one step of stochastic gradient decent across the network.
143applyUpdate :: LearningParameters
144 -> Network layers shapes
145 -> Gradients layers
···139 = (GNil, o)
140141142+-- | Apply one step of stochastic gradient descent across the network.
143applyUpdate :: LearningParameters
144 -> Network layers shapes
145 -> Gradients layers
···1{-# LANGUAGE ForeignFunctionInterface #-}
2module Grenade.Layers.Internal.Update (
3- decendMatrix
4- , decendVector
5 ) where
67import Data.Maybe ( fromJust )
···1718import System.IO.Unsafe ( unsafePerformIO )
1920-decendMatrix :: (KnownNat rows, KnownNat columns) => Double -> Double -> Double -> L rows columns -> L rows columns -> L rows columns -> (L rows columns, L rows columns)
21-decendMatrix rate momentum regulariser weights gradient lastUpdate =
22 let (rows, cols) = size weights
23 len = rows * cols
24 -- Most gradients come in in ColumnMajor,
···29 weights' = flatten . tr . extract $ weights
30 gradient' = flatten . tr . extract $ gradient
31 lastUpdate' = flatten . tr . extract $ lastUpdate
32- (vw, vm) = decendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'
3334 -- Note that it's ColumnMajor, as we did a transpose before
35 -- using the internal vectors.
···37 mm = U.matrixFromVector U.ColumnMajor rows cols vm
38 in (fromJust . create $ mw, fromJust . create $ mm)
3940-decendVector :: (KnownNat r) => Double -> Double -> Double -> R r -> R r -> R r -> (R r, R r)
41-decendVector rate momentum regulariser weights gradient lastUpdate =
42 let len = size weights
43 weights' = extract weights
44 gradient' = extract gradient
45 lastUpdate' = extract lastUpdate
46- (vw, vm) = decendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'
47 in (fromJust $ create vw, fromJust $ create vm)
4849-decendUnsafe :: Int -> Double -> Double -> Double -> Vector Double -> Vector Double -> Vector Double -> (Vector Double, Vector Double)
50-decendUnsafe len rate momentum regulariser weights gradient lastUpdate =
51 unsafePerformIO $ do
52 outWPtr <- mallocForeignPtrArray len
53 outMPtr <- mallocForeignPtrArray len
···60 withForeignPtr lPtr $ \lPtr' ->
61 withForeignPtr outWPtr $ \outWPtr' ->
62 withForeignPtr outMPtr $ \outMPtr' ->
63- decend_cpu len rate momentum regulariser wPtr' gPtr' lPtr' outWPtr' outMPtr'
6465 return (U.unsafeFromForeignPtr0 outWPtr len, U.unsafeFromForeignPtr0 outMPtr len)
6667foreign import ccall unsafe
68- decend_cpu
69 :: Int -> Double -> Double -> Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> IO ()
70
···1{-# LANGUAGE ForeignFunctionInterface #-}
2module Grenade.Layers.Internal.Update (
3+ descendMatrix
4+ , descendVector
5 ) where
67import Data.Maybe ( fromJust )
···1718import System.IO.Unsafe ( unsafePerformIO )
1920+descendMatrix :: (KnownNat rows, KnownNat columns) => Double -> Double -> Double -> L rows columns -> L rows columns -> L rows columns -> (L rows columns, L rows columns)
21+descendMatrix rate momentum regulariser weights gradient lastUpdate =
22 let (rows, cols) = size weights
23 len = rows * cols
24 -- Most gradients come in in ColumnMajor,
···29 weights' = flatten . tr . extract $ weights
30 gradient' = flatten . tr . extract $ gradient
31 lastUpdate' = flatten . tr . extract $ lastUpdate
32+ (vw, vm) = descendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'
3334 -- Note that it's ColumnMajor, as we did a transpose before
35 -- using the internal vectors.
···37 mm = U.matrixFromVector U.ColumnMajor rows cols vm
38 in (fromJust . create $ mw, fromJust . create $ mm)
3940+descendVector :: (KnownNat r) => Double -> Double -> Double -> R r -> R r -> R r -> (R r, R r)
41+descendVector rate momentum regulariser weights gradient lastUpdate =
42 let len = size weights
43 weights' = extract weights
44 gradient' = extract gradient
45 lastUpdate' = extract lastUpdate
46+ (vw, vm) = descendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'
47 in (fromJust $ create vw, fromJust $ create vm)
4849+descendUnsafe :: Int -> Double -> Double -> Double -> Vector Double -> Vector Double -> Vector Double -> (Vector Double, Vector Double)
50+descendUnsafe len rate momentum regulariser weights gradient lastUpdate =
51 unsafePerformIO $ do
52 outWPtr <- mallocForeignPtrArray len
53 outMPtr <- mallocForeignPtrArray len
···60 withForeignPtr lPtr $ \lPtr' ->
61 withForeignPtr outWPtr $ \outWPtr' ->
62 withForeignPtr outMPtr $ \outMPtr' ->
63+ descend_cpu len rate momentum regulariser wPtr' gPtr' lPtr' outWPtr' outMPtr'
6465 return (U.unsafeFromForeignPtr0 outWPtr len, U.unsafeFromForeignPtr0 outMPtr len)
6667foreign import ccall unsafe
68+ descend_cpu
69 :: Int -> Double -> Double -> Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> IO ()
70
+2-2
src/Grenade/Recurrent/Layers/LSTM.hs
···87 -- Utility function for updating with the momentum, gradients, and weights.
88 u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix))
89 u e (e -> weights) (e -> momentum) (e -> gradient) =
90- decendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum
9192 v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix))
93 v e (e -> weights) (e -> momentum) (e -> gradient) =
94- decendVector learningRate learningMomentum learningRegulariser weights gradient momentum
9596 -- There's a lot of updates here, so to try and minimise the number of data copies
97 -- we'll create a mutable bucket for each.
···87 -- Utility function for updating with the momentum, gradients, and weights.
88 u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix))
89 u e (e -> weights) (e -> momentum) (e -> gradient) =
90+ descendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum
9192 v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix))
93 v e (e -> weights) (e -> momentum) (e -> gradient) =
94+ descendVector learningRate learningMomentum learningRegulariser weights gradient momentum
9596 -- There's a lot of updates here, so to try and minimise the number of data copies
97 -- we'll create a mutable bucket for each.