aboutsummaryrefslogtreecommitdiff
path: root/bench
diff options
context:
space:
mode:
Diffstat (limited to 'bench')
-rw-r--r--bench/Main.hs92
1 files changed, 40 insertions, 52 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index af83ef7..1e8f6f3 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -1,62 +1,49 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NumericUnderscores #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE FlexibleInstances #-}
{-# OPTIONS -Wno-orphans #-}
module Main where
import Control.DeepSeq
import Criterion.Main
-import Data.Coerce
import Data.Int (Int64)
import Data.Kind (Constraint)
import GHC.Exts (withDict)
-import AST
-import AST.UnMonoid
-import Array
-import qualified CHAD (defaultConfig)
-import CHAD.Top
-import CHAD.Types
-import Compile
-import Data
-import Example
-import Example.GMM
-import Example.Types
-import Interpreter.Rep
-import Simplify
+import CHAD.AST
+import CHAD.AST.Count
+import CHAD.AST.UnMonoid
+import CHAD.Array
+import CHAD.Compile
+import CHAD.Data
+import CHAD.Drev qualified as CHAD (defaultConfig)
+import CHAD.Drev.Top
+import CHAD.Drev.Types
+import CHAD.Example
+import CHAD.Example.GMM
+import CHAD.Example.Types
+import CHAD.Interpreter.Rep
+import CHAD.Simplify
gradCHAD :: KnownEnv env => CHADConfig -> Ex env (TScal TF64) -> IO (SList Value env -> IO (Double, Rep (Tup (D2E env))))
gradCHAD config term =
- compile knownEnv $
- simplifyFix $ unMonoid $ simplifyFix $
- ELet ext (EConst ext STF64 1.0) $ chad' config knownEnv term
-
-instance KnownTy t => NFData (Value t) where
- rnf = \(Value x) -> go (knownTy @t) x
- where
- go :: STy t' -> Rep t' -> ()
- go STNil () = ()
- go (STPair a b) (x, y) = go a x `seq` go b y
- go (STEither a _) (Left x) = go a x
- go (STEither _ b) (Right y) = go b y
- go (STMaybe _) Nothing = ()
- go (STMaybe t) (Just x) = go t x
- go (STArr (_ :: SNat n) (t :: STy t2)) arr =
- withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr)
- go (STScal t) x = case t of
- STI32 -> rnf x
- STI64 -> rnf x
- STF32 -> rnf x
- STF64 -> rnf x
- STBool -> rnf x
- go STAccum{} _ = error "Cannot rnf accumulators"
+ compileStderr knownEnv $
+ simplifyFix $ pruneExpr knownEnv $
+ simplifyFix $ unMonoid $
+ simplifyFix $
+ ELet ext (EConst ext STF64 1.0) $
+ chad' config knownEnv $
+ simplifyFix term
type AllNFDataRep :: [Ty] -> Constraint
type family AllNFDataRep env where
@@ -114,18 +101,19 @@ makeGMMInputs =
accumConfig :: CHADConfig
accumConfig = chcSetAccum CHAD.defaultConfig
-main :: IO ()
-main = defaultMain
- [env (return makeNeuralInputs) $ \inputs -> bgroup "neural"
- [env (gradCHAD CHAD.defaultConfig neural) $ \fun ->
+bgroupDefaultAccum :: (KnownEnv env, NFData (Rep (Tup (D2E env))))
+ => String -> Ex env R -> SList Value env -> Benchmark
+bgroupDefaultAccum name term inputs =
+ bgroup name
+ [env (gradCHAD CHAD.defaultConfig term) $ \fun ->
bench "default" (nfAppIO fun inputs)
- ,env (gradCHAD accumConfig neural) $ \fun ->
- bench "accum" (nfAppIO fun inputs)
- ]
- ,env (return makeGMMInputs) $ \inputs -> bgroup "gmm"
- [env (gradCHAD CHAD.defaultConfig (gmmObjective False)) $ \fun ->
- bench "default" (nfAppIO fun inputs)
- ,env (gradCHAD accumConfig (gmmObjective False)) $ \fun ->
+ ,env (gradCHAD accumConfig term) $ \fun ->
bench "accum" (nfAppIO fun inputs)
]
+
+main :: IO ()
+main = defaultMain
+ [env (return makeNeuralInputs) $ bgroupDefaultAccum "neural" neural
+ ,env (return makeGMMInputs) $ bgroupDefaultAccum "gmm" (gmmObjective False)
+ ,bgroupDefaultAccum "uniform-free" exUniformFree (Value 42.0 `SCons` Value 1000_000 `SCons` SNil)
]