{-# LANGUAGE DataKinds #-} {-# LANGUAGE ImplicitParams #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module CHAD.Top where import AST import AST.Weaken.Auto import CHAD import CHAD.Accum import CHAD.EnvDescr import CHAD.Types import Data type family MergeEnv env where MergeEnv '[] = '[] MergeEnv (t : ts) = "merge" : MergeEnv ts mergeDescr :: SList STy env -> Descr env (MergeEnv env) mergeDescr SNil = DTop mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, SMerge) mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[] mergeEnvNoAccum SNil = Refl mergeEnvNoAccum (_ `SCons` env) | Refl <- mergeEnvNoAccum env = Refl mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env mergeEnvOnlyMerge SNil = Refl mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r accumDescr SNil k = k DTop accumDescr (t `SCons` env) k = accumDescr env $ \des -> if hasArrays t then k (des `DPush` (t, SAccum)) else k (des `DPush` (t, SMerge)) d1Identity :: STy t -> D1 t :~: t d1Identity = \case STNil -> Refl STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl STMaybe t | Refl <- d1Identity t -> Refl STArr _ t | Refl <- d1Identity t -> Refl STScal _ -> Refl STAccum{} -> error "Accumulators not allowed in input program" d1eIdentity :: SList STy env -> D1E env :~: env d1eIdentity SNil = Refl d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl reassembleD2E :: Descr env sto -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) -> Ex env' (Tup (D2E env)) reassembleD2E DTop _ = ENil ext reassembleD2E (des `DPush` (_, SAccum)) e = ELet ext e $ EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ))) (ESnd ext (EVar ext (typeOf e) IZ)))) (ESnd ext (EFst ext (EVar ext (typeOf e) IZ))) reassembleD2E (des `DPush` (_, SMerge)) e = ELet ext e $ EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ)) (EFst ext (ESnd ext (EVar ext (typeOf e) IZ))))) (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ))) reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero ext t) chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) chad config env (term :: Ex env t) | True <- chcArgArrayAccum config = let ?config = config in accumDescr env $ \descr -> let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr))) tvar = STPair t1 (tTup (d2e (select SAccum descr))) in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $ makeAccumulators (select SAccum descr) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #acenv (d2ace (select SAccum descr)) &. #tl (d1e env)) (#d :++: #acenv :++: #tl) (#acenv :++: #d :++: #tl)) $ freezeRet descr (drev descr term)) $ EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ)) (ESnd ext (EFst ext (EVar ext tvar IZ))))) | False <- chcArgArrayAccum config , Refl <- mergeEnvNoAccum env , Refl <- mergeEnvOnlyMerge env = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term) chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) chad' config env term | Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term) = chad config env term