aboutsummaryrefslogtreecommitdiff
path: root/bench/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'bench/Main.hs')
-rw-r--r--bench/Main.hs33
1 files changed, 18 insertions, 15 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index 34e8bae..ec9264b 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -1,11 +1,13 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE NumericUnderscores #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE FlexibleInstances #-}
{-# OPTIONS -Wno-orphans #-}
module Main where
@@ -98,18 +100,19 @@ makeGMMInputs =
accumConfig :: CHADConfig
accumConfig = chcSetAccum CHAD.defaultConfig
-main :: IO ()
-main = defaultMain
- [env (return makeNeuralInputs) $ \inputs -> bgroup "neural"
- [env (gradCHAD CHAD.defaultConfig neural) $ \fun ->
- bench "default" (nfAppIO fun inputs)
- ,env (gradCHAD accumConfig neural) $ \fun ->
- bench "accum" (nfAppIO fun inputs)
- ]
- ,env (return makeGMMInputs) $ \inputs -> bgroup "gmm"
- [env (gradCHAD CHAD.defaultConfig (gmmObjective False)) $ \fun ->
+bgroupDefaultAccum :: (KnownEnv env, NFData (Rep (Tup (D2E env))))
+ => String -> Ex env R -> SList Value env -> Benchmark
+bgroupDefaultAccum name term inputs =
+ bgroup name
+ [env (gradCHAD CHAD.defaultConfig term) $ \fun ->
bench "default" (nfAppIO fun inputs)
- ,env (gradCHAD accumConfig (gmmObjective False)) $ \fun ->
+ ,env (gradCHAD accumConfig term) $ \fun ->
bench "accum" (nfAppIO fun inputs)
]
+
+main :: IO ()
+main = defaultMain
+ [env (return makeNeuralInputs) $ bgroupDefaultAccum "neural" neural
+ ,env (return makeGMMInputs) $ bgroupDefaultAccum "gmm" (gmmObjective False)
+ ,bgroupDefaultAccum "uniform-free" exUniformFree (Value 42.0 `SCons` Value 1000_000 `SCons` SNil)
]