diff options
| -rw-r--r-- | bench/Main.hs | 5 | ||||
| -rw-r--r-- | chad-fast.cabal | 3 | ||||
| -rw-r--r-- | src/CHAD.hs | 149 | ||||
| -rw-r--r-- | src/CHAD/Accum.hs | 36 | ||||
| -rw-r--r-- | src/CHAD/EnvDescr.hs | 75 | ||||
| -rw-r--r-- | src/CHAD/Heuristics.hs | 14 | ||||
| -rw-r--r-- | src/CHAD/Top.hs | 53 | ||||
| -rw-r--r-- | src/CHAD/Types.hs | 26 | ||||
| -rw-r--r-- | src/Example.hs | 2 | ||||
| -rw-r--r-- | src/Interpreter.hs | 42 | ||||
| -rw-r--r-- | test/Main.hs | 3 | 
11 files changed, 242 insertions, 166 deletions
| diff --git a/bench/Main.hs b/bench/Main.hs index 932da9d..5d2cb5a 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -20,7 +20,6 @@ import GHC.Exts (withDict)  import AST  import Array  import qualified CHAD (defaultConfig) -import CHAD (CHADConfig(..))  import CHAD.Top  import CHAD.Types  import Data @@ -112,9 +111,7 @@ makeGMMInputs =       SNil  accumConfig :: CHADConfig -accumConfig = CHADConfig -  { chcLetArrayAccum = True -  , chcCaseArrayAccum = True } +accumConfig = chcSetAccum CHAD.defaultConfig  main :: IO ()  main = defaultMain diff --git a/chad-fast.cabal b/chad-fast.cabal index 6c60a31..893c92f 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -20,7 +20,8 @@ library      AST.Weaken      AST.Weaken.Auto      CHAD -    CHAD.Heuristics +    CHAD.Accum +    CHAD.EnvDescr      CHAD.Top      CHAD.Types      -- Compile diff --git a/src/CHAD.hs b/src/CHAD.hs index 9f58f73..3da083b 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -32,18 +32,17 @@ module CHAD (  ) where  import Data.Functor.Const -import Data.Kind (Type)  import Data.Type.Bool (If)  import Data.Type.Equality (type (==))  import GHC.Stack (HasCallStack) -import GHC.TypeLits (Symbol)  import AST  import AST.Bindings  import AST.Count  import AST.Env  import AST.Weaken.Auto -import CHAD.Heuristics +import CHAD.Accum +import CHAD.EnvDescr  import CHAD.Types  import Data  import Lemmas @@ -197,66 +196,6 @@ reconstructBindings binds tape =       ,sreverse (stapeUnfoldings binds)) ----------------------- ENVIRONMENT DESCRIPTION AND STORAGE --------------------- - -type Storage :: Symbol -> Type -data Storage s where -  SAccum :: Storage "accum"  -- ^ in the monad state as a mutable accumulator -  SMerge :: Storage "merge"  -- ^ just return and merge -  SDiscr :: Storage "discr"  -- ^ we happen to know this is a discrete type and won't need any contributions -deriving instance Show (Storage s) - --- | Environment description -data Descr env sto where -  DTop :: Descr '[] '[] -  DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto) -deriving instance Show (Descr env sto) - -descrList :: Descr env sto -> SList STy env -descrList DTop = SNil -descrList (des `DPush` (t, _)) = t `SCons` descrList des - --- | This could have more precise typing on the output storage. -subDescr :: Descr env sto -> Subenv env env' -         -> (forall sto'. Descr env' sto' -                       -> Subenv (Select env sto "merge") (Select env' sto' "merge") -                       -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) -                       -> Subenv (D1E env) (D1E env') -                       -> r) -         -> r -subDescr DTop SETop k = k DTop SETop SETop SETop -subDescr (des `DPush` (t, sto)) (SEYes sub) k = -  subDescr des sub $ \des' submerge subaccum subd1e -> -    case sto of -      SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) -      SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) -      SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e) -subDescr (des `DPush` (_, sto)) (SENo sub) k = -  subDescr des sub $ \des' submerge subaccum subd1e -> -    case sto of -      SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) -      SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) -      SDiscr -> k des' submerge subaccum (SENo subd1e) - --- | Select only the types from the environment that have the specified storage -type family Select env sto s where -  Select '[] '[] _ = '[] -  Select (t : ts) (s : sto) s = t : Select ts sto s -  Select (_ : ts) (_ : sto) s = Select ts sto s - -select :: Storage s -> Descr env sto -> SList STy (Select env sto s) -select _ DTop = SNil -select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) -select s@SMerge (DPush des (_, SAccum)) = select s des -select s@SDiscr (DPush des (_, SAccum)) = select s des -select s@SAccum (DPush des (_, SMerge)) = select s des -select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) -select s@SDiscr (DPush des (_, SMerge)) = select s des -select s@SAccum (DPush des (_, SDiscr)) = select s des -select s@SMerge (DPush des (_, SDiscr)) = select s des -select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des) - -  ---------------------------------- DERIVATIVES ---------------------------------  d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) @@ -283,24 +222,24 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))  d2op :: SOp a t -> D2Op a t  d2op op = case op of -  OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d) +  OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d)    OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> -    EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) -                              (EOp ext (OMul t) (EPair ext (EFst ext e) d))) +    EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) +                         (EOp ext (OMul t) (EPair ext (EFst ext e) d)))    ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d -  OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) -  OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) -  OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) +  OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) +  OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) +  OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))    ONot -> Linear $ \_ -> ENil ext -  OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) -  OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) +  OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil) +  OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil)    OIf -> Linear $ \_ -> ENil ext    ORound64 -> Linear $ \_ -> EConst ext STF64 0.0    OToFl64 -> Linear $ \_ -> ENil ext    ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)    OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)    OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) -  OIDiv t -> integralD2 t $ Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) +  OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil)    where      d2opUnArrangeInt :: SScalTy a                       -> (D2s a ~ TScal a => D2Op (TScal a) t) @@ -316,11 +255,11 @@ d2op op = case op of                        -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)                        -> D2Op (TPair (TScal a) (TScal a)) t      d2opBinArrangeInt ty float = case ty of -      STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) -      STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) +      STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) +      STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil)        STF32 -> float        STF64 -> float -      STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) +      STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil)      floatingD2 :: ScalIsFloating a ~ True                 => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r @@ -332,13 +271,8 @@ d2op op = case op of      integralD2 STI32 k = k      integralD2 STI64 k = k -sD1eEnv :: Descr env sto -> SList STy (D1E env) -sD1eEnv DTop = SNil -sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d) - -d2ace :: SList STy env -> SList STy (D2AcE env) -d2ace SNil = SNil -d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts) +desD1E :: Descr env sto -> SList STy (D1E env) +desD1E = d1e . descrList  -- d1W :: env :> env' -> D1E env :> D1E env'  -- d1W WId = WId @@ -610,24 +544,6 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =      --   STScal{} -> False      --   STAccum{} -> error "An accumulator in merge storage?" -makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) -makeAccumulators SNil e = e -makeAccumulators (STArr n t `SCons` envpro) e = -  makeAccumulators envpro $ -    EWith (zero (STArr n t)) e -makeAccumulators (t `SCons` _) _ = error $ "makeAccumulators: Not only arrays in envpro: " ++ show t - -uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) -uninvertTup SNil _ e = EPair ext e (ENil ext) -uninvertTup (t `SCons` list) tcore e = -  ELet ext (uninvertTup list (STPair tcore t) e) $ -    let recT = STPair (STPair tcore t) (tTup list)  -- type of the RHS of that let binding -    in EPair ext -         (EFst ext (EFst ext (EVar ext recT IZ))) -         (EPair ext -            (ESnd ext (EVar ext recT IZ)) -            (ESnd ext (EFst ext (EVar ext recT IZ)))) -  ---------------------------- RETURN TRIPLE FROM CHAD --------------------------- @@ -701,7 +617,7 @@ freezeRet :: Descr env sto            -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))  freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =    let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 -      e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2 +      e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2    in letBinds e0' $         EPair ext           (weakenExpr wInsertD2Ac e1) @@ -709,7 +625,7 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =                                            &. #tape (subList (bindingsBinds e0) subtape)                                            &. #shbinds (bindingsBinds e0)                                            &. #d2ace (d2ace (select SAccum descr)) -                                          &. #tl (sD1eEnv descr)) +                                          &. #tl (desD1E descr))                                           (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)                                           (#shbinds :++: #d :++: #d2ace :++: #tl))                        e2') $ @@ -782,7 +698,7 @@ drev des = \case          subtape          (EPair ext a1 b1)          subBoth -        (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ) +        (EMaybe ext             (zeroTup (subList (select SMerge des) subBoth))             (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))                        (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $ @@ -790,7 +706,8 @@ drev des = \case                        (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $              plus_A_B               (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) -             (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))) +             (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)) +           (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ))    EFst _ e      | Ret e0 subtape e1 sub e2 <- drev des e @@ -799,7 +716,7 @@ drev des = \case          subtape          (EFst ext e1)          sub -        (ELet ext (EInr ext STNil (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $ +        (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $             weakenExpr (WCopy WSink) e2)    ESnd _ e @@ -809,7 +726,7 @@ drev des = \case          subtape          (ESnd ext e1)          sub -        (ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $ +        (ELet ext (EJust ext (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $             weakenExpr (WCopy WSink) e2)    ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) @@ -820,11 +737,12 @@ drev des = \case          subtape          (EInl ext (d1 t2) e1)          sub -        (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ) +        (EMaybe ext             (zeroTup (subList (select SMerge des) sub))             (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ)                (weakenExpr (WCopy (wSinks' @[_,_])) e2) -              (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr"))) +              (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")) +           (EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ))    EInr _ t1 e      | Ret e0 subtape e1 sub e2 <- drev des e -> @@ -832,11 +750,12 @@ drev des = \case          subtape          (EInr ext (d1 t1) e1)          sub -        (ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ) +        (EMaybe ext             (zeroTup (subList (select SMerge des) sub))             (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ)                (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") -              (weakenExpr (WCopy (wSinks' @[_,_])) e2))) +              (weakenExpr (WCopy (wSinks' @[_,_])) e2)) +           (EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ))    ECase _ e (a :: Ex _ t) b      | STEither t1 t2 <- typeOf e @@ -907,7 +826,7 @@ drev des = \case                        (EInr ext (d2 t1)                          (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $           ELet ext -           (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $ +           (ELet ext (EJust ext (ESnd ext (EVar ext tCaseRet IZ))) $                weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $           plus_AB_E             (EFst ext (EVar ext tCaseRet (IS IZ))) @@ -986,16 +905,16 @@ drev des = \case                           (EVar ext shty IZ)                           (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil)                                                                                &. #sh (shty `SCons` SNil) -                                                                              &. #d1env (sD1eEnv des) -                                                                              &. #d1env' (sD1eEnv usedDes)) +                                                                              &. #d1env (desD1E des) +                                                                              &. #d1env' (desD1E usedDes))                                                                               (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))                                                                               (#ix :++: #sh :++: #d1env))                                                                     e0)) $                              let w = autoWeak (#ix (shty `SCons` SNil)                                                     &. #sh (shty `SCons` SNil)                                                     &. #e0 (bindingsBinds e0) -                                                   &. #d1env (sD1eEnv des) -                                                   &. #d1env' (sD1eEnv usedDes)) +                                                   &. #d1env (desD1E des) +                                                   &. #d1env' (desD1E usedDes))                                                    (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))                                                    (#e0 :++: #ix :++: #sh :++: #d1env)                              in EPair ext (weakenExpr w e1) (collectexpr w))) diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs new file mode 100644 index 0000000..e26f781 --- /dev/null +++ b/src/CHAD/Accum.hs @@ -0,0 +1,36 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +module CHAD.Accum where + +import AST +import CHAD.Types +import Data + + + +hasArrays :: STy t' -> Bool +hasArrays STNil = False +hasArrays (STPair a b) = hasArrays a || hasArrays b +hasArrays (STEither a b) = hasArrays a || hasArrays b +hasArrays (STMaybe t) = hasArrays t +hasArrays STArr{} = True +hasArrays STScal{} = False +hasArrays STAccum{} = error "Accumulators not allowed in source program" + +makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) +makeAccumulators SNil e = e +makeAccumulators (t `SCons` envpro) e = +  makeAccumulators envpro $ +    EWith (EZero t) e + +uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) +uninvertTup SNil _ e = EPair ext e (ENil ext) +uninvertTup (t `SCons` list) tcore e = +  ELet ext (uninvertTup list (STPair tcore t) e) $ +    let recT = STPair (STPair tcore t) (tTup list)  -- type of the RHS of that let binding +    in EPair ext +         (EFst ext (EFst ext (EVar ext recT IZ))) +         (EPair ext +            (ESnd ext (EVar ext recT IZ)) +            (ESnd ext (EFst ext (EVar ext recT IZ)))) + diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs new file mode 100644 index 0000000..fcd91f7 --- /dev/null +++ b/src/CHAD/EnvDescr.hs @@ -0,0 +1,75 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.EnvDescr where + +import Data.Kind (Type) +import GHC.TypeLits (Symbol) + +import AST.Env +import AST.Types +import CHAD.Types +import Data + + +type Storage :: Symbol -> Type +data Storage s where +  SAccum :: Storage "accum"  -- ^ in the monad state as a mutable accumulator +  SMerge :: Storage "merge"  -- ^ just return and merge +  SDiscr :: Storage "discr"  -- ^ we happen to know this is a discrete type and won't need any contributions +deriving instance Show (Storage s) + +-- | Environment description +data Descr env sto where +  DTop :: Descr '[] '[] +  DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto) +deriving instance Show (Descr env sto) + +descrList :: Descr env sto -> SList STy env +descrList DTop = SNil +descrList (des `DPush` (t, _)) = t `SCons` descrList des + +-- | This could have more precise typing on the output storage. +subDescr :: Descr env sto -> Subenv env env' +         -> (forall sto'. Descr env' sto' +                       -> Subenv (Select env sto "merge") (Select env' sto' "merge") +                       -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) +                       -> Subenv (D1E env) (D1E env') +                       -> r) +         -> r +subDescr DTop SETop k = k DTop SETop SETop SETop +subDescr (des `DPush` (t, sto)) (SEYes sub) k = +  subDescr des sub $ \des' submerge subaccum subd1e -> +    case sto of +      SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) +      SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) +      SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e) +subDescr (des `DPush` (_, sto)) (SENo sub) k = +  subDescr des sub $ \des' submerge subaccum subd1e -> +    case sto of +      SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) +      SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) +      SDiscr -> k des' submerge subaccum (SENo subd1e) + +-- | Select only the types from the environment that have the specified storage +type family Select env sto s where +  Select '[] '[] _ = '[] +  Select (t : ts) (s : sto) s = t : Select ts sto s +  Select (_ : ts) (_ : sto) s = Select ts sto s + +select :: Storage s -> Descr env sto -> SList STy (Select env sto s) +select _ DTop = SNil +select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) +select s@SMerge (DPush des (_, SAccum)) = select s des +select s@SDiscr (DPush des (_, SAccum)) = select s des +select s@SAccum (DPush des (_, SMerge)) = select s des +select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) +select s@SDiscr (DPush des (_, SMerge)) = select s des +select s@SAccum (DPush des (_, SDiscr)) = select s des +select s@SMerge (DPush des (_, SDiscr)) = select s des +select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des) diff --git a/src/CHAD/Heuristics.hs b/src/CHAD/Heuristics.hs deleted file mode 100644 index 6ab8222..0000000 --- a/src/CHAD/Heuristics.hs +++ /dev/null @@ -1,14 +0,0 @@ -{-# LANGUAGE GADTs #-} -module CHAD.Heuristics where - -import AST - - -hasArrays :: STy t' -> Bool -hasArrays STNil = False -hasArrays (STPair a b) = hasArrays a || hasArrays b -hasArrays (STEither a b) = hasArrays a || hasArrays b -hasArrays (STMaybe t) = hasArrays t -hasArrays STArr{} = True -hasArrays STScal{} = False -hasArrays STAccum{} = error "Accumulators not allowed in source program" diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 4b2a844..12594f2 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -1,13 +1,20 @@  {-# LANGUAGE DataKinds #-} -{-# LANGUAGE LambdaCase #-}  {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedLabels #-}  {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  module CHAD.Top where  import AST +import AST.Weaken.Auto  import CHAD +import CHAD.Accum +import CHAD.EnvDescr  import CHAD.Types  import Data @@ -28,6 +35,12 @@ mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env  mergeEnvOnlyMerge SNil = Refl  mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl +accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r +accumDescr SNil k = k DTop +accumDescr (t `SCons` env) k = accumDescr env $ \des -> +  if hasArrays t then k (des `DPush` (t, SAccum)) +                 else k (des `DPush` (t, SMerge)) +  d1Identity :: STy t -> D1 t :~: t  d1Identity = \case    STNil -> Refl @@ -42,9 +55,43 @@ d1eIdentity :: SList STy env -> D1E env :~: env  d1eIdentity SNil = Refl  d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl +reassembleD2E :: Descr env sto +              -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) +              -> Ex env' (Tup (D2E env)) +reassembleD2E DTop _ = ENil ext +reassembleD2E (des `DPush` (_, SAccum)) e = +  ELet ext e $ +    EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ))) +                                            (ESnd ext (EVar ext (typeOf e) IZ)))) +              (ESnd ext (EFst ext (EVar ext (typeOf e) IZ))) +reassembleD2E (des `DPush` (_, SMerge)) e = +  ELet ext e $ +    EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ)) +                                            (EFst ext (ESnd ext (EVar ext (typeOf e) IZ))))) +              (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ))) +reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero t) +  chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) -chad config env term -  | Refl <- mergeEnvNoAccum env +chad config env (term :: Ex env t) +  | True <- chcArgArrayAccum config +  = let ?config = config +    in accumDescr env $ \descr -> +         let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr))) +             tvar = STPair t1 (tTup (d2e (select SAccum descr))) +         in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $ +                        makeAccumulators (select SAccum descr) $ +                          weakenExpr (autoWeak (#d (auto1 @(D2 t)) +                                                &. #acenv (d2ace (select SAccum descr)) +                                                &. #tl (d1e env)) +                                               (#d :++: #acenv :++: #tl) +                                               (#acenv :++: #d :++: #tl)) $ +                            freezeRet descr (drev descr term)) $ +              EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) +                        (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ)) +                                                        (ESnd ext (EFst ext (EVar ext tvar IZ))))) + +  | False <- chcArgArrayAccum config +  , Refl <- mergeEnvNoAccum env    , Refl <- mergeEnvOnlyMerge env    = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 130493d..6662cbf 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -17,8 +17,8 @@ type family D1 t where  type family D2 t where    D2 TNil = TNil -  D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b)) -  D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b)) +  D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) +  D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b))    D2 (TMaybe t) = TMaybe (D2 t)    D2 (TArr n t) = TArr n (D2 t)    D2 (TScal t) = D2s t @@ -51,10 +51,14 @@ d1 (STArr n t) = STArr n (d1 t)  d1 (STScal t) = STScal t  d1 STAccum{} = error "Accumulators not allowed in input program" +d1e :: SList STy env -> SList STy (D1E env) +d1e SNil = SNil +d1e (t `SCons` env) = d1 t `SCons` d1e env +  d2 :: STy t -> STy (D2 t)  d2 STNil = STNil -d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b)) -d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b)) +d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b)) +d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b))  d2 (STMaybe t) = STMaybe (d2 t)  d2 (STArr n t) = STArr n (d2 t)  d2 (STScal t) = case t of @@ -67,7 +71,11 @@ d2 STAccum{} = error "Accumulators not allowed in input program"  d2e :: SList STy env -> SList STy (D2E env)  d2e SNil = SNil -d2e (SCons t ts) = SCons (d2 t) (d2e ts) +d2e (t `SCons` ts) = d2 t `SCons` d2e ts + +d2ace :: SList STy env -> SList STy (D2AcE env) +d2ace SNil = SNil +d2ace (t `SCons` ts) = STAccum (d2 t) `SCons` d2ace ts  data CHADConfig = CHADConfig @@ -75,10 +83,18 @@ data CHADConfig = CHADConfig      chcLetArrayAccum :: Bool    , -- | D[case] will bind variables containing arrays in accumulator mode.      chcCaseArrayAccum :: Bool +  , -- | Introduce top-level arguments containing arrays in accumulator mode. +    chcArgArrayAccum :: Bool    }  defaultConfig :: CHADConfig  defaultConfig = CHADConfig    { chcLetArrayAccum = False    , chcCaseArrayAccum = False +  , chcArgArrayAccum = False    } + +chcSetAccum :: CHADConfig -> CHADConfig +chcSetAccum c = c { chcLetArrayAccum = True +                  , chcCaseArrayAccum = True +                  , chcArgArrayAccum = True } diff --git a/src/Example.hs b/src/Example.hs index 94a6934..795229c 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -184,7 +184,7 @@ neuralGo =            ELet ext (EConst ext STF64 1.0) $              chad defaultConfig knownEnv neural        (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of -        (primal', (((((), Right dlay1_1'), Right dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1') +        (primal', (((((), Just dlay1_1'), Just dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1')          _ -> undefined        (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0    in trace (formatter (ppExpr knownEnv revderiv)) $ diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 37d4a83..56ebf82 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -30,6 +30,7 @@ import System.IO (hPutStrLn, stderr)  import System.IO.Unsafe (unsafePerformIO)  import Debug.Trace +import GHC.Stack  import Array  import AST @@ -185,8 +186,8 @@ interpretOp op arg = case op of  zeroD2 :: STy t -> Rep (D2 t)  zeroD2 typ = case typ of    STNil -> () -  STPair _ _ -> Left () -  STEither _ _ -> Left () +  STPair _ _ -> Nothing +  STEither _ _ -> Nothing    STMaybe _ -> Nothing    STArr SZ t -> arrayUnit (zeroD2 t)    STArr n _ -> emptyArray n @@ -202,14 +203,14 @@ addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)  addD2s typ a b = case typ of    STNil -> ()    STPair t1 t2 -> case (a, b) of -    (Left (), _) -> b -    (_, Left ()) -> a -    (Right (x1, x2), Right (y1, y2)) -> Right (addD2s t1 x1 y1, addD2s t2 x2 y2) +    (Nothing, _) -> b +    (_, Nothing) -> a +    (Just (x1, x2), Just (y1, y2)) -> Just (addD2s t1 x1 y1, addD2s t2 x2 y2)    STEither t1 t2 -> case (a, b) of -    (Left (), _) -> b -    (_, Left ()) -> a -    (Right (Left x), Right (Left y)) -> Right (Left (addD2s t1 x y)) -    (Right (Right x), Right (Right y)) -> Right (Right (addD2s t2 x y)) +    (Nothing, _) -> b +    (_, Nothing) -> a +    (Just (Left x), Just (Left y)) -> Just (Left (addD2s t1 x y)) +    (Just (Right x), Just (Right y)) -> Just (Right (addD2s t2 x y))      _ -> error "Plus of inconsistent Eithers"    STMaybe t -> case (a, b) of      (Nothing, _) -> b @@ -233,16 +234,14 @@ addD2s typ a b = case typ of  onehotD2 :: SNat i -> STy t -> Rep (AcIdx (D2 t) i) -> Rep (AcVal (D2 t) i) -> Rep (D2 t)  onehotD2 SZ _ () v = v  onehotD2 _ STNil _ _ = () -onehotD2 (SS _     ) (STPair _  _ ) (Left  _          ) (Left  _          ) = Left () -onehotD2 (SS SZ    ) (STPair _  _ ) (Right ()         ) (Right val        ) = Right val -onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Left  idx)) (Right (Left  val)) = Right (onehotD2 i t1 idx val, zeroD2 t2) -onehotD2 (SS (SS i)) (STPair t1 t2) (Right (Right idx)) (Right (Right val)) = Right (zeroD2 t1, onehotD2 i t2 idx val) -onehotD2 (SS _     ) (STPair _  _ ) _                   _                   = error "onehotD2: pair: mismatched index and value" -onehotD2 (SS _     ) (STEither _  _ ) (Left  _          ) (Left  _          ) = Left () -onehotD2 (SS SZ    ) (STEither _  _ ) (Right ()         ) (Right val        ) = Right val -onehotD2 (SS (SS i)) (STEither t1 _ ) (Right (Left  idx)) (Right (Left  val)) = Right (Left (onehotD2 i t1 idx val)) -onehotD2 (SS (SS i)) (STEither _  t2) (Right (Right idx)) (Right (Right val)) = Right (Right (onehotD2 i t2 idx val)) -onehotD2 (SS _     ) (STEither _  _ ) _                   _                   = error "onehotD2: either: mismatched index and value" +onehotD2 (SS SZ    ) (STPair _  _ ) ()          val         = Just val +onehotD2 (SS (SS i)) (STPair t1 t2) (Left  idx) (Left  val) = Just (onehotD2 i t1 idx val, zeroD2 t2) +onehotD2 (SS (SS i)) (STPair t1 t2) (Right idx) (Right val) = Just (zeroD2 t1, onehotD2 i t2 idx val) +onehotD2 (SS _     ) (STPair _  _ ) _           _           = error "onehotD2: pair: mismatched index and value" +onehotD2 (SS SZ    ) (STEither _  _ ) ()          val         = Just val +onehotD2 (SS (SS i)) (STEither t1 _ ) (Left  idx) (Left  val) = Just (Left (onehotD2 i t1 idx val)) +onehotD2 (SS (SS i)) (STEither _  t2) (Right idx) (Right val) = Just (Right (onehotD2 i t2 idx val)) +onehotD2 (SS _     ) (STEither _  _ ) _           _           = error "onehotD2: either: mismatched index and value"  onehotD2 (SS i     ) (STMaybe t) idx val = Just (onehotD2 i t idx val)  onehotD2 (SS i     ) (STArr n t) idx val = runIdentity $    onehotArray (d2 t) (\i' idx' v' -> Identity (onehotD2 i' t idx' v')) (Identity (zeroD2 t)) n (SS i) idx val @@ -251,6 +250,7 @@ onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator"  withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)  withAccum t _ initval f = AcM $ do +  putStrLn $ "withAccum: " ++ show t    accum <- newAcSparse t SZ () initval    out <- case f accum of AcM m -> m    val <- readAcSparse t accum @@ -324,7 +324,7 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n  piindexConcat PIIxEnd ix = ix  piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix) -newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) +newAcSparse :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)  newAcSparse typ SZ () val = case typ of    STNil -> return ()    STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) @@ -372,7 +372,7 @@ onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do        go mk dep' dim' idx' val' $ \arr pish ->          k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) -newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) +newAcDense :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)  newAcDense typ SZ () val = case typ of    STEither t1 t2 -> case val of      Left x -> Left <$> newAcSparse t1 SZ () x diff --git a/test/Main.hs b/test/Main.hs index 7cb15d5..d18884e 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -24,7 +24,6 @@ import Hedgehog.Main  import Array  import AST  import AST.Pretty -import CHAD (defaultConfig)  import CHAD.Top  import CHAD.Types  import qualified Example @@ -238,7 +237,7 @@ term_pairs = fromNamed $ lambda #x $ lambda #y $ body $      fst_ #q * #x + snd_ #q * fst_ #p  tests :: IO Bool -tests = checkSequential $ Group "AD" +tests = checkParallel $ Group "AD"    [("id", adTest $ fromNamed $ lambda #x $ body $ #x)    ,("idx0", adTest $ fromNamed $ lambda #x $ body $ idx0 #x) | 
