{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} {-# OPTIONS -Wno-unused-imports #-} module Example where import Array import AST import AST.Count import AST.Pretty import AST.UnMonoid import CHAD import CHAD.Top import CHAD.Types import ForwardAD import Interpreter import Language import Simplify import Debug.Trace import Example.Types -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) pipeline :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) pipeline config term | Dict <- styKnown (d2 (typeOf term)) = simplifyFix $ pruneExpr knownEnv $ simplifyFix $ unMonoid $ chad' config knownEnv $ simplifyFix $ term -- :seti -XOverloadedLabels -XPartialTypeSignatures -Wno-partial-type-signatures pipeline' :: KnownEnv env => CHADConfig -> Ex env t -> IO () pipeline' config term | Dict <- styKnown (d2 (typeOf term)) = pprintExpr (pipeline config term) bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c bin op a b = EOp ext op (EPair ext a b) senv1 :: SList STy [TScal TF32, TScal TF32] senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil -- x y |- x * y + x -- -- let x3 = (x1, x2) -- x4 = ((*) x3, x1) -- in ( (+) x4 -- , let x5 = 1.0 -- x6 = Inr (x5, x5) -- in case x6 of -- Inl x7 -> return () -- Inr x8 -> -- let x9 = fst x8 -- x10 = Inr (snd x3 * x9, fst x3 * x9) -- in case x10 of -- Inl x11 -> return () -- Inr x12 -> -- let x13 = fst x12 -- in one "v1" x13 >>= \x14 -> -- let x15 = snd x12 -- in one "v2" x15 >>= \x16 -> -- let x17 = snd x8 -- in one "v1" x17) -- -- ( (x1 * x2) + x1 -- , let x5 = 1.0 -- in do one "v1" (x2 * x5) -- one "v2" (x1 * x5) -- one "v1" x5) ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex1 = fromNamed $ lambda #x $ lambda #y $ body $ #x * #y + #x -- x y |- let z = x + y in z * (z + x) ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex2 = fromNamed $ lambda #x $ lambda #y $ body $ let_ #z (#x + #y) $ #z * (#z + #x) -- x y |- if x < y then 2 * x else 3 + x ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex3 = fromNamed $ lambda #x $ lambda #y $ body $ if_ (#x .< #y) (2 * #x) (3 * #x) -- x y |- if x < y then 2 * x + y * y else 3 + x ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex4 = fromNamed $ lambda #x $ lambda #y $ body $ if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x) -- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) ex5 = fromNamed $ lambda #x $ lambda #y $ body $ case_ #x (#a :-> #a * #y) (#b :-> #b * (#y + 1)) -- x:R n:I |- let a = unit x -- b = build1 n (\i. let c = idx0 a in c * c) -- in idx0 (b ! 3) ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) ex6 = fromNamed $ lambda #x $ lambda #n $ body $ let_ #a (unit #x) $ let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $ #b ! pair nil 3 -- A "neural network" except it's just scalars, not matrices. -- ps:((((), (R,R)), (R,R)), (R,R)) x:R -- |- let p1 = snd ps -- p1' = fst ps -- x1 = fst p1 * x + snd p1 -- p2 = snd p1' -- p2' = fst p1' -- x2 = fst p2 * x + snd p2 -- p3 = snd p2' -- p3' = fst p2' -- x3 = fst p3 * x + snd p3 -- in x3 ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $ let tR = STScal STF64 tpair = STPair tR tR layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ R) => STy p -> NExpr env R layer (STPair t (STPair (STScal STF64) (STScal STF64))) | Dict <- styKnown t = let_ #par (snd_ #parstup) $ let_ #restpars (fst_ #parstup) $ let_ #inp (fst_ #par * #inp + snd_ #par) $ let_ #parstup #restpars $ layer t layer STNil = #inp layer _ = error "Invalid layer inputs" in let_ #parstup #pars123 $ let_ #inp #input $ layer (STPair (STPair (STPair STNil tpair) tpair) tpair) neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $ let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $ -- prod = wei `matmul` x let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :-> #wei ! #idx * #x ! pair nil (snd_ #idx)) $ -- relu (prod + bias) build (SS SZ) (shape #prod) $ #idx :-> let_ #out (#prod ! #idx + #bias ! #idx) $ if_ (#out .<= const_ 0) (const_ 0) #out in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $ let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $ let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $ #x3 ! nil type NeuralGrad = ((Array N2 Double, Array N1 Double) ,(Array N2 Double, Array N1 Double) ,Array N1 Double ,Array N1 Double) neuralGo :: (Double -- primal ,NeuralGrad -- gradient using CHAD ,NeuralGrad) -- gradient using dual-numbers forward AD neuralGo = let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) lay3 = arrayFromList (ShNil `ShCons` 2) [1,1] input = arrayFromList (ShNil `ShCons` 2) [1,1] argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) revderiv = simplifyN 20 $ ELet ext (EConst ext STF64 1.0) $ chad defaultConfig knownEnv neural (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 in trace (ppExpr knownEnv revderiv) $ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) -- The build body uses free variables in a non-linear way, so their primal -- values are required in the dual of the build. Thus, compositionally, they -- are stored in the tape from each individual lambda invocation. This results -- in n copies of y and z, where only one copy would have sufficed. exUniformFree :: Ex '[R, I64] R exUniformFree = fromNamed $ lambda #n $ lambda #x $ body $ let_ #y (#x * 2) $ let_ #z (#x * 3) $ idx0 $ sum1i $ build1 #n $ #i :-> #y * #z + toFloat_ #i