{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} module Example where import AST import AST.Pretty import CHAD import Data import Language import Simplify -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c bin op a b = EOp ext op (EPair ext a b) senv1 :: SList STy [TScal TF32, TScal TF32] senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil descr1 :: Storage a -> Storage b -> Descr [TScal TF32, TScal TF32] [b, a] descr1 a b = DTop `DPush` (t, a) `DPush` (t, b) where t = STScal STF32 -- x y |- x * y + x -- -- let x3 = (x1, x2) -- x4 = ((*) x3, x1) -- in ( (+) x4 -- , let x5 = 1.0 -- x6 = Inr (x5, x5) -- in case x6 of -- Inl x7 -> return () -- Inr x8 -> -- let x9 = fst x8 -- x10 = Inr (snd x3 * x9, fst x3 * x9) -- in case x10 of -- Inl x11 -> return () -- Inr x12 -> -- let x13 = fst x12 -- in one "v1" x13 >>= \x14 -> -- let x15 = snd x12 -- in one "v2" x15 >>= \x16 -> -- let x17 = snd x8 -- in one "v1" x17) -- -- ( (x1 * x2) + x1 -- , let x5 = 1.0 -- in do one "v1" (x2 * x5) -- one "v2" (x1 * x5) -- one "v1" x5) ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex1 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ x * y + x -- x y |- let z = x + y in z * (z + x) ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex2 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ let_ (x + y) $ \z -> z * (z + x) -- x y |- if x < y then 2 * x else 3 + x ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex3 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ if_ (x .< y) (2 * x) (3 * x) -- x y |- if x < y then 2 * x + y * y else 3 + x ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex4 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ if_ (x .< y) (2 * x + y * y) (3 + x) senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)] senv5 = STScal STF32 `SCons` STEither (STScal STF32) (STScal STF32) `SCons` SNil descr5 :: Storage a -> Storage b -> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a] descr5 a b = DTop `DPush` (STEither (STScal STF32) (STScal STF32), a) `DPush` (STScal STF32, b) -- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) ex5 = scopeCheck $ lambda $ \x -> lambda $ \y -> body $ case_ x (\a -> a * y) (\b -> b * (y + 1)) senv6 :: SList STy [TScal TI64, TScal TF32] senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"] descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge) -- x:R n:I |- let a = unit x -- b = build1 n (\i. let c = idx0 a in c * c) -- in idx0 (b ! 3) ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) ex6 = scopeCheck $ lambda $ \x -> lambda $ \n -> body $ let_ (unit x) $ \a -> let_ (build1 n (\_ -> let_ (idx0 a) $ \c -> c * c)) $ \b -> idx0 (b .! 3) type R = TScal TF32 senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] senv7 = let tR = STScal STF32 tpair = STPair tR tR in tR `SCons` STPair (STPair (STPair STNil tpair) tpair) tpair `SCons` SNil descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"] descr7 = let tR = STScal STF32 tpair = STPair tR tR in DTop `DPush` (STPair (STPair (STPair STNil tpair) tpair) tpair, SMerge) `DPush` (tR, SMerge) -- A "neural network" except it's just scalars, not matrices. -- ps:((((), (R,R)), (R,R)), (R,R)) x:R -- |- let p1 = snd ps -- p1' = fst ps -- x1 = fst p1 * x + snd p1 -- p2 = snd p1' -- p2' = fst p1' -- x2 = fst p2 * x + snd p2 -- p3 = snd p2' -- p3' = fst p2' -- x3 = fst p3 * x + snd p3 -- in x3 ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R ex7 = scopeCheck $ lambda $ \pars123 -> lambda $ \input -> body $ let tR = STScal STF32 tpair = STPair tR tR layer :: STy p -> SExpr p -> SExpr R -> SExpr R layer (STPair t (STPair (STScal STF32) (STScal STF32))) pars inp | Dict <- styKnown t = let_ (snd_ pars) $ \par -> let_ (fst_ pars) $ \restpars -> let_ (fst_ par * inp + snd_ par) $ \res -> layer t restpars res layer STNil _ inp = inp layer _ _ _ = error "Invalid layer inputs" in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) pars123 input