{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} module Example where import AST import AST.Pretty import CHAD import Data import Simplify -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c bin op a b = EOp ext op (EPair ext a b) senv1 :: SList STy [TScal TF32, TScal TF32] senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil descr1 :: Storage a -> Storage b -> Descr [TScal TF32, TScal TF32] [b, a] descr1 a b = DTop `DPush` (t, a) `DPush` (t, b) where t = STScal STF32 -- x y |- x * y + x -- -- let x3 = (x1, x2) -- x4 = ((*) x3, x1) -- in ( (+) x4 -- , let x5 = 1.0 -- x6 = Inr (x5, x5) -- in case x6 of -- Inl x7 -> return () -- Inr x8 -> -- let x9 = fst x8 -- x10 = Inr (snd x3 * x9, fst x3 * x9) -- in case x10 of -- Inl x11 -> return () -- Inr x12 -> -- let x13 = fst x12 -- in one "v1" x13 >>= \x14 -> -- let x15 = snd x12 -- in one "v2" x15 >>= \x16 -> -- let x17 = snd x8 -- in one "v1" x17) -- -- ( (x1 * x2) + x1 -- , let x5 = 1.0 -- in do one "v1" (x2 * x5) -- one "v2" (x1 * x5) -- one "v1" x5) ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex1 = bin (OAdd STF32) (bin (OMul STF32) (EVar ext (STScal STF32) (IS IZ)) (EVar ext (STScal STF32) IZ)) (EVar ext (STScal STF32) (IS IZ)) -- x y |- let z = x + y in z * (z + x) ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex2 = ELet ext (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ)) (EVar ext (STScal STF32) IZ)) $ bin (OMul STF32) (EVar ext (STScal STF32) IZ) (bin (OAdd STF32) (EVar ext (STScal STF32) IZ) (EVar ext (STScal STF32) (IS (IS IZ)))) -- x y |- if x < y then 2 * x else 3 + x ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex3 = ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ)) (EVar ext (STScal STF32) IZ))) (bin (OMul STF32) (EConst ext STF32 2.0) (EVar ext (STScal STF32) (IS (IS IZ)))) (bin (OAdd STF32) (EConst ext STF32 3.0) (EVar ext (STScal STF32) (IS (IS IZ)))) -- x y |- if x < y then 2 * x + y * y else 3 + x ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex4 = ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ)) (EVar ext (STScal STF32) IZ))) (bin (OAdd STF32) (bin (OMul STF32) (EConst ext STF32 2.0) (EVar ext (STScal STF32) (IS (IS IZ)))) (bin (OMul STF32) (EVar ext (STScal STF32) (IS IZ)) (EVar ext (STScal STF32) (IS IZ)))) (bin (OAdd STF32) (EConst ext STF32 3.0) (EVar ext (STScal STF32) (IS (IS IZ)))) senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)] senv5 = STScal STF32 `SCons` STEither (STScal STF32) (STScal STF32) `SCons` SNil descr5 :: Storage a -> Storage b -> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a] descr5 a b = DTop `DPush` (STEither (STScal STF32) (STScal STF32), a) `DPush` (STScal STF32, b) -- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) ex5 = ECase ext (EVar ext (STEither (STScal STF32) (STScal STF32)) (IS IZ)) (bin (OMul STF32) (EVar ext (STScal STF32) IZ) (EVar ext (STScal STF32) (IS IZ))) (bin (OMul STF32) (EVar ext (STScal STF32) IZ) (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ)) (EConst ext STF32 1.0))) senv6 :: SList STy [TScal TI64, TScal TF32] senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"] descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge) -- x:R n:I |- let a = unit x -- b = build1 n (\i. let c = idx0 a in c * c) -- in idx0 (b ! 3) ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) ex6 = ELet ext (EUnit ext (EVar ext (STScal STF32) (IS IZ))) $ ELet ext (EBuild1 ext (EVar ext tIx (IS IZ)) $ ELet ext (EIdx0 ext (EVar ext (STArr SZ (STScal STF32)) (IS IZ))) $ bin (OMul STF32) (EVar ext (STScal STF32) IZ) (EVar ext (STScal STF32) IZ)) $ (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (STScal STF32)) IZ) (EConst ext STI64 3))) type R = TScal TF32 senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] senv7 = let tR = STScal STF32 tpair = STPair tR tR in tR `SCons` STPair (STPair (STPair STNil tpair) tpair) tpair `SCons` SNil descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"] descr7 = let tR = STScal STF32 tpair = STPair tR tR in DTop `DPush` (STPair (STPair (STPair STNil tpair) tpair) tpair, SMerge) `DPush` (tR, SMerge) -- A "neural network" except it's just scalars, not matrices. -- ps:((((), (R,R)), (R,R)), (R,R)) x:R -- |- let p1 = snd ps -- p1' = fst ps -- x1 = fst p1 * x + snd p1 -- p2 = snd p1' -- p2' = fst p1' -- x2 = fst p2 * x + snd p2 -- p3 = snd p2' -- p3' = fst p2' -- x3 = fst p3 * x + snd p3 -- in x3 ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R ex7 = let tR = STScal STF32 tpair = STPair tR tR layer :: STy p -> Idx env p -> Idx env R -> Ex env R layer parst@(STPair t (STPair (STScal STF32) (STScal STF32))) pars inp = ELet ext (ESnd ext (EVar ext parst pars)) $ ELet ext (EFst ext (EVar ext parst (IS pars))) $ ELet ext (bin (OAdd STF32) (bin (OMul STF32) (EFst ext (EVar ext tpair (IS IZ))) (EVar ext tR (IS (IS inp)))) (ESnd ext (EVar ext tpair (IS IZ)))) $ layer t (IS IZ) IZ layer STNil _ inp = EVar ext tR inp layer _ _ _ = error "Invalid layer inputs" in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) (IS IZ) IZ