{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} module Example where import Array import AST import AST.Pretty import CHAD import Data import ForwardAD import Interpreter import Language import Simplify import Debug.Trace import Example.Format -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) type family MergeEnv env where MergeEnv '[] = '[] MergeEnv (t : ts) = "merge" : MergeEnv ts mergeDescr :: KnownEnv env => Descr env (MergeEnv env) mergeDescr = go knownEnv where go :: SList STy env -> Descr env (MergeEnv env) go SNil = DTop go (t `SCons` env) = go env `DPush` (t, SMerge) bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c bin op a b = EOp ext op (EPair ext a b) senv1 :: SList STy [TScal TF32, TScal TF32] senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil descr1 :: Storage a -> Storage b -> Descr [TScal TF32, TScal TF32] [b, a] descr1 a b = DTop `DPush` (t, a) `DPush` (t, b) where t = STScal STF32 -- x y |- x * y + x -- -- let x3 = (x1, x2) -- x4 = ((*) x3, x1) -- in ( (+) x4 -- , let x5 = 1.0 -- x6 = Inr (x5, x5) -- in case x6 of -- Inl x7 -> return () -- Inr x8 -> -- let x9 = fst x8 -- x10 = Inr (snd x3 * x9, fst x3 * x9) -- in case x10 of -- Inl x11 -> return () -- Inr x12 -> -- let x13 = fst x12 -- in one "v1" x13 >>= \x14 -> -- let x15 = snd x12 -- in one "v2" x15 >>= \x16 -> -- let x17 = snd x8 -- in one "v1" x17) -- -- ( (x1 * x2) + x1 -- , let x5 = 1.0 -- in do one "v1" (x2 * x5) -- one "v2" (x1 * x5) -- one "v1" x5) ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex1 = fromNamed $ lambda #x $ lambda #y $ body $ #x * #y + #x -- x y |- let z = x + y in z * (z + x) ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex2 = fromNamed $ lambda #x $ lambda #y $ body $ let_ #z (#x + #y) $ #z * (#z + #x) -- x y |- if x < y then 2 * x else 3 + x ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex3 = fromNamed $ lambda #x $ lambda #y $ body $ if_ (#x .< #y) (2 * #x) (3 * #x) -- x y |- if x < y then 2 * x + y * y else 3 + x ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex4 = fromNamed $ lambda #x $ lambda #y $ body $ if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x) senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)] senv5 = knownEnv descr5 :: Storage a -> Storage b -> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a] descr5 a b = DTop `DPush` (knownTy, a) `DPush` (knownTy, b) -- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) ex5 = fromNamed $ lambda #x $ lambda #y $ body $ case_ #x (#a :-> #a * #y) (#b :-> #b * (#y + 1)) senv6 :: SList STy [TScal TI64, TScal TF32] senv6 = knownEnv descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"] descr6 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge) -- x:R n:I |- let a = unit x -- b = build1 n (\i. let c = idx0 a in c * c) -- in idx0 (b ! 3) ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) ex6 = fromNamed $ lambda #x $ lambda #n $ body $ let_ #a (unit #x) $ let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $ idx0 (#b .! 3) type R = TScal TF64 senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] senv7 = knownEnv descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"] descr7 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge) -- A "neural network" except it's just scalars, not matrices. -- ps:((((), (R,R)), (R,R)), (R,R)) x:R -- |- let p1 = snd ps -- p1' = fst ps -- x1 = fst p1 * x + snd p1 -- p2 = snd p1' -- p2' = fst p1' -- x2 = fst p2 * x + snd p2 -- p3 = snd p2' -- p3' = fst p2' -- x3 = fst p3 * x + snd p3 -- in x3 ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $ let tR = STScal STF64 tpair = STPair tR tR layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ R) => STy p -> NExpr env R layer (STPair t (STPair (STScal STF64) (STScal STF64))) | Dict <- styKnown t = let_ #par (snd_ #parstup) $ let_ #restpars (fst_ #parstup) $ let_ #inp (fst_ #par * #inp + snd_ #par) $ let_ #parstup #restpars $ layer t layer STNil = #inp layer _ = error "Invalid layer inputs" in let_ #parstup #pars123 $ let_ #inp #input $ layer (STPair (STPair (STPair STNil tpair) tpair) tpair) type TVec = TArr (S Z) type TMat = TArr (S (S Z)) neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $ let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $ -- prod = wei `matmul` x let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :-> #wei ! #idx * #x ! pair nil (snd_ #idx)) $ -- relu (prod + bias) build (SS SZ) (shape #prod) $ #idx :-> let_ #out (#prod ! #idx + #bias ! #idx) $ if_ (#out .<= const_ 0) (const_ 0) #out in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $ let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $ let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $ #x3 ! nil type NeuralGrad = ((Array N2 Double, Array N1 Double) ,(Array N2 Double, Array N1 Double) ,Array N1 Double ,Array N1 Double) neuralGo :: (Double -- primal ,NeuralGrad -- gradient using CHAD ,NeuralGrad) -- gradient using dual-numbers forward AD neuralGo = let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) lay3 = arrayFromList (ShNil `ShCons` 2) [1,1] input = arrayFromList (ShNil `ShCons` 2) [1,1] argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) revderiv = simplifyN 20 $ freezeRet mergeDescr (drev mergeDescr neural) (EConst ext STF64 1.0) (primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen False argument revderiv (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0 in trace (formatter (ppExpr knownEnv revderiv)) $ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2))