summaryrefslogtreecommitdiff
path: root/src/Example.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Example.hs')
-rw-r--r--src/Example.hs48
1 files changed, 48 insertions, 0 deletions
diff --git a/src/Example.hs b/src/Example.hs
index 86264e1..424351c 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
module Example where
import AST
@@ -114,6 +115,9 @@ senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil
descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"]
descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge)
+-- x:R n:I |- let a = unit x
+-- b = build1 n (\i. let c = idx0 a in c * c)
+-- in idx0 (b ! 3)
ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32)
ex6 =
ELet ext (EUnit ext (EVar ext (STScal STF32) (IS IZ))) $
@@ -122,3 +126,47 @@ ex6 =
bin (OMul STF32) (EVar ext (STScal STF32) IZ)
(EVar ext (STScal STF32) IZ)) $
(EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (STScal STF32)) IZ) (EConst ext STI64 3)))
+
+type R = TScal TF32
+
+senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)]
+senv7 =
+ let tR = STScal STF32
+ tpair = STPair tR tR
+ in tR `SCons` STPair (STPair (STPair STNil tpair) tpair) tpair `SCons` SNil
+
+descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"]
+descr7 =
+ let tR = STScal STF32
+ tpair = STPair tR tR
+ in DTop `DPush` (STPair (STPair (STPair STNil tpair) tpair) tpair, SMerge) `DPush` (tR, SMerge)
+
+-- A "neural network" except it's just scalars, not matrices.
+-- ps:((((), (R,R)), (R,R)), (R,R)) x:R
+-- |- let p1 = snd ps
+-- p1' = fst ps
+-- x1 = fst p1 * x + snd p1
+-- p2 = snd p1'
+-- p2' = fst p1'
+-- x2 = fst p2 * x + snd p2
+-- p3 = snd p2'
+-- p3' = fst p2'
+-- x3 = fst p3 * x + snd p3
+-- in x3
+ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R
+ex7 =
+ let tR = STScal STF32
+ tpair = STPair tR tR
+
+ layer :: STy p -> Idx env p -> Idx env R -> Ex env R
+ layer parst@(STPair t (STPair (STScal STF32) (STScal STF32))) pars inp =
+ ELet ext (ESnd ext (EVar ext parst pars)) $
+ ELet ext (EFst ext (EVar ext parst (IS pars))) $
+ ELet ext (bin (OAdd STF32) (bin (OMul STF32) (EFst ext (EVar ext tpair (IS IZ)))
+ (EVar ext tR (IS (IS inp))))
+ (ESnd ext (EVar ext tpair (IS IZ)))) $
+ layer t (IS IZ) IZ
+ layer STNil _ inp = EVar ext tR inp
+ layer _ _ _ = error "Invalid layer inputs"
+
+ in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) (IS IZ) IZ