diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 209 |
1 files changed, 108 insertions, 101 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 007ffe3..087a26e 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -30,10 +30,12 @@ module CHAD ( import Data.Bifunctor (first, second) import Data.Functor.Const import Data.Kind (Type) +import GHC.Stack (HasCallStack) import GHC.TypeLits (Symbol) import AST import AST.Count +import AST.Env import AST.Weaken.Auto import Data import Lemmas @@ -422,14 +424,6 @@ plusSparse t a b adder = (EVar ext t (IS IZ)) (weakenExpr (WCopy (WCopy WSink)) adder))) -type family Tup env where - Tup '[] = TNil - Tup (t : ts) = TPair (Tup ts) t - -tTup :: SList STy env -> STy (Tup env) -tTup SNil = STNil -tTup (SCons t ts) = STPair (tTup ts) t - zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) zeroTup SNil = ENil ext zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t) @@ -437,18 +431,20 @@ zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t) accumPromote :: forall dt env sto proxy r. proxy dt -> Descr env sto - -> OccEnv env -> (forall stoRepl envPro. - Descr env stoRepl + (Select env stoRepl "merge" ~ '[]) + => Descr env stoRepl -- ^ A revised environment description that switches -- arrays (used in the OccEnv) that are currently on - -- "merge" storage, to "accum" storage. - -> Subenv (Select env sto "merge") (Select env stoRepl "merge") - -- ^ The new storage has fewer "merge"-storage entries. + -- "merge" storage, to "accum" storage. Any other "merge" + -- entries are deleted. -> SList STy envPro -- ^ New entries on top of the original dual environment, -- that house the accumulators for the promoted arrays in -- the original environment. + -> Subenv (Select env sto "merge") envPro + -- ^ The promoted entries were merge entries in the + -- original environment. -> (forall shbinds. SList STy shbinds -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) @@ -458,16 +454,15 @@ accumPromote :: forall dt env sto proxy r. -- extended with some accumulators. -> r) -> r -accumPromote _ DTop _ k = k DTop SETop SNil (\_ -> WId) -accumPromote _ descr OccEnd k = k descr (subenvAll (select SMerge descr)) SNil (\_ -> WId) -accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k = - accumPromote pdty descr occenv $ \(storepl :: Descr env1 stoRepl) mergesub (envpro :: SList _ envPro) wf -> - case (t, sto, occ) of +accumPromote _ DTop k = k DTop SNil SETop (\_ -> WId) +accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub wf -> + case sto of -- Accumulators are left as-is - (_, SAccum, _) -> + SAccum -> k (storepl `DPush` (t, SAccum)) - mergesub envpro + prosub (\shbinds -> autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum descr))) (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) @@ -477,34 +472,29 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k = (#d :++: #shb :++: #acc :++: #tl) (#acc :++: (#d :++: #shb :++: #tl))) - -- Arrays with "merge" storage and non-zero usage are promoted to an accumulator in envPro - (STArr (arrn :: SNat arrn) (arrt :: STy arrt), SMerge, Occ _ c) | c > Zero -> - k (storepl `DPush` (t, SAccum)) - (SENo mergesub) - (STArr arrn arrt `SCons` envpro) - (\(shbinds :: SList _ shbinds) -> - let shbindsC = slistMap (\_ -> Const ()) shbinds - in - -- wf: - -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WCopy wf: - -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WPICK: ^ THESE TWO || - -- goal: | ARE EQUAL || - -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - WCopy (wf shbinds) - .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) - (WId @(D2AcE (Select env1 stoRepl "accum")))) - - -- Used "merge" values must be an array, so reject everything else. (TODO: generalise this) - (_, SMerge, Occ _ c) - | c > Zero -> - error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t - | otherwise -> - k (storepl `DPush` (t, SMerge)) - (SEYes mergesub) - envpro - wf + SMerge -> case t of + -- Arrays with "merge" storage are promoted to an accumulator in envPro + STArr (arrn :: SNat arrn) (arrt :: STy arrt) -> + k (storepl `DPush` (t, SAccum)) + (STArr arrn arrt `SCons` envpro) + (SEYes prosub) + (\(shbinds :: SList _ shbinds) -> + let shbindsC = slistMap (\_ -> Const ()) shbinds + in + -- wf: + -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WCopy wf: + -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WPICK: ^ THESE TWO || + -- goal: | ARE EQUAL || + -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + WCopy (wf shbinds) + .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) + (WId @(D2AcE (Select env1 stoRepl "accum")))) + + -- "merge" values must be an array, so reject everything else. (TODO: generalise this) + _ -> + error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t -- where -- containsTArr :: STy t' -> Bool -- containsTArr = \case @@ -537,14 +527,6 @@ uninvertTup (t `SCons` list) tcore e = (ESnd ext (EVar ext recT IZ)) (ESnd ext (EFst ext (EVar ext recT IZ)))) --- | @env'@ is a subset of @env@: each element of @env@ is either included in --- @env'@ ('SEYes') or not included in @env'@ ('SENo'). -data Subenv env env' where - SETop :: Subenv '[] '[] - SEYes :: Subenv env env' -> Subenv (t : env) (t : env') - SENo :: Subenv env env' -> Subenv (t : env) env' -deriving instance Show (Subenv env env') - data Ret env0 sto t = forall shbinds env0Merge. Ret (Bindings Ex (D1E env0) shbinds) -- shared binds @@ -566,24 +548,6 @@ data Rets env0 sto env list = (SList (RetPair env0 sto env shbinds) list) deriving instance Show (Rets env0 sto env list) -subList :: SList f env -> Subenv env env' -> SList f env' -subList SNil SETop = SNil -subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) -subList (SCons _ xs) (SENo sub) = subList xs sub - -subenvAll :: SList f env -> Subenv env env -subenvAll SNil = SETop -subenvAll (SCons _ env) = SEYes (subenvAll env) - -subenvNone :: SList f env -> Subenv env '[] -subenvNone SNil = SETop -subenvNone (SCons _ env) = SENo (subenvNone env) - -subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t] -subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) -subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) -subenvOnehot SNil i = case i of {} - subenvPlus :: SList STy env -> Subenv env env1 -> Subenv env env2 -> (forall env3. Subenv env env3 @@ -631,7 +595,7 @@ expandSubenvZeros (SCons t ts) (SEYes sub) e = in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t) -assertSubenvEmpty :: Subenv env env' -> env' :~: '[] +assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl assertSubenvEmpty SETop = Refl assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" @@ -748,6 +712,10 @@ data Descr env sto where DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto) deriving instance Show (Descr env sto) +descrList :: Descr env sto -> SList STy env +descrList DTop = SNil +descrList (des `DPush` (t, _)) = t `SCons` descrList des + select :: Storage s -> Descr env sto -> SList STy (Select env sto s) select _ DTop = SNil select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) @@ -755,6 +723,26 @@ select s@SMerge (DPush des (_, SAccum)) = select s des select s@SAccum (DPush des (_, SMerge)) = select s des select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) +-- | This could have more precise typing on the output storage. +subDescr :: Descr env sto -> Subenv env env' + -> (forall sto'. Descr env' sto' + -> Subenv (Select env sto "merge") (Select env' sto' "merge") + -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) + -> Subenv (D1E env) (D1E env') + -> r) + -> r +subDescr DTop SETop k = k DTop SETop SETop SETop +subDescr (des `DPush` (t, sto)) (SEYes sub) k = + subDescr des sub $ \des' submerge subaccum subd1e -> + case sto of + SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) + SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) +subDescr (des `DPush` (_, sto)) (SENo sub) k = + subDescr des sub $ \des' submerge subaccum subd1e -> + case sto of + SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) + SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) + sD1eEnv :: Descr env sto -> SList STy (D1E env) sD1eEnv DTop = SNil sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d) @@ -990,16 +978,18 @@ drev des = \case (subenvNone (select SMerge des)) (ENil ext) - EBuild1 _ ne e - | Ret (ne0 :: Bindings _ _ ne_binds) ne1 nsub ne2 <- drev des ne - , let eltty = typeOf e -> - accumPromote eltty des (occEnvPop (occCountAll e)) $ \vdes proSub envPro wPro -> - case drev (vdes `DPush` (tIx, SMerge)) e of { Ret e0 e1 sub e2 -> + EBuild1 _ ne (orige :: Ex _ eltty) + | Ret (ne0 :: Bindings _ _ ne_binds) ne1 _ _ <- drev des ne -- allowed to ignore ne2 here because ne has a discrete result + , let eltty = typeOf orige -> + deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> + let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in + subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> + accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> + case drev (prodes `DPush` (tIx, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) e1 sub e2 -> case assertSubenvEmpty sub of { Refl -> - case assertSubenvEmpty proSub of { Refl -> - let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv des) IZ e0 in + let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv usedDes) IZ e0 in Ret (bconcat (ne0 `BPush` (tIx, ne1)) - (fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0))) + (fst (weakenBindings weakenExpr (WCopy (wSinksAnd (bindingsBinds ne0) (wUndoSubenv subD1eUsed))) ve0))) (EBuild1 ext (weakenExpr (autoWeak (#ve0 (bindingsBinds ve0) &. #binds (tIx `SCons` bindingsBinds ne0) @@ -1007,7 +997,7 @@ drev des = \case #binds ((#ve0 :++: #binds) :++: #tl)) (EVar ext tIx IZ)) - (subst (\_ t i -> case splitIdx @(TIx : D1E env) (bindingsBinds e0) i of + (subst (\_ t i -> case splitIdx @(TIx : D1E env') (bindingsBinds e0) i of Left ibind -> let ibind' = autoWeak (#ix (auto1 @TIx) @@ -1020,9 +1010,9 @@ drev des = \case in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t) ibind') (EVar ext tIx IZ)) Right IZ -> EVar ext tIx IZ -- build lambda index argument - Right (IS ienv) -> EVar ext t (IS (wSinks (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) @> ienv))) + Right (IS ienv) -> EVar ext t (IS (wSinksAnd (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) (wUndoSubenv subD1eUsed) @> ienv))) e1)) - nsub + (subenvCompose subMergeUsed proSub) (ELet ext (uninvertTup (d2e envPro) (STArr (SS SZ) STNil) $ makeAccumulators @_ @_ @(TArr (S Z) TNil) envPro $ @@ -1035,8 +1025,21 @@ drev des = \case #binds (#pro :++: #d :++: (#ve0 :++: #binds) :++: #tl)) (EVar ext tIx IZ)) - -- TODO: use vectoriseExpr - (_ $ + (ELet ext (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (d2 eltty)) + (IS (wSinks @(TArr (S Z) (D2 eltty) : Append (Append (Vectorise (S Z) e_binds) (TIx : ne_binds)) (D2AcE (Select env sto "accum"))) + (d2ace envPro) + @> IZ))) + (EVar ext tIx IZ))) $ + weakenExpr (autoWeak (#i (auto1 @TIx) + &. #dpro (d2ace envPro) + &. #d (d2 eltty `SCons` SNil) + &. #darr (STArr (SS SZ) (d2 eltty) `SCons` SNil) + &. #n (auto1 @TIx) + &. #vbinds (bindingsBinds ve0) + &. #ne0 (bindingsBinds ne0) + &. #tl (d2ace (select SAccum des))) + (#i :++: (#dpro :++: #d) :++: #vbinds :++: #tl) + (#d :++: #i :++: #dpro :++: #darr :++: (#vbinds :++: #n :++: #ne0) :++: #tl)) $ vectoriseExpr (sappend (d2ace envPro) (d2 eltty `SCons` SNil)) (bindingsBinds e0) (d2ace (select SAccum des)) $ weakenExpr (autoWeak (#dpro (d2ace envPro) &. #d (d2 eltty `SCons` SNil) @@ -1044,19 +1047,12 @@ drev des = \case &. #tl (d2ace (select SAccum des))) (#dpro :++: #d :++: #binds :++: #tl) ((#dpro :++: #d) :++: #binds :++: #tl)) $ - weakenExpr (wPro (bindingsBinds e0)) e2)) $ + weakenExpr (wCopies (d2ace envPro) (WCopy @(D2 eltty) (wCopies (bindingsBinds e0) (wUndoSubenv subAccumUsed)))) $ + weakenExpr (wPro (bindingsBinds e0)) $ + e2)) $ ELet ext (ENil ext) $ - weakenExpr (autoWeak (#nil (auto1 @TNil) - &. #d (auto1 @(D2 t)) - &. #nilarr (auto1 @(TArr (S Z) TNil)) - &. #ve0 (bindingsBinds ve0) - &. #n (auto1 @TIx) - &. #binds (bindingsBinds ne0) - &. #tl (d2ace (select SAccum des))) - (#nil :++: #binds :++: #tl) - (#nil :++: #nilarr :++: #d :++: (#ve0 :++: #n :++: #binds) :++: #tl)) - ne2) - }}} + ESnd ext (EVar ext (STPair (STArr (SS SZ) STNil) (tTup (d2e envPro))) (IS IZ))) + }} EUnit _ e | Ret e0 e1 sub e2 <- drev des e -> @@ -1075,9 +1071,20 @@ drev des = \case (ELet ext (EUnit ext (EVar ext (d2 t) IZ)) $ weakenExpr (WCopy WSink) e2) + EIdx1 _ e ei + -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. + | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) + <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil + -> + Ret binds + (EIdx1 ext e1 ei1) + sub + (_ e2) + -- These should be the next to be implemented, I think - EIdx1{} -> err_unsupported "EIdx1" EFold1{} -> err_unsupported "EFold1" + EShape{} -> err_unsupported "EShape" + EReplicate{} -> err_unsupported "EReplicate" EIdx{} -> err_unsupported "EIdx" EBuild{} -> err_unsupported "EBuild" |