summaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Bindings.hs11
-rw-r--r--src/AST/Count.hs1
-rw-r--r--src/AST/Pretty.hs4
-rw-r--r--src/AST/SplitLets.hs1
-rw-r--r--src/AST/UnMonoid.hs1
5 files changed, 18 insertions, 0 deletions
diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs
index 3d99afe..745a93b 100644
--- a/src/AST/Bindings.hs
+++ b/src/AST/Bindings.hs
@@ -16,6 +16,7 @@
module AST.Bindings where
import AST
+import AST.Env
import Data
import Lemmas
@@ -62,3 +63,13 @@ bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds)
letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t
letBinds BTop = id
letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs
+
+collectBindings :: SList STy env -> Subenv env env' -> Bindings Ex env env'
+collectBindings = \env -> fst . go env WId
+ where
+ go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0)
+ go _ _ SETop = (BTop, WId)
+ go (ty `SCons` env) w (SEYes sub) =
+ let (bs, w') = go env (WPop w) sub
+ in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w')
+ go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index feaaa1e..0c682c6 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -132,6 +132,7 @@ occCountGeneral onehot unpush alter many = go WId
EShape _ e -> re e
EOp _ _ e -> re e
ECustom _ _ _ _ _ _ _ a b -> re a <> re b
+ ERecompute _ e -> re e
EWith _ _ a b -> re a <> re1 b
EAccum _ _ _ a b e -> re a <> re b <> re e
EZero _ _ e -> re e
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 6d70ca3..41da656 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -288,6 +288,10 @@ ppExpr' d val expr = case expr of
,e1'
,e2']
+ ERecompute _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e']
+
EWith _ t e1 e2 -> do
e1' <- ppExpr' 11 val e1
name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index 1379e35..3c353d4 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -61,6 +61,7 @@ splitLets' = \sub -> \case
EShape x e -> EShape x (splitLets' sub e)
EOp x op e -> EOp x op (splitLets' sub e)
ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2)
+ ERecompute x e -> ERecompute x (splitLets' sub e)
EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2)
EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3)
EZero x t ezi -> EZero x t (splitLets' sub ezi)
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 3d5f544..ac4d733 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -47,6 +47,7 @@ unMonoid = \case
EShape _ e -> EShape ext (unMonoid e)
EOp _ op e -> EOp ext op (unMonoid e)
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
+ ERecompute _ e -> ERecompute ext (unMonoid e)
EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)
EError _ t s -> EError ext t s