aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Top.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/Top.hs')
-rw-r--r--src/CHAD/Top.hs63
1 files changed, 29 insertions, 34 deletions
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index d058132..484779e 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -10,13 +10,18 @@
{-# LANGUAGE TypeOperators #-}
module CHAD.Top where
+import Analysis.Identity
import AST
+import AST.Env
+import AST.Sparse
+import AST.SplitLets
import AST.Weaken.Auto
import CHAD
import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
import Data
+import qualified Data.VarMap as VarMap
type family MergeEnv env where
@@ -25,7 +30,7 @@ type family MergeEnv env where
mergeDescr :: SList STy env -> Descr env (MergeEnv env)
mergeDescr SNil = DTop
-mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, SMerge)
+mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, Nothing, SMerge)
mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[]
mergeEnvNoAccum SNil = Refl
@@ -38,38 +43,25 @@ mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl
accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r
accumDescr SNil k = k DTop
accumDescr (t `SCons` env) k = accumDescr env $ \des ->
- if hasArrays t then k (des `DPush` (t, SAccum))
- else k (des `DPush` (t, SMerge))
-
-d1Identity :: STy t -> D1 t :~: t
-d1Identity = \case
- STNil -> Refl
- STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
- STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
- STMaybe t | Refl <- d1Identity t -> Refl
- STArr _ t | Refl <- d1Identity t -> Refl
- STScal _ -> Refl
- STAccum{} -> error "Accumulators not allowed in input program"
-
-d1eIdentity :: SList STy env -> D1E env :~: env
-d1eIdentity SNil = Refl
-d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
+ if hasArrays t then k (des `DPush` (t, Nothing, SAccum))
+ else k (des `DPush` (t, Nothing, SMerge))
reassembleD2E :: Descr env sto
+ -> D1E env :> env'
-> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge"))))
-> Ex env' (Tup (D2E env))
-reassembleD2E DTop _ = ENil ext
-reassembleD2E (des `DPush` (_, SAccum)) e =
- ELet ext e $
- EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ)))
- (ESnd ext (EVar ext (typeOf e) IZ))))
- (ESnd ext (EFst ext (EVar ext (typeOf e) IZ)))
-reassembleD2E (des `DPush` (_, SMerge)) e =
- ELet ext e $
- EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ))
- (EFst ext (ESnd ext (EVar ext (typeOf e) IZ)))))
- (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ)))
-reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero ext t)
+reassembleD2E DTop _ _ = ENil ext
+reassembleD2E (des `DPush` (_, _, SAccum)) w e =
+ eunPair e $ \w1 e1 e2 ->
+ eunPair e1 $ \w2 e11 e12 ->
+ EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12
+reassembleD2E (des `DPush` (_, _, SMerge)) w e =
+ eunPair e $ \w1 e1 e2 ->
+ eunPair e2 $ \w2 e21 e22 ->
+ EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22
+reassembleD2E (des `DPush` (t, _, SDiscr)) w e =
+ EPair ext (reassembleD2E des (WPop w) e)
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
chad config env (term :: Ex env t)
@@ -79,21 +71,24 @@ chad config env (term :: Ex env t)
let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr)))
tvar = STPair t1 (tTup (d2e (select SAccum descr)))
in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $
- makeAccumulators (select SAccum descr) $
+ makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $
weakenExpr (autoWeak (#d (auto1 @(D2 t))
&. #acenv (d2ace (select SAccum descr))
&. #tl (d1e env))
(#d :++: #acenv :++: #tl)
(#acenv :++: #d :++: #tl)) $
- freezeRet descr (drev descr term)) $
+ freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $
EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))
- (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ))
- (ESnd ext (EFst ext (EVar ext tvar IZ)))))
+ (reassembleD2E descr (WSink .> WSink)
+ (EPair ext (ESnd ext (EVar ext tvar IZ))
+ (ESnd ext (EFst ext (EVar ext tvar IZ)))))
| False <- chcArgArrayAccum config
, Refl <- mergeEnvNoAccum env
, Refl <- mergeEnvOnlyMerge env
- = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term)
+ = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term')
+ where
+ term' = identityAnalysis env (splitLets term)
chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
chad' config env term