diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 65 |
1 files changed, 64 insertions, 1 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 45fcc82..b35836a 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -2,6 +2,7 @@ {-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ImplicitParams #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} @@ -23,6 +24,8 @@ module CHAD ( drev, freezeRet, + CHADConfig(..), + defaultConfig, Storage(..), Descr(..), Select, @@ -724,10 +727,27 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) +--------------------------------- CONFIGURATION -------------------------------- + +data CHADConfig = CHADConfig + { -- | D[let] will bind variables containing arrays in accumulator mode. + chcLetArrayAccum :: Bool + , -- | D[case] will bind variables containing arrays in accumulator mode. + chcCaseArrayAccum :: Bool + } + +defaultConfig :: CHADConfig +defaultConfig = CHADConfig + { chcLetArrayAccum = False + , chcCaseArrayAccum = False + } + + ---------------------------- THE CHAD TRANSFORMATION --------------------------- drev :: forall env sto t. - Descr env sto + (?config :: CHADConfig) + => Descr env sto -> Ex env t -> Ret env sto t drev des = \case EVar _ t i -> @@ -753,6 +773,38 @@ drev des = \case (subenvNone (select SMerge des)) (ENil ext) + ELet _ (rhs :: Ex _ a) body + | chcLetArrayAccum ?config && hasArrays (typeOf rhs) + , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 + <- drev des rhs + , Ret (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 + <- drev (des `DPush` (typeOf rhs, SAccum)) body + , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 + , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) + , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> + subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> + let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in + Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') + (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody) + (weakenExpr wbody0' body1) + subBoth + (ELet ext + (EWith (EZero (typeOf rhs)) $ + weakenExpr (autoWeak (#d (auto1 @(D2 t)) + &. #body (subList (bindingsBinds body0) subtapeBody) + &. #ac (auto1 @(TAccum (D2 a))) + &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) + &. #tl (d2ace (select SAccum des))) + (#d :++: #body :++: #ac :++: #tl) + (#ac :++: #d :++: (#body :++: #rhs) :++: #tl)) + body2) $ + ELet ext + (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ + weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ + plus_RHS_Body + (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) + (EFst ext (EVar ext bodyResType (IS IZ)))) + ELet _ rhs body | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs @@ -848,6 +900,8 @@ drev des = \case (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") (weakenExpr (WCopy (wSinks' @[_,_])) e2))) + ECase{} | chcCaseArrayAccum ?config -> error "chcCaseArrayAccum unsupported" + ECase _ e (a :: Ex _ t) b | STEither t1 t2 <- typeOf e , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des e @@ -1187,3 +1241,12 @@ drev des = \case (EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ)))) (EZero t)) $ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + + hasArrays :: STy t' -> Bool + hasArrays STNil = False + hasArrays (STPair a b) = hasArrays a || hasArrays b + hasArrays (STEither a b) = hasArrays a || hasArrays b + hasArrays (STMaybe t) = hasArrays t + hasArrays STArr{} = True + hasArrays STScal{} = False + hasArrays STAccum{} = error "Accumulators not allowed in source program" |