aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Example.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/Example.hs')
-rw-r--r--src/CHAD/Example.hs197
1 files changed, 197 insertions, 0 deletions
diff --git a/src/CHAD/Example.hs b/src/CHAD/Example.hs
new file mode 100644
index 0000000..884f99a
--- /dev/null
+++ b/src/CHAD/Example.hs
@@ -0,0 +1,197 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE OverloadedLabels #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
+
+{-# OPTIONS -Wno-unused-imports #-}
+module CHAD.Example where
+
+import Debug.Trace
+
+import CHAD.Array
+import CHAD.AST
+import CHAD.AST.Count
+import CHAD.AST.Pretty
+import CHAD.AST.UnMonoid
+import CHAD.Data
+import CHAD.Drev
+import CHAD.Drev.Top
+import CHAD.Drev.Types
+import CHAD.Example.Types
+import CHAD.ForwardAD
+import CHAD.Interpreter
+import CHAD.Language
+import CHAD.Simplify
+
+
+-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0)
+
+
+pipeline :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
+pipeline config term
+ | Dict <- styKnown (d2 (typeOf term)) =
+ simplifyFix $ pruneExpr knownEnv $
+ simplifyFix $ unMonoid $
+ simplifyFix $ chad' config knownEnv $
+ simplifyFix $ term
+
+-- :seti -XOverloadedLabels -XPartialTypeSignatures -Wno-partial-type-signatures
+pipeline' :: KnownEnv env => CHADConfig -> Ex env t -> IO ()
+pipeline' config term
+ | Dict <- styKnown (d2 (typeOf term)) =
+ pprintExpr (pipeline config term)
+
+
+bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c
+bin op a b = EOp ext op (EPair ext a b)
+
+senv1 :: SList STy [TScal TF32, TScal TF32]
+senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil
+
+-- x y |- x * y + x
+--
+-- let x3 = (x1, x2)
+-- x4 = ((*) x3, x1)
+-- in ( (+) x4
+-- , let x5 = 1.0
+-- x6 = Inr (x5, x5)
+-- in case x6 of
+-- Inl x7 -> return ()
+-- Inr x8 ->
+-- let x9 = fst x8
+-- x10 = Inr (snd x3 * x9, fst x3 * x9)
+-- in case x10 of
+-- Inl x11 -> return ()
+-- Inr x12 ->
+-- let x13 = fst x12
+-- in one "v1" x13 >>= \x14 ->
+-- let x15 = snd x12
+-- in one "v2" x15 >>= \x16 ->
+-- let x17 = snd x8
+-- in one "v1" x17)
+--
+-- ( (x1 * x2) + x1
+-- , let x5 = 1.0
+-- in do one "v1" (x2 * x5)
+-- one "v2" (x1 * x5)
+-- one "v1" x5)
+ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
+ex1 = fromNamed $ lambda #x $ lambda #y $ body $
+ #x * #y + #x
+
+-- x y |- let z = x + y in z * (z + x)
+ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
+ex2 = fromNamed $ lambda #x $ lambda #y $ body $
+ let_ #z (#x + #y) $
+ #z * (#z + #x)
+
+-- x y |- if x < y then 2 * x else 3 + x
+ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
+ex3 = fromNamed $ lambda #x $ lambda #y $ body $
+ if_ (#x .< #y) (2 * #x) (3 * #x)
+
+-- x y |- if x < y then 2 * x + y * y else 3 + x
+ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32)
+ex4 = fromNamed $ lambda #x $ lambda #y $ body $
+ if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x)
+
+-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)}
+ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)
+ex5 = fromNamed $ lambda #x $ lambda #y $ body $
+ case_ #x (#a :-> #a * #y)
+ (#b :-> #b * (#y + 1))
+
+-- x:R n:I |- let a = unit x
+-- b = build1 n (\i. let c = idx0 a in c * c)
+-- in idx0 (b ! 3)
+ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32)
+ex6 = fromNamed $ lambda #x $ lambda #n $ body $
+ let_ #a (unit #x) $
+ let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $
+ #b ! pair nil 3
+
+-- A "neural network" except it's just scalars, not matrices.
+-- ps:((((), (R,R)), (R,R)), (R,R)) x:R
+-- |- let p1 = snd ps
+-- p1' = fst ps
+-- x1 = fst p1 * x + snd p1
+-- p2 = snd p1'
+-- p2' = fst p1'
+-- x2 = fst p2 * x + snd p2
+-- p3 = snd p2'
+-- p3' = fst p2'
+-- x3 = fst p3 * x + snd p3
+-- in x3
+ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R
+ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $
+ let tR = STScal STF64
+ tpair = STPair tR tR
+
+ layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ R)
+ => STy p -> NExpr env R
+ layer (STPair t (STPair (STScal STF64) (STScal STF64))) | Dict <- styKnown t =
+ let_ #par (snd_ #parstup) $
+ let_ #restpars (fst_ #parstup) $
+ let_ #inp (fst_ #par * #inp + snd_ #par) $
+ let_ #parstup #restpars $
+ layer t
+ layer STNil = #inp
+ layer _ = error "Invalid layer inputs"
+
+ in let_ #parstup #pars123 $
+ let_ #inp #input $
+ layer (STPair (STPair (STPair STNil tpair) tpair) tpair)
+
+neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R
+neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $
+ let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $
+ -- prod = wei `matmul` x
+ let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :->
+ #wei ! #idx * #x ! pair nil (snd_ #idx)) $
+ -- relu (prod + bias)
+ build (SS SZ) (shape #prod) $ #idx :->
+ let_ #out (#prod ! #idx + #bias ! #idx) $
+ if_ (#out .<= const_ 0) (const_ 0) #out
+
+ in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $
+ let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $
+ let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
+ #x3 ! nil
+
+type NeuralGrad = ((Array N2 Double, Array N1 Double)
+ ,(Array N2 Double, Array N1 Double)
+ ,Array N1 Double
+ ,Array N1 Double)
+
+neuralGo :: (Double -- primal
+ ,NeuralGrad -- gradient using CHAD
+ ,NeuralGrad) -- gradient using dual-numbers forward AD
+neuralGo =
+ let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
+ lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
+ lay3 = arrayFromList (ShNil `ShCons` 2) [1,1]
+ input = arrayFromList (ShNil `ShCons` 2) [1,1]
+ argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil)
+ revderiv =
+ simplifyN 20 $
+ ELet ext (EConst ext STF64 1.0) $
+ chad defaultConfig knownEnv neural
+ (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of
+ (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1')
+ (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0
+ in trace (ppExpr knownEnv revderiv) $
+ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2))
+
+-- The build body uses free variables in a non-linear way, so their primal
+-- values are required in the dual of the build. Thus, compositionally, they
+-- are stored in the tape from each individual lambda invocation. This results
+-- in n copies of y and z, where only one copy would have sufficed.
+exUniformFree :: Ex '[R, I64] R
+exUniformFree = fromNamed $ lambda #n $ lambda #x $ body $
+ let_ #y (#x * 2) $
+ let_ #z (#x * 3) $
+ idx0 $ sum1i $
+ build1 #n $ #i :-> #y * #z + toFloat_ #i