summaryrefslogtreecommitdiff
path: root/src/Example.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Example.hs')
-rw-r--r--src/Example.hs32
1 files changed, 22 insertions, 10 deletions
diff --git a/src/Example.hs b/src/Example.hs
index e2f1be9..6701e38 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -11,10 +11,14 @@ import AST
import AST.Pretty
import CHAD
import Data
+import ForwardAD
import Interpreter
import Language
import Simplify
+import Debug.Trace
+import Example.Format
+
-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0)
@@ -175,18 +179,26 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #
let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
#x3 ! nil
-neuralGo :: (Float
- ,(((((), Either () (Array N2 Float, Array N1 Float))
- ,Either () (Array N2 Float, Array N1 Float))
- ,Array N1 Float)
- ,Array N1 Float))
+type NeuralGrad = ((Array N2 Float, Array N1 Float)
+ ,(Array N2 Float, Array N1 Float)
+ ,Array N1 Float
+ ,Array N1 Float)
+
+neuralGo :: (Float -- primal
+ ,NeuralGrad -- gradient using CHAD
+ ,NeuralGrad) -- gradient using dual-numbers forward AD
neuralGo =
let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
lay3 = arrayFromList (ShNil `ShCons` 2) [1,1]
input = arrayFromList (ShNil `ShCons` 2) [1,1]
- in interpretOpen (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) $
- simplifyN 20 $
- freezeRet mergeDescr
- (drev mergeDescr neural)
- (EConst ext STF32 1.0)
+ argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil)
+ revderiv =
+ simplifyN 20 $
+ freezeRet mergeDescr
+ (drev mergeDescr neural)
+ (EConst ext STF32 1.0)
+ (primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen argument revderiv
+ (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0
+ in trace (formatter (ppExpr knownEnv revderiv)) $
+ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2))