aboutsummaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Count.hs81
-rw-r--r--src/AST/Types.hs28
2 files changed, 64 insertions, 45 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 5666289..bec5a9d 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -561,23 +561,7 @@ occCountX initialS topexpr k = case topexpr of
Just Refl -> expr
Nothing -> error "unreachable"
- ESum1Inner _ e
- | STArr (SS n) _ <- typeOf e ->
- case s of
- SsNone ->
- occCountX SsNone e $ \env mke ->
- k env $ \env' ->
- use (mke env') $ ENil ext
- SsArr SsNone ->
- occCountX (SsArr SsNone) e $ \env mke ->
- k env $ \env' ->
- elet (mke env') $
- EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext)
- SsArr SsFull ->
- occCountX (SsArr SsFull) e $ \env mke ->
- k env $ \env' ->
- ESum1Inner ext (mke env')
- SsFull -> occCountX (SsArr SsFull) topexpr k
+ ESum1Inner _ e -> handleReduction (ESum1Inner ext) e
EUnit _ e ->
case s of
@@ -607,14 +591,8 @@ occCountX initialS topexpr k = case topexpr of
EReplicate1Inner ext (mka env') (mkb env')
SsFull -> occCountX (SsArr SsFull) topexpr k
- {-
- EMaximum1Inner _ e ->
- let (e', env) = re (SsArr (ssUnarr s)) e
- in (EMaximum1Inner s e', env)
- EMinimum1Inner _ e ->
- let (e', env) = re (SsArr (ssUnarr s)) e
- in (EMinimum1Inner s e', env)
- -}
+ EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
+ EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
EConst _ t x ->
k OccEnd $ \_ ->
@@ -680,15 +658,26 @@ occCountX initialS topexpr k = case topexpr of
k env1 $ \env' ->
projectSmallerSubstruc SsFull s $ EOp ext op (mke env')
- {-
- ECustom _ t1 t2 t3 e1 e2 e3 a b ->
- let (e1', _) = occCountX SsFull e1
- (e2', _) = occCountX SsFull e2
- (e3', _) = occCountX SsFull e3
- (a', env1) = re SsFull a -- let's be pessimistic here for safety
- (b', env2) = re SsFull b
- in (ECustom SsFull t1 t2 t3 e1' e2' e3' a' b', env1 <> env2)
- -}
+ ECustom _ t1 t2 t3 e1 e2 e3 a b
+ | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 ->
+ error "Accumulators not allowed in input/output/tape of an ECustom"
+ | otherwise ->
+ case s of
+ SsNone ->
+ -- Allowed to ignore e1/e2/e3 here because no accumulators are
+ -- communicated, and hence no relevant effects exist
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ s' -> -- Let's be pessimistic for safety
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s' $
+ ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env')
ERecompute _ e ->
occCountX s e $ \env1 mke ->
@@ -781,11 +770,31 @@ occCountX initialS topexpr k = case topexpr of
EError _ t msg ->
k OccEnd $ \_ -> EError ext (applySubstruc s t) msg
-
- _ -> error "occCountX: TODO unimplemented"
where
s = simplifySubstruc (typeOf topexpr) initialS
+ handleReduction :: t ~ TArr n (TScal t2)
+ => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2)))
+ -> Expr x env (TArr (S n) (TScal t2))
+ -> r
+ handleReduction reduce e
+ | STArr (SS n) _ <- typeOf e =
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env mke ->
+ k env $ \env' ->
+ use (mke env') $ ENil ext
+ SsArr SsNone ->
+ occCountX (SsArr SsNone) e $ \env mke ->
+ k env $ \env' ->
+ elet (mke env') $
+ EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext)
+ SsArr SsFull ->
+ occCountX (SsArr SsFull) e $ \env mke ->
+ k env $ \env' ->
+ reduce (mke env')
+ SsFull -> occCountX (SsArr SsFull) topexpr k
+
deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r
deleteUnused SNil (Some OccEnd) k = k SETop
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index 42bfb92..4ddcb50 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -171,15 +171,25 @@ type family ScalIsIntegral t where
ScalIsIntegral TBool = False
-- | Returns true for arrays /and/ accumulators.
-hasArrays :: STy t' -> Bool
-hasArrays STNil = False
-hasArrays (STPair a b) = hasArrays a || hasArrays b
-hasArrays (STEither a b) = hasArrays a || hasArrays b
-hasArrays (STLEither a b) = hasArrays a || hasArrays b
-hasArrays (STMaybe t) = hasArrays t
-hasArrays STArr{} = True
-hasArrays STScal{} = False
-hasArrays STAccum{} = True
+typeHasArrays :: STy t' -> Bool
+typeHasArrays STNil = False
+typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STMaybe t) = typeHasArrays t
+typeHasArrays STArr{} = True
+typeHasArrays STScal{} = False
+typeHasArrays STAccum{} = True
+
+typeHasAccums :: STy t' -> Bool
+typeHasAccums STNil = False
+typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STMaybe t) = typeHasAccums t
+typeHasAccums STArr{} = False
+typeHasAccums STScal{} = False
+typeHasAccums STAccum{} = True
type family Tup env where
Tup '[] = TNil