summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs209
1 files changed, 108 insertions, 101 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 007ffe3..087a26e 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -30,10 +30,12 @@ module CHAD (
import Data.Bifunctor (first, second)
import Data.Functor.Const
import Data.Kind (Type)
+import GHC.Stack (HasCallStack)
import GHC.TypeLits (Symbol)
import AST
import AST.Count
+import AST.Env
import AST.Weaken.Auto
import Data
import Lemmas
@@ -422,14 +424,6 @@ plusSparse t a b adder =
(EVar ext t (IS IZ))
(weakenExpr (WCopy (WCopy WSink)) adder)))
-type family Tup env where
- Tup '[] = TNil
- Tup (t : ts) = TPair (Tup ts) t
-
-tTup :: SList STy env -> STy (Tup env)
-tTup SNil = STNil
-tTup (SCons t ts) = STPair (tTup ts) t
-
zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
zeroTup SNil = ENil ext
zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t)
@@ -437,18 +431,20 @@ zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t)
accumPromote :: forall dt env sto proxy r.
proxy dt
-> Descr env sto
- -> OccEnv env
-> (forall stoRepl envPro.
- Descr env stoRepl
+ (Select env stoRepl "merge" ~ '[])
+ => Descr env stoRepl
-- ^ A revised environment description that switches
-- arrays (used in the OccEnv) that are currently on
- -- "merge" storage, to "accum" storage.
- -> Subenv (Select env sto "merge") (Select env stoRepl "merge")
- -- ^ The new storage has fewer "merge"-storage entries.
+ -- "merge" storage, to "accum" storage. Any other "merge"
+ -- entries are deleted.
-> SList STy envPro
-- ^ New entries on top of the original dual environment,
-- that house the accumulators for the promoted arrays in
-- the original environment.
+ -> Subenv (Select env sto "merge") envPro
+ -- ^ The promoted entries were merge entries in the
+ -- original environment.
-> (forall shbinds.
SList STy shbinds
-> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
@@ -458,16 +454,15 @@ accumPromote :: forall dt env sto proxy r.
-- extended with some accumulators.
-> r)
-> r
-accumPromote _ DTop _ k = k DTop SETop SNil (\_ -> WId)
-accumPromote _ descr OccEnd k = k descr (subenvAll (select SMerge descr)) SNil (\_ -> WId)
-accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k =
- accumPromote pdty descr occenv $ \(storepl :: Descr env1 stoRepl) mergesub (envpro :: SList _ envPro) wf ->
- case (t, sto, occ) of
+accumPromote _ DTop k = k DTop SNil SETop (\_ -> WId)
+accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub wf ->
+ case sto of
-- Accumulators are left as-is
- (_, SAccum, _) ->
+ SAccum ->
k (storepl `DPush` (t, SAccum))
- mergesub
envpro
+ prosub
(\shbinds ->
autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum descr)))
(#acc :++: (#pro :++: #d :++: #shb :++: #tl))
@@ -477,34 +472,29 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k =
(#d :++: #shb :++: #acc :++: #tl)
(#acc :++: (#d :++: #shb :++: #tl)))
- -- Arrays with "merge" storage and non-zero usage are promoted to an accumulator in envPro
- (STArr (arrn :: SNat arrn) (arrt :: STy arrt), SMerge, Occ _ c) | c > Zero ->
- k (storepl `DPush` (t, SAccum))
- (SENo mergesub)
- (STArr arrn arrt `SCons` envpro)
- (\(shbinds :: SList _ shbinds) ->
- let shbindsC = slistMap (\_ -> Const ()) shbinds
- in
- -- wf:
- -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
- -- WCopy wf:
- -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
- -- WPICK: ^ THESE TWO ||
- -- goal: | ARE EQUAL ||
- -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
- WCopy (wf shbinds)
- .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
- (WId @(D2AcE (Select env1 stoRepl "accum"))))
-
- -- Used "merge" values must be an array, so reject everything else. (TODO: generalise this)
- (_, SMerge, Occ _ c)
- | c > Zero ->
- error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t
- | otherwise ->
- k (storepl `DPush` (t, SMerge))
- (SEYes mergesub)
- envpro
- wf
+ SMerge -> case t of
+ -- Arrays with "merge" storage are promoted to an accumulator in envPro
+ STArr (arrn :: SNat arrn) (arrt :: STy arrt) ->
+ k (storepl `DPush` (t, SAccum))
+ (STArr arrn arrt `SCons` envpro)
+ (SEYes prosub)
+ (\(shbinds :: SList _ shbinds) ->
+ let shbindsC = slistMap (\_ -> Const ()) shbinds
+ in
+ -- wf:
+ -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ -- WCopy wf:
+ -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ -- WPICK: ^ THESE TWO ||
+ -- goal: | ARE EQUAL ||
+ -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ WCopy (wf shbinds)
+ .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
+ (WId @(D2AcE (Select env1 stoRepl "accum"))))
+
+ -- "merge" values must be an array, so reject everything else. (TODO: generalise this)
+ _ ->
+ error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t
-- where
-- containsTArr :: STy t' -> Bool
-- containsTArr = \case
@@ -537,14 +527,6 @@ uninvertTup (t `SCons` list) tcore e =
(ESnd ext (EVar ext recT IZ))
(ESnd ext (EFst ext (EVar ext recT IZ))))
--- | @env'@ is a subset of @env@: each element of @env@ is either included in
--- @env'@ ('SEYes') or not included in @env'@ ('SENo').
-data Subenv env env' where
- SETop :: Subenv '[] '[]
- SEYes :: Subenv env env' -> Subenv (t : env) (t : env')
- SENo :: Subenv env env' -> Subenv (t : env) env'
-deriving instance Show (Subenv env env')
-
data Ret env0 sto t =
forall shbinds env0Merge.
Ret (Bindings Ex (D1E env0) shbinds) -- shared binds
@@ -566,24 +548,6 @@ data Rets env0 sto env list =
(SList (RetPair env0 sto env shbinds) list)
deriving instance Show (Rets env0 sto env list)
-subList :: SList f env -> Subenv env env' -> SList f env'
-subList SNil SETop = SNil
-subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub)
-subList (SCons _ xs) (SENo sub) = subList xs sub
-
-subenvAll :: SList f env -> Subenv env env
-subenvAll SNil = SETop
-subenvAll (SCons _ env) = SEYes (subenvAll env)
-
-subenvNone :: SList f env -> Subenv env '[]
-subenvNone SNil = SETop
-subenvNone (SCons _ env) = SENo (subenvNone env)
-
-subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t]
-subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env)
-subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i)
-subenvOnehot SNil i = case i of {}
-
subenvPlus :: SList STy env
-> Subenv env env1 -> Subenv env env2
-> (forall env3. Subenv env env3
@@ -631,7 +595,7 @@ expandSubenvZeros (SCons t ts) (SEYes sub) e =
in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)
expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t)
-assertSubenvEmpty :: Subenv env env' -> env' :~: '[]
+assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[]
assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
assertSubenvEmpty SETop = Refl
assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
@@ -748,6 +712,10 @@ data Descr env sto where
DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto)
deriving instance Show (Descr env sto)
+descrList :: Descr env sto -> SList STy env
+descrList DTop = SNil
+descrList (des `DPush` (t, _)) = t `SCons` descrList des
+
select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
select _ DTop = SNil
select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des)
@@ -755,6 +723,26 @@ select s@SMerge (DPush des (_, SAccum)) = select s des
select s@SAccum (DPush des (_, SMerge)) = select s des
select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des)
+-- | This could have more precise typing on the output storage.
+subDescr :: Descr env sto -> Subenv env env'
+ -> (forall sto'. Descr env' sto'
+ -> Subenv (Select env sto "merge") (Select env' sto' "merge")
+ -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum"))
+ -> Subenv (D1E env) (D1E env')
+ -> r)
+ -> r
+subDescr DTop SETop k = k DTop SETop SETop SETop
+subDescr (des `DPush` (t, sto)) (SEYes sub) k =
+ subDescr des sub $ \des' submerge subaccum subd1e ->
+ case sto of
+ SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e)
+ SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e)
+subDescr (des `DPush` (_, sto)) (SENo sub) k =
+ subDescr des sub $ \des' submerge subaccum subd1e ->
+ case sto of
+ SMerge -> k des' (SENo submerge) subaccum (SENo subd1e)
+ SAccum -> k des' submerge (SENo subaccum) (SENo subd1e)
+
sD1eEnv :: Descr env sto -> SList STy (D1E env)
sD1eEnv DTop = SNil
sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)
@@ -990,16 +978,18 @@ drev des = \case
(subenvNone (select SMerge des))
(ENil ext)
- EBuild1 _ ne e
- | Ret (ne0 :: Bindings _ _ ne_binds) ne1 nsub ne2 <- drev des ne
- , let eltty = typeOf e ->
- accumPromote eltty des (occEnvPop (occCountAll e)) $ \vdes proSub envPro wPro ->
- case drev (vdes `DPush` (tIx, SMerge)) e of { Ret e0 e1 sub e2 ->
+ EBuild1 _ ne (orige :: Ex _ eltty)
+ | Ret (ne0 :: Bindings _ _ ne_binds) ne1 _ _ <- drev des ne -- allowed to ignore ne2 here because ne has a discrete result
+ , let eltty = typeOf orige ->
+ deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
+ let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in
+ subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
+ accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro ->
+ case drev (prodes `DPush` (tIx, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) e1 sub e2 ->
case assertSubenvEmpty sub of { Refl ->
- case assertSubenvEmpty proSub of { Refl ->
- let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv des) IZ e0 in
+ let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv usedDes) IZ e0 in
Ret (bconcat (ne0 `BPush` (tIx, ne1))
- (fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0)))
+ (fst (weakenBindings weakenExpr (WCopy (wSinksAnd (bindingsBinds ne0) (wUndoSubenv subD1eUsed))) ve0)))
(EBuild1 ext
(weakenExpr (autoWeak (#ve0 (bindingsBinds ve0)
&. #binds (tIx `SCons` bindingsBinds ne0)
@@ -1007,7 +997,7 @@ drev des = \case
#binds
((#ve0 :++: #binds) :++: #tl))
(EVar ext tIx IZ))
- (subst (\_ t i -> case splitIdx @(TIx : D1E env) (bindingsBinds e0) i of
+ (subst (\_ t i -> case splitIdx @(TIx : D1E env') (bindingsBinds e0) i of
Left ibind ->
let ibind' =
autoWeak (#ix (auto1 @TIx)
@@ -1020,9 +1010,9 @@ drev des = \case
in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t) ibind')
(EVar ext tIx IZ))
Right IZ -> EVar ext tIx IZ -- build lambda index argument
- Right (IS ienv) -> EVar ext t (IS (wSinks (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) @> ienv)))
+ Right (IS ienv) -> EVar ext t (IS (wSinksAnd (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) (wUndoSubenv subD1eUsed) @> ienv)))
e1))
- nsub
+ (subenvCompose subMergeUsed proSub)
(ELet ext
(uninvertTup (d2e envPro) (STArr (SS SZ) STNil) $
makeAccumulators @_ @_ @(TArr (S Z) TNil) envPro $
@@ -1035,8 +1025,21 @@ drev des = \case
#binds
(#pro :++: #d :++: (#ve0 :++: #binds) :++: #tl))
(EVar ext tIx IZ))
- -- TODO: use vectoriseExpr
- (_ $
+ (ELet ext (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (d2 eltty))
+ (IS (wSinks @(TArr (S Z) (D2 eltty) : Append (Append (Vectorise (S Z) e_binds) (TIx : ne_binds)) (D2AcE (Select env sto "accum")))
+ (d2ace envPro)
+ @> IZ)))
+ (EVar ext tIx IZ))) $
+ weakenExpr (autoWeak (#i (auto1 @TIx)
+ &. #dpro (d2ace envPro)
+ &. #d (d2 eltty `SCons` SNil)
+ &. #darr (STArr (SS SZ) (d2 eltty) `SCons` SNil)
+ &. #n (auto1 @TIx)
+ &. #vbinds (bindingsBinds ve0)
+ &. #ne0 (bindingsBinds ne0)
+ &. #tl (d2ace (select SAccum des)))
+ (#i :++: (#dpro :++: #d) :++: #vbinds :++: #tl)
+ (#d :++: #i :++: #dpro :++: #darr :++: (#vbinds :++: #n :++: #ne0) :++: #tl)) $
vectoriseExpr (sappend (d2ace envPro) (d2 eltty `SCons` SNil)) (bindingsBinds e0) (d2ace (select SAccum des)) $
weakenExpr (autoWeak (#dpro (d2ace envPro)
&. #d (d2 eltty `SCons` SNil)
@@ -1044,19 +1047,12 @@ drev des = \case
&. #tl (d2ace (select SAccum des)))
(#dpro :++: #d :++: #binds :++: #tl)
((#dpro :++: #d) :++: #binds :++: #tl)) $
- weakenExpr (wPro (bindingsBinds e0)) e2)) $
+ weakenExpr (wCopies (d2ace envPro) (WCopy @(D2 eltty) (wCopies (bindingsBinds e0) (wUndoSubenv subAccumUsed)))) $
+ weakenExpr (wPro (bindingsBinds e0)) $
+ e2)) $
ELet ext (ENil ext) $
- weakenExpr (autoWeak (#nil (auto1 @TNil)
- &. #d (auto1 @(D2 t))
- &. #nilarr (auto1 @(TArr (S Z) TNil))
- &. #ve0 (bindingsBinds ve0)
- &. #n (auto1 @TIx)
- &. #binds (bindingsBinds ne0)
- &. #tl (d2ace (select SAccum des)))
- (#nil :++: #binds :++: #tl)
- (#nil :++: #nilarr :++: #d :++: (#ve0 :++: #n :++: #binds) :++: #tl))
- ne2)
- }}}
+ ESnd ext (EVar ext (STPair (STArr (SS SZ) STNil) (tTup (d2e envPro))) (IS IZ)))
+ }}
EUnit _ e
| Ret e0 e1 sub e2 <- drev des e ->
@@ -1075,9 +1071,20 @@ drev des = \case
(ELet ext (EUnit ext (EVar ext (d2 t) IZ)) $
weakenExpr (WCopy WSink) e2)
+ EIdx1 _ e ei
+ -- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
+ | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
+ <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
+ ->
+ Ret binds
+ (EIdx1 ext e1 ei1)
+ sub
+ (_ e2)
+
-- These should be the next to be implemented, I think
- EIdx1{} -> err_unsupported "EIdx1"
EFold1{} -> err_unsupported "EFold1"
+ EShape{} -> err_unsupported "EShape"
+ EReplicate{} -> err_unsupported "EReplicate"
EIdx{} -> err_unsupported "EIdx"
EBuild{} -> err_unsupported "EBuild"