diff options
Diffstat (limited to 'src/Example.hs')
-rw-r--r-- | src/Example.hs | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/src/Example.hs b/src/Example.hs index 86264e1..424351c 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} module Example where import AST @@ -114,6 +115,9 @@ senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"] descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge) +-- x:R n:I |- let a = unit x +-- b = build1 n (\i. let c = idx0 a in c * c) +-- in idx0 (b ! 3) ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) ex6 = ELet ext (EUnit ext (EVar ext (STScal STF32) (IS IZ))) $ @@ -122,3 +126,47 @@ ex6 = bin (OMul STF32) (EVar ext (STScal STF32) IZ) (EVar ext (STScal STF32) IZ)) $ (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (STScal STF32)) IZ) (EConst ext STI64 3))) + +type R = TScal TF32 + +senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] +senv7 = + let tR = STScal STF32 + tpair = STPair tR tR + in tR `SCons` STPair (STPair (STPair STNil tpair) tpair) tpair `SCons` SNil + +descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"] +descr7 = + let tR = STScal STF32 + tpair = STPair tR tR + in DTop `DPush` (STPair (STPair (STPair STNil tpair) tpair) tpair, SMerge) `DPush` (tR, SMerge) + +-- A "neural network" except it's just scalars, not matrices. +-- ps:((((), (R,R)), (R,R)), (R,R)) x:R +-- |- let p1 = snd ps +-- p1' = fst ps +-- x1 = fst p1 * x + snd p1 +-- p2 = snd p1' +-- p2' = fst p1' +-- x2 = fst p2 * x + snd p2 +-- p3 = snd p2' +-- p3' = fst p2' +-- x3 = fst p3 * x + snd p3 +-- in x3 +ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R +ex7 = + let tR = STScal STF32 + tpair = STPair tR tR + + layer :: STy p -> Idx env p -> Idx env R -> Ex env R + layer parst@(STPair t (STPair (STScal STF32) (STScal STF32))) pars inp = + ELet ext (ESnd ext (EVar ext parst pars)) $ + ELet ext (EFst ext (EVar ext parst (IS pars))) $ + ELet ext (bin (OAdd STF32) (bin (OMul STF32) (EFst ext (EVar ext tpair (IS IZ))) + (EVar ext tR (IS (IS inp)))) + (ESnd ext (EVar ext tpair (IS IZ)))) $ + layer t (IS IZ) IZ + layer STNil _ inp = EVar ext tR inp + layer _ _ _ = error "Invalid layer inputs" + + in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) (IS IZ) IZ |