summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-23 12:10:53 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-23 12:10:53 +0100
commit84f6845803511e24770fbf1dffc6a9a007371edf (patch)
treec5ad97e68ceb6a39149aed6ee0aa0bf8102d3d60
parente8663e189c41637d348ce100cdab40e8d19ed62c (diff)
Benchmark with accum-mode bound variables
-rw-r--r--bench/Main.hs20
1 files changed, 15 insertions, 5 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index 32fbc8c..932da9d 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -20,6 +20,7 @@ import GHC.Exts (withDict)
import AST
import Array
import qualified CHAD (defaultConfig)
+import CHAD (CHADConfig(..))
import CHAD.Top
import CHAD.Types
import Data
@@ -31,11 +32,11 @@ import Interpreter.Rep
import Simplify
-gradCHAD :: KnownEnv env => SList Value env -> Double -> Ex env (TScal TF64) -> (Double, Rep (Tup (D2E env)))
-gradCHAD input ctg term =
+gradCHAD :: KnownEnv env => CHADConfig -> SList Value env -> Double -> Ex env (TScal TF64) -> (Double, Rep (Tup (D2E env)))
+gradCHAD config input ctg term =
interpretOpen False input $
simplifyFix $
- ELet ext (EConst ext STF64 ctg) $ chad' CHAD.defaultConfig knownEnv term
+ ELet ext (EConst ext STF64 ctg) $ chad' config knownEnv term
instance KnownTy t => NFData (Value t) where
rnf = \(Value x) -> go (knownTy @t) x
@@ -110,10 +111,19 @@ makeGMMInputs =
Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
SNil
+accumConfig :: CHADConfig
+accumConfig = CHADConfig
+ { chcLetArrayAccum = True
+ , chcCaseArrayAccum = True }
+
main :: IO ()
main = defaultMain
[env (return makeNeuralInputs) $ \inputs ->
- bench "neural" (nf (\(inp, ctg) -> gradCHAD inp ctg neural) (inputs, 1.0))
+ bench "neural" (nf (\(inp, ctg) -> gradCHAD CHAD.defaultConfig inp ctg neural) (inputs, 1.0))
+ ,env (return makeNeuralInputs) $ \inputs ->
+ bench "neural-accum" (nf (\(inp, ctg) -> gradCHAD accumConfig inp ctg neural) (inputs, 1.0))
+ ,env (return makeGMMInputs) $ \inputs ->
+ bench "gmm" (nf (\(inp, ctg) -> gradCHAD CHAD.defaultConfig inp ctg (gmmObjective False)) (inputs, 1.0))
,env (return makeGMMInputs) $ \inputs ->
- bench "gmm" (nf (\(inp, ctg) -> gradCHAD inp ctg (gmmObjective False)) (inputs, 1.0))
+ bench "gmm-accum" (nf (\(inp, ctg) -> gradCHAD accumConfig inp ctg (gmmObjective False)) (inputs, 1.0))
]