aboutsummaryrefslogtreecommitdiff
path: root/bench/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'bench/Main.hs')
-rw-r--r--bench/Main.hs71
1 files changed, 40 insertions, 31 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index 358ba31..1e8f6f3 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -1,11 +1,14 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NumericUnderscores #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE FlexibleInstances #-}
{-# OPTIONS -Wno-orphans #-}
module Main where
@@ -16,26 +19,31 @@ import Data.Int (Int64)
import Data.Kind (Constraint)
import GHC.Exts (withDict)
-import AST
-import AST.UnMonoid
-import Array
-import qualified CHAD (defaultConfig)
-import CHAD.Top
-import CHAD.Types
-import Compile
-import Data
-import Example
-import Example.GMM
-import Example.Types
-import Interpreter.Rep
-import Simplify
+import CHAD.AST
+import CHAD.AST.Count
+import CHAD.AST.UnMonoid
+import CHAD.Array
+import CHAD.Compile
+import CHAD.Data
+import CHAD.Drev qualified as CHAD (defaultConfig)
+import CHAD.Drev.Top
+import CHAD.Drev.Types
+import CHAD.Example
+import CHAD.Example.GMM
+import CHAD.Example.Types
+import CHAD.Interpreter.Rep
+import CHAD.Simplify
gradCHAD :: KnownEnv env => CHADConfig -> Ex env (TScal TF64) -> IO (SList Value env -> IO (Double, Rep (Tup (D2E env))))
gradCHAD config term =
- compile knownEnv $
- simplifyFix $ unMonoid $ simplifyFix $
- ELet ext (EConst ext STF64 1.0) $ chad' config knownEnv term
+ compileStderr knownEnv $
+ simplifyFix $ pruneExpr knownEnv $
+ simplifyFix $ unMonoid $
+ simplifyFix $
+ ELet ext (EConst ext STF64 1.0) $
+ chad' config knownEnv $
+ simplifyFix term
type AllNFDataRep :: [Ty] -> Constraint
type family AllNFDataRep env where
@@ -93,18 +101,19 @@ makeGMMInputs =
accumConfig :: CHADConfig
accumConfig = chcSetAccum CHAD.defaultConfig
-main :: IO ()
-main = defaultMain
- [env (return makeNeuralInputs) $ \inputs -> bgroup "neural"
- [env (gradCHAD CHAD.defaultConfig neural) $ \fun ->
- bench "default" (nfAppIO fun inputs)
- ,env (gradCHAD accumConfig neural) $ \fun ->
- bench "accum" (nfAppIO fun inputs)
- ]
- ,env (return makeGMMInputs) $ \inputs -> bgroup "gmm"
- [env (gradCHAD CHAD.defaultConfig (gmmObjective False)) $ \fun ->
+bgroupDefaultAccum :: (KnownEnv env, NFData (Rep (Tup (D2E env))))
+ => String -> Ex env R -> SList Value env -> Benchmark
+bgroupDefaultAccum name term inputs =
+ bgroup name
+ [env (gradCHAD CHAD.defaultConfig term) $ \fun ->
bench "default" (nfAppIO fun inputs)
- ,env (gradCHAD accumConfig (gmmObjective False)) $ \fun ->
+ ,env (gradCHAD accumConfig term) $ \fun ->
bench "accum" (nfAppIO fun inputs)
]
+
+main :: IO ()
+main = defaultMain
+ [env (return makeNeuralInputs) $ bgroupDefaultAccum "neural" neural
+ ,env (return makeGMMInputs) $ bgroupDefaultAccum "gmm" (gmmObjective False)
+ ,bgroupDefaultAccum "uniform-free" exUniformFree (Value 42.0 `SCons` Value 1000_000 `SCons` SNil)
]