diff options
Diffstat (limited to 'src/CHAD')
-rw-r--r-- | src/CHAD/Accum.hs | 36 | ||||
-rw-r--r-- | src/CHAD/EnvDescr.hs | 75 | ||||
-rw-r--r-- | src/CHAD/Heuristics.hs | 14 | ||||
-rw-r--r-- | src/CHAD/Top.hs | 53 | ||||
-rw-r--r-- | src/CHAD/Types.hs | 26 |
5 files changed, 182 insertions, 22 deletions
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs new file mode 100644 index 0000000..e26f781 --- /dev/null +++ b/src/CHAD/Accum.hs @@ -0,0 +1,36 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +module CHAD.Accum where + +import AST +import CHAD.Types +import Data + + + +hasArrays :: STy t' -> Bool +hasArrays STNil = False +hasArrays (STPair a b) = hasArrays a || hasArrays b +hasArrays (STEither a b) = hasArrays a || hasArrays b +hasArrays (STMaybe t) = hasArrays t +hasArrays STArr{} = True +hasArrays STScal{} = False +hasArrays STAccum{} = error "Accumulators not allowed in source program" + +makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) +makeAccumulators SNil e = e +makeAccumulators (t `SCons` envpro) e = + makeAccumulators envpro $ + EWith (EZero t) e + +uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) +uninvertTup SNil _ e = EPair ext e (ENil ext) +uninvertTup (t `SCons` list) tcore e = + ELet ext (uninvertTup list (STPair tcore t) e) $ + let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding + in EPair ext + (EFst ext (EFst ext (EVar ext recT IZ))) + (EPair ext + (ESnd ext (EVar ext recT IZ)) + (ESnd ext (EFst ext (EVar ext recT IZ)))) + diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs new file mode 100644 index 0000000..fcd91f7 --- /dev/null +++ b/src/CHAD/EnvDescr.hs @@ -0,0 +1,75 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.EnvDescr where + +import Data.Kind (Type) +import GHC.TypeLits (Symbol) + +import AST.Env +import AST.Types +import CHAD.Types +import Data + + +type Storage :: Symbol -> Type +data Storage s where + SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator + SMerge :: Storage "merge" -- ^ just return and merge + SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions +deriving instance Show (Storage s) + +-- | Environment description +data Descr env sto where + DTop :: Descr '[] '[] + DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto) +deriving instance Show (Descr env sto) + +descrList :: Descr env sto -> SList STy env +descrList DTop = SNil +descrList (des `DPush` (t, _)) = t `SCons` descrList des + +-- | This could have more precise typing on the output storage. +subDescr :: Descr env sto -> Subenv env env' + -> (forall sto'. Descr env' sto' + -> Subenv (Select env sto "merge") (Select env' sto' "merge") + -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) + -> Subenv (D1E env) (D1E env') + -> r) + -> r +subDescr DTop SETop k = k DTop SETop SETop SETop +subDescr (des `DPush` (t, sto)) (SEYes sub) k = + subDescr des sub $ \des' submerge subaccum subd1e -> + case sto of + SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) + SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) + SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e) +subDescr (des `DPush` (_, sto)) (SENo sub) k = + subDescr des sub $ \des' submerge subaccum subd1e -> + case sto of + SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) + SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) + SDiscr -> k des' submerge subaccum (SENo subd1e) + +-- | Select only the types from the environment that have the specified storage +type family Select env sto s where + Select '[] '[] _ = '[] + Select (t : ts) (s : sto) s = t : Select ts sto s + Select (_ : ts) (_ : sto) s = Select ts sto s + +select :: Storage s -> Descr env sto -> SList STy (Select env sto s) +select _ DTop = SNil +select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) +select s@SMerge (DPush des (_, SAccum)) = select s des +select s@SDiscr (DPush des (_, SAccum)) = select s des +select s@SAccum (DPush des (_, SMerge)) = select s des +select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) +select s@SDiscr (DPush des (_, SMerge)) = select s des +select s@SAccum (DPush des (_, SDiscr)) = select s des +select s@SMerge (DPush des (_, SDiscr)) = select s des +select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des) diff --git a/src/CHAD/Heuristics.hs b/src/CHAD/Heuristics.hs deleted file mode 100644 index 6ab8222..0000000 --- a/src/CHAD/Heuristics.hs +++ /dev/null @@ -1,14 +0,0 @@ -{-# LANGUAGE GADTs #-} -module CHAD.Heuristics where - -import AST - - -hasArrays :: STy t' -> Bool -hasArrays STNil = False -hasArrays (STPair a b) = hasArrays a || hasArrays b -hasArrays (STEither a b) = hasArrays a || hasArrays b -hasArrays (STMaybe t) = hasArrays t -hasArrays STArr{} = True -hasArrays STScal{} = False -hasArrays STAccum{} = error "Accumulators not allowed in source program" diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 4b2a844..12594f2 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -1,13 +1,20 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module CHAD.Top where import AST +import AST.Weaken.Auto import CHAD +import CHAD.Accum +import CHAD.EnvDescr import CHAD.Types import Data @@ -28,6 +35,12 @@ mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env mergeEnvOnlyMerge SNil = Refl mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl +accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r +accumDescr SNil k = k DTop +accumDescr (t `SCons` env) k = accumDescr env $ \des -> + if hasArrays t then k (des `DPush` (t, SAccum)) + else k (des `DPush` (t, SMerge)) + d1Identity :: STy t -> D1 t :~: t d1Identity = \case STNil -> Refl @@ -42,9 +55,43 @@ d1eIdentity :: SList STy env -> D1E env :~: env d1eIdentity SNil = Refl d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl +reassembleD2E :: Descr env sto + -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) + -> Ex env' (Tup (D2E env)) +reassembleD2E DTop _ = ENil ext +reassembleD2E (des `DPush` (_, SAccum)) e = + ELet ext e $ + EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ))) + (ESnd ext (EVar ext (typeOf e) IZ)))) + (ESnd ext (EFst ext (EVar ext (typeOf e) IZ))) +reassembleD2E (des `DPush` (_, SMerge)) e = + ELet ext e $ + EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ)) + (EFst ext (ESnd ext (EVar ext (typeOf e) IZ))))) + (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ))) +reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero t) + chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) -chad config env term - | Refl <- mergeEnvNoAccum env +chad config env (term :: Ex env t) + | True <- chcArgArrayAccum config + = let ?config = config + in accumDescr env $ \descr -> + let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr))) + tvar = STPair t1 (tTup (d2e (select SAccum descr))) + in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $ + makeAccumulators (select SAccum descr) $ + weakenExpr (autoWeak (#d (auto1 @(D2 t)) + &. #acenv (d2ace (select SAccum descr)) + &. #tl (d1e env)) + (#d :++: #acenv :++: #tl) + (#acenv :++: #d :++: #tl)) $ + freezeRet descr (drev descr term)) $ + EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) + (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ)) + (ESnd ext (EFst ext (EVar ext tvar IZ))))) + + | False <- chcArgArrayAccum config + , Refl <- mergeEnvNoAccum env , Refl <- mergeEnvOnlyMerge env = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 130493d..6662cbf 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -17,8 +17,8 @@ type family D1 t where type family D2 t where D2 TNil = TNil - D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b)) - D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b)) + D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) + D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b)) D2 (TMaybe t) = TMaybe (D2 t) D2 (TArr n t) = TArr n (D2 t) D2 (TScal t) = D2s t @@ -51,10 +51,14 @@ d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t d1 STAccum{} = error "Accumulators not allowed in input program" +d1e :: SList STy env -> SList STy (D1E env) +d1e SNil = SNil +d1e (t `SCons` env) = d1 t `SCons` d1e env + d2 :: STy t -> STy (D2 t) d2 STNil = STNil -d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b)) -d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b)) +d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b)) +d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b)) d2 (STMaybe t) = STMaybe (d2 t) d2 (STArr n t) = STArr n (d2 t) d2 (STScal t) = case t of @@ -67,7 +71,11 @@ d2 STAccum{} = error "Accumulators not allowed in input program" d2e :: SList STy env -> SList STy (D2E env) d2e SNil = SNil -d2e (SCons t ts) = SCons (d2 t) (d2e ts) +d2e (t `SCons` ts) = d2 t `SCons` d2e ts + +d2ace :: SList STy env -> SList STy (D2AcE env) +d2ace SNil = SNil +d2ace (t `SCons` ts) = STAccum (d2 t) `SCons` d2ace ts data CHADConfig = CHADConfig @@ -75,10 +83,18 @@ data CHADConfig = CHADConfig chcLetArrayAccum :: Bool , -- | D[case] will bind variables containing arrays in accumulator mode. chcCaseArrayAccum :: Bool + , -- | Introduce top-level arguments containing arrays in accumulator mode. + chcArgArrayAccum :: Bool } defaultConfig :: CHADConfig defaultConfig = CHADConfig { chcLetArrayAccum = False , chcCaseArrayAccum = False + , chcArgArrayAccum = False } + +chcSetAccum :: CHADConfig -> CHADConfig +chcSetAccum c = c { chcLetArrayAccum = True + , chcCaseArrayAccum = True + , chcArgArrayAccum = True } |