diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/AST/Count.hs | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/AST/Count.hs')
| -rw-r--r-- | src/AST/Count.hs | 930 |
1 files changed, 0 insertions, 930 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs deleted file mode 100644 index ac8634e..0000000 --- a/src/AST/Count.hs +++ /dev/null @@ -1,930 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ViewPatterns #-} -{-# LANGUAGE PatternSynonyms #-} -module AST.Count where - -import Data.Functor.Product -import Data.Some -import Data.Type.Equality -import GHC.Generics (Generic, Generically(..)) - -import Array -import AST -import AST.Env -import Data - - --- | The monoid operation combines assuming that /both/ branches are taken. -class Monoid a => Occurrence a where - -- | One of the two branches is taken - (<||>) :: a -> a -> a - -- | This code is executed many times - scaleMany :: a -> a - - -data Count = Zero | One | Many - deriving (Show, Eq, Ord) - -instance Semigroup Count where - Zero <> n = n - n <> Zero = n - _ <> _ = Many -instance Monoid Count where - mempty = Zero -instance Occurrence Count where - (<||>) = max - scaleMany Zero = Zero - scaleMany _ = Many - -data Occ = Occ { _occLexical :: Count - , _occRuntime :: Count } - deriving (Eq, Generic) - deriving (Semigroup, Monoid) via Generically Occ - -instance Show Occ where - showsPrec d (Occ l r) = showParen (d > 10) $ - showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r - -instance Occurrence Occ where - Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (r1 <||> r2) - scaleMany (Occ l c) = Occ l (scaleMany c) - - -data Substruc t t' where - -- If you add constructors here, do not forget to update the COMPLETE pragmas of any pattern synonyms below - SsFull :: Substruc t t - SsNone :: Substruc t TNil - SsPair :: Substruc a a' -> Substruc b b' -> Substruc (TPair a b) (TPair a' b') - SsEither :: Substruc a a' -> Substruc b b' -> Substruc (TEither a b) (TEither a' b') - SsLEither :: Substruc a a' -> Substruc b b' -> Substruc (TLEither a b) (TLEither a' b') - SsMaybe :: Substruc a a' -> Substruc (TMaybe a) (TMaybe a') - SsArr :: Substruc a a' -> Substruc (TArr n a) (TArr n a') -- ^ union of usages of all array elements - SsAccum :: Substruc a a' -> Substruc (TAccum a) (TAccum a') - -pattern SsPair' :: forall a b t'. forall a' b'. t' ~ TPair a' b' => Substruc a a' -> Substruc b b' -> Substruc (TPair a b) t' -pattern SsPair' s1 s2 <- ((\case { SsFull -> SsPair SsFull SsFull ; s -> s }) -> SsPair s1 s2) - where SsPair' = SsPair -{-# COMPLETE SsNone, SsPair', SsEither, SsLEither, SsMaybe, SsArr, SsAccum #-} - -pattern SsArr' :: forall n a t'. forall a'. t' ~ TArr n a' => Substruc a a' -> Substruc (TArr n a) t' -pattern SsArr' s <- ((\case { SsFull -> SsArr SsFull ; s -> s }) -> SsArr s) - where SsArr' = SsArr -{-# COMPLETE SsNone, SsPair, SsEither, SsLEither, SsMaybe, SsArr', SsAccum #-} - -instance Semigroup (Some (Substruc t)) where - Some SsFull <> _ = Some SsFull - _ <> Some SsFull = Some SsFull - Some SsNone <> s = s - s <> Some SsNone = s - Some (SsPair a b) <> Some (SsPair a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsPair a2 b2) - Some (SsEither a b) <> Some (SsEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsEither a2 b2) - Some (SsLEither a b) <> Some (SsLEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsLEither a2 b2) - Some (SsMaybe a) <> Some (SsMaybe a') = withSome (Some a <> Some a') $ \a2 -> Some (SsMaybe a2) - Some (SsArr a) <> Some (SsArr a') = withSome (Some a <> Some a') $ \a2 -> Some (SsArr a2) - Some (SsAccum a) <> Some (SsAccum a') = withSome (Some a <> Some a') $ \a2 -> Some (SsAccum a2) -instance Monoid (Some (Substruc t)) where - mempty = Some SsNone - -instance TestEquality (Substruc t) where - testEquality SsFull s = isFull s - testEquality s SsFull = sym <$> isFull s - testEquality SsNone SsNone = Just Refl - testEquality SsNone _ = Nothing - testEquality _ SsNone = Nothing - testEquality (SsPair a b) (SsPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing - testEquality (SsEither a b) (SsEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing - testEquality (SsLEither a b) (SsLEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing - testEquality (SsMaybe s) (SsMaybe s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing - testEquality (SsArr s) (SsArr s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing - testEquality (SsAccum s) (SsAccum s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing - -isFull :: Substruc t t' -> Maybe (t :~: t') -isFull SsFull = Just Refl -isFull SsNone = Nothing -- TODO: nil? -isFull (SsPair a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing -isFull (SsEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing -isFull (SsLEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing -isFull (SsMaybe s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing -isFull (SsArr s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing -isFull (SsAccum s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing - -applySubstruc :: Substruc t t' -> STy t -> STy t' -applySubstruc SsFull t = t -applySubstruc SsNone _ = STNil -applySubstruc (SsPair s1 s2) (STPair a b) = STPair (applySubstruc s1 a) (applySubstruc s2 b) -applySubstruc (SsEither s1 s2) (STEither a b) = STEither (applySubstruc s1 a) (applySubstruc s2 b) -applySubstruc (SsLEither s1 s2) (STLEither a b) = STLEither (applySubstruc s1 a) (applySubstruc s2 b) -applySubstruc (SsMaybe s) (STMaybe t) = STMaybe (applySubstruc s t) -applySubstruc (SsArr s) (STArr n t) = STArr n (applySubstruc s t) -applySubstruc (SsAccum s) (STAccum t) = STAccum (applySubstrucM s t) - -applySubstrucM :: Substruc t t' -> SMTy t -> SMTy t' -applySubstrucM SsFull t = t -applySubstrucM SsNone _ = SMTNil -applySubstrucM (SsPair s1 s2) (SMTPair a b) = SMTPair (applySubstrucM s1 a) (applySubstrucM s2 b) -applySubstrucM (SsLEither s1 s2) (SMTLEither a b) = SMTLEither (applySubstrucM s1 a) (applySubstrucM s2 b) -applySubstrucM (SsMaybe s) (SMTMaybe t) = SMTMaybe (applySubstrucM s t) -applySubstrucM (SsArr s) (SMTArr n t) = SMTArr n (applySubstrucM s t) -applySubstrucM _ t = case t of {} - -data ExMap a b = ExMap (forall env. Ex env a -> Ex env b) - | a ~ b => ExMapId - -fromExMap :: ExMap a b -> Ex env a -> Ex env b -fromExMap (ExMap f) = f -fromExMap ExMapId = id - -simplifySubstruc :: STy t -> Substruc t t' -> Substruc t t' -simplifySubstruc STNil SsNone = SsFull - -simplifySubstruc _ SsFull = SsFull -simplifySubstruc _ SsNone = SsNone -simplifySubstruc (STPair t1 t2) (SsPair s1 s2) = SsPair (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) -simplifySubstruc (STEither t1 t2) (SsEither s1 s2) = SsEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) -simplifySubstruc (STLEither t1 t2) (SsLEither s1 s2) = SsLEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) -simplifySubstruc (STMaybe t) (SsMaybe s) = SsMaybe (simplifySubstruc t s) -simplifySubstruc (STArr _ t) (SsArr s) = SsArr (simplifySubstruc t s) -simplifySubstruc (STAccum t) (SsAccum s) = SsAccum (simplifySubstruc (fromSMTy t) s) - --- simplifySubstruc' :: Substruc t t' --- -> (forall t'2. Substruc t t'2 -> ExMap t'2 t' -> r) -> r --- simplifySubstruc' SsFull k = k SsFull ExMapId --- simplifySubstruc' SsNone k = k SsNone ExMapId --- simplifySubstruc' (SsPair s1 s2) k = --- simplifySubstruc' s1 $ \s1' f1 -> --- simplifySubstruc' s2 $ \s2' f2 -> --- case (s1', s2') of --- (SsFull, SsFull) -> --- k SsFull (case (f1, f2) of --- (ExMapId, ExMapId) -> ExMapId --- _ -> ExMap (\e -> eunPair e $ \_ e1 e2 -> --- EPair ext (fromExMap f1 e1) (fromExMap f2 e2))) --- (SsNone, SsNone) -> k SsNone (ExMap (\_ -> EPair ext (fromExMap f1 (ENil ext)) (fromExMap f2 (ENil ext)))) --- _ -> k (SsPair s1' s2') (ExMap (\e -> elet e $ EPair ext (fromExMap f1 (EFst ext (evar IZ))) (fromExMap f2 (ESnd ext (evar IZ))))) --- simplifySubstruc' _ _ = _ - --- ssUnpair :: Substruc (TPair a b) -> (Substruc a, Substruc b) --- ssUnpair SsFull = (SsFull, SsFull) --- ssUnpair SsNone = (SsNone, SsNone) --- ssUnpair (SsPair a b) = (a, b) - --- ssUnleft :: Substruc (TEither a b) -> Substruc a --- ssUnleft SsFull = SsFull --- ssUnleft SsNone = SsNone --- ssUnleft (SsEither a _) = a - --- ssUnright :: Substruc (TEither a b) -> Substruc b --- ssUnright SsFull = SsFull --- ssUnright SsNone = SsNone --- ssUnright (SsEither _ b) = b - --- ssUnlleft :: Substruc (TLEither a b) -> Substruc a --- ssUnlleft SsFull = SsFull --- ssUnlleft SsNone = SsNone --- ssUnlleft (SsLEither a _) = a - --- ssUnlright :: Substruc (TLEither a b) -> Substruc b --- ssUnlright SsFull = SsFull --- ssUnlright SsNone = SsNone --- ssUnlright (SsLEither _ b) = b - --- ssUnjust :: Substruc (TMaybe a) -> Substruc a --- ssUnjust SsFull = SsFull --- ssUnjust SsNone = SsNone --- ssUnjust (SsMaybe a) = a - --- ssUnarr :: Substruc (TArr n a) -> Substruc a --- ssUnarr SsFull = SsFull --- ssUnarr SsNone = SsNone --- ssUnarr (SsArr a) = a - --- ssUnaccum :: Substruc (TAccum a) -> Substruc a --- ssUnaccum SsFull = SsFull --- ssUnaccum SsNone = SsNone --- ssUnaccum (SsAccum a) = a - - -type family MapEmpty env where - MapEmpty '[] = '[] - MapEmpty (t : env) = TNil : MapEmpty env - -data OccEnv a env env' where - OccEnd :: OccEnv a env (MapEmpty env) -- not necessarily top! - OccPush :: OccEnv a env env' -> a -> Substruc t t' -> OccEnv a (t : env) (t' : env') - -instance Semigroup a => Semigroup (Some (OccEnv a env)) where - Some OccEnd <> e = e - e <> Some OccEnd = e - Some (OccPush e o s) <> Some (OccPush e' o' s') = withSome (Some e <> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <> o') s2) - -instance Semigroup a => Monoid (Some (OccEnv a env)) where - mempty = Some OccEnd - -instance Occurrence a => Occurrence (Some (OccEnv a env)) where - Some OccEnd <||> e = e - e <||> Some OccEnd = e - Some (OccPush e o s) <||> Some (OccPush e' o' s') = withSome (Some e <||> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <||> o') s2) - - scaleMany (Some OccEnd) = Some OccEnd - scaleMany (Some (OccPush e o s)) = withSome (scaleMany (Some e)) $ \e2 -> Some (OccPush e2 (scaleMany o) s) - -onehotOccEnv :: Monoid a => Idx env t -> a -> Substruc t t' -> Some (OccEnv a env) -onehotOccEnv IZ v s = Some (OccPush OccEnd v s) -onehotOccEnv (IS i) v s - | Some env' <- onehotOccEnv i v s - = Some (OccPush env' mempty SsNone) - -occEnvPop :: OccEnv a (t : env) (t' : env') -> (OccEnv a env env', Substruc t t') -occEnvPop (OccPush e _ s) = (e, s) -occEnvPop OccEnd = (OccEnd, SsNone) - -occEnvPop' :: OccEnv a (t : env) env' -> (forall t' env''. env' ~ t' : env'' => OccEnv a env env'' -> Substruc t t' -> r) -> r -occEnvPop' (OccPush e _ s) k = k e s -occEnvPop' OccEnd k = k OccEnd SsNone - -occEnvPopSome :: Some (OccEnv a (t : env)) -> Some (OccEnv a env) -occEnvPopSome (Some (OccPush e _ _)) = Some e -occEnvPopSome (Some OccEnd) = Some OccEnd - -occEnvPrj :: Monoid a => OccEnv a env env' -> Idx env t -> (a, Some (Substruc t)) -occEnvPrj OccEnd _ = mempty -occEnvPrj (OccPush _ o s) IZ = (o, Some s) -occEnvPrj (OccPush e _ _) (IS i) = occEnvPrj e i - -occEnvPrjS :: OccEnv a env env' -> Idx env t -> Some (Product (Substruc t) (Idx env')) -occEnvPrjS OccEnd IZ = Some (Pair SsNone IZ) -occEnvPrjS OccEnd (IS i) | Some (Pair s i') <- occEnvPrjS OccEnd i = Some (Pair s (IS i')) -occEnvPrjS (OccPush _ _ s) IZ = Some (Pair s IZ) -occEnvPrjS (OccPush e _ _) (IS i) - | Some (Pair s' i') <- occEnvPrjS e i - = Some (Pair s' (IS i')) - -projectSmallerSubstruc :: Substruc t t'big -> Substruc t t'small -> Ex env t'big -> Ex env t'small -projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of - _ | Just Refl <- testEquality topsbig topssmall -> ex - - (SsFull, SsFull) -> ex - (SsNone, SsNone) -> ex - (SsNone, _) -> error "projectSmallerSubstruc: smaller substructure not smaller" - (_, SsNone) -> - case typeOf ex of - STNil -> ex - _ -> use ex $ ENil ext - - (SsPair s1 s2, SsPair s1' s2') -> - eunPair ex $ \_ e1 e2 -> - EPair ext (projectSmallerSubstruc s1 s1' e1) (projectSmallerSubstruc s2 s2' e2) - (s@SsPair{}, SsFull) -> projectSmallerSubstruc s (SsPair SsFull SsFull) ex - (SsFull, s@SsPair{}) -> projectSmallerSubstruc (SsPair SsFull SsFull) s ex - - (SsEither s1 s2, SsEither s1' s2') - | STEither t1 t2 <- typeOf ex -> - let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) - e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) - in ecase ex - (EInl ext (typeOf e2) e1) - (EInr ext (typeOf e1) e2) - (s@SsEither{}, SsFull) -> projectSmallerSubstruc s (SsEither SsFull SsFull) ex - (SsFull, s@SsEither{}) -> projectSmallerSubstruc (SsEither SsFull SsFull) s ex - - (SsLEither s1 s2, SsLEither s1' s2') - | STLEither t1 t2 <- typeOf ex -> - let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) - e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) - in elcase ex - (ELNil ext (typeOf e1) (typeOf e2)) - (ELInl ext (typeOf e2) e1) - (ELInr ext (typeOf e1) e2) - (s@SsLEither{}, SsFull) -> projectSmallerSubstruc s (SsLEither SsFull SsFull) ex - (SsFull, s@SsLEither{}) -> projectSmallerSubstruc (SsLEither SsFull SsFull) s ex - - (SsMaybe s1, SsMaybe s1') - | STMaybe t1 <- typeOf ex -> - let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) - in emaybe ex - (ENothing ext (typeOf e1)) - (EJust ext e1) - (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex - (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex - - (SsArr s1, SsArr s2) -> emap (projectSmallerSubstruc s1 s2 (evar IZ)) ex - (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex - (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex - - (SsAccum _, SsAccum _) -> error "TODO smaller ssaccum" - (s@SsAccum{}, SsFull) -> projectSmallerSubstruc s (SsAccum SsFull) ex - (SsFull, s@SsAccum{}) -> projectSmallerSubstruc (SsAccum SsFull) s ex - - --- | A boolean for each entry in the environment, with the ability to uniformly --- mask the top part above a certain index. -data EnvMask env where - EMRest :: Bool -> EnvMask env - EMPush :: EnvMask env -> Bool -> EnvMask (t : env) - -envMaskPrj :: EnvMask env -> Idx env t -> Bool -envMaskPrj (EMRest b) _ = b -envMaskPrj (_ `EMPush` b) IZ = b -envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i - -occCount :: Idx env a -> Expr x env t -> Occ -occCount idx ex - | Some env <- occCountAll ex - = fst (occEnvPrj env idx) - -occCountAll :: Expr x env t -> Some (OccEnv Occ env) -occCountAll ex = occCountX SsFull ex $ \env _ -> Some env - -pruneExpr :: SList f env -> Expr x env t -> Ex env t -pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env) - where - fullOccEnv :: SList f env -> OccEnv () env env - fullOccEnv SNil = OccEnd - fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull - --- In one traversal, count occurrences of variables and determine what parts of --- expressions are actually used. These two results are computed independently: --- even if (almost) nothing of a particular term is actually used, variable --- references in that term still count as usual. --- --- In @occCountX s t k@: --- * s: how much of the result of this term is required --- * t: the term to analyse --- * k: is passed the actual environment usage of this expression, including --- occurrence counts. The callback reconstructs a new expression in an --- updated "response" environment. The response must be at least as large as --- the computed usages. -occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t - -> (forall env'. OccEnv Occ env env' - -- response OccEnv must be at least as large as the OccEnv returned above - -> (forall env''. OccEnv () env env'' -> Ex env'' t') - -> r) - -> r -occCountX initialS topexpr k = case topexpr of - EVar _ t i -> - withSome (onehotOccEnv i (Occ One One) s) $ \env -> - k env $ \env' -> - withSome (occEnvPrjS env' i) $ \(Pair s' i') -> - projectSmallerSubstruc s' s (EVar ext (applySubstruc s' t) i') - ELet _ rhs body -> - occCountX s body $ \envB mkbody -> - occEnvPop' envB $ \envB' s1 -> - occCountX s1 rhs $ \envR mkrhs -> - withSome (Some envB' <> Some envR) $ \env -> - k env $ \env' -> - ELet ext (mkrhs env') (mkbody (OccPush env' () s1)) - EPair _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - SsPair' s1 s2 -> - occCountX s1 a $ \env1 mka -> - occCountX s2 b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EPair ext (mka env') (mkb env') - EFst _ e -> - occCountX (SsPair s SsNone) e $ \env1 mke -> - k env1 $ \env' -> - EFst ext (mke env') - ESnd _ e -> - occCountX (SsPair SsNone s) e $ \env1 mke -> - k env1 $ \env' -> - ESnd ext (mke env') - ENil _ -> - case s of - SsFull -> k OccEnd (\_ -> ENil ext) - SsNone -> k OccEnd (\_ -> ENil ext) - EInl _ t e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - SsEither s1 s2 -> - occCountX s1 e $ \env1 mke -> - k env1 $ \env' -> - EInl ext (applySubstruc s2 t) (mke env') - SsFull -> occCountX (SsEither SsFull SsFull) topexpr k - EInr _ t e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - SsEither s1 s2 -> - occCountX s2 e $ \env1 mke -> - k env1 $ \env' -> - EInr ext (applySubstruc s1 t) (mke env') - SsFull -> occCountX (SsEither SsFull SsFull) topexpr k - ECase _ e a b -> - occCountX s a $ \env1' mka -> - occCountX s b $ \env2' mkb -> - occEnvPop' env1' $ \env1 s1 -> - occEnvPop' env2' $ \env2 s2 -> - occCountX (SsEither s1 s2) e $ \env0 mke -> - withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> - k env $ \env' -> - ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2)) - ENothing _ t -> - case s of - SsNone -> k OccEnd (\_ -> ENil ext) - SsMaybe s' -> k OccEnd (\_ -> ENothing ext (applySubstruc s' t)) - SsFull -> occCountX (SsMaybe SsFull) topexpr k - EJust _ e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - SsMaybe s' -> - occCountX s' e $ \env1 mke -> - k env1 $ \env' -> - EJust ext (mke env') - SsFull -> occCountX (SsMaybe SsFull) topexpr k - EMaybe _ a b e -> - occCountX s a $ \env1 mka -> - occCountX s b $ \env2' mkb -> - occEnvPop' env2' $ \env2 s2 -> - occCountX (SsMaybe s2) e $ \env0 mke -> - withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> - k env $ \env' -> - EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env') - ELNil _ t1 t2 -> - case s of - SsNone -> k OccEnd (\_ -> ENil ext) - SsLEither s1 s2 -> k OccEnd (\_ -> ELNil ext (applySubstruc s1 t1) (applySubstruc s2 t2)) - SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k - ELInl _ t e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - SsLEither s1 s2 -> - occCountX s1 e $ \env1 mke -> - k env1 $ \env' -> - ELInl ext (applySubstruc s2 t) (mke env') - SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k - ELInr _ t e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - SsLEither s1 s2 -> - occCountX s2 e $ \env1 mke -> - k env1 $ \env' -> - ELInr ext (applySubstruc s1 t) (mke env') - SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k - ELCase _ e a b c -> - occCountX s a $ \env1 mka -> - occCountX s b $ \env2' mkb -> - occCountX s c $ \env3' mkc -> - occEnvPop' env2' $ \env2 s1 -> - occEnvPop' env3' $ \env3 s2 -> - occCountX (SsLEither s1 s2) e $ \env0 mke -> - withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env -> - k env $ \env' -> - ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2)) - - EConstArr _ n t x -> - case s of - SsNone -> k OccEnd (\_ -> ENil ext) - SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) - SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x) - - EBuild _ n a b -> - case s of - SsNone -> - occCountX SsFull a $ \env1 mka -> - occCountX SsNone b $ \env2'' mkb -> - occEnvPop' env2'' $ \env2' s2 -> - withSome (Some env1 <> scaleMany (Some env2')) $ \env -> - k env $ \env' -> - use (EBuild ext n (mka env') $ - use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $ - ENil ext) $ - ENil ext - SsArr' s' -> - occCountX SsFull a $ \env1 mka -> - occCountX s' b $ \env2'' mkb -> - occEnvPop' env2'' $ \env2' s2 -> - withSome (Some env1 <> scaleMany (Some env2')) $ \env -> - k env $ \env' -> - EBuild ext n (mka env') $ - elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) - - EMap _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1'' mka -> - occEnvPop' env1'' $ \env1' s1 -> - occCountX (SsArr s1) b $ \env2 mkb -> - withSome (scaleMany (Some env1') <> Some env2) $ \env -> - k env $ \env' -> - use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $ - ENil ext - SsArr' s' -> - occCountX s' a $ \env1'' mka -> - occEnvPop' env1'' $ \env1' s1 -> - occCountX (SsArr s1) b $ \env2 mkb -> - withSome (scaleMany (Some env1') <> Some env2) $ \env -> - k env $ \env' -> - EMap ext (mka (OccPush env' () s1)) (mkb env') - - EFold1Inner _ commut a b c -> - occCountX SsFull a $ \env1'' mka -> - occEnvPop' env1'' $ \env1' s1' -> - let s1 = case s1' of - SsNone -> Some SsNone - SsPair' s1'a s1'b -> Some s1'a <> Some s1'b - s0 = case s of - SsNone -> Some SsNone - SsArr' s' -> Some s' in - withSome (s1 <> s0) $ \sElt -> - occCountX sElt b $ \env2 mkb -> - occCountX (SsArr sElt) c $ \env3 mkc -> - withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> - k env $ \env' -> - projectSmallerSubstruc (SsArr sElt) s $ - EFold1Inner ext commut - (projectSmallerSubstruc SsFull sElt $ - mka (OccPush env' () (SsPair sElt sElt))) - (mkb env') (mkc env') - - ESum1Inner _ e -> handleReduction (ESum1Inner ext) e - - EUnit _ e -> - case s of - SsNone -> - occCountX SsNone e $ \env mke -> - k env $ \env' -> - use (mke env') $ ENil ext - SsArr' s' -> - occCountX s' e $ \env mke -> - k env $ \env' -> - EUnit ext (mke env') - - EReplicate1Inner _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - SsArr' s' -> - occCountX SsFull a $ \env1 mka -> - occCountX (SsArr s') b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EReplicate1Inner ext (mka env') (mkb env') - - EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e - EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e - - EReshape _ n esh e -> - case s of - SsNone -> - occCountX SsNone esh $ \env1 mkesh -> - occCountX SsNone e $ \env2 mke -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mkesh env') $ use (mke env') $ ENil ext - SsArr' s' -> - occCountX SsFull esh $ \env1 mkesh -> - occCountX (SsArr s') e $ \env2 mke -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EReshape ext n (mkesh env') (mke env') - - EZip _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - SsArr' SsNone -> - occCountX (SsArr SsNone) a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mkb env') $ mka env' - SsArr' (SsPair' SsNone s2) -> - occCountX SsNone a $ \env1 mka -> - occCountX (SsArr s2) b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ - emap (EPair ext (ENil ext) (evar IZ)) (mkb env') - SsArr' (SsPair' s1 SsNone) -> - occCountX (SsArr s1) a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mkb env') $ - emap (EPair ext (evar IZ) (ENil ext)) (mka env') - SsArr' (SsPair' s1 s2) -> - occCountX (SsArr s1) a $ \env1 mka -> - occCountX (SsArr s2) b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EZip ext (mka env') (mkb env') - - EFold1InnerD1 _ cm e1 e2 e3 -> - case s of - -- If nothing is necessary, we can execute a fold and then proceed to ignore it - SsNone -> - let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) - (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) - in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex - -- If we don't need the stores, still a fold suffices - SsPair' sP SsNone -> - let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) - (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) - in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext) - -- If for whatever reason the additional stores themselves are - -- unnecessary but the shape of the array is, then oblige - SsPair' sP (SsArr' SsNone) -> - let STArr sn _ = typeOf e3 - foldex = - elet (mapExt (\_ -> ext) e3) $ - EPair ext - (EShape ext (evar IZ)) - (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1))) - (mapExt (\_ -> ext) (weakenExpr WSink e2)) - (evar IZ)) - in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> - k env1 $ \env' -> - eunPair (mkfoldex env') $ \_ eshape earr -> - EPair ext earr (EBuild ext sn eshape (ENil ext)) - -- If at least some of the additional stores are required, we need to keep this a mapAccum - SsPair' _ (SsArr' sB) -> - -- TODO: propagate usage of primals - occCountX (SsPair SsFull sB) e1 $ \env1_1' mka -> - occEnvPop' env1_1' $ \env1' _ -> - occCountX SsFull e2 $ \env2 mkb -> - occCountX SsFull e3 $ \env3 mkc -> - withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> - k env $ \env' -> - projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ - EFold1InnerD1 ext cm (mka (OccPush env' () SsFull)) - (mkb env') (mkc env') - - EFold1InnerD2 _ cm ef ebog ed -> - -- TODO: propagate usage of duals - occCountX SsFull ef $ \env1_2' mkef -> - occEnvPop' env1_2' $ \env1_1' _ -> - occEnvPop' env1_1' $ \env1' sB -> - occCountX (SsArr sB) ebog $ \env2 mkebog -> - occCountX SsFull ed $ \env3 mked -> - withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> - k env $ \env' -> - projectSmallerSubstruc SsFull s $ - EFold1InnerD2 ext cm - (mkef (OccPush (OccPush env' () sB) () SsFull)) - (mkebog env') (mked env') - - EConst _ t x -> - k OccEnd $ \_ -> - case s of - SsNone -> ENil ext - SsFull -> EConst ext t x - - EIdx0 _ e -> - occCountX (SsArr s) e $ \env1 mke -> - k env1 $ \env' -> - EIdx0 ext (mke env') - - EIdx1 _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - SsArr' s' -> - occCountX (SsArr s') a $ \env1 mka -> - occCountX SsFull b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EIdx1 ext (mka env') (mkb env') - - EIdx _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - _ -> - occCountX (SsArr s) a $ \env1 mka -> - occCountX SsFull b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EIdx ext (mka env') (mkb env') - - EShape _ e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - _ -> - occCountX (SsArr SsNone) e $ \env1 mke -> - k env1 $ \env' -> - projectSmallerSubstruc SsFull s $ EShape ext (mke env') - - EOp _ op e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - _ -> - occCountX SsFull e $ \env1 mke -> - k env1 $ \env' -> - projectSmallerSubstruc SsFull s $ EOp ext op (mke env') - - ECustom _ t1 t2 t3 e1 e2 e3 a b - | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 -> - error "Accumulators not allowed in input/output/tape of an ECustom" - | otherwise -> - case s of - SsNone -> - -- Allowed to ignore e1/e2/e3 here because no accumulators are - -- communicated, and hence no relevant effects exist - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - s' -> -- Let's be pessimistic for safety - occCountX SsFull a $ \env1 mka -> - occCountX SsFull b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - projectSmallerSubstruc SsFull s' $ - ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env') - - ERecompute _ e -> - occCountX s e $ \env1 mke -> - k env1 $ \env' -> - ERecompute ext (mke env') - - EWith _ t a b -> - case s of - SsNone -> -- TODO: simplifier should remove accumulations to an unused with, and then remove the with - occCountX SsNone b $ \env2' mkb -> - occEnvPop' env2' $ \env2 s1 -> - withSome (case s1 of - SsFull -> Some SsFull - SsAccum s' -> Some s' - SsNone -> Some SsNone) $ \s1' -> - occCountX s1' a $ \env1 mka -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (EWith ext (applySubstrucM s1' t) (mka env') (mkb (OccPush env' () (SsAccum s1')))) $ - ENil ext - SsPair sB sA -> - occCountX sB b $ \env2' mkb -> - occEnvPop' env2' $ \env2 s1 -> - let s1' = case s1 of - SsFull -> Some SsFull - SsAccum s' -> Some s' - SsNone -> Some SsNone in - withSome (Some sA <> s1') $ \sA' -> - occCountX sA' a $ \env1 mka -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - projectSmallerSubstruc (SsPair sB sA') (SsPair sB sA) $ - EWith ext (applySubstrucM sA' t) (mka env') (mkb (OccPush env' () (SsAccum sA'))) - SsFull -> occCountX (SsPair SsFull SsFull) topexpr k - - EAccum _ t p a sp b e -> - -- TODO: do better! - occCountX SsFull a $ \env1 mka -> - occCountX SsFull b $ \env2 mkb -> - occCountX SsFull e $ \env3 mke -> - withSome (Some env1 <> Some env2) $ \env12 -> - withSome (Some env12 <> Some env3) $ \env -> - k env $ \env' -> - case s of {SsFull -> id; SsNone -> id} $ - EAccum ext t p (mka env') sp (mkb env') (mke env') - - EZero _ t e -> - occCountX (subZeroInfo s) e $ \env1 mke -> - k env1 $ \env' -> - EZero ext (applySubstrucM s t) (mke env') - where - subZeroInfo :: Substruc t1 t2 -> Substruc (ZeroInfo t1) (ZeroInfo t2) - subZeroInfo SsFull = SsFull - subZeroInfo SsNone = SsNone - subZeroInfo (SsPair s1 s2) = SsPair (subZeroInfo s1) (subZeroInfo s2) - subZeroInfo SsEither{} = error "Either is not a monoid" - subZeroInfo SsLEither{} = SsNone - subZeroInfo SsMaybe{} = SsNone - subZeroInfo (SsArr s') = SsArr (subZeroInfo s') - subZeroInfo SsAccum{} = error "Accum is not a monoid" - - EDeepZero _ t e -> - occCountX (subDeepZeroInfo s) e $ \env1 mke -> - k env1 $ \env' -> - EDeepZero ext (applySubstrucM s t) (mke env') - where - subDeepZeroInfo :: Substruc t1 t2 -> Substruc (DeepZeroInfo t1) (DeepZeroInfo t2) - subDeepZeroInfo SsFull = SsFull - subDeepZeroInfo SsNone = SsNone - subDeepZeroInfo (SsPair s1 s2) = SsPair (subDeepZeroInfo s1) (subDeepZeroInfo s2) - subDeepZeroInfo SsEither{} = error "Either is not a monoid" - subDeepZeroInfo (SsLEither s1 s2) = SsLEither (subDeepZeroInfo s1) (subDeepZeroInfo s2) - subDeepZeroInfo (SsMaybe s') = SsMaybe (subDeepZeroInfo s') - subDeepZeroInfo (SsArr s') = SsArr (subDeepZeroInfo s') - subDeepZeroInfo SsAccum{} = error "Accum is not a monoid" - - EPlus _ t a b -> - occCountX s a $ \env1 mka -> - occCountX s b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EPlus ext (applySubstrucM s t) (mka env') (mkb env') - - EOneHot _ t p a b -> - occCountX SsFull a $ \env1 mka -> - occCountX SsFull b $ \env2 mkb -> -- TODO: do better - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - projectSmallerSubstruc SsFull s $ EOneHot ext t p (mka env') (mkb env') - - EError _ t msg -> - k OccEnd $ \_ -> EError ext (applySubstruc s t) msg - where - s = simplifySubstruc (typeOf topexpr) initialS - - handleReduction :: t ~ TArr n (TScal t2) - => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2))) - -> Expr x env (TArr (S n) (TScal t2)) - -> r - handleReduction reduce e - | STArr (SS n) _ <- typeOf e = - case s of - SsNone -> - occCountX SsNone e $ \env mke -> - k env $ \env' -> - use (mke env') $ ENil ext - SsArr' SsNone -> - occCountX (SsArr SsNone) e $ \env mke -> - k env $ \env' -> - elet (mke env') $ - EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) - SsArr' SsFull -> - occCountX (SsArr SsFull) e $ \env mke -> - k env $ \env' -> - reduce (mke env') - - -deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r -deleteUnused SNil (Some OccEnd) k = k SETop -deleteUnused (_ `SCons` env) (Some OccEnd) k = - deleteUnused env (Some OccEnd) $ \sub -> k (SENo sub) -deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k = - deleteUnused env (Some occenv) $ \sub -> - case count of Zero -> k (SENo sub) - _ -> k (SEYesR sub) - -unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t -unsafeWeakenWithSubenv = \sub -> - subst (\x t i -> case sinkViaSubenv i sub of - Just i' -> EVar x t i' - Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") - where - sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) - sinkViaSubenv IZ (SEYesR _) = Just IZ - sinkViaSubenv IZ (SENo _) = Nothing - sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub - sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub |
