summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-14 19:27:57 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-14 19:27:57 +0100
commitb8c162ce9cb1faeec621b751fff9aff46e022417 (patch)
tree9c31700f34f9a1f1a67e0a73c880938130e87ee6
parentbb84f6930702a02ba982795e2bb95a64d61f672b (diff)
Configuration for CHAD
-rw-r--r--bench/Main.hs3
-rw-r--r--src/CHAD.hs65
-rw-r--r--src/CHAD/Top.hs13
-rw-r--r--src/Example.hs2
-rw-r--r--test/Main.hs3
5 files changed, 76 insertions, 10 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index 5bb81ac..32fbc8c 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -19,6 +19,7 @@ import GHC.Exts (withDict)
import AST
import Array
+import qualified CHAD (defaultConfig)
import CHAD.Top
import CHAD.Types
import Data
@@ -34,7 +35,7 @@ gradCHAD :: KnownEnv env => SList Value env -> Double -> Ex env (TScal TF64) ->
gradCHAD input ctg term =
interpretOpen False input $
simplifyFix $
- ELet ext (EConst ext STF64 ctg) $ chad' knownEnv term
+ ELet ext (EConst ext STF64 ctg) $ chad' CHAD.defaultConfig knownEnv term
instance KnownTy t => NFData (Value t) where
rnf = \(Value x) -> go (knownTy @t) x
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 45fcc82..b35836a 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -2,6 +2,7 @@
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
@@ -23,6 +24,8 @@
module CHAD (
drev,
freezeRet,
+ CHADConfig(..),
+ defaultConfig,
Storage(..),
Descr(..),
Select,
@@ -724,10 +727,27 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))
+--------------------------------- CONFIGURATION --------------------------------
+
+data CHADConfig = CHADConfig
+ { -- | D[let] will bind variables containing arrays in accumulator mode.
+ chcLetArrayAccum :: Bool
+ , -- | D[case] will bind variables containing arrays in accumulator mode.
+ chcCaseArrayAccum :: Bool
+ }
+
+defaultConfig :: CHADConfig
+defaultConfig = CHADConfig
+ { chcLetArrayAccum = False
+ , chcCaseArrayAccum = False
+ }
+
+
---------------------------- THE CHAD TRANSFORMATION ---------------------------
drev :: forall env sto t.
- Descr env sto
+ (?config :: CHADConfig)
+ => Descr env sto
-> Ex env t -> Ret env sto t
drev des = \case
EVar _ t i ->
@@ -753,6 +773,38 @@ drev des = \case
(subenvNone (select SMerge des))
(ENil ext)
+ ELet _ (rhs :: Ex _ a) body
+ | chcLetArrayAccum ?config && hasArrays (typeOf rhs)
+ , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2
+ <- drev des rhs
+ , Ret (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2
+ <- drev (des `DPush` (typeOf rhs, SAccum)) body
+ , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
+ , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
+ , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) ->
+ subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body ->
+ let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in
+ Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
+ (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody)
+ (weakenExpr wbody0' body1)
+ subBoth
+ (ELet ext
+ (EWith (EZero (typeOf rhs)) $
+ weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #body (subList (bindingsBinds body0) subtapeBody)
+ &. #ac (auto1 @(TAccum (D2 a)))
+ &. #rhs (subList (bindingsBinds rhs0) subtapeRHS)
+ &. #tl (d2ace (select SAccum des)))
+ (#d :++: #body :++: #ac :++: #tl)
+ (#ac :++: #d :++: (#body :++: #rhs) :++: #tl))
+ body2) $
+ ELet ext
+ (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
+ weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $
+ plus_RHS_Body
+ (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
+ (EFst ext (EVar ext bodyResType (IS IZ))))
+
ELet _ rhs body
| Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2
<- drev des rhs
@@ -848,6 +900,8 @@ drev des = \case
(EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
(weakenExpr (WCopy (wSinks' @[_,_])) e2)))
+ ECase{} | chcCaseArrayAccum ?config -> error "chcCaseArrayAccum unsupported"
+
ECase _ e (a :: Ex _ t) b
| STEither t1 t2 <- typeOf e
, Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des e
@@ -1187,3 +1241,12 @@ drev des = \case
(EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ))))
(EZero t)) $
weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+
+ hasArrays :: STy t' -> Bool
+ hasArrays STNil = False
+ hasArrays (STPair a b) = hasArrays a || hasArrays b
+ hasArrays (STEither a b) = hasArrays a || hasArrays b
+ hasArrays (STMaybe t) = hasArrays t
+ hasArrays STArr{} = True
+ hasArrays STScal{} = False
+ hasArrays STAccum{} = error "Accumulators not allowed in source program"
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index 9df5412..4b2a844 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
@@ -41,13 +42,13 @@ d1eIdentity :: SList STy env -> D1E env :~: env
d1eIdentity SNil = Refl
d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
-chad :: SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
-chad env term
+chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
+chad config env term
| Refl <- mergeEnvNoAccum env
, Refl <- mergeEnvOnlyMerge env
- = freezeRet (mergeDescr env) (drev (mergeDescr env) term)
+ = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term)
-chad' :: SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
-chad' env term
+chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
+chad' config env term
| Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term)
- = chad env term
+ = chad config env term
diff --git a/src/Example.hs b/src/Example.hs
index 390031e..94a6934 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -182,7 +182,7 @@ neuralGo =
revderiv =
simplifyN 20 $
ELet ext (EConst ext STF64 1.0) $
- chad knownEnv neural
+ chad defaultConfig knownEnv neural
(primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of
(primal', (((((), Right dlay1_1'), Right dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1')
_ -> undefined
diff --git a/test/Main.hs b/test/Main.hs
index d617228..7cb15d5 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -24,6 +24,7 @@ import Hedgehog.Main
import Array
import AST
import AST.Pretty
+import CHAD (defaultConfig)
import CHAD.Top
import CHAD.Types
import qualified Example
@@ -41,7 +42,7 @@ data SimplIters = SimplIters Int | SimplFix
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD = \simplIters env term input ->
- let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' env term
+ let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
dterm | Dict <- envKnown env
= case simplIters of
SimplIters n -> simplifyN n dtermNonSimpl