diff options
Diffstat (limited to 'src/Example.hs')
-rw-r--r-- | src/Example.hs | 32 |
1 files changed, 5 insertions, 27 deletions
diff --git a/src/Example.hs b/src/Example.hs index 2c710a1..b320ead 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -5,11 +5,14 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} + +{-# OPTIONS -Wno-unused-imports #-} module Example where import Array import AST import AST.Pretty +import AST.UnMonoid import CHAD import CHAD.Top import ForwardAD @@ -30,11 +33,6 @@ bin op a b = EOp ext op (EPair ext a b) senv1 :: SList STy [TScal TF32, TScal TF32] senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil -descr1 :: Storage a -> Storage b - -> Descr [TScal TF32, TScal TF32] [b, a] -descr1 a b = DTop `DPush` (t, a) `DPush` (t, b) - where t = STScal STF32 - -- x y |- x * y + x -- -- let x3 = (x1, x2) @@ -82,25 +80,12 @@ ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex4 = fromNamed $ lambda #x $ lambda #y $ body $ if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x) -senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)] -senv5 = knownEnv - -descr5 :: Storage a -> Storage b - -> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a] -descr5 a b = DTop `DPush` (knownTy, a) `DPush` (knownTy, b) - -- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) ex5 = fromNamed $ lambda #x $ lambda #y $ body $ case_ #x (#a :-> #a * #y) (#b :-> #b * (#y + 1)) -senv6 :: SList STy [TScal TI64, TScal TF32] -senv6 = knownEnv - -descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"] -descr6 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge) - -- x:R n:I |- let a = unit x -- b = build1 n (\i. let c = idx0 a in c * c) -- in idx0 (b ! 3) @@ -110,12 +95,6 @@ ex6 = fromNamed $ lambda #x $ lambda #n $ body $ let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $ #b ! pair nil 3 -senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] -senv7 = knownEnv - -descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"] -descr7 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge) - -- A "neural network" except it's just scalars, not matrices. -- ps:((((), (R,R)), (R,R)), (R,R)) x:R -- |- let p1 = snd ps @@ -182,9 +161,8 @@ neuralGo = simplifyN 20 $ ELet ext (EConst ext STF64 1.0) $ chad defaultConfig knownEnv neural - (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of - (primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') - _ -> undefined + (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of + (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 in trace (ppExpr knownEnv revderiv) $ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) |