aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-07-23 21:09:59 +0200
committerTom Smeding <tom@tomsmeding.com>2025-07-23 21:09:59 +0200
commitf6b5850f949eb671f0c7038db6dff80ca23ed946 (patch)
tree9cbf14f3fe9512044b744549a4bad85a2a6a3cc0
parent888bf4ed19afb2970a5f449fc285d3ef217baed8 (diff)
WIP pruneExpr in AST.Countfancy-count
-rw-r--r--src/AST.hs29
-rw-r--r--src/AST/Count.hs834
-rw-r--r--src/AST/Pretty.hs1
-rw-r--r--src/AST/Sparse.hs3
-rw-r--r--src/AST/Weaken.hs6
-rw-r--r--src/CHAD.hs2
-rw-r--r--src/Simplify.hs11
7 files changed, 772 insertions, 114 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 5aab4fc..b8bee1b 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -21,6 +21,7 @@ module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) w
import Data.Functor.Const
import Data.Functor.Identity
+import Data.Int (Int64)
import Data.Kind (Type)
import Array
@@ -447,6 +448,16 @@ envKnown :: SList STy env -> Dict (KnownEnv env)
envKnown SNil = Dict
envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict
+cheapExpr :: Expr x env t -> Bool
+cheapExpr = \case
+ EVar{} -> True
+ ENil{} -> True
+ EConst{} -> True
+ EFst _ e -> cheapExpr e
+ ESnd _ e -> cheapExpr e
+ EUnit _ e -> cheapExpr e
+ _ -> False
+
eTup :: SList (Ex env) list -> Ex env (Tup list)
eTup = mkTup (ENil ext) (EPair ext)
@@ -516,6 +527,10 @@ eshapeEmpty (SS n) e =
(EConst ext STI64 0)))
(eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))))
+eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx))
+eshapeConst ShNil = ENil ext
+eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n))
+
-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t)
-- ezeroD2 t ezi = EZero ext (d2M t) ezi
@@ -527,6 +542,7 @@ eshapeEmpty (SS n) e =
eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r
eunPair (EPair _ e1 e2) k = k WId e1 e2
+eunPair e k | cheapExpr e = k WId (EFst ext e) (ESnd ext e)
eunPair e k =
elet e $
k WSink
@@ -546,13 +562,24 @@ elet rhs body
| Dict <- styKnown (typeOf rhs)
= ELet ext rhs body
+-- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost)
+use :: Ex env a -> Ex env b -> Ex env b
+use a b = elet a $ weakenExpr WSink b
+
emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b
emaybe e a b
| STMaybe t <- typeOf e
, Dict <- styKnown t
= EMaybe ext a b e
-elcase :: Ex env (TLEither a b) -> Ex env c -> (KnownTy a => Ex (a : env) c) -> (KnownTy b => Ex (b : env) c) -> Ex env c
+ecase :: Ex env (TEither a b) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c
+ecase e a b
+ | STEither t1 t2 <- typeOf e
+ , Dict <- styKnown t1
+ , Dict <- styKnown t2
+ = ECase ext e a b
+
+elcase :: Ex env (TLEither a b) -> ((KnownTy a, KnownTy b) => Ex env c) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c
elcase e a b c
| STLEither t1 t2 <- typeOf e
, Dict <- styKnown t1
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index ca4d7ab..8cd0192 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
@@ -10,17 +11,31 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# LANGUAGE PatternSynonyms #-}
module AST.Count where
-import Data.Functor.Const
+import Data.Functor.Product
+import Data.Some
+import Data.Type.Equality
import GHC.Generics (Generic, Generically(..))
+import Array
import AST
import AST.Env
import Data
+-- | The monoid operation combines assuming that /both/ branches are taken.
+class Monoid a => Occurrence a where
+ -- | One of the two branches is taken
+ (<||>) :: a -> a -> a
+ -- | This code is executed many times
+ scaleMany :: a -> a
+
+
data Count = Zero | One | Many
deriving (Show, Eq, Ord)
@@ -30,6 +45,10 @@ instance Semigroup Count where
_ <> _ = Many
instance Monoid Count where
mempty = Zero
+instance Occurrence Count where
+ (<||>) = max
+ scaleMany Zero = Zero
+ scaleMany _ = Many
data Occ = Occ { _occLexical :: Count
, _occRuntime :: Count }
@@ -40,120 +59,737 @@ instance Show Occ where
showsPrec d (Occ l r) = showParen (d > 10) $
showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r
--- | One of the two branches is taken
-(<||>) :: Occ -> Occ -> Occ
-Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
+instance Occurrence Occ where
+ Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (r1 <||> r2)
+ scaleMany (Occ l c) = Occ l (scaleMany c)
--- | This code is executed many times
-scaleMany :: Occ -> Occ
-scaleMany (Occ l Zero) = Occ l Zero
-scaleMany (Occ l _) = Occ l Many
-occCount :: Idx env a -> Expr x env t -> Occ
-occCount idx =
- getConst . occCountGeneral
- (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty)
- (\(Const o) -> Const o)
- (\(Const o1) (Const o2) -> Const (o1 <||> o2))
- (\(Const o) -> Const (scaleMany o))
+data Substruc t t' where
+ SsFull :: Substruc t t
+ SsNone :: Substruc t TNil
+ SsPair :: Substruc a a' -> Substruc b b' -> Substruc (TPair a b) (TPair a' b')
+ SsEither :: Substruc a a' -> Substruc b b' -> Substruc (TEither a b) (TEither a' b')
+ SsLEither :: Substruc a a' -> Substruc b b' -> Substruc (TLEither a b) (TLEither a' b')
+ SsMaybe :: Substruc a a' -> Substruc (TMaybe a) (TMaybe a')
+ SsArr :: Substruc a a' -> Substruc (TArr n a) (TArr n a') -- ^ union of usages of all array elements
+ SsAccum :: Substruc a a' -> Substruc (TAccum a) (TAccum a')
+
+pattern SsPair' :: forall a b t'. forall a' b'. t' ~ TPair a' b' => Substruc a a' -> Substruc b b' -> Substruc (TPair a b) t'
+pattern SsPair' s1 s2 <- ((\case { SsFull -> SsPair SsFull SsFull ; s -> s }) -> SsPair s1 s2)
+ where SsPair' = SsPair
+{-# COMPLETE SsNone, SsPair', SsEither, SsLEither, SsMaybe, SsArr, SsAccum #-}
+
+instance Semigroup (Some (Substruc t)) where
+ Some SsFull <> _ = Some SsFull
+ _ <> Some SsFull = Some SsFull
+ Some SsNone <> s = s
+ s <> Some SsNone = s
+ Some (SsPair a b) <> Some (SsPair a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsPair a2 b2)
+ Some (SsEither a b) <> Some (SsEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsEither a2 b2)
+ Some (SsLEither a b) <> Some (SsLEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsLEither a2 b2)
+ Some (SsMaybe a) <> Some (SsMaybe a') = withSome (Some a <> Some a') $ \a2 -> Some (SsMaybe a2)
+ Some (SsArr a) <> Some (SsArr a') = withSome (Some a <> Some a') $ \a2 -> Some (SsArr a2)
+ Some (SsAccum a) <> Some (SsAccum a') = withSome (Some a <> Some a') $ \a2 -> Some (SsAccum a2)
+instance Monoid (Some (Substruc t)) where
+ mempty = Some SsNone
+
+instance TestEquality (Substruc t) where
+ testEquality SsFull s = isFull s
+ testEquality s SsFull = sym <$> isFull s
+ testEquality SsNone SsNone = Just Refl
+ testEquality SsNone _ = Nothing
+ testEquality _ SsNone = Nothing
+ testEquality (SsPair a b) (SsPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing
+ testEquality (SsEither a b) (SsEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing
+ testEquality (SsLEither a b) (SsLEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing
+ testEquality (SsMaybe s) (SsMaybe s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing
+ testEquality (SsArr s) (SsArr s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing
+ testEquality (SsAccum s) (SsAccum s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing
+
+isFull :: Substruc t t' -> Maybe (t :~: t')
+isFull SsFull = Just Refl
+isFull SsNone = Nothing -- TODO: nil?
+isFull (SsPair a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing
+isFull (SsEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing
+isFull (SsLEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing
+isFull (SsMaybe s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing
+isFull (SsArr s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing
+isFull (SsAccum s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing
+
+applySubstruc :: Substruc t t' -> STy t -> STy t'
+applySubstruc SsFull t = t
+applySubstruc SsNone _ = STNil
+applySubstruc (SsPair s1 s2) (STPair a b) = STPair (applySubstruc s1 a) (applySubstruc s2 b)
+applySubstruc (SsEither s1 s2) (STEither a b) = STEither (applySubstruc s1 a) (applySubstruc s2 b)
+applySubstruc (SsLEither s1 s2) (STLEither a b) = STLEither (applySubstruc s1 a) (applySubstruc s2 b)
+applySubstruc (SsMaybe s) (STMaybe t) = STMaybe (applySubstruc s t)
+applySubstruc (SsArr s) (STArr n t) = STArr n (applySubstruc s t)
+applySubstruc (SsAccum s) (STAccum t) = STAccum (applySubstrucM s t)
+
+applySubstrucM :: Substruc t t' -> SMTy t -> SMTy t'
+applySubstrucM SsFull t = t
+applySubstrucM SsNone _ = SMTNil
+applySubstrucM (SsPair s1 s2) (SMTPair a b) = SMTPair (applySubstrucM s1 a) (applySubstrucM s2 b)
+applySubstrucM (SsLEither s1 s2) (SMTLEither a b) = SMTLEither (applySubstrucM s1 a) (applySubstrucM s2 b)
+applySubstrucM (SsMaybe s) (SMTMaybe t) = SMTMaybe (applySubstrucM s t)
+applySubstrucM (SsArr s) (SMTArr n t) = SMTArr n (applySubstrucM s t)
+applySubstrucM _ t = case t of {}
+
+data ExMap a b = ExMap (forall env. Ex env a -> Ex env b)
+ | a ~ b => ExMapId
+
+fromExMap :: ExMap a b -> Ex env a -> Ex env b
+fromExMap (ExMap f) = f
+fromExMap ExMapId = id
+
+simplifySubstruc :: STy t -> Substruc t t' -> Substruc t t'
+simplifySubstruc STNil SsNone = SsFull
+
+simplifySubstruc _ SsFull = SsFull
+simplifySubstruc _ SsNone = SsNone
+simplifySubstruc (STPair t1 t2) (SsPair s1 s2) = SsPair (simplifySubstruc t1 s1) (simplifySubstruc t2 s2)
+simplifySubstruc (STEither t1 t2) (SsEither s1 s2) = SsEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2)
+simplifySubstruc (STLEither t1 t2) (SsLEither s1 s2) = SsLEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2)
+simplifySubstruc (STMaybe t) (SsMaybe s) = SsMaybe (simplifySubstruc t s)
+simplifySubstruc (STArr _ t) (SsArr s) = SsArr (simplifySubstruc t s)
+simplifySubstruc (STAccum t) (SsAccum s) = SsAccum (simplifySubstruc (fromSMTy t) s)
+
+-- simplifySubstruc' :: Substruc t t'
+-- -> (forall t'2. Substruc t t'2 -> ExMap t'2 t' -> r) -> r
+-- simplifySubstruc' SsFull k = k SsFull ExMapId
+-- simplifySubstruc' SsNone k = k SsNone ExMapId
+-- simplifySubstruc' (SsPair s1 s2) k =
+-- simplifySubstruc' s1 $ \s1' f1 ->
+-- simplifySubstruc' s2 $ \s2' f2 ->
+-- case (s1', s2') of
+-- (SsFull, SsFull) ->
+-- k SsFull (case (f1, f2) of
+-- (ExMapId, ExMapId) -> ExMapId
+-- _ -> ExMap (\e -> eunPair e $ \_ e1 e2 ->
+-- EPair ext (fromExMap f1 e1) (fromExMap f2 e2)))
+-- (SsNone, SsNone) -> k SsNone (ExMap (\_ -> EPair ext (fromExMap f1 (ENil ext)) (fromExMap f2 (ENil ext))))
+-- _ -> k (SsPair s1' s2') (ExMap (\e -> elet e $ EPair ext (fromExMap f1 (EFst ext (evar IZ))) (fromExMap f2 (ESnd ext (evar IZ)))))
+-- simplifySubstruc' _ _ = _
+
+-- ssUnpair :: Substruc (TPair a b) -> (Substruc a, Substruc b)
+-- ssUnpair SsFull = (SsFull, SsFull)
+-- ssUnpair SsNone = (SsNone, SsNone)
+-- ssUnpair (SsPair a b) = (a, b)
+
+-- ssUnleft :: Substruc (TEither a b) -> Substruc a
+-- ssUnleft SsFull = SsFull
+-- ssUnleft SsNone = SsNone
+-- ssUnleft (SsEither a _) = a
+
+-- ssUnright :: Substruc (TEither a b) -> Substruc b
+-- ssUnright SsFull = SsFull
+-- ssUnright SsNone = SsNone
+-- ssUnright (SsEither _ b) = b
+
+-- ssUnlleft :: Substruc (TLEither a b) -> Substruc a
+-- ssUnlleft SsFull = SsFull
+-- ssUnlleft SsNone = SsNone
+-- ssUnlleft (SsLEither a _) = a
+
+-- ssUnlright :: Substruc (TLEither a b) -> Substruc b
+-- ssUnlright SsFull = SsFull
+-- ssUnlright SsNone = SsNone
+-- ssUnlright (SsLEither _ b) = b
+
+-- ssUnjust :: Substruc (TMaybe a) -> Substruc a
+-- ssUnjust SsFull = SsFull
+-- ssUnjust SsNone = SsNone
+-- ssUnjust (SsMaybe a) = a
+
+-- ssUnarr :: Substruc (TArr n a) -> Substruc a
+-- ssUnarr SsFull = SsFull
+-- ssUnarr SsNone = SsNone
+-- ssUnarr (SsArr a) = a
+
+-- ssUnaccum :: Substruc (TAccum a) -> Substruc a
+-- ssUnaccum SsFull = SsFull
+-- ssUnaccum SsNone = SsNone
+-- ssUnaccum (SsAccum a) = a
+
+
+type family MapEmpty env where
+ MapEmpty '[] = '[]
+ MapEmpty (t : env) = TNil : MapEmpty env
+
+data OccEnv a env env' where
+ OccEnd :: OccEnv a env (MapEmpty env) -- not necessarily top!
+ OccPush :: OccEnv a env env' -> a -> Substruc t t' -> OccEnv a (t : env) (t' : env')
+
+instance Semigroup a => Semigroup (Some (OccEnv a env)) where
+ Some OccEnd <> e = e
+ e <> Some OccEnd = e
+ Some (OccPush e o s) <> Some (OccPush e' o' s') = withSome (Some e <> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <> o') s2)
+
+instance Semigroup a => Monoid (Some (OccEnv a env)) where
+ mempty = Some OccEnd
+
+onehotOccEnv :: Monoid a => Idx env t -> a -> Substruc t t' -> Some (OccEnv a env)
+onehotOccEnv IZ v s = Some (OccPush OccEnd v s)
+onehotOccEnv (IS i) v s
+ | Some env' <- onehotOccEnv i v s
+ = Some (OccPush env' mempty SsNone)
+
+instance Occurrence a => Occurrence (Some (OccEnv a env)) where
+ Some OccEnd <||> e = e
+ e <||> Some OccEnd = e
+ Some (OccPush e o s) <||> Some (OccPush e' o' s') = withSome (Some e <||> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <||> o') s2)
+
+ scaleMany (Some OccEnd) = Some OccEnd
+ scaleMany (Some (OccPush e o s)) = withSome (scaleMany (Some e)) $ \e2 -> Some (OccPush e2 (scaleMany o) s)
+
+occEnvPop :: OccEnv a (t : env) (t' : env') -> (OccEnv a env env', Substruc t t')
+occEnvPop (OccPush e _ s) = (e, s)
+occEnvPop OccEnd = (OccEnd, SsNone)
+
+occEnvPop' :: OccEnv a (t : env) env' -> (forall t' env''. env' ~ t' : env'' => OccEnv a env env'' -> Substruc t t' -> r) -> r
+occEnvPop' (OccPush e _ s) k = k e s
+occEnvPop' OccEnd k = k OccEnd SsNone
+occEnvPopSome :: Some (OccEnv a (t : env)) -> Some (OccEnv a env)
+occEnvPopSome (Some (OccPush e _ _)) = Some e
+occEnvPopSome (Some OccEnd) = Some OccEnd
-data OccEnv env where
- OccEnd :: OccEnv env -- not necessarily top!
- OccPush :: OccEnv env -> Occ -> OccEnv (t : env)
+occEnvPrj :: Monoid a => OccEnv a env env' -> Idx env t -> (a, Some (Substruc t))
+occEnvPrj OccEnd _ = mempty
+occEnvPrj (OccPush _ o s) IZ = (o, Some s)
+occEnvPrj (OccPush e _ _) (IS i) = occEnvPrj e i
-instance Semigroup (OccEnv env) where
- OccEnd <> e = e
- e <> OccEnd = e
- OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o')
+occEnvPrjS :: OccEnv a env env' -> Idx env t -> Some (Product (Substruc t) (Idx env'))
+occEnvPrjS OccEnd IZ = Some (Pair SsNone IZ)
+occEnvPrjS OccEnd (IS i) | Some (Pair s i') <- occEnvPrjS OccEnd i = Some (Pair s (IS i'))
+occEnvPrjS (OccPush _ _ s) IZ = Some (Pair s IZ)
+occEnvPrjS (OccPush e _ _) (IS i)
+ | Some (Pair s' i') <- occEnvPrjS e i
+ = Some (Pair s' (IS i'))
-instance Monoid (OccEnv env) where
- mempty = OccEnd
+projectSmallerSubstruc :: Substruc t t'big -> Substruc t t'small -> Ex env t'big -> Ex env t'small
+projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of
+ _ | Just Refl <- testEquality topsbig topssmall -> ex
-onehotOccEnv :: Idx env t -> Occ -> OccEnv env
-onehotOccEnv IZ v = OccPush OccEnd v
-onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty
+ (SsFull, SsFull) -> ex
+ (SsNone, SsNone) -> ex
+ (SsNone, _) -> error "projectSmallerSubstruc: smaller substructure not smaller"
+ (_, SsNone) ->
+ case typeOf ex of
+ STNil -> ex
+ _ -> use ex $ ENil ext
-(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env
-OccEnd <||>! e = e
-e <||>! OccEnd = e
-OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o')
+ (SsPair s1 s2, SsPair s1' s2') ->
+ eunPair ex $ \_ e1 e2 ->
+ EPair ext (projectSmallerSubstruc s1 s1' e1) (projectSmallerSubstruc s2 s2' e2)
+ (s@SsPair{}, SsFull) -> projectSmallerSubstruc s (SsPair SsFull SsFull) ex
+ (SsFull, s@SsPair{}) -> projectSmallerSubstruc (SsPair SsFull SsFull) s ex
-scaleManyOccEnv :: OccEnv env -> OccEnv env
-scaleManyOccEnv OccEnd = OccEnd
-scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o)
+ (SsEither s1 s2, SsEither s1' s2')
+ | STEither t1 t2 <- typeOf ex ->
+ let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ)
+ e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ)
+ in ecase ex
+ (EInl ext (typeOf e2) e1)
+ (EInr ext (typeOf e1) e2)
+ (s@SsEither{}, SsFull) -> projectSmallerSubstruc s (SsEither SsFull SsFull) ex
+ (SsFull, s@SsEither{}) -> projectSmallerSubstruc (SsEither SsFull SsFull) s ex
-occEnvPop :: OccEnv (t : env) -> OccEnv env
-occEnvPop (OccPush o _) = o
-occEnvPop OccEnd = OccEnd
+ (SsLEither s1 s2, SsLEither s1' s2')
+ | STLEither t1 t2 <- typeOf ex ->
+ let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ)
+ e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ)
+ in elcase ex
+ (ELNil ext (typeOf e1) (typeOf e2))
+ (ELInl ext (typeOf e2) e1)
+ (ELInr ext (typeOf e1) e2)
+ (s@SsLEither{}, SsFull) -> projectSmallerSubstruc s (SsLEither SsFull SsFull) ex
+ (SsFull, s@SsLEither{}) -> projectSmallerSubstruc (SsLEither SsFull SsFull) s ex
-occCountAll :: Expr x env t -> OccEnv env
-occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv
+ (SsMaybe s1, SsMaybe s1')
+ | STMaybe t1 <- typeOf ex ->
+ let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ)
+ in emaybe ex
+ (ENothing ext (typeOf e1))
+ (EJust ext e1)
+ (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex
+ (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex
-occCountGeneral :: forall r env t x.
- (forall env'. Monoid (r env'))
- => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot
- -> (forall env' a. r (a : env') -> r env') -- ^ unpush
- -> (forall env'. r env' -> r env' -> r env') -- ^ alternation
- -> (forall env'. r env' -> r env') -- ^ scale-many
- -> Expr x env t -> r env
-occCountGeneral onehot unpush alter many = go WId
+ (SsArr s1, SsArr s2)
+ | STArr n t <- typeOf ex ->
+ elet ex $
+ EBuild ext n (EShape ext (evar IZ)) $
+ projectSmallerSubstruc s1 s2
+ (EIdx ext (EVar ext (STArr n t) (IS IZ))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
+ (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex
+ (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex
+
+ (SsAccum _, SsAccum _) -> error "TODO smaller ssaccum"
+ (s@SsAccum{}, SsFull) -> projectSmallerSubstruc s (SsAccum SsFull) ex
+ (SsFull, s@SsAccum{}) -> projectSmallerSubstruc (SsAccum SsFull) s ex
+
+
+-- | A boolean for each entry in the environment, with the ability to uniformly
+-- mask the top part above a certain index.
+data EnvMask env where
+ EMRest :: Bool -> EnvMask env
+ EMPush :: EnvMask env -> Bool -> EnvMask (t : env)
+
+envMaskPrj :: EnvMask env -> Idx env t -> Bool
+envMaskPrj (EMRest b) _ = b
+envMaskPrj (_ `EMPush` b) IZ = b
+envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i
+
+occCount :: Idx env a -> Expr x env t -> Occ
+occCount idx ex
+ | Some env <- occCountAll ex
+ = fst (occEnvPrj env idx)
+
+occCountAll :: Expr x env t -> Some (OccEnv Occ env)
+occCountAll ex = occCountX SsFull ex $ \env _ -> Some env
+
+pruneExpr :: SList f env -> Expr x env t -> Ex env t
+pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env)
where
- go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env'
- go w = \case
- EVar _ _ i -> onehot w i (Occ One One)
- ELet _ rhs body -> re rhs <> re1 body
- EPair _ a b -> re a <> re b
- EFst _ e -> re e
- ESnd _ e -> re e
- ENil _ -> mempty
- EInl _ _ e -> re e
- EInr _ _ e -> re e
- ECase _ e a b -> re e <> (re1 a `alter` re1 b)
- ENothing _ _ -> mempty
- EJust _ e -> re e
- EMaybe _ a b e -> re a <> re1 b <> re e
- ELNil _ _ _ -> mempty
- ELInl _ _ e -> re e
- ELInr _ _ e -> re e
- ELCase _ e a b c -> re e <> (re a `alter` re1 b `alter` re1 c)
- EConstArr{} -> mempty
- EBuild _ _ a b -> re a <> many (re1 b)
- EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c
- ESum1Inner _ e -> re e
- EUnit _ e -> re e
- EReplicate1Inner _ a b -> re a <> re b
- EMaximum1Inner _ e -> re e
- EMinimum1Inner _ e -> re e
- EConst{} -> mempty
- EIdx0 _ e -> re e
- EIdx1 _ a b -> re a <> re b
- EIdx _ a b -> re a <> re b
- EShape _ e -> re e
- EOp _ _ e -> re e
- ECustom _ _ _ _ _ _ _ a b -> re a <> re b
- ERecompute _ e -> re e
- EWith _ _ a b -> re a <> re1 b
- EAccum _ _ _ a _ b e -> re a <> re b <> re e
- EZero _ _ e -> re e
- EDeepZero _ _ e -> re e
- EPlus _ _ a b -> re a <> re b
- EOneHot _ _ _ a b -> re a <> re b
- EError{} -> mempty
- where
- re :: Monoid (r env') => Expr x env' t'' -> r env'
- re = go w
+ fullOccEnv :: SList f env -> OccEnv () env env
+ fullOccEnv SNil = OccEnd
+ fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull
+
+-- * s: how much of the result is required
+-- * k: is passed the actual environment usage of this expression, including
+-- occurrence counts. The callback reconstructs a new expression in an
+-- updated "response" environment. The response must be at least as large as
+-- the computed usages.
+occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t
+ -> (forall env'. OccEnv Occ env env'
+ -- response OccEnv must be at least as large as the OccEnv returned above
+ -> (forall env''. OccEnv () env env'' -> Ex env'' t')
+ -> r)
+ -> r
+occCountX initialS topexpr k = case topexpr of
+ EVar _ t i ->
+ withSome (onehotOccEnv i (Occ One One) s) $ \env ->
+ k env $ \env' ->
+ withSome (occEnvPrjS env' i) $ \(Pair s' i') ->
+ projectSmallerSubstruc s' s (EVar ext (applySubstruc s' t) i')
+ ELet _ rhs body ->
+ occCountX s body $ \envB mkbody ->
+ occEnvPop' envB $ \envB' s1 ->
+ occCountX s1 rhs $ \envR mkrhs ->
+ withSome (Some envB' <> Some envR) $ \env ->
+ k env $ \env' ->
+ ELet ext (mkrhs env') (mkbody (OccPush env' () s1))
+ EPair _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsPair' s1 s2 ->
+ occCountX s1 a $ \env1 mka ->
+ occCountX s2 b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EPair ext (mka env') (mkb env')
+ EFst _ e ->
+ occCountX (SsPair s SsNone) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EFst ext (mke env')
+ ESnd _ e ->
+ occCountX (SsPair SsNone s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ ESnd ext (mke env')
+ ENil _ ->
+ case s of
+ SsFull -> k OccEnd (\_ -> ENil ext)
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ EInl _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsEither s1 s2 ->
+ occCountX s1 e $ \env1 mke ->
+ k env1 $ \env' ->
+ EInl ext (applySubstruc s2 t) (mke env')
+ SsFull -> occCountX (SsEither SsFull SsFull) topexpr k
+ EInr _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsEither s1 s2 ->
+ occCountX s2 e $ \env1 mke ->
+ k env1 $ \env' ->
+ EInr ext (applySubstruc s1 t) (mke env')
+ SsFull -> occCountX (SsEither SsFull SsFull) topexpr k
+ ECase _ e a b ->
+ occCountX s a $ \env1' mka ->
+ occCountX s b $ \env2' mkb ->
+ occEnvPop' env1' $ \env1 s1 ->
+ occEnvPop' env2' $ \env2 s2 ->
+ occCountX (SsEither s1 s2) e $ \env0 mke ->
+ withSome (Some env0 <> Some env1) $ \env01 ->
+ withSome (Some env01 <> Some env2) $ \env ->
+ k env $ \env' ->
+ ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2))
+ ENothing _ t ->
+ case s of
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ SsMaybe s' -> k OccEnd (\_ -> ENothing ext (applySubstruc s' t))
+ SsFull -> occCountX (SsMaybe SsFull) topexpr k
+ EJust _ e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsMaybe s' ->
+ occCountX s' e $ \env1 mke ->
+ k env1 $ \env' ->
+ EJust ext (mke env')
+ SsFull -> occCountX (SsMaybe SsFull) topexpr k
+ EMaybe _ a b e ->
+ occCountX s a $ \env1 mka ->
+ occCountX s b $ \env2' mkb ->
+ occEnvPop' env2' $ \env2 s2 ->
+ occCountX (SsMaybe s2) e $ \env0 mke ->
+ withSome (Some env0 <> Some env1) $ \env01 ->
+ withSome (Some env01 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env')
+ ELNil _ t1 t2 ->
+ case s of
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ SsLEither s1 s2 -> k OccEnd (\_ -> ELNil ext (applySubstruc s1 t1) (applySubstruc s2 t2))
+ SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k
+ ELInl _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsLEither s1 s2 ->
+ occCountX s1 e $ \env1 mke ->
+ k env1 $ \env' ->
+ ELInl ext (applySubstruc s2 t) (mke env')
+ SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k
+ ELInr _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsLEither s1 s2 ->
+ occCountX s2 e $ \env1 mke ->
+ k env1 $ \env' ->
+ ELInr ext (applySubstruc s1 t) (mke env')
+ SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k
+ ELCase _ e a b c ->
+ occCountX s a $ \env1 mka ->
+ occCountX s b $ \env2' mkb ->
+ occCountX s c $ \env3' mkc ->
+ occEnvPop' env2' $ \env2 s1 ->
+ occEnvPop' env3' $ \env3 s2 ->
+ occCountX (SsLEither s1 s2) e $ \env0 mke ->
+ withSome (Some env0 <> Some env1) $ \env01 ->
+ withSome (Some env01 <> Some env2) $ \env012 ->
+ withSome (Some env012 <> Some env3) $ \env ->
+ k env $ \env' ->
+ ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2))
+
+ EConstArr _ n t x ->
+ case s of
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ SsArr SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext))
+ SsArr SsFull -> k OccEnd (\_ -> EConstArr ext n t x)
+ SsFull -> occCountX (SsArr SsFull) topexpr k
+
+ EBuild _ n a b ->
+ case s of
+ SsNone ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsNone b $ \env2' mkb ->
+ occEnvPop' env2' $ \env2 s2 ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (EBuild ext n (mka env') $
+ use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $
+ ENil ext) $
+ ENil ext
+ SsArr s' ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX s' b $ \env2' mkb ->
+ occEnvPop' env2' $ \env2 s2 ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EBuild ext n (mka env') $
+ elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))
+ SsFull -> occCountX (SsArr SsFull) topexpr k
+
+ EFold1Inner _ commut a b c ->
+ occCountX SsFull a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s2 ->
+ occEnvPop' env1' $ \env1 s1 ->
+ let s0 = case s of
+ SsNone -> Some SsNone
+ SsArr s' -> Some s'
+ SsFull -> Some SsFull in
+ withSome (Some s1 <> Some s2 <> s0) $ \sElt ->
+ occCountX sElt b $ \env2 mkb ->
+ occCountX (SsArr sElt) c $ \env3 mkc ->
+ withSome (Some env1 <> Some env2 <> Some env3) $ \env ->
+ k env $ \env' ->
+ let expr = EFold1Inner ext commut
+ (projectSmallerSubstruc SsFull sElt $
+ mka (OccPush (OccPush env' () sElt) () sElt))
+ (mkb env') (mkc env') in
+ case s of
+ SsNone -> use expr $ ENil ext
+ SsArr s' -> projectSmallerSubstruc (SsArr sElt) (SsArr s') expr
+ SsFull -> case testEquality sElt SsFull of
+ Just Refl -> expr
+ Nothing -> error "unreachable"
+
+ ESum1Inner _ e
+ | STArr (SS n) _ <- typeOf e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env mke ->
+ k env $ \env' ->
+ use (mke env') $ ENil ext
+ SsArr SsNone ->
+ occCountX (SsArr SsNone) e $ \env mke ->
+ k env $ \env' ->
+ elet (mke env') $
+ EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext)
+ SsArr SsFull ->
+ occCountX (SsArr SsFull) e $ \env mke ->
+ k env $ \env' ->
+ ESum1Inner ext (mke env')
+ SsFull -> occCountX (SsArr SsFull) topexpr k
+
+ EUnit _ e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env mke ->
+ k env $ \env' ->
+ use (mke env') $ ENil ext
+ SsArr s' ->
+ occCountX s' e $ \env mke ->
+ k env $ \env' ->
+ EUnit ext (mke env')
+ SsFull -> occCountX (SsArr SsFull) topexpr k
+
+ EReplicate1Inner _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsArr s' ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX (SsArr s') b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EReplicate1Inner ext (mka env') (mkb env')
+ SsFull -> occCountX (SsArr SsFull) topexpr k
+
+ {-
+ EMaximum1Inner _ e ->
+ let (e', env) = re (SsArr (ssUnarr s)) e
+ in (EMaximum1Inner s e', env)
+ EMinimum1Inner _ e ->
+ let (e', env) = re (SsArr (ssUnarr s)) e
+ in (EMinimum1Inner s e', env)
+ -}
- re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env'
- re1 = unpush . go (WSink .> w)
+ EConst _ t x ->
+ k OccEnd $ \_ ->
+ case s of
+ SsNone -> ENil ext
+ SsFull -> EConst ext t x
+
+ EIdx0 _ e ->
+ occCountX (SsArr s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EIdx0 ext (mke env')
+
+ EIdx1 _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsArr s' ->
+ occCountX (SsArr s') a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EIdx1 ext (mka env') (mkb env')
+ SsFull -> occCountX (SsArr SsFull) topexpr k
+
+ EIdx _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ _ ->
+ occCountX (SsArr s) a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EIdx ext (mka env') (mkb env')
+
+ EShape _ e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ _ ->
+ occCountX (SsArr SsNone) e $ \env1 mke ->
+ k env1 $ \env' ->
+ projectSmallerSubstruc SsFull s $ EShape ext (mke env')
+
+ EOp _ op e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ _ ->
+ occCountX SsFull e $ \env1 mke ->
+ k env1 $ \env' ->
+ projectSmallerSubstruc SsFull s $ EOp ext op (mke env')
+
+ {-
+ ECustom _ t1 t2 t3 e1 e2 e3 a b ->
+ let (e1', _) = occCountX SsFull e1
+ (e2', _) = occCountX SsFull e2
+ (e3', _) = occCountX SsFull e3
+ (a', env1) = re SsFull a -- let's be pessimistic here for safety
+ (b', env2) = re SsFull b
+ in (ECustom SsFull t1 t2 t3 e1' e2' e3' a' b', env1 <> env2)
+ -}
+
+ ERecompute _ e ->
+ occCountX s e $ \env1 mke ->
+ k env1 $ \env' ->
+ ERecompute ext (mke env')
+
+ EWith _ t a b ->
+ case s of
+ SsNone -> -- TODO: simplifier should remove accumulations to an unused with, and then remove the with
+ occCountX SsNone b $ \env2' mkb ->
+ occEnvPop' env2' $ \env2 s1 ->
+ withSome (case s1 of
+ SsFull -> Some SsFull
+ SsAccum s' -> Some s'
+ SsNone -> Some SsNone) $ \s1' ->
+ occCountX s1' a $ \env1 mka ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (EWith ext (applySubstrucM s1' t) (mka env') (mkb (OccPush env' () (SsAccum s1')))) $
+ ENil ext
+ SsPair sB sA ->
+ occCountX sB b $ \env2' mkb ->
+ occEnvPop' env2' $ \env2 s1 ->
+ let s1' = case s1 of
+ SsFull -> Some SsFull
+ SsAccum s' -> Some s'
+ SsNone -> Some SsNone in
+ withSome (Some sA <> s1') $ \sA' ->
+ occCountX sA' a $ \env1 mka ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc (SsPair sB sA') (SsPair sB sA) $
+ EWith ext (applySubstrucM sA' t) (mka env') (mkb (OccPush env' () (SsAccum sA')))
+ SsFull -> occCountX (SsPair SsFull SsFull) topexpr k
+
+ EAccum _ t p a sp b e ->
+ -- TODO: do better!
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ occCountX SsFull e $ \env3 mke ->
+ withSome (Some env1 <> Some env2) $ \env12 ->
+ withSome (Some env12 <> Some env3) $ \env ->
+ k env $ \env' ->
+ case s of {SsFull -> id; SsNone -> id} $
+ EAccum ext t p (mka env') sp (mkb env') (mke env')
+
+ EZero _ t e ->
+ occCountX (subZeroInfo s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EZero ext (applySubstrucM s t) (mke env')
+ where
+ subZeroInfo :: Substruc t1 t2 -> Substruc (ZeroInfo t1) (ZeroInfo t2)
+ subZeroInfo SsFull = SsFull
+ subZeroInfo SsNone = SsNone
+ subZeroInfo (SsPair s1 s2) = SsPair (subZeroInfo s1) (subZeroInfo s2)
+ subZeroInfo SsEither{} = error "Either is not a monoid"
+ subZeroInfo SsLEither{} = SsNone
+ subZeroInfo SsMaybe{} = SsNone
+ subZeroInfo (SsArr s') = SsArr (subZeroInfo s')
+ subZeroInfo SsAccum{} = error "Accum is not a monoid"
+
+ EDeepZero _ t e ->
+ occCountX (subDeepZeroInfo s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EDeepZero ext (applySubstrucM s t) (mke env')
+ where
+ subDeepZeroInfo :: Substruc t1 t2 -> Substruc (DeepZeroInfo t1) (DeepZeroInfo t2)
+ subDeepZeroInfo SsFull = SsFull
+ subDeepZeroInfo SsNone = SsNone
+ subDeepZeroInfo (SsPair s1 s2) = SsPair (subDeepZeroInfo s1) (subDeepZeroInfo s2)
+ subDeepZeroInfo SsEither{} = error "Either is not a monoid"
+ subDeepZeroInfo (SsLEither s1 s2) = SsLEither (subDeepZeroInfo s1) (subDeepZeroInfo s2)
+ subDeepZeroInfo (SsMaybe s') = SsMaybe (subDeepZeroInfo s')
+ subDeepZeroInfo (SsArr s') = SsArr (subDeepZeroInfo s')
+ subDeepZeroInfo SsAccum{} = error "Accum is not a monoid"
+
+ EPlus _ t a b ->
+ occCountX s a $ \env1 mka ->
+ occCountX s b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EPlus ext (applySubstrucM s t) (mka env') (mkb env')
+
+ EOneHot _ t p a b ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb -> -- TODO: do better
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s $ EOneHot ext t p (mka env') (mkb env')
+
+ EError _ t msg ->
+ k OccEnd $ \_ -> EError ext (applySubstruc s t) msg
+
+ _ -> error "occCountX: TODO unimplemented"
+ where
+ s = simplifySubstruc (typeOf topexpr) initialS
-deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r
-deleteUnused SNil OccEnd k = k SETop
-deleteUnused (_ `SCons` env) OccEnd k =
- deleteUnused env OccEnd $ \sub -> k (SENo sub)
-deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k =
- deleteUnused env occenv $ \sub ->
+deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r
+deleteUnused SNil (Some OccEnd) k = k SETop
+deleteUnused (_ `SCons` env) (Some OccEnd) k =
+ deleteUnused env (Some OccEnd) $ \sub -> k (SENo sub)
+deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k =
+ deleteUnused env (Some occenv) $ \sub ->
case count of Zero -> k (SENo sub)
_ -> k (SEYesR sub)
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index fef9686..9018602 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -6,6 +6,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where
diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs
index 93258b7..2a29799 100644
--- a/src/AST/Sparse.hs
+++ b/src/AST/Sparse.hs
@@ -85,9 +85,6 @@ withInj2 (Inj f) (Inj g) k = Inj (k f g)
withInj2 Noinj _ _ = Noinj
withInj2 _ Noinj _ = Noinj
-use :: Ex env a -> Ex env b -> Ex env b
-use a b = elet a $ weakenExpr WSink b
-
-- | This function produces quadratically-sized code in the presence of nested
-- dynamic sparsity. TODO can this be improved?
sparsePlusS
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index d882e28..3a97fd1 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -19,6 +19,7 @@ module AST.Weaken (module AST.Weaken, Append) where
import Data.Bifunctor (first)
import Data.Functor.Const
+import Data.GADT.Compare
import Data.Kind (Type)
import Data
@@ -31,6 +32,11 @@ data Idx env t where
IS :: Idx env t -> Idx (a : env) t
deriving instance Show (Idx env t)
+instance GEq (Idx env) where
+ geq IZ IZ = Just Refl
+ geq (IS i) (IS j) | Just Refl <- geq i j = Just Refl
+ geq _ _ = Nothing
+
splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t)
splitIdx SNil i = Right i
splitIdx (SCons _ _) IZ = Left IZ
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 143376a..dcb10aa 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1055,7 +1055,7 @@ drev des accumMap sd = \case
, let eltty = typeOf orige
, shty :: STy shty <- tTup (sreplicate ndim tIx)
, Refl <- indexTupD1Id ndim ->
- deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
+ deleteUnused (descrList des) (occEnvPopSome (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in
subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 74b6601..28b6d60 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -200,6 +200,7 @@ simplify'Rec = \case
EMaybe ext (ESnd ext e1) (ESnd ext e2) e3
-- TODO: more array indexing
+ EIdx _ (EBuild _ _ _ e1) e2 -> acted $ simplify' $ elet e2 e1
EIdx _ (EReplicate1Inner _ _ e2) e3 -> acted $ simplify' $ EIdx ext e2 (EFst ext e3)
EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1
@@ -336,16 +337,6 @@ simplify'Rec = \case
EPlus _ t a b -> [simprec| EPlus ext t *a *b |]
EError _ t s -> pure $ EError ext t s
-cheapExpr :: Expr x env t -> Bool
-cheapExpr = \case
- EVar{} -> True
- ENil{} -> True
- EConst{} -> True
- EFst _ e -> cheapExpr e
- ESnd _ e -> cheapExpr e
- EUnit _ e -> cheapExpr e
- _ -> False
-
-- | This can be made more precise by tracking (and not counting) adds on
-- locally eliminated accumulators.
hasAdds :: Expr x env t -> Bool