aboutsummaryrefslogtreecommitdiff
path: root/src/Example.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Example.hs')
-rw-r--r--src/Example.hs32
1 files changed, 5 insertions, 27 deletions
diff --git a/src/Example.hs b/src/Example.hs
index 2c710a1..b320ead 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -5,11 +5,14 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
+
+{-# OPTIONS -Wno-unused-imports #-}
module Example where
import Array
import AST
import AST.Pretty
+import AST.UnMonoid
import CHAD
import CHAD.Top
import ForwardAD
@@ -30,11 +33,6 @@ bin op a b = EOp ext op (EPair ext a b)
senv1 :: SList STy [TScal TF32, TScal TF32]
senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil
-descr1 :: Storage a -> Storage b
- -> Descr [TScal TF32, TScal TF32] [b, a]
-descr1 a b = DTop `DPush` (t, a) `DPush` (t, b)
- where t = STScal STF32
-
-- x y |- x * y + x
--
-- let x3 = (x1, x2)
@@ -82,25 +80,12 @@ ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
ex4 = fromNamed $ lambda #x $ lambda #y $ body $
if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x)
-senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)]
-senv5 = knownEnv
-
-descr5 :: Storage a -> Storage b
- -> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a]
-descr5 a b = DTop `DPush` (knownTy, a) `DPush` (knownTy, b)
-
-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)}
ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)
ex5 = fromNamed $ lambda #x $ lambda #y $ body $
case_ #x (#a :-> #a * #y)
(#b :-> #b * (#y + 1))
-senv6 :: SList STy [TScal TI64, TScal TF32]
-senv6 = knownEnv
-
-descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"]
-descr6 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge)
-
-- x:R n:I |- let a = unit x
-- b = build1 n (\i. let c = idx0 a in c * c)
-- in idx0 (b ! 3)
@@ -110,12 +95,6 @@ ex6 = fromNamed $ lambda #x $ lambda #n $ body $
let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $
#b ! pair nil 3
-senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)]
-senv7 = knownEnv
-
-descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"]
-descr7 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge)
-
-- A "neural network" except it's just scalars, not matrices.
-- ps:((((), (R,R)), (R,R)), (R,R)) x:R
-- |- let p1 = snd ps
@@ -182,9 +161,8 @@ neuralGo =
simplifyN 20 $
ELet ext (EConst ext STF64 1.0) $
chad defaultConfig knownEnv neural
- (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of
- (primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1')
- _ -> undefined
+ (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of
+ (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1')
(Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0
in trace (ppExpr knownEnv revderiv) $
(primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2))