diff options
Diffstat (limited to 'src/Example.hs')
-rw-r--r-- | src/Example.hs | 32 |
1 files changed, 22 insertions, 10 deletions
diff --git a/src/Example.hs b/src/Example.hs index e2f1be9..6701e38 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -11,10 +11,14 @@ import AST import AST.Pretty import CHAD import Data +import ForwardAD import Interpreter import Language import Simplify +import Debug.Trace +import Example.Format + -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) @@ -175,18 +179,26 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda # let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $ #x3 ! nil -neuralGo :: (Float - ,(((((), Either () (Array N2 Float, Array N1 Float)) - ,Either () (Array N2 Float, Array N1 Float)) - ,Array N1 Float) - ,Array N1 Float)) +type NeuralGrad = ((Array N2 Float, Array N1 Float) + ,(Array N2 Float, Array N1 Float) + ,Array N1 Float + ,Array N1 Float) + +neuralGo :: (Float -- primal + ,NeuralGrad -- gradient using CHAD + ,NeuralGrad) -- gradient using dual-numbers forward AD neuralGo = let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) lay3 = arrayFromList (ShNil `ShCons` 2) [1,1] input = arrayFromList (ShNil `ShCons` 2) [1,1] - in interpretOpen (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) $ - simplifyN 20 $ - freezeRet mergeDescr - (drev mergeDescr neural) - (EConst ext STF32 1.0) + argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) + revderiv = + simplifyN 20 $ + freezeRet mergeDescr + (drev mergeDescr neural) + (EConst ext STF32 1.0) + (primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen argument revderiv + (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0 + in trace (formatter (ppExpr knownEnv revderiv)) $ + (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) |