diff options
Diffstat (limited to 'src/AST/Count.hs')
| -rw-r--r-- | src/AST/Count.hs | 48 |
1 files changed, 28 insertions, 20 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs index bc02417..ac8634e 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -321,13 +321,7 @@ projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex - (SsArr s1, SsArr s2) - | STArr n t <- typeOf ex -> - elet ex $ - EBuild ext n (EShape ext (evar IZ)) $ - projectSmallerSubstruc s1 s2 - (EIdx ext (EVar ext (STArr n t) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) + (SsArr s1, SsArr s2) -> emap (projectSmallerSubstruc s1 s2 (evar IZ)) ex (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex @@ -560,22 +554,23 @@ occCountX initialS topexpr k = case topexpr of EMap ext (mka (OccPush env' () s1)) (mkb env') EFold1Inner _ commut a b c -> - occCountX SsFull a $ \env1''' mka -> - withSome (scaleMany (Some env1''')) $ \env1'' -> - occEnvPop' env1'' $ \env1' s2 -> - occEnvPop' env1' $ \env1 s1 -> - let s0 = case s of + occCountX SsFull a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1' -> + let s1 = case s1' of + SsNone -> Some SsNone + SsPair' s1'a s1'b -> Some s1'a <> Some s1'b + s0 = case s of SsNone -> Some SsNone SsArr' s' -> Some s' in - withSome (Some s1 <> Some s2 <> s0) $ \sElt -> + withSome (s1 <> s0) $ \sElt -> occCountX sElt b $ \env2 mkb -> - occCountX (SsArr sElt) c $ \env3 mkc -> - withSome (Some env1 <> Some env2 <> Some env3) $ \env -> + occCountX (SsArr sElt) c $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> projectSmallerSubstruc (SsArr sElt) s $ EFold1Inner ext commut (projectSmallerSubstruc SsFull sElt $ - mka (OccPush (OccPush env' () sElt) () sElt)) + mka (OccPush env' () (SsPair sElt sElt))) (mkb env') (mkc env') ESum1Inner _ e -> handleReduction (ESum1Inner ext) e @@ -638,6 +633,20 @@ occCountX initialS topexpr k = case topexpr of withSome (Some env1 <> Some env2) $ \env -> k env $ \env' -> use (mkb env') $ mka env' + SsArr' (SsPair' SsNone s2) -> + occCountX SsNone a $ \env1 mka -> + occCountX (SsArr s2) b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ + emap (EPair ext (ENil ext) (evar IZ)) (mkb env') + SsArr' (SsPair' s1 SsNone) -> + occCountX (SsArr s1) a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkb env') $ + emap (EPair ext (evar IZ) (ENil ext)) (mka env') SsArr' (SsPair' s1 s2) -> occCountX (SsArr s1) a $ \env1 mka -> occCountX (SsArr s2) b $ \env2 mkb -> @@ -665,7 +674,7 @@ occCountX initialS topexpr k = case topexpr of elet (mapExt (\_ -> ext) e3) $ EPair ext (EShape ext (evar IZ)) - (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1))) + (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1))) (mapExt (\_ -> ext) (weakenExpr WSink e2)) (evar IZ)) in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> @@ -675,15 +684,14 @@ occCountX initialS topexpr k = case topexpr of -- If at least some of the additional stores are required, we need to keep this a mapAccum SsPair' _ (SsArr' sB) -> -- TODO: propagate usage of primals - occCountX (SsPair SsFull sB) e1 $ \env1_2' mka -> - occEnvPop' env1_2' $ \env1_1' _ -> + occCountX (SsPair SsFull sB) e1 $ \env1_1' mka -> occEnvPop' env1_1' $ \env1' _ -> occCountX SsFull e2 $ \env2 mkb -> occCountX SsFull e3 $ \env3 mkc -> withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ - EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull)) + EFold1InnerD1 ext cm (mka (OccPush env' () SsFull)) (mkb env') (mkc env') EFold1InnerD2 _ cm ef ebog ed -> |
