{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Example where import AST import AST.Pretty import CHAD import Data import Language import Simplify -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) type family MergeEnv env where MergeEnv '[] = '[] MergeEnv (t : ts) = "merge" : MergeEnv ts mergeDescr :: KnownEnv env => Descr env (MergeEnv env) mergeDescr = go knownEnv where go :: SList STy env -> Descr env (MergeEnv env) go SNil = DTop go (t `SCons` env) = go env `DPush` (t, SMerge) bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c bin op a b = EOp ext op (EPair ext a b) senv1 :: SList STy [TScal TF32, TScal TF32] senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil descr1 :: Storage a -> Storage b -> Descr [TScal TF32, TScal TF32] [b, a] descr1 a b = DTop `DPush` (t, a) `DPush` (t, b) where t = STScal STF32 -- x y |- x * y + x -- -- let x3 = (x1, x2) -- x4 = ((*) x3, x1) -- in ( (+) x4 -- , let x5 = 1.0 -- x6 = Inr (x5, x5) -- in case x6 of -- Inl x7 -> return () -- Inr x8 -> -- let x9 = fst x8 -- x10 = Inr (snd x3 * x9, fst x3 * x9) -- in case x10 of -- Inl x11 -> return () -- Inr x12 -> -- let x13 = fst x12 -- in one "v1" x13 >>= \x14 -> -- let x15 = snd x12 -- in one "v2" x15 >>= \x16 -> -- let x17 = snd x8 -- in one "v1" x17) -- -- ( (x1 * x2) + x1 -- , let x5 = 1.0 -- in do one "v1" (x2 * x5) -- one "v2" (x1 * x5) -- one "v1" x5) ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex1 = fromNamed $ lambda #x $ lambda #y $ body $ #x * #y + #x -- x y |- let z = x + y in z * (z + x) ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex2 = fromNamed $ lambda #x $ lambda #y $ body $ let_ #z (#x + #y) $ #z * (#z + #x) -- x y |- if x < y then 2 * x else 3 + x ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex3 = fromNamed $ lambda #x $ lambda #y $ body $ if_ (#x .< #y) (2 * #x) (3 * #x) -- x y |- if x < y then 2 * x + y * y else 3 + x ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex4 = fromNamed $ lambda #x $ lambda #y $ body $ if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x) senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)] senv5 = knownEnv descr5 :: Storage a -> Storage b -> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a] descr5 a b = DTop `DPush` (knownTy, a) `DPush` (knownTy, b) -- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) ex5 = fromNamed $ lambda #x $ lambda #y $ body $ case_ #x (#a :-> #a * #y) (#b :-> #b * (#y + 1)) senv6 :: SList STy [TScal TI64, TScal TF32] senv6 = knownEnv descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"] descr6 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge) -- x:R n:I |- let a = unit x -- b = build1 n (\i. let c = idx0 a in c * c) -- in idx0 (b ! 3) ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) ex6 = fromNamed $ lambda #x $ lambda #n $ body $ let_ #a (unit #x) $ let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $ idx0 (#b .! 3) type R = TScal TF32 senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] senv7 = knownEnv descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"] descr7 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge) -- A "neural network" except it's just scalars, not matrices. -- ps:((((), (R,R)), (R,R)), (R,R)) x:R -- |- let p1 = snd ps -- p1' = fst ps -- x1 = fst p1 * x + snd p1 -- p2 = snd p1' -- p2' = fst p1' -- x2 = fst p2 * x + snd p2 -- p3 = snd p2' -- p3' = fst p2' -- x3 = fst p3 * x + snd p3 -- in x3 ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $ let tR = STScal STF32 tpair = STPair tR tR layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ TScal TF32) => STy p -> NExpr env R layer (STPair t (STPair (STScal STF32) (STScal STF32))) | Dict <- styKnown t = let_ #par (snd_ #parstup) $ let_ #restpars (fst_ #parstup) $ let_ #inp (fst_ #par * #inp + snd_ #par) $ let_ #parstup #restpars $ layer t layer STNil = #inp layer _ = error "Invalid layer inputs" in let_ #parstup #pars123 $ let_ #inp #input $ layer (STPair (STPair (STPair STNil tpair) tpair) tpair) type TVec = TArr (S Z) type TMat = TArr (S (S Z)) neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $ let layer :: (Lookup "wei" env ~ TMat R, Lookup "bias" env ~ TVec R, Lookup "x" env ~ TVec R) => NExpr env (TVec R) layer = -- prod = wei `matmul` x let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :-> #wei ! #idx * #x ! pair nil (snd_ #idx)) $ -- relu (prod + bias) build (SS SZ) (shape #prod) $ #idx :-> let_ #out (#prod ! #idx + #bias ! #idx) $ if_ (#out .<= const_ 0) (const_ 0) #out in let_ #x1 (let_ #wei (fst_ #layer1) $ let_ #bias (snd_ #layer1) $ let_ #x #input $ layer) $ let_ #x2 (let_ #wei (fst_ #layer2) $ let_ #bias (snd_ #layer2) $ let_ #x #x1 $ layer) $ let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $ #x3 ! nil