summaryrefslogtreecommitdiff
path: root/src/CHAD/Top.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/Top.hs')
-rw-r--r--src/CHAD/Top.hs53
1 files changed, 50 insertions, 3 deletions
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index 4b2a844..12594f2 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -1,13 +1,20 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Top where
import AST
+import AST.Weaken.Auto
import CHAD
+import CHAD.Accum
+import CHAD.EnvDescr
import CHAD.Types
import Data
@@ -28,6 +35,12 @@ mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env
mergeEnvOnlyMerge SNil = Refl
mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl
+accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r
+accumDescr SNil k = k DTop
+accumDescr (t `SCons` env) k = accumDescr env $ \des ->
+ if hasArrays t then k (des `DPush` (t, SAccum))
+ else k (des `DPush` (t, SMerge))
+
d1Identity :: STy t -> D1 t :~: t
d1Identity = \case
STNil -> Refl
@@ -42,9 +55,43 @@ d1eIdentity :: SList STy env -> D1E env :~: env
d1eIdentity SNil = Refl
d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
+reassembleD2E :: Descr env sto
+ -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge"))))
+ -> Ex env' (Tup (D2E env))
+reassembleD2E DTop _ = ENil ext
+reassembleD2E (des `DPush` (_, SAccum)) e =
+ ELet ext e $
+ EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ)))
+ (ESnd ext (EVar ext (typeOf e) IZ))))
+ (ESnd ext (EFst ext (EVar ext (typeOf e) IZ)))
+reassembleD2E (des `DPush` (_, SMerge)) e =
+ ELet ext e $
+ EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ))
+ (EFst ext (ESnd ext (EVar ext (typeOf e) IZ)))))
+ (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ)))
+reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero t)
+
chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
-chad config env term
- | Refl <- mergeEnvNoAccum env
+chad config env (term :: Ex env t)
+ | True <- chcArgArrayAccum config
+ = let ?config = config
+ in accumDescr env $ \descr ->
+ let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr)))
+ tvar = STPair t1 (tTup (d2e (select SAccum descr)))
+ in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $
+ makeAccumulators (select SAccum descr) $
+ weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #acenv (d2ace (select SAccum descr))
+ &. #tl (d1e env))
+ (#d :++: #acenv :++: #tl)
+ (#acenv :++: #d :++: #tl)) $
+ freezeRet descr (drev descr term)) $
+ EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))
+ (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ))
+ (ESnd ext (EFst ext (EVar ext tvar IZ)))))
+
+ | False <- chcArgArrayAccum config
+ , Refl <- mergeEnvNoAccum env
, Refl <- mergeEnvOnlyMerge env
= let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term)