summaryrefslogtreecommitdiff
path: root/src/Example.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-07 23:11:36 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-07 23:11:36 +0100
commit92ddb2263ae495c229badcc209c76a1252bd2752 (patch)
treed69059d755a04121db23406050a643bf33c5b764 /src/Example.hs
parent401e74939fe2a717852acc4b7a452b222d82274a (diff)
Benchmark
Diffstat (limited to 'src/Example.hs')
-rw-r--r--src/Example.hs16
1 files changed, 3 insertions, 13 deletions
diff --git a/src/Example.hs b/src/Example.hs
index d0405af..1775bb9 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -11,6 +11,7 @@ import Array
import AST
import AST.Pretty
import CHAD
+import CHAD.Top
import Data
import ForwardAD
import Interpreter
@@ -23,16 +24,6 @@ import Example.Format
-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0)
-type family MergeEnv env where
- MergeEnv '[] = '[]
- MergeEnv (t : ts) = "merge" : MergeEnv ts
-
-mergeDescr :: KnownEnv env => Descr env (MergeEnv env)
-mergeDescr = go knownEnv
- where go :: SList STy env -> Descr env (MergeEnv env)
- go SNil = DTop
- go (t `SCons` env) = go env `DPush` (t, SMerge)
-
bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c
bin op a b = EOp ext op (EPair ext a b)
@@ -195,9 +186,8 @@ neuralGo =
argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil)
revderiv =
simplifyN 20 $
- freezeRet mergeDescr
- (drev mergeDescr neural)
- (EConst ext STF64 1.0)
+ ELet ext (EConst ext STF64 1.0) $
+ chad knownEnv neural
(primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen False argument revderiv
(Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0
in trace (formatter (ppExpr knownEnv revderiv)) $