From c06b4bd71a94601d467b509a26c08020d1fbd794 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 28 Mar 2025 22:40:41 +0100 Subject: Pass around an accumMap (but it's empty still) --- src/AST.hs | 45 ++++++++++++++++++----- src/AST/Weaken.hs | 9 +++++ src/Analysis/Identity.hs | 5 +++ src/CHAD.hs | 73 +++++++++++++++++++------------------ src/CHAD/Top.hs | 6 ++-- src/Data/VarMap.hs | 93 ++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 185 insertions(+), 46 deletions(-) create mode 100644 src/Data/VarMap.hs (limited to 'src') diff --git a/src/AST.hs b/src/AST.hs index c8377de..652d003 100644 --- a/src/AST.hs +++ b/src/AST.hs @@ -246,6 +246,42 @@ extOf = \case EOneHot x _ _ _ _ -> x EError x _ _ -> x +mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t +mapExt f = \case + EVar x t i -> EVar (f x) t i + ELet x rhs body -> ELet (f x) (mapExt f rhs) (mapExt f body) + EPair x a b -> EPair (f x) (mapExt f a) (mapExt f b) + EFst x e -> EFst (f x) (mapExt f e) + ESnd x e -> ESnd (f x) (mapExt f e) + ENil x -> ENil (f x) + EInl x t e -> EInl (f x) t (mapExt f e) + EInr x t e -> EInr (f x) t (mapExt f e) + ECase x e a b -> ECase (f x) (mapExt f e) (mapExt f a) (mapExt f b) + ENothing x t -> ENothing (f x) t + EJust x e -> EJust (f x) (mapExt f e) + EMaybe x a b e -> EMaybe (f x) (mapExt f a) (mapExt f b) (mapExt f e) + EConstArr x n t a -> EConstArr (f x) n t a + EBuild x n a b -> EBuild (f x) n (mapExt f a) (mapExt f b) + EFold1Inner x cm a b c -> EFold1Inner (f x) cm (mapExt f a) (mapExt f b) (mapExt f c) + ESum1Inner x e -> ESum1Inner (f x) (mapExt f e) + EUnit x e -> EUnit (f x) (mapExt f e) + EReplicate1Inner x a b -> EReplicate1Inner (f x) (mapExt f a) (mapExt f b) + EMaximum1Inner x e -> EMaximum1Inner (f x) (mapExt f e) + EMinimum1Inner x e -> EMinimum1Inner (f x) (mapExt f e) + EConst x t v -> EConst (f x) t v + EIdx0 x e -> EIdx0 (f x) (mapExt f e) + EIdx1 x a b -> EIdx1 (f x) (mapExt f a) (mapExt f b) + EIdx x e es -> EIdx (f x) (mapExt f e) (mapExt f es) + EShape x e -> EShape (f x) (mapExt f e) + EOp x op e -> EOp (f x) op (mapExt f e) + ECustom x s t p a b c e1 e2 -> ECustom (f x) s t p (mapExt f a) (mapExt f b) (mapExt f c) (mapExt f e1) (mapExt f e2) + EWith x t e1 e2 -> EWith (f x) t (mapExt f e1) (mapExt f e2) + EAccum x t p e1 e2 e3 -> EAccum (f x) t p (mapExt f e1) (mapExt f e2) (mapExt f e3) + EZero x t -> EZero (f x) t + EPlus x t a b -> EPlus (f x) t (mapExt f a) (mapExt f b) + EOneHot x t p a b -> EOneHot (f x) t p (mapExt f a) (mapExt f b) + EError x t s -> EError (f x) t s + subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t subst1 repl = subst $ \x t -> \case IZ -> repl IS i -> EVar x t i @@ -302,15 +338,6 @@ subst' f w = \case weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) -slistIdx :: SList f list -> Idx list t -> f t -slistIdx (SCons x _) IZ = x -slistIdx (SCons _ list) (IS i) = slistIdx list i -slistIdx SNil i = case i of {} - -idx2int :: Idx env t -> Int -idx2int IZ = 0 -idx2int (IS n) = 1 + idx2int n - class KnownScalTy t where knownScalTy :: SScalTy t instance KnownScalTy TI32 where knownScalTy = STI32 instance KnownScalTy TI64 where knownScalTy = STI64 diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index dbb37f7..bd2c244 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -36,6 +36,15 @@ splitIdx SNil i = Right i splitIdx (SCons _ _) IZ = Left IZ splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i) +slistIdx :: SList f list -> Idx list t -> f t +slistIdx (SCons x _) IZ = x +slistIdx (SCons _ list) (IS i) = slistIdx list i +slistIdx SNil i = case i of {} + +idx2int :: Idx env t -> Int +idx2int IZ = 0 +idx2int (IS n) = 1 + idx2int n + data env :> env' where WId :: env :> env WSink :: forall t env. env :> (t : env) diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 095d0fa..54f7cd2 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -3,7 +3,9 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} module Analysis.Identity ( + ValId(..), identityAnalysis, + identityAnalysis', ) where import Data.Foldable (toList) @@ -50,6 +52,9 @@ identityAnalysis env term = runIdGen 0 $ do env' <- slistMapA genIds env snd <$> idana env' term +identityAnalysis' :: SList ValId env -> Expr x env t -> Expr ValId env t +identityAnalysis' env term = snd (runIdGen 0 (idana env term)) + idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t) idana env expr = case expr of EVar _ t i -> do diff --git a/src/CHAD.hs b/src/CHAD.hs index be308cd..6a4d5f5 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -36,6 +36,7 @@ import Data.Type.Bool (If) import Data.Type.Equality (type (==)) import GHC.Stack (HasCallStack) +import Analysis.Identity (ValId(..)) import AST import AST.Bindings import AST.Count @@ -45,6 +46,8 @@ import CHAD.Accum import CHAD.EnvDescr import CHAD.Types import Data +import qualified Data.VarMap as VarMap +import Data.VarMap (VarMap) import Lemmas @@ -558,9 +561,9 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = drev :: forall env sto t. (?config :: CHADConfig) - => Descr env sto - -> Ex env t -> Ret env sto t -drev des = \case + => Descr env sto -> VarMap Int env + -> Expr ValId env t -> Ret env sto t +drev des accumMap = \case EVar _ t i -> case conv2Idx des i of Idx2Ac accI -> @@ -584,10 +587,10 @@ drev des = \case (subenvNone (select SMerge des)) (ENil ext) - ELet _ (rhs :: Ex _ a) body - | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs + ELet _ (rhs :: Expr _ _ a) body + | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge - , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des (typeOf rhs) storage body + , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage body , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> @@ -613,7 +616,7 @@ drev des = \case EPair _ a b | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) - <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil + <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B -> Ret binds @@ -632,7 +635,7 @@ drev des = \case (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)) EFst _ e - | Ret e0 subtape e1 sub e2 <- drev des e + | Ret e0 subtape e1 sub e2 <- drev des accumMap e , STPair t1 t2 <- typeOf e -> Ret e0 subtape @@ -642,7 +645,7 @@ drev des = \case weakenExpr (WCopy WSink) e2) ESnd _ e - | Ret e0 subtape e1 sub e2 <- drev des e + | Ret e0 subtape e1 sub e2 <- drev des accumMap e , STPair t1 t2 <- typeOf e -> Ret e0 subtape @@ -654,7 +657,7 @@ drev des = \case ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) EInl _ t2 e - | Ret e0 subtape e1 sub e2 <- drev des e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> Ret e0 subtape (EInl ext (d1 t2) e1) @@ -667,7 +670,7 @@ drev des = \case (EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ)) EInr _ t1 e - | Ret e0 subtape e1 sub e2 <- drev des e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> Ret e0 subtape (EInr ext (d1 t1) e1) @@ -679,13 +682,13 @@ drev des = \case (weakenExpr (WCopy (wSinks' @[_,_])) e2)) (EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ)) - ECase _ e (a :: Ex _ t) b + ECase _ e (a :: Expr _ _ t) b | STEither t1 t2 <- typeOf e - , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des e + , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge - , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des t1 storage1 a - , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des t2 storage2 b + , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 a + , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 b , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) @@ -762,7 +765,7 @@ drev des = \case (ENil ext) EOp _ op e - | Ret e0 subtape e1 sub e2 <- drev des e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> case d2op op of Linear d2opfun -> Ret e0 @@ -783,15 +786,15 @@ drev des = \case ECustom _ _ _ storety _ pr du a b -- allowed to ignore a2 because 'a' is the part of the input that is inactive | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil) - <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil -> + <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil -> Ret (binds `BPush` (typeOf a1, a1) `BPush` (typeOf b1, weakenExpr WSink b1) - `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) pr) + `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) (SEYes (SENo (SENo (SENo subtape)))) (EFst ext (EVar ext (typeOf pr) (IS IZ))) bsub - (ELet ext (weakenExpr (WCopy (WCopy WClosed)) du) $ + (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ weakenExpr (WCopy (WSink .> WSink)) b2) EError _ t s -> @@ -808,8 +811,8 @@ drev des = \case (subenvNone (select SMerge des)) (ENil ext) - EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty) - | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des she -- allowed to ignore she2 here because she has a discrete result + EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty) + | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des accumMap she -- allowed to ignore she2 here because she has a discrete result , let eltty = typeOf orige , shty :: STy shty <- tTup (sreplicate ndim tIx) , Refl <- indexTupD1Id ndim -> @@ -817,7 +820,7 @@ drev des = \case let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> - case drev (prodes `DPush` (shty, SDiscr)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> + case drev (prodes `DPush` (shty, SDiscr)) (VarMap.sink1 (VarMap.subMap usedSub accumMap)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> case assertSubenvEmpty sub of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in let collectexpr = bindingsCollect e0 subtapeE in @@ -881,7 +884,7 @@ drev des = \case }} EUnit _ e - | Ret e0 subtape e1 sub e2 <- drev des e -> + | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> Ret e0 subtape (EUnit ext e1) @@ -895,7 +898,7 @@ drev des = \case EReplicate1Inner _ en e -- We're allowed to ignore en2 here because the output of 'ei' is discrete. | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) - <- retConcat des $ drev des en `SCons` drev des e `SCons` SNil + <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil , let STArr ndim eltty = typeOf e -> Ret binds subtape @@ -911,7 +914,7 @@ drev des = \case (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) EIdx0 _ e - | Ret e0 subtape e1 sub e2 <- drev des e + | Ret e0 subtape e1 sub e2 <- drev des accumMap e , STArr _ t <- typeOf e -> Ret e0 subtape @@ -925,7 +928,7 @@ drev des = \case EIdx1 _ e ei -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) - <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil + <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil , STArr (SS n) eltty <- typeOf e -> Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) @@ -942,7 +945,7 @@ drev des = \case EIdx _ e ei -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) - <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil + <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil , STArr n eltty <- typeOf e , Refl <- indexTupD1Id n , let tIxN = tTup (sreplicate n tIx) -> @@ -962,7 +965,7 @@ drev des = \case EShape _ e -- Allowed to ignore e2 here because the output of EShape is discrete, -- hence we'd be passing a zero cotangent to e2 anyway. - | Ret e0 subtape e1 _ _ <- drev des e + | Ret e0 subtape e1 _ _ <- drev des accumMap e , STArr n _ <- typeOf e , Refl <- indexTupD1Id n -> Ret e0 @@ -972,7 +975,7 @@ drev des = \case (ENil ext) ESum1Inner _ e - | Ret e0 subtape e1 sub e2 <- drev des e + | Ret e0 subtape e1 sub e2 <- drev des accumMap e , STArr (SS n) t <- typeOf e -> Ret (e0 `BPush` (STArr (SS n) t, e1) `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) @@ -1010,9 +1013,9 @@ drev des = \case deriv_extremum :: ScalIsNumeric t' ~ True => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) - -> Ex env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t')) + -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t')) deriv_extremum extremum e - | Ret e0 subtape e1 sub e2 <- drev des e + | Ret e0 subtape e1 sub e2 <- drev des accumMap e , at@(STArr (SS n) t@(STScal st)) <- typeOf e , let at' = STArr n t , let tIxN = tTup (sreplicate (SS n) tIx) = @@ -1052,11 +1055,11 @@ deriving instance Show (RetScoped env0 sto a s t) drevScoped :: forall a s env sto t. (?config :: CHADConfig) - => Descr env sto -> STy a -> Storage s - -> Ex (a : env) t + => Descr env sto -> VarMap Int env -> STy a -> Storage s + -> Expr ValId (a : env) t -> RetScoped env sto a s t -drevScoped des argty argsto expr - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) expr +drevScoped des accumMap argty argsto expr + | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) (VarMap.sink1 accumMap) expr = case argsto of SMerge -> case sub of diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index d058132..ced7550 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -10,6 +10,7 @@ {-# LANGUAGE TypeOperators #-} module CHAD.Top where +import Analysis.Identity import AST import AST.Weaken.Auto import CHAD @@ -17,6 +18,7 @@ import CHAD.Accum import CHAD.EnvDescr import CHAD.Types import Data +import qualified Data.VarMap as VarMap type family MergeEnv env where @@ -85,7 +87,7 @@ chad config env (term :: Ex env t) &. #tl (d1e env)) (#d :++: #acenv :++: #tl) (#acenv :++: #d :++: #tl)) $ - freezeRet descr (drev descr term)) $ + freezeRet descr (drev descr VarMap.empty (identityAnalysis env term))) $ EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ)) (ESnd ext (EFst ext (EVar ext tvar IZ))))) @@ -93,7 +95,7 @@ chad config env (term :: Ex env t) | False <- chcArgArrayAccum config , Refl <- mergeEnvNoAccum env , Refl <- mergeEnvOnlyMerge env - = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term) + = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (identityAnalysis env term)) chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) chad' config env term diff --git a/src/Data/VarMap.hs b/src/Data/VarMap.hs new file mode 100644 index 0000000..16c2d27 --- /dev/null +++ b/src/Data/VarMap.hs @@ -0,0 +1,93 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +module Data.VarMap ( + VarMap, + empty, + insert, + delete, + TypedIdx(..), + lookup, + sink1, + unsink1, + subMap, +) where + +import Prelude hiding (lookup) + +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import Data.Maybe (mapMaybe) +import Data.Some +import qualified Data.Vector.Storable as VS +import Unsafe.Coerce + +import AST.Env +import AST.Types +import AST.Weaken + + +type role VarMap _ nominal -- ensure that 'env' is not phantom +data VarMap k (env :: [Ty]) = + VarMap Int -- ^ Global offset; must be added to any value in the map in order to get the proper index + Int -- ^ Time since last cleanup + (Map k (Some STy, Int)) +deriving instance Show k => Show (VarMap k env) + +empty :: VarMap k env +empty = VarMap 0 0 Map.empty + +insert :: Ord k => k -> STy t -> Idx env t -> VarMap k env -> VarMap k env +insert k ty idx (VarMap off interval mp) = + maybeCleanup $ VarMap off (interval + 1) (Map.insert k (Some ty, idx2int idx - off) mp) + +delete :: Ord k => k -> VarMap k env -> VarMap k env +delete k (VarMap off interval mp) = + maybeCleanup $ VarMap off (interval + 1) (Map.delete k mp) + +data TypedIdx env t = TypedIdx (STy t) (Idx env t) + deriving (Show) + +lookup :: Ord k => k -> VarMap k env -> Maybe (Some (TypedIdx env)) +lookup k (VarMap off _ mp) = do + (Some ty, i) <- Map.lookup k mp + idx <- unsafeInt2idx (i + off) + return (Some (TypedIdx ty idx)) + +sink1 :: VarMap k env -> VarMap k (t : env) +sink1 (VarMap off interval mp) = VarMap (off + 1) interval mp + +unsink1 :: VarMap k (t : env) -> VarMap k env +unsink1 (VarMap off interval mp) = VarMap (off - 1) interval mp + +subMap :: Eq k => Subenv env env' -> VarMap k env -> VarMap k env' +subMap subenv = + let bools = let loop :: Subenv env env' -> [Bool] + loop SETop = [] + loop (SEYes sub) = True : loop sub + loop (SENo sub) = False : loop sub + in VS.fromList $ loop subenv + newIndices = VS.init $ VS.scanl' (\n b -> if b then n + 1 else n) (0 :: Int) bools + modify off (k, (ty, i)) + | i + off < 0 = Nothing + | i + off >= VS.length bools = error "VarMap.subMap: found negative indices in map" + | bools VS.! (i + off) = Just (k, (ty, newIndices VS.! (i + off))) + | otherwise = Nothing + in \(VarMap off _ mp) -> VarMap 0 0 (Map.fromAscList . mapMaybe (modify off) . Map.toAscList $ mp) + +maybeCleanup :: VarMap k env -> VarMap k env +maybeCleanup (VarMap off interval mp) + | let sz = Map.size mp + , sz > 0, 2 * interval >= 3 * sz + = VarMap off 0 (Map.filter (\(_, i) -> i + off >= 0) mp) +maybeCleanup vm = vm + +unsafeInt2idx :: Int -> Maybe (Idx env t) +unsafeInt2idx = \n -> if n < 0 then Nothing else Just (go n) + where + go :: Int -> Idx env t + go 0 = unsafeCoerce IZ + go n = unsafeCoerce (IS (go (n-1))) -- cgit v1.2.3-70-g09d2