diff options
Diffstat (limited to 'src/AST')
| -rw-r--r-- | src/AST/Count.hs | 834 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 1 | ||||
| -rw-r--r-- | src/AST/Sparse.hs | 3 | ||||
| -rw-r--r-- | src/AST/Weaken.hs | 6 | 
4 files changed, 742 insertions, 102 deletions
| diff --git a/src/AST/Count.hs b/src/AST/Count.hs index ca4d7ab..8cd0192 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -3,6 +3,7 @@  {-# LANGUAGE DerivingStrategies #-}  {-# LANGUAGE DerivingVia #-}  {-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleInstances #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE PolyKinds #-} @@ -10,17 +11,31 @@  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE PatternSynonyms #-}  module AST.Count where -import Data.Functor.Const +import Data.Functor.Product +import Data.Some +import Data.Type.Equality  import GHC.Generics (Generic, Generically(..)) +import Array  import AST  import AST.Env  import Data +-- | The monoid operation combines assuming that /both/ branches are taken. +class Monoid a => Occurrence a where +  -- | One of the two branches is taken +  (<||>) :: a -> a -> a +  -- | This code is executed many times +  scaleMany :: a -> a + +  data Count = Zero | One | Many    deriving (Show, Eq, Ord) @@ -30,6 +45,10 @@ instance Semigroup Count where    _ <> _ = Many  instance Monoid Count where    mempty = Zero +instance Occurrence Count where +  (<||>) = max +  scaleMany Zero = Zero +  scaleMany _ = Many  data Occ = Occ { _occLexical :: Count                 , _occRuntime :: Count } @@ -40,120 +59,737 @@ instance Show Occ where    showsPrec d (Occ l r) = showParen (d > 10) $      showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r --- | One of the two branches is taken -(<||>) :: Occ -> Occ -> Occ -Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) +instance Occurrence Occ where +  Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (r1 <||> r2) +  scaleMany (Occ l c) = Occ l (scaleMany c) --- | This code is executed many times -scaleMany :: Occ -> Occ -scaleMany (Occ l Zero) = Occ l Zero -scaleMany (Occ l _) = Occ l Many -occCount :: Idx env a -> Expr x env t -> Occ -occCount idx = -  getConst . occCountGeneral -    (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty) -    (\(Const o) -> Const o) -    (\(Const o1) (Const o2) -> Const (o1 <||> o2)) -    (\(Const o) -> Const (scaleMany o)) +data Substruc t t' where +  SsFull :: Substruc t t +  SsNone :: Substruc t TNil +  SsPair :: Substruc a a' -> Substruc b b' -> Substruc (TPair a b) (TPair a' b') +  SsEither :: Substruc a a' -> Substruc b b' -> Substruc (TEither a b) (TEither a' b') +  SsLEither :: Substruc a a' -> Substruc b b' -> Substruc (TLEither a b) (TLEither a' b') +  SsMaybe :: Substruc a a' -> Substruc (TMaybe a) (TMaybe a') +  SsArr :: Substruc a a' -> Substruc (TArr n a) (TArr n a')  -- ^ union of usages of all array elements +  SsAccum :: Substruc a a' -> Substruc (TAccum a) (TAccum a') + +pattern SsPair' :: forall a b t'. forall a' b'. t' ~ TPair a' b' => Substruc a a' -> Substruc b b' -> Substruc (TPair a b) t' +pattern SsPair' s1 s2 <- ((\case { SsFull -> SsPair SsFull SsFull ; s -> s }) -> SsPair s1 s2) +  where SsPair' = SsPair +{-# COMPLETE SsNone, SsPair', SsEither, SsLEither, SsMaybe, SsArr, SsAccum #-} + +instance Semigroup (Some (Substruc t)) where +  Some SsFull <> _ = Some SsFull +  _ <> Some SsFull = Some SsFull +  Some SsNone <> s = s +  s <> Some SsNone = s +  Some (SsPair a b) <> Some (SsPair a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsPair a2 b2) +  Some (SsEither a b) <> Some (SsEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsEither a2 b2) +  Some (SsLEither a b) <> Some (SsLEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsLEither a2 b2) +  Some (SsMaybe a) <> Some (SsMaybe a') = withSome (Some a <> Some a') $ \a2 -> Some (SsMaybe a2) +  Some (SsArr a) <> Some (SsArr a') = withSome (Some a <> Some a') $ \a2 -> Some (SsArr a2) +  Some (SsAccum a) <> Some (SsAccum a') = withSome (Some a <> Some a') $ \a2 -> Some (SsAccum a2) +instance Monoid (Some (Substruc t)) where +  mempty = Some SsNone + +instance TestEquality (Substruc t) where +  testEquality SsFull s = isFull s +  testEquality s SsFull = sym <$> isFull s +  testEquality SsNone SsNone = Just Refl +  testEquality SsNone _ = Nothing +  testEquality _ SsNone = Nothing +  testEquality (SsPair a b) (SsPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing +  testEquality (SsEither a b) (SsEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing +  testEquality (SsLEither a b) (SsLEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing +  testEquality (SsMaybe s) (SsMaybe s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing +  testEquality (SsArr s) (SsArr s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing +  testEquality (SsAccum s) (SsAccum s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + +isFull :: Substruc t t' -> Maybe (t :~: t') +isFull SsFull = Just Refl +isFull SsNone = Nothing  -- TODO: nil? +isFull (SsPair a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsLEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsMaybe s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing +isFull (SsArr s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing +isFull (SsAccum s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing + +applySubstruc :: Substruc t t' -> STy t -> STy t' +applySubstruc SsFull t = t +applySubstruc SsNone _ = STNil +applySubstruc (SsPair s1 s2) (STPair a b) = STPair (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsEither s1 s2) (STEither a b) = STEither (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsLEither s1 s2) (STLEither a b) = STLEither (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsMaybe s) (STMaybe t) = STMaybe (applySubstruc s t) +applySubstruc (SsArr s) (STArr n t) = STArr n (applySubstruc s t) +applySubstruc (SsAccum s) (STAccum t) = STAccum (applySubstrucM s t) + +applySubstrucM :: Substruc t t' -> SMTy t -> SMTy t' +applySubstrucM SsFull t = t +applySubstrucM SsNone _ = SMTNil +applySubstrucM (SsPair s1 s2) (SMTPair a b) = SMTPair (applySubstrucM s1 a) (applySubstrucM s2 b) +applySubstrucM (SsLEither s1 s2) (SMTLEither a b) = SMTLEither (applySubstrucM s1 a) (applySubstrucM s2 b) +applySubstrucM (SsMaybe s) (SMTMaybe t) = SMTMaybe (applySubstrucM s t) +applySubstrucM (SsArr s) (SMTArr n t) = SMTArr n (applySubstrucM s t) +applySubstrucM _ t = case t of {} + +data ExMap a b = ExMap (forall env. Ex env a -> Ex env b) +               | a ~ b => ExMapId + +fromExMap :: ExMap a b -> Ex env a -> Ex env b +fromExMap (ExMap f) = f +fromExMap ExMapId = id + +simplifySubstruc :: STy t -> Substruc t t' -> Substruc t t' +simplifySubstruc STNil SsNone = SsFull + +simplifySubstruc _ SsFull = SsFull +simplifySubstruc _ SsNone = SsNone +simplifySubstruc (STPair t1 t2) (SsPair s1 s2) = SsPair (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STEither t1 t2) (SsEither s1 s2) = SsEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STLEither t1 t2) (SsLEither s1 s2) = SsLEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STMaybe t) (SsMaybe s) = SsMaybe (simplifySubstruc t s) +simplifySubstruc (STArr _ t) (SsArr s) = SsArr (simplifySubstruc t s) +simplifySubstruc (STAccum t) (SsAccum s) = SsAccum (simplifySubstruc (fromSMTy t) s) + +-- simplifySubstruc' :: Substruc t t' +--                  -> (forall t'2. Substruc t t'2 -> ExMap t'2 t' -> r) -> r +-- simplifySubstruc' SsFull k = k SsFull ExMapId +-- simplifySubstruc' SsNone k = k SsNone ExMapId +-- simplifySubstruc' (SsPair s1 s2) k = +--   simplifySubstruc' s1 $ \s1' f1 -> +--   simplifySubstruc' s2 $ \s2' f2 -> +--   case (s1', s2') of +--     (SsFull, SsFull) -> +--       k SsFull (case (f1, f2) of +--                   (ExMapId, ExMapId) -> ExMapId +--                   _ -> ExMap (\e -> eunPair e $ \_ e1 e2 -> +--                                       EPair ext (fromExMap f1 e1) (fromExMap f2 e2))) +--     (SsNone, SsNone) -> k SsNone (ExMap (\_ -> EPair ext (fromExMap f1 (ENil ext)) (fromExMap f2 (ENil ext)))) +--     _ -> k (SsPair s1' s2') (ExMap (\e -> elet e $ EPair ext (fromExMap f1 (EFst ext (evar IZ))) (fromExMap f2 (ESnd ext (evar IZ))))) +-- simplifySubstruc' _ _ = _ + +-- ssUnpair :: Substruc (TPair a b) -> (Substruc a, Substruc b) +-- ssUnpair SsFull = (SsFull, SsFull) +-- ssUnpair SsNone = (SsNone, SsNone) +-- ssUnpair (SsPair a b) = (a, b) + +-- ssUnleft :: Substruc (TEither a b) -> Substruc a +-- ssUnleft SsFull = SsFull +-- ssUnleft SsNone = SsNone +-- ssUnleft (SsEither a _) = a + +-- ssUnright :: Substruc (TEither a b) -> Substruc b +-- ssUnright SsFull = SsFull +-- ssUnright SsNone = SsNone +-- ssUnright (SsEither _ b) = b + +-- ssUnlleft :: Substruc (TLEither a b) -> Substruc a +-- ssUnlleft SsFull = SsFull +-- ssUnlleft SsNone = SsNone +-- ssUnlleft (SsLEither a _) = a + +-- ssUnlright :: Substruc (TLEither a b) -> Substruc b +-- ssUnlright SsFull = SsFull +-- ssUnlright SsNone = SsNone +-- ssUnlright (SsLEither _ b) = b + +-- ssUnjust :: Substruc (TMaybe a) -> Substruc a +-- ssUnjust SsFull = SsFull +-- ssUnjust SsNone = SsNone +-- ssUnjust (SsMaybe a) = a + +-- ssUnarr :: Substruc (TArr n a) -> Substruc a +-- ssUnarr SsFull = SsFull +-- ssUnarr SsNone = SsNone +-- ssUnarr (SsArr a) = a + +-- ssUnaccum :: Substruc (TAccum a) -> Substruc a +-- ssUnaccum SsFull = SsFull +-- ssUnaccum SsNone = SsNone +-- ssUnaccum (SsAccum a) = a + + +type family MapEmpty env where +  MapEmpty '[] = '[] +  MapEmpty (t : env) = TNil : MapEmpty env + +data OccEnv a env env' where +  OccEnd :: OccEnv a env (MapEmpty env)  -- not necessarily top! +  OccPush :: OccEnv a env env' -> a -> Substruc t t' -> OccEnv a (t : env) (t' : env') + +instance Semigroup a => Semigroup (Some (OccEnv a env)) where +  Some OccEnd <> e = e +  e <> Some OccEnd = e +  Some (OccPush e o s) <> Some (OccPush e' o' s') = withSome (Some e <> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <> o') s2) + +instance Semigroup a => Monoid (Some (OccEnv a env)) where +  mempty = Some OccEnd + +onehotOccEnv :: Monoid a => Idx env t -> a -> Substruc t t' -> Some (OccEnv a env) +onehotOccEnv IZ v s = Some (OccPush OccEnd v s) +onehotOccEnv (IS i) v s +  | Some env' <- onehotOccEnv i v s +  = Some (OccPush env' mempty SsNone) + +instance Occurrence a => Occurrence (Some (OccEnv a env)) where +  Some OccEnd <||> e = e +  e <||> Some OccEnd = e +  Some (OccPush e o s) <||> Some (OccPush e' o' s') = withSome (Some e <||> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <||> o') s2) + +  scaleMany (Some OccEnd) = Some OccEnd +  scaleMany (Some (OccPush e o s)) = withSome (scaleMany (Some e)) $ \e2 -> Some (OccPush e2 (scaleMany o) s) + +occEnvPop :: OccEnv a (t : env) (t' : env') -> (OccEnv a env env', Substruc t t') +occEnvPop (OccPush e _ s) = (e, s) +occEnvPop OccEnd = (OccEnd, SsNone) + +occEnvPop' :: OccEnv a (t : env) env' -> (forall t' env''. env' ~ t' : env'' => OccEnv a env env'' -> Substruc t t' -> r) -> r +occEnvPop' (OccPush e _ s) k = k e s +occEnvPop' OccEnd k = k OccEnd SsNone +occEnvPopSome :: Some (OccEnv a (t : env)) -> Some (OccEnv a env) +occEnvPopSome (Some (OccPush e _ _)) = Some e +occEnvPopSome (Some OccEnd) = Some OccEnd -data OccEnv env where -  OccEnd :: OccEnv env  -- not necessarily top! -  OccPush :: OccEnv env -> Occ -> OccEnv (t : env) +occEnvPrj :: Monoid a => OccEnv a env env' -> Idx env t -> (a, Some (Substruc t)) +occEnvPrj OccEnd _ = mempty +occEnvPrj (OccPush _ o s) IZ = (o, Some s) +occEnvPrj (OccPush e _ _) (IS i) = occEnvPrj e i -instance Semigroup (OccEnv env) where -  OccEnd <> e = e -  e <> OccEnd = e -  OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o') +occEnvPrjS :: OccEnv a env env' -> Idx env t -> Some (Product (Substruc t) (Idx env')) +occEnvPrjS OccEnd IZ = Some (Pair SsNone IZ) +occEnvPrjS OccEnd (IS i) | Some (Pair s i') <- occEnvPrjS OccEnd i = Some (Pair s (IS i')) +occEnvPrjS (OccPush _ _ s) IZ = Some (Pair s IZ) +occEnvPrjS (OccPush e _ _) (IS i) +  | Some (Pair s' i') <- occEnvPrjS e i +  = Some (Pair s' (IS i')) -instance Monoid (OccEnv env) where -  mempty = OccEnd +projectSmallerSubstruc :: Substruc t t'big -> Substruc t t'small -> Ex env t'big -> Ex env t'small +projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of +  _ | Just Refl <- testEquality topsbig topssmall -> ex -onehotOccEnv :: Idx env t -> Occ -> OccEnv env -onehotOccEnv IZ v = OccPush OccEnd v -onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty +  (SsFull, SsFull) -> ex +  (SsNone, SsNone) -> ex +  (SsNone, _) -> error "projectSmallerSubstruc: smaller substructure not smaller" +  (_, SsNone) -> +    case typeOf ex of +      STNil -> ex +      _ -> use ex $ ENil ext -(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env -OccEnd <||>! e = e -e <||>! OccEnd = e -OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o') +  (SsPair s1 s2, SsPair s1' s2') -> +    eunPair ex $ \_ e1 e2 -> +      EPair ext (projectSmallerSubstruc s1 s1' e1) (projectSmallerSubstruc s2 s2' e2) +  (s@SsPair{}, SsFull) -> projectSmallerSubstruc s (SsPair SsFull SsFull) ex +  (SsFull, s@SsPair{}) -> projectSmallerSubstruc (SsPair SsFull SsFull) s ex -scaleManyOccEnv :: OccEnv env -> OccEnv env -scaleManyOccEnv OccEnd = OccEnd -scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o) +  (SsEither s1 s2, SsEither s1' s2') +    | STEither t1 t2 <- typeOf ex -> +    let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) +        e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) +    in ecase ex +         (EInl ext (typeOf e2) e1) +         (EInr ext (typeOf e1) e2) +  (s@SsEither{}, SsFull) -> projectSmallerSubstruc s (SsEither SsFull SsFull) ex +  (SsFull, s@SsEither{}) -> projectSmallerSubstruc (SsEither SsFull SsFull) s ex -occEnvPop :: OccEnv (t : env) -> OccEnv env -occEnvPop (OccPush o _) = o -occEnvPop OccEnd = OccEnd +  (SsLEither s1 s2, SsLEither s1' s2') +    | STLEither t1 t2 <- typeOf ex -> +    let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) +        e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) +    in elcase ex +         (ELNil ext (typeOf e1) (typeOf e2)) +         (ELInl ext (typeOf e2) e1) +         (ELInr ext (typeOf e1) e2) +  (s@SsLEither{}, SsFull) -> projectSmallerSubstruc s (SsLEither SsFull SsFull) ex +  (SsFull, s@SsLEither{}) -> projectSmallerSubstruc (SsLEither SsFull SsFull) s ex -occCountAll :: Expr x env t -> OccEnv env -occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv +  (SsMaybe s1, SsMaybe s1') +    | STMaybe t1 <- typeOf ex -> +    let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) +    in emaybe ex +         (ENothing ext (typeOf e1)) +         (EJust ext e1) +  (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex +  (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex -occCountGeneral :: forall r env t x. -                   (forall env'. Monoid (r env')) -                => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env')  -- ^ one-hot -                -> (forall env' a. r (a : env') -> r env')  -- ^ unpush -                -> (forall env'. r env' -> r env' -> r env')  -- ^ alternation -                -> (forall env'. r env' -> r env')  -- ^ scale-many -                -> Expr x env t -> r env -occCountGeneral onehot unpush alter many = go WId +  (SsArr s1, SsArr s2) +    | STArr n t <- typeOf ex -> +    elet ex $ +      EBuild ext n (EShape ext (evar IZ)) $ +        projectSmallerSubstruc s1 s2 +            (EIdx ext (EVar ext (STArr n t) (IS IZ)) +                      (EVar ext (tTup (sreplicate n tIx)) IZ)) +  (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex +  (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex + +  (SsAccum _, SsAccum _) -> error "TODO smaller ssaccum" +  (s@SsAccum{}, SsFull) -> projectSmallerSubstruc s (SsAccum SsFull) ex +  (SsFull, s@SsAccum{}) -> projectSmallerSubstruc (SsAccum SsFull) s ex + + +-- | A boolean for each entry in the environment, with the ability to uniformly +-- mask the top part above a certain index. +data EnvMask env where +  EMRest :: Bool -> EnvMask env +  EMPush :: EnvMask env -> Bool -> EnvMask (t : env) + +envMaskPrj :: EnvMask env -> Idx env t -> Bool +envMaskPrj (EMRest b) _ = b +envMaskPrj (_ `EMPush` b) IZ = b +envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i + +occCount :: Idx env a -> Expr x env t -> Occ +occCount idx ex +  | Some env <- occCountAll ex +  = fst (occEnvPrj env idx) + +occCountAll :: Expr x env t -> Some (OccEnv Occ env) +occCountAll ex = occCountX SsFull ex $ \env _ -> Some env + +pruneExpr :: SList f env -> Expr x env t -> Ex env t +pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env)    where -    go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env' -    go w = \case -      EVar _ _ i -> onehot w i (Occ One One) -      ELet _ rhs body -> re rhs <> re1 body -      EPair _ a b -> re a <> re b -      EFst _ e -> re e -      ESnd _ e -> re e -      ENil _ -> mempty -      EInl _ _ e -> re e -      EInr _ _ e -> re e -      ECase _ e a b -> re e <> (re1 a `alter` re1 b) -      ENothing _ _ -> mempty -      EJust _ e -> re e -      EMaybe _ a b e -> re a <> re1 b <> re e -      ELNil _ _ _ -> mempty -      ELInl _ _ e -> re e -      ELInr _ _ e -> re e -      ELCase _ e a b c -> re e <> (re a `alter` re1 b `alter` re1 c) -      EConstArr{} -> mempty -      EBuild _ _ a b -> re a <> many (re1 b) -      EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c -      ESum1Inner _ e -> re e -      EUnit _ e -> re e -      EReplicate1Inner _ a b -> re a <> re b -      EMaximum1Inner _ e -> re e -      EMinimum1Inner _ e -> re e -      EConst{} -> mempty -      EIdx0 _ e -> re e -      EIdx1 _ a b -> re a <> re b -      EIdx _ a b -> re a <> re b -      EShape _ e -> re e -      EOp _ _ e -> re e -      ECustom _ _ _ _ _ _ _ a b -> re a <> re b -      ERecompute _ e -> re e -      EWith _ _ a b -> re a <> re1 b -      EAccum _ _ _ a _ b e -> re a <> re b <> re e -      EZero _ _ e -> re e -      EDeepZero _ _ e -> re e -      EPlus _ _ a b -> re a <> re b -      EOneHot _ _ _ a b -> re a <> re b -      EError{} -> mempty -      where -        re :: Monoid (r env') => Expr x env' t'' -> r env' -        re = go w +    fullOccEnv :: SList f env -> OccEnv () env env +    fullOccEnv SNil = OccEnd +    fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull + +-- * s: how much of the result is required +-- * k: is passed the actual environment usage of this expression, including +--   occurrence counts. The callback reconstructs a new expression in an +--   updated "response" environment. The response must be at least as large as +--   the computed usages. +occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t +          -> (forall env'. OccEnv Occ env env' +                           -- response OccEnv must be at least as large as the OccEnv returned above +                        -> (forall env''. OccEnv () env env'' -> Ex env'' t') +                        -> r) +          -> r +occCountX initialS topexpr k = case topexpr of +  EVar _ t i -> +    withSome (onehotOccEnv i (Occ One One) s) $ \env -> +    k env $ \env' -> +    withSome (occEnvPrjS env' i) $ \(Pair s' i') -> +      projectSmallerSubstruc s' s (EVar ext (applySubstruc s' t) i') +  ELet _ rhs body -> +    occCountX s body $ \envB mkbody -> +    occEnvPop' envB $ \envB' s1 -> +    occCountX s1 rhs $ \envR mkrhs -> +    withSome (Some envB' <> Some envR) $ \env -> +    k env $ \env' -> +      ELet ext (mkrhs env') (mkbody (OccPush env' () s1)) +  EPair _ a b -> +    case s of +      SsNone -> +        occCountX SsNone a $ \env1 mka -> +        occCountX SsNone b $ \env2 mkb -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          use (mka env') $ use (mkb env') $ ENil ext +      SsPair' s1 s2 -> +        occCountX s1 a $ \env1 mka -> +        occCountX s2 b $ \env2 mkb -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          EPair ext (mka env') (mkb env') +  EFst _ e -> +    occCountX (SsPair s SsNone) e $ \env1 mke -> +    k env1 $ \env' -> +      EFst ext (mke env') +  ESnd _ e -> +    occCountX (SsPair SsNone s) e $ \env1 mke -> +    k env1 $ \env' -> +      ESnd ext (mke env') +  ENil _ -> +    case s of +      SsFull -> k OccEnd (\_ -> ENil ext) +      SsNone -> k OccEnd (\_ -> ENil ext) +  EInl _ t e -> +    case s of +      SsNone -> +        occCountX SsNone e $ \env1 mke -> +        k env1 $ \env' -> +          use (mke env') $ ENil ext +      SsEither s1 s2 -> +        occCountX s1 e $ \env1 mke -> +        k env1 $ \env' -> +          EInl ext (applySubstruc s2 t) (mke env') +      SsFull -> occCountX (SsEither SsFull SsFull) topexpr k +  EInr _ t e -> +    case s of +      SsNone -> +        occCountX SsNone e $ \env1 mke -> +        k env1 $ \env' -> +          use (mke env') $ ENil ext +      SsEither s1 s2 -> +        occCountX s2 e $ \env1 mke -> +        k env1 $ \env' -> +          EInr ext (applySubstruc s1 t) (mke env') +      SsFull -> occCountX (SsEither SsFull SsFull) topexpr k +  ECase _ e a b -> +    occCountX s a $ \env1' mka -> +    occCountX s b $ \env2' mkb -> +    occEnvPop' env1' $ \env1 s1 -> +    occEnvPop' env2' $ \env2 s2 -> +    occCountX (SsEither s1 s2) e $ \env0 mke -> +    withSome (Some env0 <> Some env1) $ \env01 -> +    withSome (Some env01 <> Some env2) $ \env -> +    k env $ \env' -> +      ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2)) +  ENothing _ t -> +    case s of +      SsNone -> k OccEnd (\_ -> ENil ext) +      SsMaybe s' -> k OccEnd (\_ -> ENothing ext (applySubstruc s' t)) +      SsFull -> occCountX (SsMaybe SsFull) topexpr k +  EJust _ e -> +    case s of +      SsNone -> +        occCountX SsNone e $ \env1 mke -> +        k env1 $ \env' -> +          use (mke env') $ ENil ext +      SsMaybe s' -> +        occCountX s' e $ \env1 mke -> +        k env1 $ \env' -> +          EJust ext (mke env') +      SsFull -> occCountX (SsMaybe SsFull) topexpr k +  EMaybe _ a b e -> +    occCountX s a $ \env1 mka -> +    occCountX s b $ \env2' mkb -> +    occEnvPop' env2' $ \env2 s2 -> +    occCountX (SsMaybe s2) e $ \env0 mke -> +    withSome (Some env0 <> Some env1) $ \env01 -> +    withSome (Some env01 <> Some env2) $ \env -> +    k env $ \env' -> +      EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env') +  ELNil _ t1 t2 -> +    case s of +      SsNone -> k OccEnd (\_ -> ENil ext) +      SsLEither s1 s2 -> k OccEnd (\_ -> ELNil ext (applySubstruc s1 t1) (applySubstruc s2 t2)) +      SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k +  ELInl _ t e -> +    case s of +      SsNone -> +        occCountX SsNone e $ \env1 mke -> +        k env1 $ \env' -> +          use (mke env') $ ENil ext +      SsLEither s1 s2 -> +        occCountX s1 e $ \env1 mke -> +        k env1 $ \env' -> +          ELInl ext (applySubstruc s2 t) (mke env') +      SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k +  ELInr _ t e -> +    case s of +      SsNone -> +        occCountX SsNone e $ \env1 mke -> +        k env1 $ \env' -> +          use (mke env') $ ENil ext +      SsLEither s1 s2 -> +        occCountX s2 e $ \env1 mke -> +        k env1 $ \env' -> +          ELInr ext (applySubstruc s1 t) (mke env') +      SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k +  ELCase _ e a b c -> +    occCountX s a $ \env1 mka -> +    occCountX s b $ \env2' mkb -> +    occCountX s c $ \env3' mkc -> +    occEnvPop' env2' $ \env2 s1 -> +    occEnvPop' env3' $ \env3 s2 -> +    occCountX (SsLEither s1 s2) e $ \env0 mke -> +    withSome (Some env0 <> Some env1) $ \env01 -> +    withSome (Some env01 <> Some env2) $ \env012 -> +    withSome (Some env012 <> Some env3) $ \env -> +    k env $ \env' -> +      ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2)) + +  EConstArr _ n t x -> +    case s of +      SsNone -> k OccEnd (\_ -> ENil ext) +      SsArr SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) +      SsArr SsFull -> k OccEnd (\_ -> EConstArr ext n t x) +      SsFull -> occCountX (SsArr SsFull) topexpr k + +  EBuild _ n a b -> +    case s of +      SsNone -> +        occCountX SsFull a $ \env1 mka -> +        occCountX SsNone b $ \env2' mkb -> +        occEnvPop' env2' $ \env2 s2 -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          use (EBuild ext n (mka env') $ +                 use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ +                        weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $ +                   ENil ext) $ +            ENil ext +      SsArr s' -> +        occCountX SsFull a $ \env1 mka -> +        occCountX s' b $ \env2' mkb -> +        occEnvPop' env2' $ \env2 s2 -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          EBuild ext n (mka env') $ +            elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ +              weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) +      SsFull -> occCountX (SsArr SsFull) topexpr k + +  EFold1Inner _ commut a b c -> +    occCountX SsFull a $ \env1'' mka -> +    occEnvPop' env1'' $ \env1' s2 -> +    occEnvPop' env1' $ \env1 s1 -> +    let s0 = case s of +               SsNone -> Some SsNone +               SsArr s' -> Some s' +               SsFull -> Some SsFull in +    withSome (Some s1 <> Some s2 <> s0) $ \sElt -> +    occCountX sElt b $ \env2 mkb -> +    occCountX (SsArr sElt) c  $ \env3 mkc -> +    withSome (Some env1 <> Some env2 <> Some env3) $ \env -> +    k env $ \env' -> +      let expr = EFold1Inner ext commut +                    (projectSmallerSubstruc SsFull sElt $ +                      mka (OccPush (OccPush env' () sElt) () sElt)) +                    (mkb env') (mkc env') in +      case s of +        SsNone -> use expr $ ENil ext +        SsArr s' -> projectSmallerSubstruc (SsArr sElt) (SsArr s') expr +        SsFull -> case testEquality sElt SsFull of +                    Just Refl -> expr +                    Nothing -> error "unreachable" + +  ESum1Inner _ e +    | STArr (SS n) _ <- typeOf e -> +    case s of +      SsNone -> +        occCountX SsNone e $ \env mke -> +        k env $ \env' -> +          use (mke env') $ ENil ext +      SsArr SsNone -> +        occCountX (SsArr SsNone) e $ \env mke -> +        k env $ \env' -> +          elet (mke env') $ +            EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) +      SsArr SsFull -> +        occCountX (SsArr SsFull) e $ \env mke -> +        k env $ \env' -> +          ESum1Inner ext (mke env') +      SsFull -> occCountX (SsArr SsFull) topexpr k + +  EUnit _ e -> +    case s of +      SsNone -> +        occCountX SsNone e $ \env mke -> +        k env $ \env' -> +          use (mke env') $ ENil ext +      SsArr s' -> +        occCountX s' e $ \env mke -> +        k env $ \env' -> +          EUnit ext (mke env') +      SsFull -> occCountX (SsArr SsFull) topexpr k + +  EReplicate1Inner _ a b -> +    case s of +      SsNone -> +        occCountX SsNone a $ \env1 mka -> +        occCountX SsNone b $ \env2 mkb -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          use (mka env') $ use (mkb env') $ ENil ext +      SsArr s' -> +        occCountX SsFull a $ \env1 mka -> +        occCountX (SsArr s') b $ \env2 mkb -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          EReplicate1Inner ext (mka env') (mkb env') +      SsFull -> occCountX (SsArr SsFull) topexpr k + +  {- +  EMaximum1Inner _ e -> +    let (e', env) = re (SsArr (ssUnarr s)) e +    in (EMaximum1Inner s e', env) +  EMinimum1Inner _ e -> +    let (e', env) = re (SsArr (ssUnarr s)) e +    in (EMinimum1Inner s e', env) +  -} -        re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env' -        re1 = unpush . go (WSink .> w) +  EConst _ t x -> +    k OccEnd $ \_ -> +      case s of +        SsNone -> ENil ext +        SsFull -> EConst ext t x + +  EIdx0 _ e -> +    occCountX (SsArr s) e $ \env1 mke -> +    k env1 $ \env' -> +      EIdx0 ext (mke env') + +  EIdx1 _ a b -> +    case s of +      SsNone -> +        occCountX SsNone a $ \env1 mka -> +        occCountX SsNone b $ \env2 mkb -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          use (mka env') $ use (mkb env') $ ENil ext +      SsArr s' -> +        occCountX (SsArr s') a $ \env1 mka -> +        occCountX SsFull b $ \env2 mkb -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          EIdx1 ext (mka env') (mkb env') +      SsFull -> occCountX (SsArr SsFull) topexpr k + +  EIdx _ a b -> +    case s of +      SsNone -> +        occCountX SsNone a $ \env1 mka -> +        occCountX SsNone b $ \env2 mkb -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          use (mka env') $ use (mkb env') $ ENil ext +      _ -> +        occCountX (SsArr s) a $ \env1 mka -> +        occCountX SsFull b $ \env2 mkb -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          EIdx ext (mka env') (mkb env') + +  EShape _ e -> +    case s of +      SsNone -> +        occCountX SsNone e $ \env1 mke -> +        k env1 $ \env' -> +          use (mke env') $ ENil ext +      _ -> +        occCountX (SsArr SsNone) e $ \env1 mke -> +        k env1 $ \env' -> +          projectSmallerSubstruc SsFull s $ EShape ext (mke env') + +  EOp _ op e -> +    case s of +      SsNone -> +        occCountX SsNone e $ \env1 mke -> +        k env1 $ \env' -> +          use (mke env') $ ENil ext +      _ -> +        occCountX SsFull e $ \env1 mke -> +        k env1 $ \env' -> +          projectSmallerSubstruc SsFull s $ EOp ext op (mke env') + +  {- +  ECustom _ t1 t2 t3 e1 e2 e3 a b -> +    let (e1', _) = occCountX SsFull e1 +        (e2', _) = occCountX SsFull e2 +        (e3', _) = occCountX SsFull e3 +        (a', env1) = re SsFull a  -- let's be pessimistic here for safety +        (b', env2) = re SsFull b +    in (ECustom SsFull t1 t2 t3 e1' e2' e3' a' b', env1 <> env2) +  -} + +  ERecompute _ e -> +    occCountX s e $ \env1 mke -> +    k env1 $ \env' -> +      ERecompute ext (mke env') + +  EWith _ t a b -> +    case s of +      SsNone ->  -- TODO: simplifier should remove accumulations to an unused with, and then remove the with +        occCountX SsNone b $ \env2' mkb -> +        occEnvPop' env2' $ \env2 s1 -> +        withSome (case s1 of +                    SsFull -> Some SsFull +                    SsAccum s' -> Some s' +                    SsNone -> Some SsNone) $ \s1' -> +        occCountX s1' a $ \env1 mka -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          use (EWith ext (applySubstrucM s1' t) (mka env') (mkb (OccPush env' () (SsAccum s1')))) $ +            ENil ext +      SsPair sB sA -> +        occCountX sB b $ \env2' mkb -> +        occEnvPop' env2' $ \env2 s1 -> +        let s1' = case s1 of +                    SsFull -> Some SsFull +                    SsAccum s' -> Some s' +                    SsNone -> Some SsNone in +        withSome (Some sA <> s1') $ \sA' -> +        occCountX sA' a $ \env1 mka -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          projectSmallerSubstruc (SsPair sB sA') (SsPair sB sA) $ +            EWith ext (applySubstrucM sA' t) (mka env') (mkb (OccPush env' () (SsAccum sA'))) +      SsFull -> occCountX (SsPair SsFull SsFull) topexpr k + +  EAccum _ t p a sp b e -> +    -- TODO: do better! +    occCountX SsFull a $ \env1 mka -> +    occCountX SsFull b $ \env2 mkb -> +    occCountX SsFull e $ \env3 mke -> +    withSome (Some env1 <> Some env2) $ \env12 -> +    withSome (Some env12 <> Some env3) $ \env -> +    k env $ \env' -> +      case s of {SsFull -> id; SsNone -> id} $ +        EAccum ext t p (mka env') sp (mkb env') (mke env') + +  EZero _ t e -> +    occCountX (subZeroInfo s) e $ \env1 mke -> +    k env1 $ \env' -> +      EZero ext (applySubstrucM s t) (mke env') +    where +      subZeroInfo :: Substruc t1 t2 -> Substruc (ZeroInfo t1) (ZeroInfo t2) +      subZeroInfo SsFull = SsFull +      subZeroInfo SsNone = SsNone +      subZeroInfo (SsPair s1 s2) = SsPair (subZeroInfo s1) (subZeroInfo s2) +      subZeroInfo SsEither{} = error "Either is not a monoid" +      subZeroInfo SsLEither{} = SsNone +      subZeroInfo SsMaybe{} = SsNone +      subZeroInfo (SsArr s') = SsArr (subZeroInfo s') +      subZeroInfo SsAccum{} = error "Accum is not a monoid" + +  EDeepZero _ t e -> +    occCountX (subDeepZeroInfo s) e $ \env1 mke -> +    k env1 $ \env' -> +      EDeepZero ext (applySubstrucM s t) (mke env') +    where +      subDeepZeroInfo :: Substruc t1 t2 -> Substruc (DeepZeroInfo t1) (DeepZeroInfo t2) +      subDeepZeroInfo SsFull = SsFull +      subDeepZeroInfo SsNone = SsNone +      subDeepZeroInfo (SsPair s1 s2) = SsPair (subDeepZeroInfo s1) (subDeepZeroInfo s2) +      subDeepZeroInfo SsEither{} = error "Either is not a monoid" +      subDeepZeroInfo (SsLEither s1 s2) = SsLEither (subDeepZeroInfo s1) (subDeepZeroInfo s2) +      subDeepZeroInfo (SsMaybe s') = SsMaybe (subDeepZeroInfo s') +      subDeepZeroInfo (SsArr s') = SsArr (subDeepZeroInfo s') +      subDeepZeroInfo SsAccum{} = error "Accum is not a monoid" + +  EPlus _ t a b -> +    occCountX s a $ \env1 mka -> +    occCountX s b $ \env2 mkb -> +    withSome (Some env1 <> Some env2) $ \env -> +    k env $ \env' -> +      EPlus ext (applySubstrucM s t) (mka env') (mkb env') + +  EOneHot _ t p a b -> +    occCountX SsFull a $ \env1 mka -> +    occCountX SsFull b $ \env2 mkb ->  -- TODO: do better +    withSome (Some env1 <> Some env2) $ \env -> +    k env $ \env' -> +      projectSmallerSubstruc SsFull s $ EOneHot ext t p (mka env') (mkb env') + +  EError _ t msg -> +    k OccEnd $ \_ -> EError ext (applySubstruc s t) msg + +  _ -> error "occCountX: TODO unimplemented" +  where +    s = simplifySubstruc (typeOf topexpr) initialS -deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r -deleteUnused SNil OccEnd k = k SETop -deleteUnused (_ `SCons` env) OccEnd k = -  deleteUnused env OccEnd $ \sub -> k (SENo sub) -deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k = -  deleteUnused env occenv $ \sub -> +deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r +deleteUnused SNil (Some OccEnd) k = k SETop +deleteUnused (_ `SCons` env) (Some OccEnd) k = +  deleteUnused env (Some OccEnd) $ \sub -> k (SENo sub) +deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k = +  deleteUnused env (Some occenv) $ \sub ->      case count of Zero -> k (SENo sub)                    _    -> k (SEYesR sub) diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index fef9686..9018602 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -6,6 +6,7 @@  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeOperators #-}  module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs index 93258b7..2a29799 100644 --- a/src/AST/Sparse.hs +++ b/src/AST/Sparse.hs @@ -85,9 +85,6 @@ withInj2 (Inj f) (Inj g) k = Inj (k f g)  withInj2 Noinj _ _ = Noinj  withInj2 _ Noinj _ = Noinj -use :: Ex env a -> Ex env b -> Ex env b -use a b = elet a $ weakenExpr WSink b -  -- | This function produces quadratically-sized code in the presence of nested  -- dynamic sparsity. TODO can this be improved?  sparsePlusS diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index d882e28..3a97fd1 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -19,6 +19,7 @@ module AST.Weaken (module AST.Weaken, Append) where  import Data.Bifunctor (first)  import Data.Functor.Const +import Data.GADT.Compare  import Data.Kind (Type)  import Data @@ -31,6 +32,11 @@ data Idx env t where    IS :: Idx env t -> Idx (a : env) t  deriving instance Show (Idx env t) +instance GEq (Idx env) where +  geq IZ IZ = Just Refl +  geq (IS i) (IS j) | Just Refl <- geq i j = Just Refl +  geq _ _ = Nothing +  splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t)  splitIdx SNil i = Right i  splitIdx (SCons _ _) IZ = Left IZ | 
