summaryrefslogtreecommitdiff
path: root/src/CHAD
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD')
-rw-r--r--src/CHAD/Accum.hs36
-rw-r--r--src/CHAD/EnvDescr.hs75
-rw-r--r--src/CHAD/Heuristics.hs14
-rw-r--r--src/CHAD/Top.hs53
-rw-r--r--src/CHAD/Types.hs26
5 files changed, 182 insertions, 22 deletions
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs
new file mode 100644
index 0000000..e26f781
--- /dev/null
+++ b/src/CHAD/Accum.hs
@@ -0,0 +1,36 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+module CHAD.Accum where
+
+import AST
+import CHAD.Types
+import Data
+
+
+
+hasArrays :: STy t' -> Bool
+hasArrays STNil = False
+hasArrays (STPair a b) = hasArrays a || hasArrays b
+hasArrays (STEither a b) = hasArrays a || hasArrays b
+hasArrays (STMaybe t) = hasArrays t
+hasArrays STArr{} = True
+hasArrays STScal{} = False
+hasArrays STAccum{} = error "Accumulators not allowed in source program"
+
+makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
+makeAccumulators SNil e = e
+makeAccumulators (t `SCons` envpro) e =
+ makeAccumulators envpro $
+ EWith (EZero t) e
+
+uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
+uninvertTup SNil _ e = EPair ext e (ENil ext)
+uninvertTup (t `SCons` list) tcore e =
+ ELet ext (uninvertTup list (STPair tcore t) e) $
+ let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding
+ in EPair ext
+ (EFst ext (EFst ext (EVar ext recT IZ)))
+ (EPair ext
+ (ESnd ext (EVar ext recT IZ))
+ (ESnd ext (EFst ext (EVar ext recT IZ))))
+
diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs
new file mode 100644
index 0000000..fcd91f7
--- /dev/null
+++ b/src/CHAD/EnvDescr.hs
@@ -0,0 +1,75 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.EnvDescr where
+
+import Data.Kind (Type)
+import GHC.TypeLits (Symbol)
+
+import AST.Env
+import AST.Types
+import CHAD.Types
+import Data
+
+
+type Storage :: Symbol -> Type
+data Storage s where
+ SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator
+ SMerge :: Storage "merge" -- ^ just return and merge
+ SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions
+deriving instance Show (Storage s)
+
+-- | Environment description
+data Descr env sto where
+ DTop :: Descr '[] '[]
+ DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto)
+deriving instance Show (Descr env sto)
+
+descrList :: Descr env sto -> SList STy env
+descrList DTop = SNil
+descrList (des `DPush` (t, _)) = t `SCons` descrList des
+
+-- | This could have more precise typing on the output storage.
+subDescr :: Descr env sto -> Subenv env env'
+ -> (forall sto'. Descr env' sto'
+ -> Subenv (Select env sto "merge") (Select env' sto' "merge")
+ -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum"))
+ -> Subenv (D1E env) (D1E env')
+ -> r)
+ -> r
+subDescr DTop SETop k = k DTop SETop SETop SETop
+subDescr (des `DPush` (t, sto)) (SEYes sub) k =
+ subDescr des sub $ \des' submerge subaccum subd1e ->
+ case sto of
+ SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e)
+ SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e)
+ SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e)
+subDescr (des `DPush` (_, sto)) (SENo sub) k =
+ subDescr des sub $ \des' submerge subaccum subd1e ->
+ case sto of
+ SMerge -> k des' (SENo submerge) subaccum (SENo subd1e)
+ SAccum -> k des' submerge (SENo subaccum) (SENo subd1e)
+ SDiscr -> k des' submerge subaccum (SENo subd1e)
+
+-- | Select only the types from the environment that have the specified storage
+type family Select env sto s where
+ Select '[] '[] _ = '[]
+ Select (t : ts) (s : sto) s = t : Select ts sto s
+ Select (_ : ts) (_ : sto) s = Select ts sto s
+
+select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
+select _ DTop = SNil
+select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des)
+select s@SMerge (DPush des (_, SAccum)) = select s des
+select s@SDiscr (DPush des (_, SAccum)) = select s des
+select s@SAccum (DPush des (_, SMerge)) = select s des
+select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des)
+select s@SDiscr (DPush des (_, SMerge)) = select s des
+select s@SAccum (DPush des (_, SDiscr)) = select s des
+select s@SMerge (DPush des (_, SDiscr)) = select s des
+select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des)
diff --git a/src/CHAD/Heuristics.hs b/src/CHAD/Heuristics.hs
deleted file mode 100644
index 6ab8222..0000000
--- a/src/CHAD/Heuristics.hs
+++ /dev/null
@@ -1,14 +0,0 @@
-{-# LANGUAGE GADTs #-}
-module CHAD.Heuristics where
-
-import AST
-
-
-hasArrays :: STy t' -> Bool
-hasArrays STNil = False
-hasArrays (STPair a b) = hasArrays a || hasArrays b
-hasArrays (STEither a b) = hasArrays a || hasArrays b
-hasArrays (STMaybe t) = hasArrays t
-hasArrays STArr{} = True
-hasArrays STScal{} = False
-hasArrays STAccum{} = error "Accumulators not allowed in source program"
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index 4b2a844..12594f2 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -1,13 +1,20 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Top where
import AST
+import AST.Weaken.Auto
import CHAD
+import CHAD.Accum
+import CHAD.EnvDescr
import CHAD.Types
import Data
@@ -28,6 +35,12 @@ mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env
mergeEnvOnlyMerge SNil = Refl
mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl
+accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r
+accumDescr SNil k = k DTop
+accumDescr (t `SCons` env) k = accumDescr env $ \des ->
+ if hasArrays t then k (des `DPush` (t, SAccum))
+ else k (des `DPush` (t, SMerge))
+
d1Identity :: STy t -> D1 t :~: t
d1Identity = \case
STNil -> Refl
@@ -42,9 +55,43 @@ d1eIdentity :: SList STy env -> D1E env :~: env
d1eIdentity SNil = Refl
d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
+reassembleD2E :: Descr env sto
+ -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge"))))
+ -> Ex env' (Tup (D2E env))
+reassembleD2E DTop _ = ENil ext
+reassembleD2E (des `DPush` (_, SAccum)) e =
+ ELet ext e $
+ EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ)))
+ (ESnd ext (EVar ext (typeOf e) IZ))))
+ (ESnd ext (EFst ext (EVar ext (typeOf e) IZ)))
+reassembleD2E (des `DPush` (_, SMerge)) e =
+ ELet ext e $
+ EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ))
+ (EFst ext (ESnd ext (EVar ext (typeOf e) IZ)))))
+ (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ)))
+reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero t)
+
chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
-chad config env term
- | Refl <- mergeEnvNoAccum env
+chad config env (term :: Ex env t)
+ | True <- chcArgArrayAccum config
+ = let ?config = config
+ in accumDescr env $ \descr ->
+ let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr)))
+ tvar = STPair t1 (tTup (d2e (select SAccum descr)))
+ in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $
+ makeAccumulators (select SAccum descr) $
+ weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #acenv (d2ace (select SAccum descr))
+ &. #tl (d1e env))
+ (#d :++: #acenv :++: #tl)
+ (#acenv :++: #d :++: #tl)) $
+ freezeRet descr (drev descr term)) $
+ EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))
+ (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ))
+ (ESnd ext (EFst ext (EVar ext tvar IZ)))))
+
+ | False <- chcArgArrayAccum config
+ , Refl <- mergeEnvNoAccum env
, Refl <- mergeEnvOnlyMerge env
= let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 130493d..6662cbf 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -17,8 +17,8 @@ type family D1 t where
type family D2 t where
D2 TNil = TNil
- D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b))
- D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b))
+ D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))
+ D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b))
D2 (TMaybe t) = TMaybe (D2 t)
D2 (TArr n t) = TArr n (D2 t)
D2 (TScal t) = D2s t
@@ -51,10 +51,14 @@ d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
d1 STAccum{} = error "Accumulators not allowed in input program"
+d1e :: SList STy env -> SList STy (D1E env)
+d1e SNil = SNil
+d1e (t `SCons` env) = d1 t `SCons` d1e env
+
d2 :: STy t -> STy (D2 t)
d2 STNil = STNil
-d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b))
-d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b))
+d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b))
+d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b))
d2 (STMaybe t) = STMaybe (d2 t)
d2 (STArr n t) = STArr n (d2 t)
d2 (STScal t) = case t of
@@ -67,7 +71,11 @@ d2 STAccum{} = error "Accumulators not allowed in input program"
d2e :: SList STy env -> SList STy (D2E env)
d2e SNil = SNil
-d2e (SCons t ts) = SCons (d2 t) (d2e ts)
+d2e (t `SCons` ts) = d2 t `SCons` d2e ts
+
+d2ace :: SList STy env -> SList STy (D2AcE env)
+d2ace SNil = SNil
+d2ace (t `SCons` ts) = STAccum (d2 t) `SCons` d2ace ts
data CHADConfig = CHADConfig
@@ -75,10 +83,18 @@ data CHADConfig = CHADConfig
chcLetArrayAccum :: Bool
, -- | D[case] will bind variables containing arrays in accumulator mode.
chcCaseArrayAccum :: Bool
+ , -- | Introduce top-level arguments containing arrays in accumulator mode.
+ chcArgArrayAccum :: Bool
}
defaultConfig :: CHADConfig
defaultConfig = CHADConfig
{ chcLetArrayAccum = False
, chcCaseArrayAccum = False
+ , chcArgArrayAccum = False
}
+
+chcSetAccum :: CHADConfig -> CHADConfig
+chcSetAccum c = c { chcLetArrayAccum = True
+ , chcCaseArrayAccum = True
+ , chcArgArrayAccum = True }