diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-28 22:40:41 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-28 22:40:41 +0100 | 
| commit | c06b4bd71a94601d467b509a26c08020d1fbd794 (patch) | |
| tree | b16981c769231ef4af2c3ec5f002a01f857d95c6 /src | |
| parent | a3ba3bdc5c2f9606a0b98cdf53183841cca07eac (diff) | |
Pass around an accumMap (but it's empty still)
Diffstat (limited to 'src')
| -rw-r--r-- | src/AST.hs | 45 | ||||
| -rw-r--r-- | src/AST/Weaken.hs | 9 | ||||
| -rw-r--r-- | src/Analysis/Identity.hs | 5 | ||||
| -rw-r--r-- | src/CHAD.hs | 73 | ||||
| -rw-r--r-- | src/CHAD/Top.hs | 6 | ||||
| -rw-r--r-- | src/Data/VarMap.hs | 93 | 
6 files changed, 185 insertions, 46 deletions
| @@ -246,6 +246,42 @@ extOf = \case    EOneHot x _ _ _ _ -> x    EError x _ _ -> x +mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t +mapExt f = \case +  EVar x t i -> EVar (f x) t i +  ELet x rhs body -> ELet (f x) (mapExt f rhs) (mapExt f body) +  EPair x a b -> EPair (f x) (mapExt f a) (mapExt f b) +  EFst x e -> EFst (f x) (mapExt f e) +  ESnd x e -> ESnd (f x) (mapExt f e) +  ENil x -> ENil (f x) +  EInl x t e -> EInl (f x) t (mapExt f e) +  EInr x t e -> EInr (f x) t (mapExt f e) +  ECase x e a b -> ECase (f x) (mapExt f e) (mapExt f a) (mapExt f b) +  ENothing x t -> ENothing (f x) t +  EJust x e -> EJust (f x) (mapExt f e) +  EMaybe x a b e -> EMaybe (f x) (mapExt f a) (mapExt f b) (mapExt f e) +  EConstArr x n t a -> EConstArr (f x) n t a +  EBuild x n a b -> EBuild (f x) n (mapExt f a) (mapExt f b) +  EFold1Inner x cm a b c -> EFold1Inner (f x) cm (mapExt f a) (mapExt f b) (mapExt f c) +  ESum1Inner x e -> ESum1Inner (f x) (mapExt f e) +  EUnit x e -> EUnit (f x) (mapExt f e) +  EReplicate1Inner x a b -> EReplicate1Inner (f x) (mapExt f a) (mapExt f b) +  EMaximum1Inner x e -> EMaximum1Inner (f x) (mapExt f e) +  EMinimum1Inner x e -> EMinimum1Inner (f x) (mapExt f e) +  EConst x t v -> EConst (f x) t v +  EIdx0 x e -> EIdx0 (f x) (mapExt f e) +  EIdx1 x a b -> EIdx1 (f x) (mapExt f a) (mapExt f b) +  EIdx x e es -> EIdx (f x) (mapExt f e) (mapExt f es) +  EShape x e -> EShape (f x) (mapExt f e) +  EOp x op e -> EOp (f x) op (mapExt f e) +  ECustom x s t p a b c e1 e2 -> ECustom (f x) s t p (mapExt f a) (mapExt f b) (mapExt f c) (mapExt f e1) (mapExt f e2) +  EWith x t e1 e2 -> EWith (f x) t (mapExt f e1) (mapExt f e2) +  EAccum x t p e1 e2 e3 -> EAccum (f x) t p (mapExt f e1) (mapExt f e2) (mapExt f e3) +  EZero x t -> EZero (f x) t +  EPlus x t a b -> EPlus (f x) t (mapExt f a) (mapExt f b) +  EOneHot x t p a b -> EOneHot (f x) t p (mapExt f a) (mapExt f b) +  EError x t s -> EError (f x) t s +  subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t  subst1 repl = subst $ \x t -> \case IZ -> repl                                      IS i -> EVar x t i @@ -302,15 +338,6 @@ subst' f w = \case  weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t  weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) -slistIdx :: SList f list -> Idx list t -> f t -slistIdx (SCons x _) IZ = x -slistIdx (SCons _ list) (IS i) = slistIdx list i -slistIdx SNil i = case i of {} - -idx2int :: Idx env t -> Int -idx2int IZ = 0 -idx2int (IS n) = 1 + idx2int n -  class KnownScalTy t where knownScalTy :: SScalTy t  instance KnownScalTy TI32 where knownScalTy = STI32  instance KnownScalTy TI64 where knownScalTy = STI64 diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index dbb37f7..bd2c244 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -36,6 +36,15 @@ splitIdx SNil i = Right i  splitIdx (SCons _ _) IZ = Left IZ  splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i) +slistIdx :: SList f list -> Idx list t -> f t +slistIdx (SCons x _) IZ = x +slistIdx (SCons _ list) (IS i) = slistIdx list i +slistIdx SNil i = case i of {} + +idx2int :: Idx env t -> Int +idx2int IZ = 0 +idx2int (IS n) = 1 + idx2int n +  data env :> env' where    WId :: env :> env    WSink :: forall t env. env :> (t : env) diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 095d0fa..54f7cd2 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -3,7 +3,9 @@  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE ScopedTypeVariables #-}  module Analysis.Identity ( +  ValId(..),    identityAnalysis, +  identityAnalysis',  ) where  import Data.Foldable (toList) @@ -50,6 +52,9 @@ identityAnalysis env term = runIdGen 0 $ do    env' <- slistMapA genIds env    snd <$> idana env' term +identityAnalysis' :: SList ValId env -> Expr x env t -> Expr ValId env t +identityAnalysis' env term = snd (runIdGen 0 (idana env term)) +  idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t)  idana env expr = case expr of    EVar _ t i -> do diff --git a/src/CHAD.hs b/src/CHAD.hs index be308cd..6a4d5f5 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -36,6 +36,7 @@ import Data.Type.Bool (If)  import Data.Type.Equality (type (==))  import GHC.Stack (HasCallStack) +import Analysis.Identity (ValId(..))  import AST  import AST.Bindings  import AST.Count @@ -45,6 +46,8 @@ import CHAD.Accum  import CHAD.EnvDescr  import CHAD.Types  import Data +import qualified Data.VarMap as VarMap +import Data.VarMap (VarMap)  import Lemmas @@ -558,9 +561,9 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =  drev :: forall env sto t.          (?config :: CHADConfig) -     => Descr env sto -     -> Ex env t -> Ret env sto t -drev des = \case +     => Descr env sto -> VarMap Int env +     -> Expr ValId env t -> Ret env sto t +drev des accumMap = \case    EVar _ t i ->      case conv2Idx des i of        Idx2Ac accI -> @@ -584,10 +587,10 @@ drev des = \case              (subenvNone (select SMerge des))              (ENil ext) -  ELet _ (rhs :: Ex _ a) body -    | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs +  ELet _ (rhs :: Expr _ _ a) body +    | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs      , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge -    , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des (typeOf rhs) storage body +    , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage body      , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0      , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)      , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> @@ -613,7 +616,7 @@ drev des = \case    EPair _ a b      | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) -        <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil +        <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil      , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->      subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B ->      Ret binds @@ -632,7 +635,7 @@ drev des = \case             (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ))    EFst _ e -    | Ret e0 subtape e1 sub e2 <- drev des e +    | Ret e0 subtape e1 sub e2 <- drev des accumMap e      , STPair t1 t2 <- typeOf e ->      Ret e0          subtape @@ -642,7 +645,7 @@ drev des = \case             weakenExpr (WCopy WSink) e2)    ESnd _ e -    | Ret e0 subtape e1 sub e2 <- drev des e +    | Ret e0 subtape e1 sub e2 <- drev des accumMap e      , STPair t1 t2 <- typeOf e ->      Ret e0          subtape @@ -654,7 +657,7 @@ drev des = \case    ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)    EInl _ t2 e -    | Ret e0 subtape e1 sub e2 <- drev des e -> +    | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->      Ret e0          subtape          (EInl ext (d1 t2) e1) @@ -667,7 +670,7 @@ drev des = \case             (EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ))    EInr _ t1 e -    | Ret e0 subtape e1 sub e2 <- drev des e -> +    | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->      Ret e0          subtape          (EInr ext (d1 t1) e1) @@ -679,13 +682,13 @@ drev des = \case                (weakenExpr (WCopy (wSinks' @[_,_])) e2))             (EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ)) -  ECase _ e (a :: Ex _ t) b +  ECase _ e (a :: Expr _ _ t) b      | STEither t1 t2 <- typeOf e -    , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des e +    , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e      , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge      , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge -    , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des t1 storage1 a -    , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des t2 storage2 b +    , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 a +    , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 b      , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))      , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))      , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) @@ -762,7 +765,7 @@ drev des = \case          (ENil ext)    EOp _ op e -    | Ret e0 subtape e1 sub e2 <- drev des e -> +    | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->      case d2op op of        Linear d2opfun ->          Ret e0 @@ -783,15 +786,15 @@ drev des = \case    ECustom _ _ _ storety _ pr du a b      -- allowed to ignore a2 because 'a' is the part of the input that is inactive      | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil) -        <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil -> +        <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil ->      Ret (binds `BPush` (typeOf a1, a1)                 `BPush` (typeOf b1, weakenExpr WSink b1) -               `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) pr) +               `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))                 `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ)))          (SEYes (SENo (SENo (SENo subtape))))          (EFst ext (EVar ext (typeOf pr) (IS IZ)))          bsub -        (ELet ext (weakenExpr (WCopy (WCopy WClosed)) du) $ +        (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $             weakenExpr (WCopy (WSink .> WSink)) b2)    EError _ t s -> @@ -808,8 +811,8 @@ drev des = \case          (subenvNone (select SMerge des))          (ENil ext) -  EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty) -    | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des she  -- allowed to ignore she2 here because she has a discrete result +  EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty) +    | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des accumMap she  -- allowed to ignore she2 here because she has a discrete result      , let eltty = typeOf orige      , shty :: STy shty <- tTup (sreplicate ndim tIx)      , Refl <- indexTupD1Id ndim -> @@ -817,7 +820,7 @@ drev des = \case      let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in      subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->      accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> -    case drev (prodes `DPush` (shty, SDiscr)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> +    case drev (prodes `DPush` (shty, SDiscr)) (VarMap.sink1 (VarMap.subMap usedSub accumMap)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->      case assertSubenvEmpty sub of { Refl ->      let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in      let collectexpr = bindingsCollect e0 subtapeE in @@ -881,7 +884,7 @@ drev des = \case      }}    EUnit _ e -    | Ret e0 subtape e1 sub e2 <- drev des e -> +    | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->      Ret e0          subtape          (EUnit ext e1) @@ -895,7 +898,7 @@ drev des = \case    EReplicate1Inner _ en e      -- We're allowed to ignore en2 here because the output of 'ei' is discrete.      | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) -        <- retConcat des $ drev des en `SCons` drev des e `SCons` SNil +        <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil      , let STArr ndim eltty = typeOf e ->      Ret binds          subtape @@ -911,7 +914,7 @@ drev des = \case            (EVar ext (d2 (STArr (SS ndim) eltty)) IZ))    EIdx0 _ e -    | Ret e0 subtape e1 sub e2 <- drev des e +    | Ret e0 subtape e1 sub e2 <- drev des accumMap e      , STArr _ t <- typeOf e ->      Ret e0          subtape @@ -925,7 +928,7 @@ drev des = \case    EIdx1 _ e ei      -- We're allowed to ignore ei2 here because the output of 'ei' is discrete.      | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) -        <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil +        <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil      , STArr (SS n) eltty <- typeOf e ->      Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1)                 `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) @@ -942,7 +945,7 @@ drev des = \case    EIdx _ e ei      -- We're allowed to ignore ei2 here because the output of 'ei' is discrete.      | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) -        <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil +        <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil      , STArr n eltty <- typeOf e      , Refl <- indexTupD1Id n      , let tIxN = tTup (sreplicate n tIx)  -> @@ -962,7 +965,7 @@ drev des = \case    EShape _ e      -- Allowed to ignore e2 here because the output of EShape is discrete,      -- hence we'd be passing a zero cotangent to e2 anyway. -    | Ret e0 subtape e1 _ _ <- drev des e +    | Ret e0 subtape e1 _ _ <- drev des accumMap e      , STArr n _ <- typeOf e      , Refl <- indexTupD1Id n ->      Ret e0 @@ -972,7 +975,7 @@ drev des = \case          (ENil ext)    ESum1Inner _ e -    | Ret e0 subtape e1 sub e2 <- drev des e +    | Ret e0 subtape e1 sub e2 <- drev des accumMap e      , STArr (SS n) t <- typeOf e ->      Ret (e0 `BPush` (STArr (SS n) t, e1)              `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) @@ -1010,9 +1013,9 @@ drev des = \case      deriv_extremum :: ScalIsNumeric t' ~ True                     => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) -                   -> Ex env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t')) +                   -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t'))      deriv_extremum extremum e -      | Ret e0 subtape e1 sub e2 <- drev des e +      | Ret e0 subtape e1 sub e2 <- drev des accumMap e        , at@(STArr (SS n) t@(STScal st)) <- typeOf e        , let at' = STArr n t        , let tIxN = tTup (sreplicate (SS n) tIx) = @@ -1052,11 +1055,11 @@ deriving instance Show (RetScoped env0 sto a s t)  drevScoped :: forall a s env sto t.                (?config :: CHADConfig) -           => Descr env sto -> STy a -> Storage s -           -> Ex (a : env) t +           => Descr env sto -> VarMap Int env -> STy a -> Storage s +           -> Expr ValId (a : env) t             -> RetScoped env sto a s t -drevScoped des argty argsto expr -  | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) expr +drevScoped des accumMap argty argsto expr +  | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) (VarMap.sink1 accumMap) expr    = case argsto of        SMerge ->          case sub of diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index d058132..ced7550 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -10,6 +10,7 @@  {-# LANGUAGE TypeOperators #-}  module CHAD.Top where +import Analysis.Identity  import AST  import AST.Weaken.Auto  import CHAD @@ -17,6 +18,7 @@ import CHAD.Accum  import CHAD.EnvDescr  import CHAD.Types  import Data +import qualified Data.VarMap as VarMap  type family MergeEnv env where @@ -85,7 +87,7 @@ chad config env (term :: Ex env t)                                                  &. #tl (d1e env))                                                 (#d :++: #acenv :++: #tl)                                                 (#acenv :++: #d :++: #tl)) $ -                            freezeRet descr (drev descr term)) $ +                            freezeRet descr (drev descr VarMap.empty (identityAnalysis env term))) $                EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))                          (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ))                                                          (ESnd ext (EFst ext (EVar ext tvar IZ))))) @@ -93,7 +95,7 @@ chad config env (term :: Ex env t)    | False <- chcArgArrayAccum config    , Refl <- mergeEnvNoAccum env    , Refl <- mergeEnvOnlyMerge env -  = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term) +  = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (identityAnalysis env term))  chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))  chad' config env term diff --git a/src/Data/VarMap.hs b/src/Data/VarMap.hs new file mode 100644 index 0000000..16c2d27 --- /dev/null +++ b/src/Data/VarMap.hs @@ -0,0 +1,93 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +module Data.VarMap ( +  VarMap, +  empty, +  insert, +  delete, +  TypedIdx(..), +  lookup, +  sink1, +  unsink1, +  subMap, +) where + +import Prelude hiding (lookup) + +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import Data.Maybe (mapMaybe) +import Data.Some +import qualified Data.Vector.Storable as VS +import Unsafe.Coerce + +import AST.Env +import AST.Types +import AST.Weaken + + +type role VarMap _ nominal  -- ensure that 'env' is not phantom +data VarMap k (env :: [Ty]) = +  VarMap Int  -- ^ Global offset; must be added to any value in the map in order to get the proper index +         Int  -- ^ Time since last cleanup +         (Map k (Some STy, Int)) +deriving instance Show k => Show (VarMap k env) + +empty :: VarMap k env +empty = VarMap 0 0 Map.empty + +insert :: Ord k => k -> STy t -> Idx env t -> VarMap k env -> VarMap k env +insert k ty idx (VarMap off interval mp) = +  maybeCleanup $ VarMap off (interval + 1) (Map.insert k (Some ty, idx2int idx - off) mp) + +delete :: Ord k => k -> VarMap k env -> VarMap k env +delete k (VarMap off interval mp) = +  maybeCleanup $ VarMap off (interval + 1) (Map.delete k mp) + +data TypedIdx env t = TypedIdx (STy t) (Idx env t) +  deriving (Show) + +lookup :: Ord k => k -> VarMap k env -> Maybe (Some (TypedIdx env)) +lookup k (VarMap off _ mp) = do +  (Some ty, i) <- Map.lookup k mp +  idx <- unsafeInt2idx (i + off) +  return (Some (TypedIdx ty idx)) + +sink1 :: VarMap k env -> VarMap k (t : env) +sink1 (VarMap off interval mp) = VarMap (off + 1) interval mp + +unsink1 :: VarMap k (t : env) -> VarMap k env +unsink1 (VarMap off interval mp) = VarMap (off - 1) interval mp + +subMap :: Eq k => Subenv env env' -> VarMap k env -> VarMap k env' +subMap subenv = +  let bools = let loop :: Subenv env env' -> [Bool] +                  loop SETop = [] +                  loop (SEYes sub) = True : loop sub +                  loop (SENo sub) = False : loop sub +              in VS.fromList $ loop subenv +      newIndices = VS.init $ VS.scanl' (\n b -> if b then n + 1 else n) (0 :: Int) bools +      modify off (k, (ty, i)) +        | i + off < 0 = Nothing +        | i + off >= VS.length bools = error "VarMap.subMap: found negative indices in map" +        | bools VS.! (i + off) = Just (k, (ty, newIndices VS.! (i + off))) +        | otherwise = Nothing +  in \(VarMap off _ mp) -> VarMap 0 0 (Map.fromAscList . mapMaybe (modify off) . Map.toAscList $ mp) + +maybeCleanup :: VarMap k env -> VarMap k env +maybeCleanup (VarMap off interval mp) +  | let sz = Map.size mp +  , sz > 0, 2 * interval >= 3 * sz +  = VarMap off 0 (Map.filter (\(_, i) -> i + off >= 0) mp) +maybeCleanup vm = vm + +unsafeInt2idx :: Int -> Maybe (Idx env t) +unsafeInt2idx = \n -> if n < 0 then Nothing else Just (go n) +  where +    go :: Int -> Idx env t +    go 0 = unsafeCoerce IZ +    go n = unsafeCoerce (IS (go (n-1))) | 
