From 174af2ba568de66e0d890825b8bda930b8e7bb96 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 10 Nov 2025 21:49:45 +0100 Subject: Move module hierarchy under CHAD. --- src/CHAD/Drev/Top.hs | 96 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 src/CHAD/Drev/Top.hs (limited to 'src/CHAD/Drev/Top.hs') diff --git a/src/CHAD/Drev/Top.hs b/src/CHAD/Drev/Top.hs new file mode 100644 index 0000000..510e73e --- /dev/null +++ b/src/CHAD/Drev/Top.hs @@ -0,0 +1,96 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Drev.Top where + +import CHAD.Analysis.Identity +import CHAD.AST +import CHAD.AST.Env +import CHAD.AST.Sparse +import CHAD.AST.SplitLets +import CHAD.AST.Weaken.Auto +import CHAD.Data +import qualified CHAD.Data.VarMap as VarMap +import CHAD.Drev +import CHAD.Drev.Accum +import CHAD.Drev.EnvDescr +import CHAD.Drev.Types + + +type family MergeEnv env where + MergeEnv '[] = '[] + MergeEnv (t : ts) = "merge" : MergeEnv ts + +mergeDescr :: SList STy env -> Descr env (MergeEnv env) +mergeDescr SNil = DTop +mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, Nothing, SMerge) + +mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[] +mergeEnvNoAccum SNil = Refl +mergeEnvNoAccum (_ `SCons` env) | Refl <- mergeEnvNoAccum env = Refl + +mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env +mergeEnvOnlyMerge SNil = Refl +mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl + +accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r +accumDescr SNil k = k DTop +accumDescr (t `SCons` env) k = accumDescr env $ \des -> + if typeHasArrays t then k (des `DPush` (t, Nothing, SAccum)) + else k (des `DPush` (t, Nothing, SMerge)) + +reassembleD2E :: Descr env sto + -> D1E env :> env' + -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) + -> Ex env' (Tup (D2E env)) +reassembleD2E DTop _ _ = ENil ext +reassembleD2E (des `DPush` (_, _, SAccum)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e1 $ \w2 e11 e12 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12 +reassembleD2E (des `DPush` (_, _, SMerge)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e2 $ \w2 e21 e22 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22 +reassembleD2E (des `DPush` (t, _, SDiscr)) w e = + EPair ext (reassembleD2E des (WPop w) e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) + +chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) +chad config env (term :: Ex env t) + | True <- chcArgArrayAccum config + = let ?config = config + in accumDescr env $ \descr -> + let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr))) + tvar = STPair t1 (tTup (d2e (select SAccum descr))) + in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $ + makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $ + weakenExpr (autoWeak (#d (auto1 @(D2 t)) + &. #acenv (d2ace (select SAccum descr)) + &. #tl (d1e env)) + (#d :++: #acenv :++: #tl) + (#acenv :++: #d :++: #tl)) $ + freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $ + EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) + (reassembleD2E descr (WSink .> WSink) + (EPair ext (ESnd ext (EVar ext tvar IZ)) + (ESnd ext (EFst ext (EVar ext tvar IZ))))) + + | False <- chcArgArrayAccum config + , Refl <- mergeEnvNoAccum env + , Refl <- mergeEnvOnlyMerge env + = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term') + where + term' = identityAnalysis env (splitLets term) + +chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +chad' config env term + | Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term) + = chad config env term -- cgit v1.2.3-70-g09d2