aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-07 21:39:09 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-07 21:39:09 +0100
commit4ccf1996a5bd739dfb1e62fb3bfb189c04fb6d89 (patch)
tree3f9cb86e846705c87f5539f2aecfdb2b00d76545
parent92bca235e3aaa287286b6af082d3fce585825a35 (diff)
Rewrite CPSy code as do-code using QualifiedDoqualified-contdo
Credits for this trick go to Leary on IRC: https://ircbrowse.tomsmeding.com/browse/lchaskell?id=1691743#trid1691743 Advantage: all the binders are on the left-hand side. Disadvantages: - all continuations need to pass exactly one value, i.e. tuples are required - wacky shit
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/CHAD.hs85
-rw-r--r--src/ContDo.hs15
3 files changed, 82 insertions, 19 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index df0409d..689ecc6 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -36,6 +36,7 @@ library
CHAD.Types.ToTan
Compile
Compile.Exec
+ ContDo
Data
Data.VarMap
Example.GMM
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 298d964..d1d02fa 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -23,6 +23,9 @@
-- useful here.
{-# LANGUAGE PartialTypeSignatures #-}
{-# OPTIONS -Wno-partial-type-signatures #-}
+
+{-# LANGUAGE ImpredicativeTypes #-}
+{-# LANGUAGE QualifiedDo #-}
module CHAD (
drev,
freezeRet,
@@ -47,6 +50,7 @@ import AST.Weaken.Auto
import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
+import qualified ContDo as Cont
import Data
import qualified Data.VarMap as VarMap
import Data.VarMap (VarMap)
@@ -1493,6 +1497,50 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of
, Refl <- lemAppendNil @tapebinds ->
RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2
+subDescr' :: Descr env sto -> Subenv env env'
+ -> (forall sto'. (Descr env' sto'
+ ,Subenv (Select env sto "merge") (Select env' sto' "merge")
+ ,Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum"))
+ ,Subenv (D1E env) (D1E env'))
+ -> r)
+ -> r
+subDescr' des sub k =
+ subDescr des sub $ \a b c d -> k (a, b, c, d)
+
+accumPromote' :: forall dt env sto proxy r.
+ proxy dt
+ -> Descr env sto
+ -> (forall stoRepl envPro.
+ (Select env stoRepl "merge" ~ '[])
+ => (Descr env stoRepl
+ -- ^ A revised environment description that switches
+ -- arrays (used in the OccEnv) that are currently on
+ -- "merge" storage, to "accum" storage.
+ ,SList STy envPro
+ -- ^ New entries on top of the original dual environment,
+ -- that house the accumulators for the promoted arrays in
+ -- the original environment.
+ ,Subenv (Select env sto "merge") envPro
+ -- ^ The promoted entries were merge entries in the
+ -- original environment.
+ ,Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum"))
+ -- ^ All entries that were accumulators are still
+ -- accumulators.
+ ,VarMap Int (D2AcE (Select env stoRepl "accum"))
+ -- ^ Accumulator map for _only_ the the newly allocated
+ -- accumulators.
+ ,(forall shbinds.
+ SList STy shbinds
+ -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
+ :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum")))))
+ -- ^ A weakening that converts a computation in the
+ -- revised environment to one in the original environment
+ -- extended with some accumulators.
+ -> r)
+ -> r
+accumPromote' dt des k =
+ accumPromote dt des $ \a b c d e f -> k (a, b, c, d, e, f)
+
drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False)
=> Descr env sto
-> VarMap Int (D2AcE (Select env sto "accum"))
@@ -1514,19 +1562,19 @@ drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False)
-> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a'
-> r)
-> r
-drevLambda des accumMap (argty, argsto) sd origef k =
- let t = typeOf origef in
- deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) $ \(usedSub :: Subenv env env') ->
- let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef in
- subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
- accumPromote (applySparse sd (d2 t)) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
- let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
- let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
- let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
- case prf1 prodes argty argsto of { Refl ->
- case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 ->
- let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in
- extractContrib prodes argty argsto subEf $ \argSp getSparseArg ->
+drevLambda des accumMap (argty, argsto) sd origef k = Cont.do
+ let t = typeOf origef
+ (usedSub :: Subenv env env') <- deleteUnused (descrList des) (occEnvPopSome (occCountAll origef))
+ let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef
+ (usedDes :: Descr env' _, subMergeUsed, subAccumUsed, subD1eUsed) <- subDescr' des usedSub
+ (prodes, envPro :: SList _ envPro, proSub, proAccRevSub, accumMapProPart, wPro) <- accumPromote' (applySparse sd (d2 t)) usedDes
+ let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart
+ mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub)
+ mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub
+ Refl <- flip id $ prf1 prodes argty argsto
+ Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 <- flip id $ drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef
+ let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf)
+ (argSp, getSparseArg) <- extractContrib prodes argty argsto subEf
let library = #fbinds (bindingsBinds ef0)
&. #ftapebinds (subList (bindingsBinds ef0) subtapeEf)
&. #ftape (auto1 @(Tape e_tape))
@@ -1538,7 +1586,7 @@ drevLambda des accumMap (argty, argsto) sd origef k =
&. #d2acUsed (d2ace (select SAccum usedDes))
&. #d2acEnv (d2ace (select SAccum des))
&. #d2acPro (d2ace envPro)
- &. #efPrerebinds efPrerebinds in
+ &. #efPrerebinds efPrerebinds
k envPro
(subenvD2E (subenvCompose subMergeUsed proSub))
mergePrimalBindings
@@ -1558,17 +1606,16 @@ drevLambda des accumMap (argty, argsto) sd origef k =
((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv)
.> wPro (subList (bindingsBinds ef0) subtapeEf))
(getSparseArg ef2))
- }}
where
extractContrib :: (Select env sto "merge" ~ '[], (s == "accum") ~ False)
=> proxy env sto -> proxy2 a -> Storage s
-- if s == "merge", this simplifies to SubenvS '[D2 a] t'
-- if s == "discr", this simplifies to SubenvS '[] t'
-> SubenvS (D2E (Select (a : env) (s : sto) "merge")) t'
- -> (forall d'. Sparse (D2 a) d' -> (forall env'. Ex env' (Tup t') -> Ex env' d') -> r) -> r
- extractContrib _ _ SMerge (SENo SETop) k' = k' SpAbsent id
- extractContrib _ _ SMerge (SEYes s SETop) k' = k' s (ESnd ext)
- extractContrib _ _ SDiscr SETop k' = k' SpAbsent id
+ -> (forall d'. (Sparse (D2 a) d', (forall env'. Ex env' (Tup t') -> Ex env' d')) -> r) -> r
+ extractContrib _ _ SMerge (SENo SETop) k' = k' (SpAbsent, id)
+ extractContrib _ _ SMerge (SEYes s SETop) k' = k' (s, ESnd ext)
+ extractContrib _ _ SDiscr SETop k' = k' (SpAbsent, id)
prf1 :: (s == "accum") ~ False => proxy env sto -> proxy2 a -> Storage s
-> Select (a : env) (s : sto) "accum" :~: Select env sto "accum"
diff --git a/src/ContDo.hs b/src/ContDo.hs
new file mode 100644
index 0000000..255e21a
--- /dev/null
+++ b/src/ContDo.hs
@@ -0,0 +1,15 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE MonoLocalBinds #-}
+module ContDo where
+
+import GHC.TypeLits
+
+(>>=) :: (a -> b) -> a -> b
+(>>=) = ($)
+
+class AlwaysFail a
+instance TypeError (Text "fail") => AlwaysFail a
+
+fail :: AlwaysFail a => String -> a
+fail = error