diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-14 19:27:57 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-14 19:27:57 +0100 |
commit | b8c162ce9cb1faeec621b751fff9aff46e022417 (patch) | |
tree | 9c31700f34f9a1f1a67e0a73c880938130e87ee6 | |
parent | bb84f6930702a02ba982795e2bb95a64d61f672b (diff) |
Configuration for CHAD
-rw-r--r-- | bench/Main.hs | 3 | ||||
-rw-r--r-- | src/CHAD.hs | 65 | ||||
-rw-r--r-- | src/CHAD/Top.hs | 13 | ||||
-rw-r--r-- | src/Example.hs | 2 | ||||
-rw-r--r-- | test/Main.hs | 3 |
5 files changed, 76 insertions, 10 deletions
diff --git a/bench/Main.hs b/bench/Main.hs index 5bb81ac..32fbc8c 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -19,6 +19,7 @@ import GHC.Exts (withDict) import AST import Array +import qualified CHAD (defaultConfig) import CHAD.Top import CHAD.Types import Data @@ -34,7 +35,7 @@ gradCHAD :: KnownEnv env => SList Value env -> Double -> Ex env (TScal TF64) -> gradCHAD input ctg term = interpretOpen False input $ simplifyFix $ - ELet ext (EConst ext STF64 ctg) $ chad' knownEnv term + ELet ext (EConst ext STF64 ctg) $ chad' CHAD.defaultConfig knownEnv term instance KnownTy t => NFData (Value t) where rnf = \(Value x) -> go (knownTy @t) x diff --git a/src/CHAD.hs b/src/CHAD.hs index 45fcc82..b35836a 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -2,6 +2,7 @@ {-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ImplicitParams #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} @@ -23,6 +24,8 @@ module CHAD ( drev, freezeRet, + CHADConfig(..), + defaultConfig, Storage(..), Descr(..), Select, @@ -724,10 +727,27 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) +--------------------------------- CONFIGURATION -------------------------------- + +data CHADConfig = CHADConfig + { -- | D[let] will bind variables containing arrays in accumulator mode. + chcLetArrayAccum :: Bool + , -- | D[case] will bind variables containing arrays in accumulator mode. + chcCaseArrayAccum :: Bool + } + +defaultConfig :: CHADConfig +defaultConfig = CHADConfig + { chcLetArrayAccum = False + , chcCaseArrayAccum = False + } + + ---------------------------- THE CHAD TRANSFORMATION --------------------------- drev :: forall env sto t. - Descr env sto + (?config :: CHADConfig) + => Descr env sto -> Ex env t -> Ret env sto t drev des = \case EVar _ t i -> @@ -753,6 +773,38 @@ drev des = \case (subenvNone (select SMerge des)) (ENil ext) + ELet _ (rhs :: Ex _ a) body + | chcLetArrayAccum ?config && hasArrays (typeOf rhs) + , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 + <- drev des rhs + , Ret (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 + <- drev (des `DPush` (typeOf rhs, SAccum)) body + , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 + , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) + , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> + subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> + let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in + Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') + (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody) + (weakenExpr wbody0' body1) + subBoth + (ELet ext + (EWith (EZero (typeOf rhs)) $ + weakenExpr (autoWeak (#d (auto1 @(D2 t)) + &. #body (subList (bindingsBinds body0) subtapeBody) + &. #ac (auto1 @(TAccum (D2 a))) + &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) + &. #tl (d2ace (select SAccum des))) + (#d :++: #body :++: #ac :++: #tl) + (#ac :++: #d :++: (#body :++: #rhs) :++: #tl)) + body2) $ + ELet ext + (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ + weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ + plus_RHS_Body + (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) + (EFst ext (EVar ext bodyResType (IS IZ)))) + ELet _ rhs body | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs @@ -848,6 +900,8 @@ drev des = \case (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") (weakenExpr (WCopy (wSinks' @[_,_])) e2))) + ECase{} | chcCaseArrayAccum ?config -> error "chcCaseArrayAccum unsupported" + ECase _ e (a :: Ex _ t) b | STEither t1 t2 <- typeOf e , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des e @@ -1187,3 +1241,12 @@ drev des = \case (EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ)))) (EZero t)) $ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + + hasArrays :: STy t' -> Bool + hasArrays STNil = False + hasArrays (STPair a b) = hasArrays a || hasArrays b + hasArrays (STEither a b) = hasArrays a || hasArrays b + hasArrays (STMaybe t) = hasArrays t + hasArrays STArr{} = True + hasArrays STScal{} = False + hasArrays STAccum{} = error "Accumulators not allowed in source program" diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 9df5412..4b2a844 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ImplicitParams #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} @@ -41,13 +42,13 @@ d1eIdentity :: SList STy env -> D1E env :~: env d1eIdentity SNil = Refl d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl -chad :: SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) -chad env term +chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) +chad config env term | Refl <- mergeEnvNoAccum env , Refl <- mergeEnvOnlyMerge env - = freezeRet (mergeDescr env) (drev (mergeDescr env) term) + = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term) -chad' :: SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) -chad' env term +chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +chad' config env term | Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term) - = chad env term + = chad config env term diff --git a/src/Example.hs b/src/Example.hs index 390031e..94a6934 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -182,7 +182,7 @@ neuralGo = revderiv = simplifyN 20 $ ELet ext (EConst ext STF64 1.0) $ - chad knownEnv neural + chad defaultConfig knownEnv neural (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of (primal', (((((), Right dlay1_1'), Right dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1') _ -> undefined diff --git a/test/Main.hs b/test/Main.hs index d617228..7cb15d5 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -24,6 +24,7 @@ import Hedgehog.Main import Array import AST import AST.Pretty +import CHAD (defaultConfig) import CHAD.Top import CHAD.Types import qualified Example @@ -41,7 +42,7 @@ data SimplIters = SimplIters Int | SimplFix -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env))) gradientByCHAD = \simplIters env term input -> - let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' env term + let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term dterm | Dict <- envKnown env = case simplIters of SimplIters n -> simplifyN n dtermNonSimpl |