summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/AST.hs45
-rw-r--r--src/AST/Weaken.hs9
-rw-r--r--src/Analysis/Identity.hs5
-rw-r--r--src/CHAD.hs73
-rw-r--r--src/CHAD/Top.hs6
-rw-r--r--src/Compile.hs9
-rw-r--r--src/Compile/Exec.hs3
-rw-r--r--src/Data/VarMap.hs93
8 files changed, 193 insertions, 50 deletions
diff --git a/src/AST.hs b/src/AST.hs
index c8377de..652d003 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -246,6 +246,42 @@ extOf = \case
EOneHot x _ _ _ _ -> x
EError x _ _ -> x
+mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t
+mapExt f = \case
+ EVar x t i -> EVar (f x) t i
+ ELet x rhs body -> ELet (f x) (mapExt f rhs) (mapExt f body)
+ EPair x a b -> EPair (f x) (mapExt f a) (mapExt f b)
+ EFst x e -> EFst (f x) (mapExt f e)
+ ESnd x e -> ESnd (f x) (mapExt f e)
+ ENil x -> ENil (f x)
+ EInl x t e -> EInl (f x) t (mapExt f e)
+ EInr x t e -> EInr (f x) t (mapExt f e)
+ ECase x e a b -> ECase (f x) (mapExt f e) (mapExt f a) (mapExt f b)
+ ENothing x t -> ENothing (f x) t
+ EJust x e -> EJust (f x) (mapExt f e)
+ EMaybe x a b e -> EMaybe (f x) (mapExt f a) (mapExt f b) (mapExt f e)
+ EConstArr x n t a -> EConstArr (f x) n t a
+ EBuild x n a b -> EBuild (f x) n (mapExt f a) (mapExt f b)
+ EFold1Inner x cm a b c -> EFold1Inner (f x) cm (mapExt f a) (mapExt f b) (mapExt f c)
+ ESum1Inner x e -> ESum1Inner (f x) (mapExt f e)
+ EUnit x e -> EUnit (f x) (mapExt f e)
+ EReplicate1Inner x a b -> EReplicate1Inner (f x) (mapExt f a) (mapExt f b)
+ EMaximum1Inner x e -> EMaximum1Inner (f x) (mapExt f e)
+ EMinimum1Inner x e -> EMinimum1Inner (f x) (mapExt f e)
+ EConst x t v -> EConst (f x) t v
+ EIdx0 x e -> EIdx0 (f x) (mapExt f e)
+ EIdx1 x a b -> EIdx1 (f x) (mapExt f a) (mapExt f b)
+ EIdx x e es -> EIdx (f x) (mapExt f e) (mapExt f es)
+ EShape x e -> EShape (f x) (mapExt f e)
+ EOp x op e -> EOp (f x) op (mapExt f e)
+ ECustom x s t p a b c e1 e2 -> ECustom (f x) s t p (mapExt f a) (mapExt f b) (mapExt f c) (mapExt f e1) (mapExt f e2)
+ EWith x t e1 e2 -> EWith (f x) t (mapExt f e1) (mapExt f e2)
+ EAccum x t p e1 e2 e3 -> EAccum (f x) t p (mapExt f e1) (mapExt f e2) (mapExt f e3)
+ EZero x t -> EZero (f x) t
+ EPlus x t a b -> EPlus (f x) t (mapExt f a) (mapExt f b)
+ EOneHot x t p a b -> EOneHot (f x) t p (mapExt f a) (mapExt f b)
+ EError x t s -> EError (f x) t s
+
subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t
subst1 repl = subst $ \x t -> \case IZ -> repl
IS i -> EVar x t i
@@ -302,15 +338,6 @@ subst' f w = \case
weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))
-slistIdx :: SList f list -> Idx list t -> f t
-slistIdx (SCons x _) IZ = x
-slistIdx (SCons _ list) (IS i) = slistIdx list i
-slistIdx SNil i = case i of {}
-
-idx2int :: Idx env t -> Int
-idx2int IZ = 0
-idx2int (IS n) = 1 + idx2int n
-
class KnownScalTy t where knownScalTy :: SScalTy t
instance KnownScalTy TI32 where knownScalTy = STI32
instance KnownScalTy TI64 where knownScalTy = STI64
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index dbb37f7..bd2c244 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -36,6 +36,15 @@ splitIdx SNil i = Right i
splitIdx (SCons _ _) IZ = Left IZ
splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i)
+slistIdx :: SList f list -> Idx list t -> f t
+slistIdx (SCons x _) IZ = x
+slistIdx (SCons _ list) (IS i) = slistIdx list i
+slistIdx SNil i = case i of {}
+
+idx2int :: Idx env t -> Int
+idx2int IZ = 0
+idx2int (IS n) = 1 + idx2int n
+
data env :> env' where
WId :: env :> env
WSink :: forall t env. env :> (t : env)
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 095d0fa..54f7cd2 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -3,7 +3,9 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Analysis.Identity (
+ ValId(..),
identityAnalysis,
+ identityAnalysis',
) where
import Data.Foldable (toList)
@@ -50,6 +52,9 @@ identityAnalysis env term = runIdGen 0 $ do
env' <- slistMapA genIds env
snd <$> idana env' term
+identityAnalysis' :: SList ValId env -> Expr x env t -> Expr ValId env t
+identityAnalysis' env term = snd (runIdGen 0 (idana env term))
+
idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t)
idana env expr = case expr of
EVar _ t i -> do
diff --git a/src/CHAD.hs b/src/CHAD.hs
index be308cd..6a4d5f5 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -36,6 +36,7 @@ import Data.Type.Bool (If)
import Data.Type.Equality (type (==))
import GHC.Stack (HasCallStack)
+import Analysis.Identity (ValId(..))
import AST
import AST.Bindings
import AST.Count
@@ -45,6 +46,8 @@ import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
import Data
+import qualified Data.VarMap as VarMap
+import Data.VarMap (VarMap)
import Lemmas
@@ -558,9 +561,9 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
drev :: forall env sto t.
(?config :: CHADConfig)
- => Descr env sto
- -> Ex env t -> Ret env sto t
-drev des = \case
+ => Descr env sto -> VarMap Int env
+ -> Expr ValId env t -> Ret env sto t
+drev des accumMap = \case
EVar _ t i ->
case conv2Idx des i of
Idx2Ac accI ->
@@ -584,10 +587,10 @@ drev des = \case
(subenvNone (select SMerge des))
(ENil ext)
- ELet _ (rhs :: Ex _ a) body
- | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs
+ ELet _ (rhs :: Expr _ _ a) body
+ | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs
, ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
- , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des (typeOf rhs) storage body
+ , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage body
, let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
, Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
, Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
@@ -613,7 +616,7 @@ drev des = \case
EPair _ a b
| Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
- <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil
+ <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil
, let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->
subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B ->
Ret binds
@@ -632,7 +635,7 @@ drev des = \case
(EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ))
EFst _ e
- | Ret e0 subtape e1 sub e2 <- drev des e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e
, STPair t1 t2 <- typeOf e ->
Ret e0
subtape
@@ -642,7 +645,7 @@ drev des = \case
weakenExpr (WCopy WSink) e2)
ESnd _ e
- | Ret e0 subtape e1 sub e2 <- drev des e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e
, STPair t1 t2 <- typeOf e ->
Ret e0
subtape
@@ -654,7 +657,7 @@ drev des = \case
ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)
EInl _ t2 e
- | Ret e0 subtape e1 sub e2 <- drev des e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
Ret e0
subtape
(EInl ext (d1 t2) e1)
@@ -667,7 +670,7 @@ drev des = \case
(EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ))
EInr _ t1 e
- | Ret e0 subtape e1 sub e2 <- drev des e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
Ret e0
subtape
(EInr ext (d1 t1) e1)
@@ -679,13 +682,13 @@ drev des = \case
(weakenExpr (WCopy (wSinks' @[_,_])) e2))
(EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ))
- ECase _ e (a :: Ex _ t) b
+ ECase _ e (a :: Expr _ _ t) b
| STEither t1 t2 <- typeOf e
- , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des e
+ , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e
, ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
, ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
- , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des t1 storage1 a
- , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des t2 storage2 b
+ , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 a
+ , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 b
, Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA)
@@ -762,7 +765,7 @@ drev des = \case
(ENil ext)
EOp _ op e
- | Ret e0 subtape e1 sub e2 <- drev des e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
case d2op op of
Linear d2opfun ->
Ret e0
@@ -783,15 +786,15 @@ drev des = \case
ECustom _ _ _ storety _ pr du a b
-- allowed to ignore a2 because 'a' is the part of the input that is inactive
| Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil)
- <- retConcat des $ drev des a `SCons` drev des b `SCons` SNil ->
+ <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil ->
Ret (binds `BPush` (typeOf a1, a1)
`BPush` (typeOf b1, weakenExpr WSink b1)
- `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) pr)
+ `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr))
`BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ)))
(SEYes (SENo (SENo (SENo subtape))))
(EFst ext (EVar ext (typeOf pr) (IS IZ)))
bsub
- (ELet ext (weakenExpr (WCopy (WCopy WClosed)) du) $
+ (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $
weakenExpr (WCopy (WSink .> WSink)) b2)
EError _ t s ->
@@ -808,8 +811,8 @@ drev des = \case
(subenvNone (select SMerge des))
(ENil ext)
- EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty)
- | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des she -- allowed to ignore she2 here because she has a discrete result
+ EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty)
+ | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des accumMap she -- allowed to ignore she2 here because she has a discrete result
, let eltty = typeOf orige
, shty :: STy shty <- tTup (sreplicate ndim tIx)
, Refl <- indexTupD1Id ndim ->
@@ -817,7 +820,7 @@ drev des = \case
let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in
subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro ->
- case drev (prodes `DPush` (shty, SDiscr)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
+ case drev (prodes `DPush` (shty, SDiscr)) (VarMap.sink1 (VarMap.subMap usedSub accumMap)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 ->
case assertSubenvEmpty sub of { Refl ->
let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in
let collectexpr = bindingsCollect e0 subtapeE in
@@ -881,7 +884,7 @@ drev des = \case
}}
EUnit _ e
- | Ret e0 subtape e1 sub e2 <- drev des e ->
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e ->
Ret e0
subtape
(EUnit ext e1)
@@ -895,7 +898,7 @@ drev des = \case
EReplicate1Inner _ en e
-- We're allowed to ignore en2 here because the output of 'ei' is discrete.
| Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil)
- <- retConcat des $ drev des en `SCons` drev des e `SCons` SNil
+ <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil
, let STArr ndim eltty = typeOf e ->
Ret binds
subtape
@@ -911,7 +914,7 @@ drev des = \case
(EVar ext (d2 (STArr (SS ndim) eltty)) IZ))
EIdx0 _ e
- | Ret e0 subtape e1 sub e2 <- drev des e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e
, STArr _ t <- typeOf e ->
Ret e0
subtape
@@ -925,7 +928,7 @@ drev des = \case
EIdx1 _ e ei
-- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
| Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
- <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
+ <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil
, STArr (SS n) eltty <- typeOf e ->
Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1)
`BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)))
@@ -942,7 +945,7 @@ drev des = \case
EIdx _ e ei
-- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
| Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
- <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
+ <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil
, STArr n eltty <- typeOf e
, Refl <- indexTupD1Id n
, let tIxN = tTup (sreplicate n tIx) ->
@@ -962,7 +965,7 @@ drev des = \case
EShape _ e
-- Allowed to ignore e2 here because the output of EShape is discrete,
-- hence we'd be passing a zero cotangent to e2 anyway.
- | Ret e0 subtape e1 _ _ <- drev des e
+ | Ret e0 subtape e1 _ _ <- drev des accumMap e
, STArr n _ <- typeOf e
, Refl <- indexTupD1Id n ->
Ret e0
@@ -972,7 +975,7 @@ drev des = \case
(ENil ext)
ESum1Inner _ e
- | Ret e0 subtape e1 sub e2 <- drev des e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e
, STArr (SS n) t <- typeOf e ->
Ret (e0 `BPush` (STArr (SS n) t, e1)
`BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ)))
@@ -1010,9 +1013,9 @@ drev des = \case
deriv_extremum :: ScalIsNumeric t' ~ True
=> (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t')))
- -> Ex env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t'))
+ -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t'))
deriv_extremum extremum e
- | Ret e0 subtape e1 sub e2 <- drev des e
+ | Ret e0 subtape e1 sub e2 <- drev des accumMap e
, at@(STArr (SS n) t@(STScal st)) <- typeOf e
, let at' = STArr n t
, let tIxN = tTup (sreplicate (SS n) tIx) =
@@ -1052,11 +1055,11 @@ deriving instance Show (RetScoped env0 sto a s t)
drevScoped :: forall a s env sto t.
(?config :: CHADConfig)
- => Descr env sto -> STy a -> Storage s
- -> Ex (a : env) t
+ => Descr env sto -> VarMap Int env -> STy a -> Storage s
+ -> Expr ValId (a : env) t
-> RetScoped env sto a s t
-drevScoped des argty argsto expr
- | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) expr
+drevScoped des accumMap argty argsto expr
+ | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) (VarMap.sink1 accumMap) expr
= case argsto of
SMerge ->
case sub of
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index d058132..ced7550 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -10,6 +10,7 @@
{-# LANGUAGE TypeOperators #-}
module CHAD.Top where
+import Analysis.Identity
import AST
import AST.Weaken.Auto
import CHAD
@@ -17,6 +18,7 @@ import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
import Data
+import qualified Data.VarMap as VarMap
type family MergeEnv env where
@@ -85,7 +87,7 @@ chad config env (term :: Ex env t)
&. #tl (d1e env))
(#d :++: #acenv :++: #tl)
(#acenv :++: #d :++: #tl)) $
- freezeRet descr (drev descr term)) $
+ freezeRet descr (drev descr VarMap.empty (identityAnalysis env term))) $
EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))
(reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ))
(ESnd ext (EFst ext (EVar ext tvar IZ)))))
@@ -93,7 +95,7 @@ chad config env (term :: Ex env t)
| False <- chcArgArrayAccum config
, Refl <- mergeEnvNoAccum env
, Refl <- mergeEnvOnlyMerge env
- = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term)
+ = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (identityAnalysis env term))
chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
chad' config env term
diff --git a/src/Compile.hs b/src/Compile.hs
index b4261ca..7bbb043 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -1082,9 +1082,12 @@ compile' env = \case
compileAssign :: String -> SList (Const String) env -> Ex env t -> CompM String
compileAssign prefix env e = do
e' <- compile' env e
- name <- genName' prefix
- emit $ SVarDecl True (repSTy (typeOf e)) name e'
- return name
+ case e' of
+ CELit name -> return name
+ _ -> do
+ name <- genName' prefix
+ emit $ SVarDecl True (repSTy (typeOf e)) name e'
+ return name
data Increment = Increment | Decrement
deriving (Show)
diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs
index 9b29486..5f90ea2 100644
--- a/src/Compile/Exec.hs
+++ b/src/Compile/Exec.hs
@@ -40,7 +40,8 @@ buildKernel csource funnames = do
,"-std=c99", "-x", "c"
,"-o", outso, "-"
,"-Wall", "-Wextra"
- ,"-Wno-unused-variable", "-Wno-unused-parameter", "-Wno-unused-function"]
+ ,"-Wno-unused-variable", "-Wno-unused-but-set-variable"
+ ,"-Wno-unused-parameter", "-Wno-unused-function"]
(ec, gccStdout, gccStderr) <- readProcessWithExitCode "gcc" args csource
-- Print the source before the GCC output.
diff --git a/src/Data/VarMap.hs b/src/Data/VarMap.hs
new file mode 100644
index 0000000..16c2d27
--- /dev/null
+++ b/src/Data/VarMap.hs
@@ -0,0 +1,93 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeOperators #-}
+module Data.VarMap (
+ VarMap,
+ empty,
+ insert,
+ delete,
+ TypedIdx(..),
+ lookup,
+ sink1,
+ unsink1,
+ subMap,
+) where
+
+import Prelude hiding (lookup)
+
+import qualified Data.Map.Strict as Map
+import Data.Map.Strict (Map)
+import Data.Maybe (mapMaybe)
+import Data.Some
+import qualified Data.Vector.Storable as VS
+import Unsafe.Coerce
+
+import AST.Env
+import AST.Types
+import AST.Weaken
+
+
+type role VarMap _ nominal -- ensure that 'env' is not phantom
+data VarMap k (env :: [Ty]) =
+ VarMap Int -- ^ Global offset; must be added to any value in the map in order to get the proper index
+ Int -- ^ Time since last cleanup
+ (Map k (Some STy, Int))
+deriving instance Show k => Show (VarMap k env)
+
+empty :: VarMap k env
+empty = VarMap 0 0 Map.empty
+
+insert :: Ord k => k -> STy t -> Idx env t -> VarMap k env -> VarMap k env
+insert k ty idx (VarMap off interval mp) =
+ maybeCleanup $ VarMap off (interval + 1) (Map.insert k (Some ty, idx2int idx - off) mp)
+
+delete :: Ord k => k -> VarMap k env -> VarMap k env
+delete k (VarMap off interval mp) =
+ maybeCleanup $ VarMap off (interval + 1) (Map.delete k mp)
+
+data TypedIdx env t = TypedIdx (STy t) (Idx env t)
+ deriving (Show)
+
+lookup :: Ord k => k -> VarMap k env -> Maybe (Some (TypedIdx env))
+lookup k (VarMap off _ mp) = do
+ (Some ty, i) <- Map.lookup k mp
+ idx <- unsafeInt2idx (i + off)
+ return (Some (TypedIdx ty idx))
+
+sink1 :: VarMap k env -> VarMap k (t : env)
+sink1 (VarMap off interval mp) = VarMap (off + 1) interval mp
+
+unsink1 :: VarMap k (t : env) -> VarMap k env
+unsink1 (VarMap off interval mp) = VarMap (off - 1) interval mp
+
+subMap :: Eq k => Subenv env env' -> VarMap k env -> VarMap k env'
+subMap subenv =
+ let bools = let loop :: Subenv env env' -> [Bool]
+ loop SETop = []
+ loop (SEYes sub) = True : loop sub
+ loop (SENo sub) = False : loop sub
+ in VS.fromList $ loop subenv
+ newIndices = VS.init $ VS.scanl' (\n b -> if b then n + 1 else n) (0 :: Int) bools
+ modify off (k, (ty, i))
+ | i + off < 0 = Nothing
+ | i + off >= VS.length bools = error "VarMap.subMap: found negative indices in map"
+ | bools VS.! (i + off) = Just (k, (ty, newIndices VS.! (i + off)))
+ | otherwise = Nothing
+ in \(VarMap off _ mp) -> VarMap 0 0 (Map.fromAscList . mapMaybe (modify off) . Map.toAscList $ mp)
+
+maybeCleanup :: VarMap k env -> VarMap k env
+maybeCleanup (VarMap off interval mp)
+ | let sz = Map.size mp
+ , sz > 0, 2 * interval >= 3 * sz
+ = VarMap off 0 (Map.filter (\(_, i) -> i + off >= 0) mp)
+maybeCleanup vm = vm
+
+unsafeInt2idx :: Int -> Maybe (Idx env t)
+unsafeInt2idx = \n -> if n < 0 then Nothing else Just (go n)
+ where
+ go :: Int -> Idx env t
+ go 0 = unsafeCoerce IZ
+ go n = unsafeCoerce (IS (go (n-1)))