···1717 Grenade provides an API for composing layers of a neural network
1818 into a sequence parallel graph in a type safe manner; running
1919 networks with reverse automatic differentiation to calculate their
2020- gradients; and applying gradient decent for learning.
2020+ gradients; and applying gradient descent for learning.
2121 .
2222 Documentation and examples are available on github
2323 <https://github.com/HuwCampbell/grenade>.
···2626 README.md
2727 cbits/im2col.h
2828 cbits/im2col.c
2929- cbits/gradient_decent.h
3030- cbits/gradient_decent.c
2929+ cbits/gradient_descent.h
3030+ cbits/gradient_descent.c
3131 cbits/pad.h
3232 cbits/pad.c
3333···108108 Grenade.Utils.OneHot
109109110110 includes: cbits/im2col.h
111111- cbits/gradient_decent.h
111111+ cbits/gradient_descent.h
112112 cbits/pad.h
113113 c-sources: cbits/im2col.c
114114- cbits/gradient_decent.c
114114+ cbits/gradient_descent.c
115115 cbits/pad.c
116116117117 cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1
+1-1
src/Grenade/Core/Network.hs
···139139 = (GNil, o)
140140141141142142--- | Apply one step of stochastic gradient decent across the network.
142142+-- | Apply one step of stochastic gradient descent across the network.
143143applyUpdate :: LearningParameters
144144 -> Network layers shapes
145145 -> Gradients layers
···11{-# LANGUAGE ForeignFunctionInterface #-}
22module Grenade.Layers.Internal.Update (
33- decendMatrix
44- , decendVector
33+ descendMatrix
44+ , descendVector
55 ) where
6677import Data.Maybe ( fromJust )
···17171818import System.IO.Unsafe ( unsafePerformIO )
19192020-decendMatrix :: (KnownNat rows, KnownNat columns) => Double -> Double -> Double -> L rows columns -> L rows columns -> L rows columns -> (L rows columns, L rows columns)
2121-decendMatrix rate momentum regulariser weights gradient lastUpdate =
2020+descendMatrix :: (KnownNat rows, KnownNat columns) => Double -> Double -> Double -> L rows columns -> L rows columns -> L rows columns -> (L rows columns, L rows columns)
2121+descendMatrix rate momentum regulariser weights gradient lastUpdate =
2222 let (rows, cols) = size weights
2323 len = rows * cols
2424 -- Most gradients come in in ColumnMajor,
···2929 weights' = flatten . tr . extract $ weights
3030 gradient' = flatten . tr . extract $ gradient
3131 lastUpdate' = flatten . tr . extract $ lastUpdate
3232- (vw, vm) = decendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'
3232+ (vw, vm) = descendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'
33333434 -- Note that it's ColumnMajor, as we did a transpose before
3535 -- using the internal vectors.
···3737 mm = U.matrixFromVector U.ColumnMajor rows cols vm
3838 in (fromJust . create $ mw, fromJust . create $ mm)
39394040-decendVector :: (KnownNat r) => Double -> Double -> Double -> R r -> R r -> R r -> (R r, R r)
4141-decendVector rate momentum regulariser weights gradient lastUpdate =
4040+descendVector :: (KnownNat r) => Double -> Double -> Double -> R r -> R r -> R r -> (R r, R r)
4141+descendVector rate momentum regulariser weights gradient lastUpdate =
4242 let len = size weights
4343 weights' = extract weights
4444 gradient' = extract gradient
4545 lastUpdate' = extract lastUpdate
4646- (vw, vm) = decendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'
4646+ (vw, vm) = descendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'
4747 in (fromJust $ create vw, fromJust $ create vm)
48484949-decendUnsafe :: Int -> Double -> Double -> Double -> Vector Double -> Vector Double -> Vector Double -> (Vector Double, Vector Double)
5050-decendUnsafe len rate momentum regulariser weights gradient lastUpdate =
4949+descendUnsafe :: Int -> Double -> Double -> Double -> Vector Double -> Vector Double -> Vector Double -> (Vector Double, Vector Double)
5050+descendUnsafe len rate momentum regulariser weights gradient lastUpdate =
5151 unsafePerformIO $ do
5252 outWPtr <- mallocForeignPtrArray len
5353 outMPtr <- mallocForeignPtrArray len
···6060 withForeignPtr lPtr $ \lPtr' ->
6161 withForeignPtr outWPtr $ \outWPtr' ->
6262 withForeignPtr outMPtr $ \outMPtr' ->
6363- decend_cpu len rate momentum regulariser wPtr' gPtr' lPtr' outWPtr' outMPtr'
6363+ descend_cpu len rate momentum regulariser wPtr' gPtr' lPtr' outWPtr' outMPtr'
64646565 return (U.unsafeFromForeignPtr0 outWPtr len, U.unsafeFromForeignPtr0 outMPtr len)
66666767foreign import ccall unsafe
6868- decend_cpu
6868+ descend_cpu
6969 :: Int -> Double -> Double -> Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> IO ()
7070
+2-2
src/Grenade/Recurrent/Layers/LSTM.hs
···8787 -- Utility function for updating with the momentum, gradients, and weights.
8888 u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix))
8989 u e (e -> weights) (e -> momentum) (e -> gradient) =
9090- decendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum
9090+ descendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum
91919292 v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix))
9393 v e (e -> weights) (e -> momentum) (e -> gradient) =
9494- decendVector learningRate learningMomentum learningRegulariser weights gradient momentum
9494+ descendVector learningRate learningMomentum learningRegulariser weights gradient momentum
95959696 -- There's a lot of updates here, so to try and minimise the number of data copies
9797 -- we'll create a mutable bucket for each.