diff options
Diffstat (limited to 'src/AST/Count.hs')
-rw-r--r-- | src/AST/Count.hs | 81 |
1 files changed, 45 insertions, 36 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 5666289..bec5a9d 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -561,23 +561,7 @@ occCountX initialS topexpr k = case topexpr of Just Refl -> expr Nothing -> error "unreachable" - ESum1Inner _ e - | STArr (SS n) _ <- typeOf e -> - case s of - SsNone -> - occCountX SsNone e $ \env mke -> - k env $ \env' -> - use (mke env') $ ENil ext - SsArr SsNone -> - occCountX (SsArr SsNone) e $ \env mke -> - k env $ \env' -> - elet (mke env') $ - EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) - SsArr SsFull -> - occCountX (SsArr SsFull) e $ \env mke -> - k env $ \env' -> - ESum1Inner ext (mke env') - SsFull -> occCountX (SsArr SsFull) topexpr k + ESum1Inner _ e -> handleReduction (ESum1Inner ext) e EUnit _ e -> case s of @@ -607,14 +591,8 @@ occCountX initialS topexpr k = case topexpr of EReplicate1Inner ext (mka env') (mkb env') SsFull -> occCountX (SsArr SsFull) topexpr k - {- - EMaximum1Inner _ e -> - let (e', env) = re (SsArr (ssUnarr s)) e - in (EMaximum1Inner s e', env) - EMinimum1Inner _ e -> - let (e', env) = re (SsArr (ssUnarr s)) e - in (EMinimum1Inner s e', env) - -} + EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e + EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e EConst _ t x -> k OccEnd $ \_ -> @@ -680,15 +658,26 @@ occCountX initialS topexpr k = case topexpr of k env1 $ \env' -> projectSmallerSubstruc SsFull s $ EOp ext op (mke env') - {- - ECustom _ t1 t2 t3 e1 e2 e3 a b -> - let (e1', _) = occCountX SsFull e1 - (e2', _) = occCountX SsFull e2 - (e3', _) = occCountX SsFull e3 - (a', env1) = re SsFull a -- let's be pessimistic here for safety - (b', env2) = re SsFull b - in (ECustom SsFull t1 t2 t3 e1' e2' e3' a' b', env1 <> env2) - -} + ECustom _ t1 t2 t3 e1 e2 e3 a b + | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 -> + error "Accumulators not allowed in input/output/tape of an ECustom" + | otherwise -> + case s of + SsNone -> + -- Allowed to ignore e1/e2/e3 here because no accumulators are + -- communicated, and hence no relevant effects exist + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + s' -> -- Let's be pessimistic for safety + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s' $ + ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env') ERecompute _ e -> occCountX s e $ \env1 mke -> @@ -781,11 +770,31 @@ occCountX initialS topexpr k = case topexpr of EError _ t msg -> k OccEnd $ \_ -> EError ext (applySubstruc s t) msg - - _ -> error "occCountX: TODO unimplemented" where s = simplifySubstruc (typeOf topexpr) initialS + handleReduction :: t ~ TArr n (TScal t2) + => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2))) + -> Expr x env (TArr (S n) (TScal t2)) + -> r + handleReduction reduce e + | STArr (SS n) _ <- typeOf e = + case s of + SsNone -> + occCountX SsNone e $ \env mke -> + k env $ \env' -> + use (mke env') $ ENil ext + SsArr SsNone -> + occCountX (SsArr SsNone) e $ \env mke -> + k env $ \env' -> + elet (mke env') $ + EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) + SsArr SsFull -> + occCountX (SsArr SsFull) e $ \env mke -> + k env $ \env' -> + reduce (mke env') + SsFull -> occCountX (SsArr SsFull) topexpr k + deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r deleteUnused SNil (Some OccEnd) k = k SETop |