aboutsummaryrefslogtreecommitdiff
path: root/src/AST/Count.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/Count.hs')
-rw-r--r--src/AST/Count.hs81
1 files changed, 45 insertions, 36 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 5666289..bec5a9d 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -561,23 +561,7 @@ occCountX initialS topexpr k = case topexpr of
Just Refl -> expr
Nothing -> error "unreachable"
- ESum1Inner _ e
- | STArr (SS n) _ <- typeOf e ->
- case s of
- SsNone ->
- occCountX SsNone e $ \env mke ->
- k env $ \env' ->
- use (mke env') $ ENil ext
- SsArr SsNone ->
- occCountX (SsArr SsNone) e $ \env mke ->
- k env $ \env' ->
- elet (mke env') $
- EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext)
- SsArr SsFull ->
- occCountX (SsArr SsFull) e $ \env mke ->
- k env $ \env' ->
- ESum1Inner ext (mke env')
- SsFull -> occCountX (SsArr SsFull) topexpr k
+ ESum1Inner _ e -> handleReduction (ESum1Inner ext) e
EUnit _ e ->
case s of
@@ -607,14 +591,8 @@ occCountX initialS topexpr k = case topexpr of
EReplicate1Inner ext (mka env') (mkb env')
SsFull -> occCountX (SsArr SsFull) topexpr k
- {-
- EMaximum1Inner _ e ->
- let (e', env) = re (SsArr (ssUnarr s)) e
- in (EMaximum1Inner s e', env)
- EMinimum1Inner _ e ->
- let (e', env) = re (SsArr (ssUnarr s)) e
- in (EMinimum1Inner s e', env)
- -}
+ EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
+ EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
EConst _ t x ->
k OccEnd $ \_ ->
@@ -680,15 +658,26 @@ occCountX initialS topexpr k = case topexpr of
k env1 $ \env' ->
projectSmallerSubstruc SsFull s $ EOp ext op (mke env')
- {-
- ECustom _ t1 t2 t3 e1 e2 e3 a b ->
- let (e1', _) = occCountX SsFull e1
- (e2', _) = occCountX SsFull e2
- (e3', _) = occCountX SsFull e3
- (a', env1) = re SsFull a -- let's be pessimistic here for safety
- (b', env2) = re SsFull b
- in (ECustom SsFull t1 t2 t3 e1' e2' e3' a' b', env1 <> env2)
- -}
+ ECustom _ t1 t2 t3 e1 e2 e3 a b
+ | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 ->
+ error "Accumulators not allowed in input/output/tape of an ECustom"
+ | otherwise ->
+ case s of
+ SsNone ->
+ -- Allowed to ignore e1/e2/e3 here because no accumulators are
+ -- communicated, and hence no relevant effects exist
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ s' -> -- Let's be pessimistic for safety
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s' $
+ ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env')
ERecompute _ e ->
occCountX s e $ \env1 mke ->
@@ -781,11 +770,31 @@ occCountX initialS topexpr k = case topexpr of
EError _ t msg ->
k OccEnd $ \_ -> EError ext (applySubstruc s t) msg
-
- _ -> error "occCountX: TODO unimplemented"
where
s = simplifySubstruc (typeOf topexpr) initialS
+ handleReduction :: t ~ TArr n (TScal t2)
+ => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2)))
+ -> Expr x env (TArr (S n) (TScal t2))
+ -> r
+ handleReduction reduce e
+ | STArr (SS n) _ <- typeOf e =
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env mke ->
+ k env $ \env' ->
+ use (mke env') $ ENil ext
+ SsArr SsNone ->
+ occCountX (SsArr SsNone) e $ \env mke ->
+ k env $ \env' ->
+ elet (mke env') $
+ EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext)
+ SsArr SsFull ->
+ occCountX (SsArr SsFull) e $ \env mke ->
+ k env $ \env' ->
+ reduce (mke env')
+ SsFull -> occCountX (SsArr SsFull) topexpr k
+
deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r
deleteUnused SNil (Some OccEnd) k = k SETop