aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Drev/Top.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-10 21:49:45 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-10 21:50:25 +0100
commit174af2ba568de66e0d890825b8bda930b8e7bb96 (patch)
tree5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/CHAD/Drev/Top.hs
parent92bca235e3aaa287286b6af082d3fce585825a35 (diff)
Move module hierarchy under CHAD.
Diffstat (limited to 'src/CHAD/Drev/Top.hs')
-rw-r--r--src/CHAD/Drev/Top.hs96
1 files changed, 96 insertions, 0 deletions
diff --git a/src/CHAD/Drev/Top.hs b/src/CHAD/Drev/Top.hs
new file mode 100644
index 0000000..510e73e
--- /dev/null
+++ b/src/CHAD/Drev/Top.hs
@@ -0,0 +1,96 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE OverloadedLabels #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Drev.Top where
+
+import CHAD.Analysis.Identity
+import CHAD.AST
+import CHAD.AST.Env
+import CHAD.AST.Sparse
+import CHAD.AST.SplitLets
+import CHAD.AST.Weaken.Auto
+import CHAD.Data
+import qualified CHAD.Data.VarMap as VarMap
+import CHAD.Drev
+import CHAD.Drev.Accum
+import CHAD.Drev.EnvDescr
+import CHAD.Drev.Types
+
+
+type family MergeEnv env where
+ MergeEnv '[] = '[]
+ MergeEnv (t : ts) = "merge" : MergeEnv ts
+
+mergeDescr :: SList STy env -> Descr env (MergeEnv env)
+mergeDescr SNil = DTop
+mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, Nothing, SMerge)
+
+mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[]
+mergeEnvNoAccum SNil = Refl
+mergeEnvNoAccum (_ `SCons` env) | Refl <- mergeEnvNoAccum env = Refl
+
+mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env
+mergeEnvOnlyMerge SNil = Refl
+mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl
+
+accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r
+accumDescr SNil k = k DTop
+accumDescr (t `SCons` env) k = accumDescr env $ \des ->
+ if typeHasArrays t then k (des `DPush` (t, Nothing, SAccum))
+ else k (des `DPush` (t, Nothing, SMerge))
+
+reassembleD2E :: Descr env sto
+ -> D1E env :> env'
+ -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge"))))
+ -> Ex env' (Tup (D2E env))
+reassembleD2E DTop _ _ = ENil ext
+reassembleD2E (des `DPush` (_, _, SAccum)) w e =
+ eunPair e $ \w1 e1 e2 ->
+ eunPair e1 $ \w2 e11 e12 ->
+ EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12
+reassembleD2E (des `DPush` (_, _, SMerge)) w e =
+ eunPair e $ \w1 e1 e2 ->
+ eunPair e2 $ \w2 e21 e22 ->
+ EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22
+reassembleD2E (des `DPush` (t, _, SDiscr)) w e =
+ EPair ext (reassembleD2E des (WPop w) e)
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
+
+chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
+chad config env (term :: Ex env t)
+ | True <- chcArgArrayAccum config
+ = let ?config = config
+ in accumDescr env $ \descr ->
+ let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr)))
+ tvar = STPair t1 (tTup (d2e (select SAccum descr)))
+ in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $
+ makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $
+ weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #acenv (d2ace (select SAccum descr))
+ &. #tl (d1e env))
+ (#d :++: #acenv :++: #tl)
+ (#acenv :++: #d :++: #tl)) $
+ freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $
+ EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))
+ (reassembleD2E descr (WSink .> WSink)
+ (EPair ext (ESnd ext (EVar ext tvar IZ))
+ (ESnd ext (EFst ext (EVar ext tvar IZ)))))
+
+ | False <- chcArgArrayAccum config
+ , Refl <- mergeEnvNoAccum env
+ , Refl <- mergeEnvOnlyMerge env
+ = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term')
+ where
+ term' = identityAnalysis env (splitLets term)
+
+chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
+chad' config env term
+ | Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term)
+ = chad config env term