From 4ccf1996a5bd739dfb1e62fb3bfb189c04fb6d89 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 7 Nov 2025 21:39:09 +0100 Subject: Rewrite CPSy code as do-code using QualifiedDo Credits for this trick go to Leary on IRC: https://ircbrowse.tomsmeding.com/browse/lchaskell?id=1691743#trid1691743 Advantage: all the binders are on the left-hand side. Disadvantages: - all continuations need to pass exactly one value, i.e. tuples are required - wacky shit --- src/CHAD.hs | 85 +++++++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 66 insertions(+), 19 deletions(-) (limited to 'src/CHAD.hs') diff --git a/src/CHAD.hs b/src/CHAD.hs index 298d964..d1d02fa 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -23,6 +23,9 @@ -- useful here. {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} + +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE QualifiedDo #-} module CHAD ( drev, freezeRet, @@ -47,6 +50,7 @@ import AST.Weaken.Auto import CHAD.Accum import CHAD.EnvDescr import CHAD.Types +import qualified ContDo as Cont import Data import qualified Data.VarMap as VarMap import Data.VarMap (VarMap) @@ -1493,6 +1497,50 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of , Refl <- lemAppendNil @tapebinds -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 +subDescr' :: Descr env sto -> Subenv env env' + -> (forall sto'. (Descr env' sto' + ,Subenv (Select env sto "merge") (Select env' sto' "merge") + ,Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) + ,Subenv (D1E env) (D1E env')) + -> r) + -> r +subDescr' des sub k = + subDescr des sub $ \a b c d -> k (a, b, c, d) + +accumPromote' :: forall dt env sto proxy r. + proxy dt + -> Descr env sto + -> (forall stoRepl envPro. + (Select env stoRepl "merge" ~ '[]) + => (Descr env stoRepl + -- ^ A revised environment description that switches + -- arrays (used in the OccEnv) that are currently on + -- "merge" storage, to "accum" storage. + ,SList STy envPro + -- ^ New entries on top of the original dual environment, + -- that house the accumulators for the promoted arrays in + -- the original environment. + ,Subenv (Select env sto "merge") envPro + -- ^ The promoted entries were merge entries in the + -- original environment. + ,Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum")) + -- ^ All entries that were accumulators are still + -- accumulators. + ,VarMap Int (D2AcE (Select env stoRepl "accum")) + -- ^ Accumulator map for _only_ the the newly allocated + -- accumulators. + ,(forall shbinds. + SList STy shbinds + -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) + :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum"))))) + -- ^ A weakening that converts a computation in the + -- revised environment to one in the original environment + -- extended with some accumulators. + -> r) + -> r +accumPromote' dt des k = + accumPromote dt des $ \a b c d e f -> k (a, b, c, d, e, f) + drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) @@ -1514,19 +1562,19 @@ drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' -> r) -> r -drevLambda des accumMap (argty, argsto) sd origef k = - let t = typeOf origef in - deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) $ \(usedSub :: Subenv env env') -> - let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef in - subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> - accumPromote (applySparse sd (d2 t)) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> - let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in - let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in - let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in - case prf1 prodes argty argsto of { Refl -> - case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> - let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in - extractContrib prodes argty argsto subEf $ \argSp getSparseArg -> +drevLambda des accumMap (argty, argsto) sd origef k = Cont.do + let t = typeOf origef + (usedSub :: Subenv env env') <- deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) + let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef + (usedDes :: Descr env' _, subMergeUsed, subAccumUsed, subD1eUsed) <- subDescr' des usedSub + (prodes, envPro :: SList _ envPro, proSub, proAccRevSub, accumMapProPart, wPro) <- accumPromote' (applySparse sd (d2 t)) usedDes + let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart + mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) + mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub + Refl <- flip id $ prf1 prodes argty argsto + Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 <- flip id $ drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef + let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) + (argSp, getSparseArg) <- extractContrib prodes argty argsto subEf let library = #fbinds (bindingsBinds ef0) &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) &. #ftape (auto1 @(Tape e_tape)) @@ -1538,7 +1586,7 @@ drevLambda des accumMap (argty, argsto) sd origef k = &. #d2acUsed (d2ace (select SAccum usedDes)) &. #d2acEnv (d2ace (select SAccum des)) &. #d2acPro (d2ace envPro) - &. #efPrerebinds efPrerebinds in + &. #efPrerebinds efPrerebinds k envPro (subenvD2E (subenvCompose subMergeUsed proSub)) mergePrimalBindings @@ -1558,17 +1606,16 @@ drevLambda des accumMap (argty, argsto) sd origef k = ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv) .> wPro (subList (bindingsBinds ef0) subtapeEf)) (getSparseArg ef2)) - }} where extractContrib :: (Select env sto "merge" ~ '[], (s == "accum") ~ False) => proxy env sto -> proxy2 a -> Storage s -- if s == "merge", this simplifies to SubenvS '[D2 a] t' -- if s == "discr", this simplifies to SubenvS '[] t' -> SubenvS (D2E (Select (a : env) (s : sto) "merge")) t' - -> (forall d'. Sparse (D2 a) d' -> (forall env'. Ex env' (Tup t') -> Ex env' d') -> r) -> r - extractContrib _ _ SMerge (SENo SETop) k' = k' SpAbsent id - extractContrib _ _ SMerge (SEYes s SETop) k' = k' s (ESnd ext) - extractContrib _ _ SDiscr SETop k' = k' SpAbsent id + -> (forall d'. (Sparse (D2 a) d', (forall env'. Ex env' (Tup t') -> Ex env' d')) -> r) -> r + extractContrib _ _ SMerge (SENo SETop) k' = k' (SpAbsent, id) + extractContrib _ _ SMerge (SEYes s SETop) k' = k' (s, ESnd ext) + extractContrib _ _ SDiscr SETop k' = k' (SpAbsent, id) prf1 :: (s == "accum") ~ False => proxy env sto -> proxy2 a -> Storage s -> Select (a : env) (s : sto) "accum" :~: Select env sto "accum" -- cgit v1.2.3-70-g09d2