summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-14 19:27:57 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-14 19:27:57 +0100
commitb8c162ce9cb1faeec621b751fff9aff46e022417 (patch)
tree9c31700f34f9a1f1a67e0a73c880938130e87ee6 /src/CHAD.hs
parentbb84f6930702a02ba982795e2bb95a64d61f672b (diff)
Configuration for CHAD
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs65
1 files changed, 64 insertions, 1 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 45fcc82..b35836a 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -2,6 +2,7 @@
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
@@ -23,6 +24,8 @@
module CHAD (
drev,
freezeRet,
+ CHADConfig(..),
+ defaultConfig,
Storage(..),
Descr(..),
Select,
@@ -724,10 +727,27 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))
+--------------------------------- CONFIGURATION --------------------------------
+
+data CHADConfig = CHADConfig
+ { -- | D[let] will bind variables containing arrays in accumulator mode.
+ chcLetArrayAccum :: Bool
+ , -- | D[case] will bind variables containing arrays in accumulator mode.
+ chcCaseArrayAccum :: Bool
+ }
+
+defaultConfig :: CHADConfig
+defaultConfig = CHADConfig
+ { chcLetArrayAccum = False
+ , chcCaseArrayAccum = False
+ }
+
+
---------------------------- THE CHAD TRANSFORMATION ---------------------------
drev :: forall env sto t.
- Descr env sto
+ (?config :: CHADConfig)
+ => Descr env sto
-> Ex env t -> Ret env sto t
drev des = \case
EVar _ t i ->
@@ -753,6 +773,38 @@ drev des = \case
(subenvNone (select SMerge des))
(ENil ext)
+ ELet _ (rhs :: Ex _ a) body
+ | chcLetArrayAccum ?config && hasArrays (typeOf rhs)
+ , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2
+ <- drev des rhs
+ , Ret (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2
+ <- drev (des `DPush` (typeOf rhs, SAccum)) body
+ , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
+ , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
+ , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
+ subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
+ let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in
+ Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
+ (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody)
+ (weakenExpr wbody0' body1)
+ subBoth
+ (ELet ext
+ (EWith (EZero (typeOf rhs)) $
+ weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #body (subList (bindingsBinds body0) subtapeBody)
+ &. #ac (auto1 @(TAccum (D2 a)))
+ &. #rhs (subList (bindingsBinds rhs0) subtapeRHS)
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: #body :++: #ac :++: #tl)
+ (#ac :++: #d :++: (#body :++: #rhs) :++: #tl))
+ body2) $
+ ELet ext
+ (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
+ weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $
+ plus_RHS_Body
+ (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
+ (EFst ext (EVar ext bodyResType (IS IZ))))
+
ELet _ rhs body
| Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2
<- drev des rhs
@@ -848,6 +900,8 @@ drev des = \case
(EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
(weakenExpr (WCopy (wSinks' @[_,_])) e2)))
+ ECase{} | chcCaseArrayAccum ?config -> error "chcCaseArrayAccum unsupported"
+
ECase _ e (a :: Ex _ t) b
| STEither t1 t2 <- typeOf e
, Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des e
@@ -1187,3 +1241,12 @@ drev des = \case
(EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ))))
(EZero t)) $
weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+
+ hasArrays :: STy t' -> Bool
+ hasArrays STNil = False
+ hasArrays (STPair a b) = hasArrays a || hasArrays b
+ hasArrays (STEither a b) = hasArrays a || hasArrays b
+ hasArrays (STMaybe t) = hasArrays t
+ hasArrays STArr{} = True
+ hasArrays STScal{} = False
+ hasArrays STAccum{} = error "Accumulators not allowed in source program"