diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-26 15:25:13 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-26 15:25:13 +0100 |
commit | ae2b1b71a91d60d3bd1dfb21fce98c05c1a4fcbb (patch) | |
tree | 1f6afda4b1d6925fe8224ee4f2ca40212fe11aa6 | |
parent | 7774da51c532006da82617ce307d136897693280 (diff) |
WIP accum top-level args
-rw-r--r-- | bench/Main.hs | 5 | ||||
-rw-r--r-- | chad-fast.cabal | 3 | ||||
-rw-r--r-- | src/CHAD.hs | 149 | ||||
-rw-r--r-- | src/CHAD/Accum.hs | 36 | ||||
-rw-r--r-- | src/CHAD/EnvDescr.hs | 75 | ||||
-rw-r--r-- | src/CHAD/Heuristics.hs | 14 | ||||
-rw-r--r-- | src/CHAD/Top.hs | 53 | ||||
-rw-r--r-- | src/CHAD/Types.hs | 26 | ||||
-rw-r--r-- | src/Example.hs | 2 | ||||
-rw-r--r-- | src/Interpreter.hs | 42 | ||||
-rw-r--r-- | test/Main.hs | 3 |
11 files changed, 242 insertions, 166 deletions
diff --git a/bench/Main.hs b/bench/Main.hs index 932da9d..5d2cb5a 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -20,7 +20,6 @@ import GHC.Exts (withDict) import AST import Array import qualified CHAD (defaultConfig) -import CHAD (CHADConfig(..)) import CHAD.Top import CHAD.Types import Data @@ -112,9 +111,7 @@ makeGMMInputs = SNil accumConfig :: CHADConfig -accumConfig = CHADConfig - { chcLetArrayAccum = True - , chcCaseArrayAccum = True } +accumConfig = chcSetAccum CHAD.defaultConfig main :: IO () main = defaultMain diff --git a/chad-fast.cabal b/chad-fast.cabal index 6c60a31..893c92f 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -20,7 +20,8 @@ library AST.Weaken AST.Weaken.Auto CHAD - CHAD.Heuristics + CHAD.Accum + CHAD.EnvDescr CHAD.Top CHAD.Types -- Compile diff --git a/src/CHAD.hs b/src/CHAD.hs index 9f58f73..3da083b 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -32,18 +32,17 @@ module CHAD ( ) where import Data.Functor.Const -import Data.Kind (Type) import Data.Type.Bool (If) import Data.Type.Equality (type (==)) import GHC.Stack (HasCallStack) -import GHC.TypeLits (Symbol) import AST import AST.Bindings import AST.Count import AST.Env import AST.Weaken.Auto -import CHAD.Heuristics +import CHAD.Accum +import CHAD.EnvDescr import CHAD.Types import Data import Lemmas @@ -197,66 +196,6 @@ reconstructBindings binds tape = ,sreverse (stapeUnfoldings binds)) ----------------------- ENVIRONMENT DESCRIPTION AND STORAGE --------------------- - -type Storage :: Symbol -> Type -data Storage s where - SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator - SMerge :: Storage "merge" -- ^ just return and merge - SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions -deriving instance Show (Storage s) - --- | Environment description -data Descr env sto where - DTop :: Descr '[] '[] - DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto) -deriving instance Show (Descr env sto) - -descrList :: Descr env sto -> SList STy env -descrList DTop = SNil -descrList (des `DPush` (t, _)) = t `SCons` descrList des - --- | This could have more precise typing on the output storage. -subDescr :: Descr env sto -> Subenv env env' - -> (forall sto'. Descr env' sto' - -> Subenv (Select env sto "merge") (Select env' sto' "merge") - -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) - -> Subenv (D1E env) (D1E env') - -> r) - -> r -subDescr DTop SETop k = k DTop SETop SETop SETop -subDescr (des `DPush` (t, sto)) (SEYes sub) k = - subDescr des sub $ \des' submerge subaccum subd1e -> - case sto of - SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) - SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) - SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e) -subDescr (des `DPush` (_, sto)) (SENo sub) k = - subDescr des sub $ \des' submerge subaccum subd1e -> - case sto of - SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) - SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) - SDiscr -> k des' submerge subaccum (SENo subd1e) - --- | Select only the types from the environment that have the specified storage -type family Select env sto s where - Select '[] '[] _ = '[] - Select (t : ts) (s : sto) s = t : Select ts sto s - Select (_ : ts) (_ : sto) s = Select ts sto s - -select :: Storage s -> Descr env sto -> SList STy (Select env sto s) -select _ DTop = SNil -select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) -select s@SMerge (DPush des (_, SAccum)) = select s des -select s@SDiscr (DPush des (_, SAccum)) = select s des -select s@SAccum (DPush des (_, SMerge)) = select s des -select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) -select s@SDiscr (DPush des (_, SMerge)) = select s des -select s@SAccum (DPush des (_, SDiscr)) = select s des -select s@SMerge (DPush des (_, SDiscr)) = select s des -select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des) - - ---------------------------------- DERIVATIVES --------------------------------- d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) @@ -283,24 +222,24 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) d2op :: SOp a t -> D2Op a t d2op op = case op of - OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d) + OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d) OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> - EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) - (EOp ext (OMul t) (EPair ext (EFst ext e) d))) + EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) + (EOp ext (OMul t) (EPair ext (EFst ext e) d))) ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d - OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) - OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) - OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) + OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) + OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) + OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) ONot -> Linear $ \_ -> ENil ext - OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) - OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil) OIf -> Linear $ \_ -> ENil ext ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 OToFl64 -> Linear $ \_ -> ENil ext ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) - OIDiv t -> integralD2 t $ Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) where d2opUnArrangeInt :: SScalTy a -> (D2s a ~ TScal a => D2Op (TScal a) t) @@ -316,11 +255,11 @@ d2op op = case op of -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) -> D2Op (TPair (TScal a) (TScal a)) t d2opBinArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) - STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) STF32 -> float STF64 -> float - STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil) floatingD2 :: ScalIsFloating a ~ True => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r @@ -332,13 +271,8 @@ d2op op = case op of integralD2 STI32 k = k integralD2 STI64 k = k -sD1eEnv :: Descr env sto -> SList STy (D1E env) -sD1eEnv DTop = SNil -sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d) - -d2ace :: SList STy env -> SList STy (D2AcE env) -d2ace SNil = SNil -d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts) +desD1E :: Descr env sto -> SList STy (D1E env) +desD1E = d1e . descrList -- d1W :: env :> env' -> D1E env :> D1E env' -- d1W WId = WId @@ -610,24 +544,6 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = -- STScal{} -> False -- STAccum{} -> error "An accumulator in merge storage?" -makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) -makeAccumulators SNil e = e -makeAccumulators (STArr n t `SCons` envpro) e = - makeAccumulators envpro $ - EWith (zero (STArr n t)) e -makeAccumulators (t `SCons` _) _ = error $ "makeAccumulators: Not only arrays in envpro: " ++ show t - -uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) -uninvertTup SNil _ e = EPair ext e (ENil ext) -uninvertTup (t `SCons` list) tcore e = - ELet ext (uninvertTup list (STPair tcore t) e) $ - let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding - in EPair ext - (EFst ext (EFst ext (EVar ext recT IZ))) - (EPair ext - (ESnd ext (EVar ext recT IZ)) - (ESnd ext (EFst ext (EVar ext recT IZ)))) - ---------------------------- RETURN TRIPLE FROM CHAD --------------------------- @@ -701,7 +617,7 @@ freezeRet :: Descr env sto -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 - e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2 + e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 in letBinds e0' $ EPair ext (weakenExpr wInsertD2Ac e1) @@ -709,7 +625,7 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = &. #tape (subList (bindingsBinds e0) subtape) &. #shbinds (bindingsBinds e0) &. #d2ace (d2ace (select SAccum descr)) - &. #tl (sD1eEnv descr)) + &. #tl (desD1E descr)) (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) (#shbinds :++: #d :++: #d2ace :++: #tl)) e2') $ @@ -782,7 +698,7 @@ drev des = \case subtape (EPair ext a1 b1) subBoth - (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ) + (EMaybe ext (zeroTup (subList (select SMerge des) subBoth)) (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $ @@ -790,7 +706,8 @@ drev des = \case (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $ plus_A_B (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) - (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))) + (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)) + (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)) EFst _ e | Ret e0 subtape e1 sub e2 <- drev des e @@ -799,7 +716,7 @@ drev des = \case subtape (EFst ext e1) sub - (ELet ext (EInr ext STNil (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $ + (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $ weakenExpr (WCopy WSink) e2) ESnd _ e @@ -809,7 +726,7 @@ drev des = \case subtape (ESnd ext e1) sub - (ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $ + (ELet ext (EJust ext (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $ weakenExpr (WCopy WSink) e2) ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) @@ -820,11 +737,12 @@ drev des = \case subtape (EInl ext (d1 t2) e1) sub - (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ) + (EMaybe ext (zeroTup (subList (select SMerge des) sub)) (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ) (weakenExpr (WCopy (wSinks' @[_,_])) e2) - (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr"))) + (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")) + (EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ)) EInr _ t1 e | Ret e0 subtape e1 sub e2 <- drev des e -> @@ -832,11 +750,12 @@ drev des = \case subtape (EInr ext (d1 t1) e1) sub - (ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ) + (EMaybe ext (zeroTup (subList (select SMerge des) sub)) (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ) (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") - (weakenExpr (WCopy (wSinks' @[_,_])) e2))) + (weakenExpr (WCopy (wSinks' @[_,_])) e2)) + (EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ)) ECase _ e (a :: Ex _ t) b | STEither t1 t2 <- typeOf e @@ -907,7 +826,7 @@ drev des = \case (EInr ext (d2 t1) (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $ ELet ext - (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $ + (ELet ext (EJust ext (ESnd ext (EVar ext tCaseRet IZ))) $ weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $ plus_AB_E (EFst ext (EVar ext tCaseRet (IS IZ))) @@ -986,16 +905,16 @@ drev des = \case (EVar ext shty IZ) (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) &. #sh (shty `SCons` SNil) - &. #d1env (sD1eEnv des) - &. #d1env' (sD1eEnv usedDes)) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes)) (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) (#ix :++: #sh :++: #d1env)) e0)) $ let w = autoWeak (#ix (shty `SCons` SNil) &. #sh (shty `SCons` SNil) &. #e0 (bindingsBinds e0) - &. #d1env (sD1eEnv des) - &. #d1env' (sD1eEnv usedDes)) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes)) (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) (#e0 :++: #ix :++: #sh :++: #d1env) in EPair ext (weakenExpr w e1) (collectexpr w))) diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs new file mode 100644 index 0000000..e26f781 --- /dev/null +++ b/src/CHAD/Accum.hs @@ -0,0 +1,36 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +module CHAD.Accum where + +import AST +import CHAD.Types +import Data + + + +hasArrays :: STy t' -> Bool +hasArrays STNil = False +hasArrays (STPair a b) = hasArrays a || hasArrays b +hasArrays (STEither a b) = hasArrays a || hasArrays b +hasArrays (STMaybe t) = hasArrays t +hasArrays STArr{} = True +hasArrays STScal{} = False +hasArrays STAccum{} = error "Accumulators not allowed in source program" + +makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) +makeAccumulators SNil e = e +makeAccumulators (t `SCons` envpro) e = + makeAccumulators envpro $ + EWith (EZero t) e + +uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) +uninvertTup SNil _ e = EPair ext e (ENil ext) +uninvertTup (t `SCons` list) tcore e = + ELet ext (uninvertTup list (STPair tcore t) e) $ + let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding + in EPair ext + (EFst ext (EFst ext (EVar ext recT IZ))) + (EPair ext + (ESnd ext (EVar ext recT IZ)) + (ESnd ext (EFst ext (EVar ext recT IZ)))) + diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs new file mode 100644 index 0000000..fcd91f7 --- /dev/null +++ b/src/CHAD/EnvDescr.hs @@ -0,0 +1,75 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.EnvDescr where + +import Data.Kind (Type) +import GHC.TypeLits (Symbol) + +import AST.Env +import AST.Types +import CHAD.Types +import Data + + +type Storage :: Symbol -> Type +data Storage s where + SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator + SMerge :: Storage "merge" -- ^ just return and merge + SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions +deriving instance Show (Storage s) + +-- | Environment description +data Descr env sto where + DTop :: Descr '[] '[] + DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto) +deriving instance Show (Descr env sto) + +descrList :: Descr env sto -> SList STy env +descrList DTop = SNil +descrList (des `DPush` (t, _)) = t `SCons` descrList des + +-- | This could have more precise typing on the output storage. +subDescr :: Descr env sto -> Subenv env env' + -> (forall sto'. Descr env' sto' + -> Subenv (Select env sto "merge") (Select env' sto' "merge") + -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) + -> Subenv (D1E env) (D1E env') + -> r) + -> r +subDescr DTop SETop k = k DTop SETop SETop SETop +subDescr (des `DPush` (t, sto)) (SEYes sub) k = + subDescr des sub $ \des' submerge subaccum subd1e -> + case sto of + SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) + SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) + SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e) +subDescr (des `DPush` (_, sto)) (SENo sub) k = + subDescr des sub $ \des' submerge subaccum subd1e -> + case sto of + SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) + SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) + SDiscr -> k des' submerge subaccum (SENo subd1e) + +-- | Select only the types from the environment that have the specified storage +type family Select env sto s where + Select '[] '[] _ = '[] + Select (t : ts) (s : sto) s = t : Select ts sto s + Select (_ : ts) (_ : sto) s = Select ts sto s + +select :: Storage s -> Descr env sto -> SList STy (Select env sto s) +select _ DTop = SNil +select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) +select s@SMerge (DPush des (_, SAccum)) = select s des +select s@SDiscr (DPush des (_, SAccum)) = select s des +select s@SAccum (DPush des (_, SMerge)) = select s des +select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) +select s@SDiscr (DPush des (_, SMerge)) = select s des +select s@SAccum (DPush des (_, SDiscr)) = select s des +select s@SMerge (DPush des (_, SDiscr)) = select s des +select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des) diff --git a/src/CHAD/Heuristics.hs b/src/CHAD/Heuristics.hs deleted file mode 100644 index 6ab8222..0000000 --- a/src/CHAD/Heuristics.hs +++ /dev/null @@ -1,14 +0,0 @@ -{-# LANGUAGE GADTs #-} -module CHAD.Heuristics where - -import AST - - -hasArrays :: STy t' -> Bool -hasArrays STNil = False -hasArrays (STPair a b) = hasArrays a || hasArrays b -hasArrays (STEither a b) = hasArrays a || hasArrays b -hasArrays (STMaybe t) = hasArrays t -hasArrays STArr{} = True -hasArrays STScal{} = False -hasArrays STAccum{} = error "Accumulators not allowed in source program" diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 4b2a844..12594f2 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -1,13 +1,20 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module CHAD.Top where import AST +import AST.Weaken.Auto import CHAD +import CHAD.Accum +import CHAD.EnvDescr import CHAD.Types import Data @@ -28,6 +35,12 @@ mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env mergeEnvOnlyMerge SNil = Refl mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl +accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r +accumDescr SNil k = k DTop +accumDescr (t `SCons` env) k = accumDescr env $ \des -> + if hasArrays t then k (des `DPush` (t, SAccum)) + else k (des `DPush` (t, SMerge)) + d1Identity :: STy t -> D1 t :~: t d1Identity = \case STNil -> Refl @@ -42,9 +55,43 @@ d1eIdentity :: SList STy env -> D1E env :~: env d1eIdentity SNil = Refl d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl +reassembleD2E :: Descr env sto + -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) + -> Ex env' (Tup (D2E env)) +reassembleD2E DTop _ = ENil ext +reassembleD2E (des `DPush` (_, SAccum)) e = + ELet ext e $ + EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ))) + (ESnd ext (EVar ext (typeOf e) IZ)))) + (ESnd ext (EFst ext (EVar ext (typeOf e) IZ))) +reassembleD2E (des `DPush` (_, SMerge)) e = + ELet ext e $ + EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ)) + (EFst ext (ESnd ext (EVar ext (typeOf e) IZ))))) + (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ))) +reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero t) + chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) -chad config env term - | Refl <- mergeEnvNoAccum env +chad config env (term :: Ex env t) + | True <- chcArgArrayAccum config + = let ?config = config + in accumDescr env $ \descr -> + let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr))) + tvar = STPair t1 (tTup (d2e (select SAccum descr))) + in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $ + makeAccumulators (select SAccum descr) $ + weakenExpr (autoWeak (#d (auto1 @(D2 t)) + &. #acenv (d2ace (select SAccum descr)) + &. #tl (d1e env)) + (#d :++: #acenv :++: #tl) + (#acenv :++: #d :++: #tl)) $ + freezeRet descr (drev descr term)) $ + EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) + (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ)) + (ESnd ext (EFst ext (EVar ext tvar IZ))))) + + | False <- chcArgArrayAccum config + , Refl <- mergeEnvNoAccum env , Refl <- mergeEnvOnlyMerge env = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 130493d..6662cbf 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -17,8 +17,8 @@ type family D1 t where type family D2 t where D2 TNil = TNil - D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b)) - D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b)) + D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) + D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b)) D2 (TMaybe t) = TMaybe (D2 t) D2 (TArr n t) = TArr n (D2 t) D2 (TScal t) = D2s t @@ -51,10 +51,14 @@ d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t d1 STAccum{} = error "Accumulators not allowed in input program" +d1e :: SList STy env -> SList STy (D1E env) +d1e SNil = SNil +d1e (t `SCons` env) = d1 t `SCons` d1e env + d2 :: STy t -> STy (D2 t) d2 STNil = STNil -d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b)) -d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b)) +d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b)) +d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b)) d2 (STMaybe t) = STMaybe (d2 t) d2 (STArr n t) = STArr n (d2 t) d2 (STScal t) = case t of @@ -67,7 +71,11 @@ d2 STAccum{} = error "Accumulators not allowed in input program" d2e :: SList STy env -> SList STy (D2E env) d2e SNil = SNil -d2e (SCons t ts) = SCons (d2 t) (d2e ts) +d2e (t `SCons` ts) = d2 t `SCons` d2e ts + +d2ace :: SList STy env -> SList STy (D2AcE env) +d2ace SNil = SNil +d2ace (t `SCons` ts) = STAccum (d2 t) `SCons` d2ace ts data CHADConfig = CHADConfig @@ -75,10 +83,18 @@ data CHADConfig = CHADConfig chcLetArrayAccum :: Bool , -- | D[case] will bind variables containing arrays in accumulator mode. chcCaseArrayAccum :: Bool + , -- | Introduce top-level arguments containing arrays in accumulator mode. + chcArgArrayAccum :: Bool } defaultConfig :: CHADConfig defaultConfig = CHADConfig { chcLetArrayAccum = False , chcCaseArrayAccum = False + , chcArgArrayAccum = False } + +chcSetAccum :: CHADConfig -> CHADConfig +chcSetAccum c = c { chcLetArrayAccum = True + , chcCaseArrayAccum = True + , chcArgArrayAccum = True } diff --git a/src/Example.hs b/src/Example.hs index 94a6934..795229c 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -184,7 +184,7 @@ neuralGo = ELet ext (EConst ext STF64 1.0) $ chad defaultConfig knownEnv neural (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of - (primal', (((((), Right dlay1_1'), Right dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1') + (primal', (((((), Just dlay1_1'), Just dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1') _ -> undefined (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0 in trace (formatter (ppExpr knownEnv revderiv)) $ diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 37d4a83..56ebf82 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -30,6 +30,7 @@ import System.IO (hPutStrLn, stderr) import System.IO.Unsafe (unsafePerformIO) import Debug.Trace +import GHC.Stack import Array import AST @@ -185,8 +186,8 @@ interpretOp op arg = case op of zeroD2 :: STy t -> Rep (D2 t) zeroD2 typ = case typ of STNil -> () - STPair _ _ -> Left () - STEither _ _ -> Left () + STPair _ _ -> Nothing + STEither _ _ -> Nothing STMaybe _ -> Nothing STArr SZ t -> arrayUnit (zeroD2 t) STArr n _ -> emptyArray n @@ -202,14 +203,14 @@ addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) addD2s typ a b = case typ of STNil -> () STPair t1 t2 -> case (a, b) of - (Left (), _) -> b - (_, Left ()) -> a - (Right (x1, x2), Right (y1, y2)) -> Right (addD2s t1 x1 y1, addD2s t2 x2 y2) + (Nothing, _) -> b + (_, Nothing) -> a + (Just (x1, x2), Just (y1, y2)) -> Just (addD2s t1 x1 y1, addD2s t2 x2 y2) STEither t1 t2 -> case (a, b) of - (Left (), _) -> b - (_, Left ()) -> a - (Right (Left x), Right (Left y)) -> Right (Left (addD2s t1 x y)) - (Right (Right x), Right (Right y)) -> Right (Right (addD2s t2 x y)) + (Nothing, _) -> b + (_, Nothing) -> a + (Just (Left x), Just (Left y)) -> Just (Left (addD2s t1 x y)) + (Just (Right x), Just (Right y)) -> Just (Right (addD2s t2 x y)) _ -> error "Plus of inconsistent Eithers" STMaybe t -> case (a, b) of (Nothing, _) -> b @@ -233,16 +234,14 @@ addD2s typ a b = case typ of onehotD2 :: SNat i -> STy t -> Rep (AcIdx (D2 t) i) -> Rep (AcVal (D2 t) i) -> Rep (D2 t) onehotD2 SZ _ () v = v onehotD2 _ STNil _ _ = () -onehotD2 (SS _ ) (STPair _ _ ) (Left _ ) (Left _ ) = Left () -onehotD2 (SS SZ ) (STPair _ _ ) (Right () ) (Right val ) = Right val -onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Left idx)) (Right (Left val)) = Right (onehotD2 i t1 idx val, zeroD2 t2) -onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Right idx)) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val) -onehotD2 (SS _ ) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value" -onehotD2 (SS _ ) (STEither _ _ ) (Left _ ) (Left _ ) = Left () -onehotD2 (SS SZ ) (STEither _ _ ) (Right () ) (Right val ) = Right val -onehotD2 (SS (SS i)) (STEither t1 _ ) (Right (Left idx)) (Right (Left val)) = Right (Left (onehotD2 i t1 idx val)) -onehotD2 (SS (SS i)) (STEither _ t2) (Right (Right idx)) (Right (Right val)) = Right (Right (onehotD2 i t2 idx val)) -onehotD2 (SS _ ) (STEither _ _ ) _ _ = error "onehotD2: either: mismatched index and value" +onehotD2 (SS SZ ) (STPair _ _ ) () val = Just val +onehotD2 (SS (SS i)) (STPair t1 t2) (Left idx) (Left val) = Just (onehotD2 i t1 idx val, zeroD2 t2) +onehotD2 (SS (SS i)) (STPair t1 t2) (Right idx) (Right val) = Just (zeroD2 t1, onehotD2 i t2 idx val) +onehotD2 (SS _ ) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value" +onehotD2 (SS SZ ) (STEither _ _ ) () val = Just val +onehotD2 (SS (SS i)) (STEither t1 _ ) (Left idx) (Left val) = Just (Left (onehotD2 i t1 idx val)) +onehotD2 (SS (SS i)) (STEither _ t2) (Right idx) (Right val) = Just (Right (onehotD2 i t2 idx val)) +onehotD2 (SS _ ) (STEither _ _ ) _ _ = error "onehotD2: either: mismatched index and value" onehotD2 (SS i ) (STMaybe t) idx val = Just (onehotD2 i t idx val) onehotD2 (SS i ) (STArr n t) idx val = runIdentity $ onehotArray (d2 t) (\i' idx' v' -> Identity (onehotD2 i' t idx' v')) (Identity (zeroD2 t)) n (SS i) idx val @@ -251,6 +250,7 @@ onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator" withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) withAccum t _ initval f = AcM $ do + putStrLn $ "withAccum: " ++ show t accum <- newAcSparse t SZ () initval out <- case f accum of AcM m -> m val <- readAcSparse t accum @@ -324,7 +324,7 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n piindexConcat PIIxEnd ix = ix piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix) -newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) +newAcSparse :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) newAcSparse typ SZ () val = case typ of STNil -> return () STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) @@ -372,7 +372,7 @@ onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do go mk dep' dim' idx' val' $ \arr pish -> k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) -newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) +newAcDense :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) newAcDense typ SZ () val = case typ of STEither t1 t2 -> case val of Left x -> Left <$> newAcSparse t1 SZ () x diff --git a/test/Main.hs b/test/Main.hs index 7cb15d5..d18884e 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -24,7 +24,6 @@ import Hedgehog.Main import Array import AST import AST.Pretty -import CHAD (defaultConfig) import CHAD.Top import CHAD.Types import qualified Example @@ -238,7 +237,7 @@ term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ fst_ #q * #x + snd_ #q * fst_ #p tests :: IO Bool -tests = checkSequential $ Group "AD" +tests = checkParallel $ Group "AD" [("id", adTest $ fromNamed $ lambda #x $ body $ #x) ,("idx0", adTest $ fromNamed $ lambda #x $ body $ idx0 #x) |