summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-26 15:25:13 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-26 15:25:13 +0100
commitae2b1b71a91d60d3bd1dfb21fce98c05c1a4fcbb (patch)
tree1f6afda4b1d6925fe8224ee4f2ca40212fe11aa6
parent7774da51c532006da82617ce307d136897693280 (diff)
WIP accum top-level args
-rw-r--r--bench/Main.hs5
-rw-r--r--chad-fast.cabal3
-rw-r--r--src/CHAD.hs149
-rw-r--r--src/CHAD/Accum.hs36
-rw-r--r--src/CHAD/EnvDescr.hs75
-rw-r--r--src/CHAD/Heuristics.hs14
-rw-r--r--src/CHAD/Top.hs53
-rw-r--r--src/CHAD/Types.hs26
-rw-r--r--src/Example.hs2
-rw-r--r--src/Interpreter.hs42
-rw-r--r--test/Main.hs3
11 files changed, 242 insertions, 166 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index 932da9d..5d2cb5a 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -20,7 +20,6 @@ import GHC.Exts (withDict)
import AST
import Array
import qualified CHAD (defaultConfig)
-import CHAD (CHADConfig(..))
import CHAD.Top
import CHAD.Types
import Data
@@ -112,9 +111,7 @@ makeGMMInputs =
SNil
accumConfig :: CHADConfig
-accumConfig = CHADConfig
- { chcLetArrayAccum = True
- , chcCaseArrayAccum = True }
+accumConfig = chcSetAccum CHAD.defaultConfig
main :: IO ()
main = defaultMain
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 6c60a31..893c92f 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -20,7 +20,8 @@ library
AST.Weaken
AST.Weaken.Auto
CHAD
- CHAD.Heuristics
+ CHAD.Accum
+ CHAD.EnvDescr
CHAD.Top
CHAD.Types
-- Compile
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 9f58f73..3da083b 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -32,18 +32,17 @@ module CHAD (
) where
import Data.Functor.Const
-import Data.Kind (Type)
import Data.Type.Bool (If)
import Data.Type.Equality (type (==))
import GHC.Stack (HasCallStack)
-import GHC.TypeLits (Symbol)
import AST
import AST.Bindings
import AST.Count
import AST.Env
import AST.Weaken.Auto
-import CHAD.Heuristics
+import CHAD.Accum
+import CHAD.EnvDescr
import CHAD.Types
import Data
import Lemmas
@@ -197,66 +196,6 @@ reconstructBindings binds tape =
,sreverse (stapeUnfoldings binds))
----------------------- ENVIRONMENT DESCRIPTION AND STORAGE ---------------------
-
-type Storage :: Symbol -> Type
-data Storage s where
- SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator
- SMerge :: Storage "merge" -- ^ just return and merge
- SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions
-deriving instance Show (Storage s)
-
--- | Environment description
-data Descr env sto where
- DTop :: Descr '[] '[]
- DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto)
-deriving instance Show (Descr env sto)
-
-descrList :: Descr env sto -> SList STy env
-descrList DTop = SNil
-descrList (des `DPush` (t, _)) = t `SCons` descrList des
-
--- | This could have more precise typing on the output storage.
-subDescr :: Descr env sto -> Subenv env env'
- -> (forall sto'. Descr env' sto'
- -> Subenv (Select env sto "merge") (Select env' sto' "merge")
- -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum"))
- -> Subenv (D1E env) (D1E env')
- -> r)
- -> r
-subDescr DTop SETop k = k DTop SETop SETop SETop
-subDescr (des `DPush` (t, sto)) (SEYes sub) k =
- subDescr des sub $ \des' submerge subaccum subd1e ->
- case sto of
- SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e)
- SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e)
- SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e)
-subDescr (des `DPush` (_, sto)) (SENo sub) k =
- subDescr des sub $ \des' submerge subaccum subd1e ->
- case sto of
- SMerge -> k des' (SENo submerge) subaccum (SENo subd1e)
- SAccum -> k des' submerge (SENo subaccum) (SENo subd1e)
- SDiscr -> k des' submerge subaccum (SENo subd1e)
-
--- | Select only the types from the environment that have the specified storage
-type family Select env sto s where
- Select '[] '[] _ = '[]
- Select (t : ts) (s : sto) s = t : Select ts sto s
- Select (_ : ts) (_ : sto) s = Select ts sto s
-
-select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
-select _ DTop = SNil
-select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des)
-select s@SMerge (DPush des (_, SAccum)) = select s des
-select s@SDiscr (DPush des (_, SAccum)) = select s des
-select s@SAccum (DPush des (_, SMerge)) = select s des
-select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des)
-select s@SDiscr (DPush des (_, SMerge)) = select s des
-select s@SAccum (DPush des (_, SDiscr)) = select s des
-select s@SMerge (DPush des (_, SDiscr)) = select s des
-select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des)
-
-
---------------------------------- DERIVATIVES ---------------------------------
d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t)
@@ -283,24 +222,24 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
d2op :: SOp a t -> D2Op a t
d2op op = case op of
- OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d)
+ OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d)
OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
- EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
- (EOp ext (OMul t) (EPair ext (EFst ext e) d)))
+ EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
+ (EOp ext (OMul t) (EPair ext (EFst ext e) d)))
ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
- OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
+ OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
+ OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
ONot -> Linear $ \_ -> ENil ext
- OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
- OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
OIf -> Linear $ \_ -> ENil ext
ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
OToFl64 -> Linear $ \_ -> ENil ext
ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)
OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)
OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d)
- OIDiv t -> integralD2 t $ Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil)
where
d2opUnArrangeInt :: SScalTy a
-> (D2s a ~ TScal a => D2Op (TScal a) t)
@@ -316,11 +255,11 @@ d2op op = case op of
-> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
-> D2Op (TPair (TScal a) (TScal a)) t
d2opBinArrangeInt ty float = case ty of
- STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
- STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
STF32 -> float
STF64 -> float
- STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
floatingD2 :: ScalIsFloating a ~ True
=> SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r
@@ -332,13 +271,8 @@ d2op op = case op of
integralD2 STI32 k = k
integralD2 STI64 k = k
-sD1eEnv :: Descr env sto -> SList STy (D1E env)
-sD1eEnv DTop = SNil
-sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)
-
-d2ace :: SList STy env -> SList STy (D2AcE env)
-d2ace SNil = SNil
-d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts)
+desD1E :: Descr env sto -> SList STy (D1E env)
+desD1E = d1e . descrList
-- d1W :: env :> env' -> D1E env :> D1E env'
-- d1W WId = WId
@@ -610,24 +544,6 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
-- STScal{} -> False
-- STAccum{} -> error "An accumulator in merge storage?"
-makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
-makeAccumulators SNil e = e
-makeAccumulators (STArr n t `SCons` envpro) e =
- makeAccumulators envpro $
- EWith (zero (STArr n t)) e
-makeAccumulators (t `SCons` _) _ = error $ "makeAccumulators: Not only arrays in envpro: " ++ show t
-
-uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
-uninvertTup SNil _ e = EPair ext e (ENil ext)
-uninvertTup (t `SCons` list) tcore e =
- ELet ext (uninvertTup list (STPair tcore t) e) $
- let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding
- in EPair ext
- (EFst ext (EFst ext (EVar ext recT IZ)))
- (EPair ext
- (ESnd ext (EVar ext recT IZ))
- (ESnd ext (EFst ext (EVar ext recT IZ))))
-
---------------------------- RETURN TRIPLE FROM CHAD ---------------------------
@@ -701,7 +617,7 @@ freezeRet :: Descr env sto
-> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0
- e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2
+ e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2
in letBinds e0' $
EPair ext
(weakenExpr wInsertD2Ac e1)
@@ -709,7 +625,7 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
&. #tape (subList (bindingsBinds e0) subtape)
&. #shbinds (bindingsBinds e0)
&. #d2ace (d2ace (select SAccum descr))
- &. #tl (sD1eEnv descr))
+ &. #tl (desD1E descr))
(#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)
(#shbinds :++: #d :++: #d2ace :++: #tl))
e2') $
@@ -782,7 +698,7 @@ drev des = \case
subtape
(EPair ext a1 b1)
subBoth
- (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)
+ (EMaybe ext
(zeroTup (subList (select SMerge des) subBoth))
(ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
(weakenExpr (WCopy (wSinks' @[_,_])) a2)) $
@@ -790,7 +706,8 @@ drev des = \case
(weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $
plus_A_B
(EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ))
- (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)))
+ (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))
+ (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ))
EFst _ e
| Ret e0 subtape e1 sub e2 <- drev des e
@@ -799,7 +716,7 @@ drev des = \case
subtape
(EFst ext e1)
sub
- (ELet ext (EInr ext STNil (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $
+ (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $
weakenExpr (WCopy WSink) e2)
ESnd _ e
@@ -809,7 +726,7 @@ drev des = \case
subtape
(ESnd ext e1)
sub
- (ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $
+ (ELet ext (EJust ext (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $
weakenExpr (WCopy WSink) e2)
ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)
@@ -820,11 +737,12 @@ drev des = \case
subtape
(EInl ext (d1 t2) e1)
sub
- (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ)
+ (EMaybe ext
(zeroTup (subList (select SMerge des) sub))
(ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ)
(weakenExpr (WCopy (wSinks' @[_,_])) e2)
- (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")))
+ (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr"))
+ (EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ))
EInr _ t1 e
| Ret e0 subtape e1 sub e2 <- drev des e ->
@@ -832,11 +750,12 @@ drev des = \case
subtape
(EInr ext (d1 t1) e1)
sub
- (ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ)
+ (EMaybe ext
(zeroTup (subList (select SMerge des) sub))
(ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ)
(EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
- (weakenExpr (WCopy (wSinks' @[_,_])) e2)))
+ (weakenExpr (WCopy (wSinks' @[_,_])) e2))
+ (EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ))
ECase _ e (a :: Ex _ t) b
| STEither t1 t2 <- typeOf e
@@ -907,7 +826,7 @@ drev des = \case
(EInr ext (d2 t1)
(ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $
ELet ext
- (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $
+ (ELet ext (EJust ext (ESnd ext (EVar ext tCaseRet IZ))) $
weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $
plus_AB_E
(EFst ext (EVar ext tCaseRet (IS IZ)))
@@ -986,16 +905,16 @@ drev des = \case
(EVar ext shty IZ)
(letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil)
&. #sh (shty `SCons` SNil)
- &. #d1env (sD1eEnv des)
- &. #d1env' (sD1eEnv usedDes))
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes))
(#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
(#ix :++: #sh :++: #d1env))
e0)) $
let w = autoWeak (#ix (shty `SCons` SNil)
&. #sh (shty `SCons` SNil)
&. #e0 (bindingsBinds e0)
- &. #d1env (sD1eEnv des)
- &. #d1env' (sD1eEnv usedDes))
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes))
(#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
(#e0 :++: #ix :++: #sh :++: #d1env)
in EPair ext (weakenExpr w e1) (collectexpr w)))
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs
new file mode 100644
index 0000000..e26f781
--- /dev/null
+++ b/src/CHAD/Accum.hs
@@ -0,0 +1,36 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+module CHAD.Accum where
+
+import AST
+import CHAD.Types
+import Data
+
+
+
+hasArrays :: STy t' -> Bool
+hasArrays STNil = False
+hasArrays (STPair a b) = hasArrays a || hasArrays b
+hasArrays (STEither a b) = hasArrays a || hasArrays b
+hasArrays (STMaybe t) = hasArrays t
+hasArrays STArr{} = True
+hasArrays STScal{} = False
+hasArrays STAccum{} = error "Accumulators not allowed in source program"
+
+makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
+makeAccumulators SNil e = e
+makeAccumulators (t `SCons` envpro) e =
+ makeAccumulators envpro $
+ EWith (EZero t) e
+
+uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
+uninvertTup SNil _ e = EPair ext e (ENil ext)
+uninvertTup (t `SCons` list) tcore e =
+ ELet ext (uninvertTup list (STPair tcore t) e) $
+ let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding
+ in EPair ext
+ (EFst ext (EFst ext (EVar ext recT IZ)))
+ (EPair ext
+ (ESnd ext (EVar ext recT IZ))
+ (ESnd ext (EFst ext (EVar ext recT IZ))))
+
diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs
new file mode 100644
index 0000000..fcd91f7
--- /dev/null
+++ b/src/CHAD/EnvDescr.hs
@@ -0,0 +1,75 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.EnvDescr where
+
+import Data.Kind (Type)
+import GHC.TypeLits (Symbol)
+
+import AST.Env
+import AST.Types
+import CHAD.Types
+import Data
+
+
+type Storage :: Symbol -> Type
+data Storage s where
+ SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator
+ SMerge :: Storage "merge" -- ^ just return and merge
+ SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions
+deriving instance Show (Storage s)
+
+-- | Environment description
+data Descr env sto where
+ DTop :: Descr '[] '[]
+ DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto)
+deriving instance Show (Descr env sto)
+
+descrList :: Descr env sto -> SList STy env
+descrList DTop = SNil
+descrList (des `DPush` (t, _)) = t `SCons` descrList des
+
+-- | This could have more precise typing on the output storage.
+subDescr :: Descr env sto -> Subenv env env'
+ -> (forall sto'. Descr env' sto'
+ -> Subenv (Select env sto "merge") (Select env' sto' "merge")
+ -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum"))
+ -> Subenv (D1E env) (D1E env')
+ -> r)
+ -> r
+subDescr DTop SETop k = k DTop SETop SETop SETop
+subDescr (des `DPush` (t, sto)) (SEYes sub) k =
+ subDescr des sub $ \des' submerge subaccum subd1e ->
+ case sto of
+ SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e)
+ SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e)
+ SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e)
+subDescr (des `DPush` (_, sto)) (SENo sub) k =
+ subDescr des sub $ \des' submerge subaccum subd1e ->
+ case sto of
+ SMerge -> k des' (SENo submerge) subaccum (SENo subd1e)
+ SAccum -> k des' submerge (SENo subaccum) (SENo subd1e)
+ SDiscr -> k des' submerge subaccum (SENo subd1e)
+
+-- | Select only the types from the environment that have the specified storage
+type family Select env sto s where
+ Select '[] '[] _ = '[]
+ Select (t : ts) (s : sto) s = t : Select ts sto s
+ Select (_ : ts) (_ : sto) s = Select ts sto s
+
+select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
+select _ DTop = SNil
+select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des)
+select s@SMerge (DPush des (_, SAccum)) = select s des
+select s@SDiscr (DPush des (_, SAccum)) = select s des
+select s@SAccum (DPush des (_, SMerge)) = select s des
+select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des)
+select s@SDiscr (DPush des (_, SMerge)) = select s des
+select s@SAccum (DPush des (_, SDiscr)) = select s des
+select s@SMerge (DPush des (_, SDiscr)) = select s des
+select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des)
diff --git a/src/CHAD/Heuristics.hs b/src/CHAD/Heuristics.hs
deleted file mode 100644
index 6ab8222..0000000
--- a/src/CHAD/Heuristics.hs
+++ /dev/null
@@ -1,14 +0,0 @@
-{-# LANGUAGE GADTs #-}
-module CHAD.Heuristics where
-
-import AST
-
-
-hasArrays :: STy t' -> Bool
-hasArrays STNil = False
-hasArrays (STPair a b) = hasArrays a || hasArrays b
-hasArrays (STEither a b) = hasArrays a || hasArrays b
-hasArrays (STMaybe t) = hasArrays t
-hasArrays STArr{} = True
-hasArrays STScal{} = False
-hasArrays STAccum{} = error "Accumulators not allowed in source program"
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index 4b2a844..12594f2 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -1,13 +1,20 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Top where
import AST
+import AST.Weaken.Auto
import CHAD
+import CHAD.Accum
+import CHAD.EnvDescr
import CHAD.Types
import Data
@@ -28,6 +35,12 @@ mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env
mergeEnvOnlyMerge SNil = Refl
mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl
+accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r
+accumDescr SNil k = k DTop
+accumDescr (t `SCons` env) k = accumDescr env $ \des ->
+ if hasArrays t then k (des `DPush` (t, SAccum))
+ else k (des `DPush` (t, SMerge))
+
d1Identity :: STy t -> D1 t :~: t
d1Identity = \case
STNil -> Refl
@@ -42,9 +55,43 @@ d1eIdentity :: SList STy env -> D1E env :~: env
d1eIdentity SNil = Refl
d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
+reassembleD2E :: Descr env sto
+ -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge"))))
+ -> Ex env' (Tup (D2E env))
+reassembleD2E DTop _ = ENil ext
+reassembleD2E (des `DPush` (_, SAccum)) e =
+ ELet ext e $
+ EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ)))
+ (ESnd ext (EVar ext (typeOf e) IZ))))
+ (ESnd ext (EFst ext (EVar ext (typeOf e) IZ)))
+reassembleD2E (des `DPush` (_, SMerge)) e =
+ ELet ext e $
+ EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ))
+ (EFst ext (ESnd ext (EVar ext (typeOf e) IZ)))))
+ (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ)))
+reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero t)
+
chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
-chad config env term
- | Refl <- mergeEnvNoAccum env
+chad config env (term :: Ex env t)
+ | True <- chcArgArrayAccum config
+ = let ?config = config
+ in accumDescr env $ \descr ->
+ let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr)))
+ tvar = STPair t1 (tTup (d2e (select SAccum descr)))
+ in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $
+ makeAccumulators (select SAccum descr) $
+ weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #acenv (d2ace (select SAccum descr))
+ &. #tl (d1e env))
+ (#d :++: #acenv :++: #tl)
+ (#acenv :++: #d :++: #tl)) $
+ freezeRet descr (drev descr term)) $
+ EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))
+ (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ))
+ (ESnd ext (EFst ext (EVar ext tvar IZ)))))
+
+ | False <- chcArgArrayAccum config
+ , Refl <- mergeEnvNoAccum env
, Refl <- mergeEnvOnlyMerge env
= let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 130493d..6662cbf 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -17,8 +17,8 @@ type family D1 t where
type family D2 t where
D2 TNil = TNil
- D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b))
- D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b))
+ D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))
+ D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b))
D2 (TMaybe t) = TMaybe (D2 t)
D2 (TArr n t) = TArr n (D2 t)
D2 (TScal t) = D2s t
@@ -51,10 +51,14 @@ d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
d1 STAccum{} = error "Accumulators not allowed in input program"
+d1e :: SList STy env -> SList STy (D1E env)
+d1e SNil = SNil
+d1e (t `SCons` env) = d1 t `SCons` d1e env
+
d2 :: STy t -> STy (D2 t)
d2 STNil = STNil
-d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b))
-d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b))
+d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b))
+d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b))
d2 (STMaybe t) = STMaybe (d2 t)
d2 (STArr n t) = STArr n (d2 t)
d2 (STScal t) = case t of
@@ -67,7 +71,11 @@ d2 STAccum{} = error "Accumulators not allowed in input program"
d2e :: SList STy env -> SList STy (D2E env)
d2e SNil = SNil
-d2e (SCons t ts) = SCons (d2 t) (d2e ts)
+d2e (t `SCons` ts) = d2 t `SCons` d2e ts
+
+d2ace :: SList STy env -> SList STy (D2AcE env)
+d2ace SNil = SNil
+d2ace (t `SCons` ts) = STAccum (d2 t) `SCons` d2ace ts
data CHADConfig = CHADConfig
@@ -75,10 +83,18 @@ data CHADConfig = CHADConfig
chcLetArrayAccum :: Bool
, -- | D[case] will bind variables containing arrays in accumulator mode.
chcCaseArrayAccum :: Bool
+ , -- | Introduce top-level arguments containing arrays in accumulator mode.
+ chcArgArrayAccum :: Bool
}
defaultConfig :: CHADConfig
defaultConfig = CHADConfig
{ chcLetArrayAccum = False
, chcCaseArrayAccum = False
+ , chcArgArrayAccum = False
}
+
+chcSetAccum :: CHADConfig -> CHADConfig
+chcSetAccum c = c { chcLetArrayAccum = True
+ , chcCaseArrayAccum = True
+ , chcArgArrayAccum = True }
diff --git a/src/Example.hs b/src/Example.hs
index 94a6934..795229c 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -184,7 +184,7 @@ neuralGo =
ELet ext (EConst ext STF64 1.0) $
chad defaultConfig knownEnv neural
(primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of
- (primal', (((((), Right dlay1_1'), Right dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1')
+ (primal', (((((), Just dlay1_1'), Just dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1')
_ -> undefined
(Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0
in trace (formatter (ppExpr knownEnv revderiv)) $
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 37d4a83..56ebf82 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -30,6 +30,7 @@ import System.IO (hPutStrLn, stderr)
import System.IO.Unsafe (unsafePerformIO)
import Debug.Trace
+import GHC.Stack
import Array
import AST
@@ -185,8 +186,8 @@ interpretOp op arg = case op of
zeroD2 :: STy t -> Rep (D2 t)
zeroD2 typ = case typ of
STNil -> ()
- STPair _ _ -> Left ()
- STEither _ _ -> Left ()
+ STPair _ _ -> Nothing
+ STEither _ _ -> Nothing
STMaybe _ -> Nothing
STArr SZ t -> arrayUnit (zeroD2 t)
STArr n _ -> emptyArray n
@@ -202,14 +203,14 @@ addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)
addD2s typ a b = case typ of
STNil -> ()
STPair t1 t2 -> case (a, b) of
- (Left (), _) -> b
- (_, Left ()) -> a
- (Right (x1, x2), Right (y1, y2)) -> Right (addD2s t1 x1 y1, addD2s t2 x2 y2)
+ (Nothing, _) -> b
+ (_, Nothing) -> a
+ (Just (x1, x2), Just (y1, y2)) -> Just (addD2s t1 x1 y1, addD2s t2 x2 y2)
STEither t1 t2 -> case (a, b) of
- (Left (), _) -> b
- (_, Left ()) -> a
- (Right (Left x), Right (Left y)) -> Right (Left (addD2s t1 x y))
- (Right (Right x), Right (Right y)) -> Right (Right (addD2s t2 x y))
+ (Nothing, _) -> b
+ (_, Nothing) -> a
+ (Just (Left x), Just (Left y)) -> Just (Left (addD2s t1 x y))
+ (Just (Right x), Just (Right y)) -> Just (Right (addD2s t2 x y))
_ -> error "Plus of inconsistent Eithers"
STMaybe t -> case (a, b) of
(Nothing, _) -> b
@@ -233,16 +234,14 @@ addD2s typ a b = case typ of
onehotD2 :: SNat i -> STy t -> Rep (AcIdx (D2 t) i) -> Rep (AcVal (D2 t) i) -> Rep (D2 t)
onehotD2 SZ _ () v = v
onehotD2 _ STNil _ _ = ()
-onehotD2 (SS _ ) (STPair _ _ ) (Left _ ) (Left _ ) = Left ()
-onehotD2 (SS SZ ) (STPair _ _ ) (Right () ) (Right val ) = Right val
-onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Left idx)) (Right (Left val)) = Right (onehotD2 i t1 idx val, zeroD2 t2)
-onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Right idx)) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val)
-onehotD2 (SS _ ) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value"
-onehotD2 (SS _ ) (STEither _ _ ) (Left _ ) (Left _ ) = Left ()
-onehotD2 (SS SZ ) (STEither _ _ ) (Right () ) (Right val ) = Right val
-onehotD2 (SS (SS i)) (STEither t1 _ ) (Right (Left idx)) (Right (Left val)) = Right (Left (onehotD2 i t1 idx val))
-onehotD2 (SS (SS i)) (STEither _ t2) (Right (Right idx)) (Right (Right val)) = Right (Right (onehotD2 i t2 idx val))
-onehotD2 (SS _ ) (STEither _ _ ) _ _ = error "onehotD2: either: mismatched index and value"
+onehotD2 (SS SZ ) (STPair _ _ ) () val = Just val
+onehotD2 (SS (SS i)) (STPair t1 t2) (Left idx) (Left val) = Just (onehotD2 i t1 idx val, zeroD2 t2)
+onehotD2 (SS (SS i)) (STPair t1 t2) (Right idx) (Right val) = Just (zeroD2 t1, onehotD2 i t2 idx val)
+onehotD2 (SS _ ) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value"
+onehotD2 (SS SZ ) (STEither _ _ ) () val = Just val
+onehotD2 (SS (SS i)) (STEither t1 _ ) (Left idx) (Left val) = Just (Left (onehotD2 i t1 idx val))
+onehotD2 (SS (SS i)) (STEither _ t2) (Right idx) (Right val) = Just (Right (onehotD2 i t2 idx val))
+onehotD2 (SS _ ) (STEither _ _ ) _ _ = error "onehotD2: either: mismatched index and value"
onehotD2 (SS i ) (STMaybe t) idx val = Just (onehotD2 i t idx val)
onehotD2 (SS i ) (STArr n t) idx val = runIdentity $
onehotArray (d2 t) (\i' idx' v' -> Identity (onehotD2 i' t idx' v')) (Identity (zeroD2 t)) n (SS i) idx val
@@ -251,6 +250,7 @@ onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator"
withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
withAccum t _ initval f = AcM $ do
+ putStrLn $ "withAccum: " ++ show t
accum <- newAcSparse t SZ () initval
out <- case f accum of AcM m -> m
val <- readAcSparse t accum
@@ -324,7 +324,7 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n
piindexConcat PIIxEnd ix = ix
piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix)
-newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
+newAcSparse :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
newAcSparse typ SZ () val = case typ of
STNil -> return ()
STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
@@ -372,7 +372,7 @@ onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do
go mk dep' dim' idx' val' $ \arr pish ->
k arr (PIIxCons (fromIntegral @Int64 @Int i) pish)
-newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)
+newAcDense :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)
newAcDense typ SZ () val = case typ of
STEither t1 t2 -> case val of
Left x -> Left <$> newAcSparse t1 SZ () x
diff --git a/test/Main.hs b/test/Main.hs
index 7cb15d5..d18884e 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -24,7 +24,6 @@ import Hedgehog.Main
import Array
import AST
import AST.Pretty
-import CHAD (defaultConfig)
import CHAD.Top
import CHAD.Types
import qualified Example
@@ -238,7 +237,7 @@ term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
fst_ #q * #x + snd_ #q * fst_ #p
tests :: IO Bool
-tests = checkSequential $ Group "AD"
+tests = checkParallel $ Group "AD"
[("id", adTest $ fromNamed $ lambda #x $ body $ #x)
,("idx0", adTest $ fromNamed $ lambda #x $ body $ idx0 #x)