{-# LANGUAGE DataKinds #-} {-# LANGUAGE ImplicitParams #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module CHAD.Top where import Analysis.Identity import AST import AST.Env import AST.Sparse import AST.SplitLets import AST.Weaken.Auto import CHAD import CHAD.Accum import CHAD.EnvDescr import CHAD.Types import Data import qualified Data.VarMap as VarMap type family MergeEnv env where MergeEnv '[] = '[] MergeEnv (t : ts) = "merge" : MergeEnv ts mergeDescr :: SList STy env -> Descr env (MergeEnv env) mergeDescr SNil = DTop mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, Nothing, SMerge) mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[] mergeEnvNoAccum SNil = Refl mergeEnvNoAccum (_ `SCons` env) | Refl <- mergeEnvNoAccum env = Refl mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env mergeEnvOnlyMerge SNil = Refl mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r accumDescr SNil k = k DTop accumDescr (t `SCons` env) k = accumDescr env $ \des -> if hasArrays t then k (des `DPush` (t, Nothing, SAccum)) else k (des `DPush` (t, Nothing, SMerge)) reassembleD2E :: Descr env sto -> D1E env :> env' -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) -> Ex env' (Tup (D2E env)) reassembleD2E DTop _ _ = ENil ext reassembleD2E (des `DPush` (_, _, SAccum)) w e = eunPair e $ \w1 e1 e2 -> eunPair e1 $ \w2 e11 e12 -> EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12 reassembleD2E (des `DPush` (_, _, SMerge)) w e = eunPair e $ \w1 e1 e2 -> eunPair e2 $ \w2 e21 e22 -> EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22 reassembleD2E (des `DPush` (t, _, SDiscr)) w e = EPair ext (reassembleD2E des (WPop w) e) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) chad config env (term :: Ex env t) | True <- chcArgArrayAccum config = let ?config = config in accumDescr env $ \descr -> let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr))) tvar = STPair t1 (tTup (d2e (select SAccum descr))) in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $ makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #acenv (d2ace (select SAccum descr)) &. #tl (d1e env)) (#d :++: #acenv :++: #tl) (#acenv :++: #d :++: #tl)) $ freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $ EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) (reassembleD2E descr (WSink .> WSink) (EPair ext (ESnd ext (EVar ext tvar IZ)) (ESnd ext (EFst ext (EVar ext tvar IZ))))) | False <- chcArgArrayAccum config , Refl <- mergeEnvNoAccum env , Refl <- mergeEnvOnlyMerge env = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term') where term' = identityAnalysis env (splitLets term) chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) chad' config env term | Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term) = chad config env term