diff options
Diffstat (limited to 'bench/Main.hs')
| -rw-r--r-- | bench/Main.hs | 92 |
1 files changed, 40 insertions, 52 deletions
diff --git a/bench/Main.hs b/bench/Main.hs index af83ef7..1e8f6f3 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -1,62 +1,49 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE FlexibleInstances #-} {-# OPTIONS -Wno-orphans #-} module Main where import Control.DeepSeq import Criterion.Main -import Data.Coerce import Data.Int (Int64) import Data.Kind (Constraint) import GHC.Exts (withDict) -import AST -import AST.UnMonoid -import Array -import qualified CHAD (defaultConfig) -import CHAD.Top -import CHAD.Types -import Compile -import Data -import Example -import Example.GMM -import Example.Types -import Interpreter.Rep -import Simplify +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.UnMonoid +import CHAD.Array +import CHAD.Compile +import CHAD.Data +import CHAD.Drev qualified as CHAD (defaultConfig) +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.Example +import CHAD.Example.GMM +import CHAD.Example.Types +import CHAD.Interpreter.Rep +import CHAD.Simplify gradCHAD :: KnownEnv env => CHADConfig -> Ex env (TScal TF64) -> IO (SList Value env -> IO (Double, Rep (Tup (D2E env)))) gradCHAD config term = - compile knownEnv $ - simplifyFix $ unMonoid $ simplifyFix $ - ELet ext (EConst ext STF64 1.0) $ chad' config knownEnv term - -instance KnownTy t => NFData (Value t) where - rnf = \(Value x) -> go (knownTy @t) x - where - go :: STy t' -> Rep t' -> () - go STNil () = () - go (STPair a b) (x, y) = go a x `seq` go b y - go (STEither a _) (Left x) = go a x - go (STEither _ b) (Right y) = go b y - go (STMaybe _) Nothing = () - go (STMaybe t) (Just x) = go t x - go (STArr (_ :: SNat n) (t :: STy t2)) arr = - withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr) - go (STScal t) x = case t of - STI32 -> rnf x - STI64 -> rnf x - STF32 -> rnf x - STF64 -> rnf x - STBool -> rnf x - go STAccum{} _ = error "Cannot rnf accumulators" + compileStderr knownEnv $ + simplifyFix $ pruneExpr knownEnv $ + simplifyFix $ unMonoid $ + simplifyFix $ + ELet ext (EConst ext STF64 1.0) $ + chad' config knownEnv $ + simplifyFix term type AllNFDataRep :: [Ty] -> Constraint type family AllNFDataRep env where @@ -114,18 +101,19 @@ makeGMMInputs = accumConfig :: CHADConfig accumConfig = chcSetAccum CHAD.defaultConfig -main :: IO () -main = defaultMain - [env (return makeNeuralInputs) $ \inputs -> bgroup "neural" - [env (gradCHAD CHAD.defaultConfig neural) $ \fun -> +bgroupDefaultAccum :: (KnownEnv env, NFData (Rep (Tup (D2E env)))) + => String -> Ex env R -> SList Value env -> Benchmark +bgroupDefaultAccum name term inputs = + bgroup name + [env (gradCHAD CHAD.defaultConfig term) $ \fun -> bench "default" (nfAppIO fun inputs) - ,env (gradCHAD accumConfig neural) $ \fun -> - bench "accum" (nfAppIO fun inputs) - ] - ,env (return makeGMMInputs) $ \inputs -> bgroup "gmm" - [env (gradCHAD CHAD.defaultConfig (gmmObjective False)) $ \fun -> - bench "default" (nfAppIO fun inputs) - ,env (gradCHAD accumConfig (gmmObjective False)) $ \fun -> + ,env (gradCHAD accumConfig term) $ \fun -> bench "accum" (nfAppIO fun inputs) ] + +main :: IO () +main = defaultMain + [env (return makeNeuralInputs) $ bgroupDefaultAccum "neural" neural + ,env (return makeGMMInputs) $ bgroupDefaultAccum "gmm" (gmmObjective False) + ,bgroupDefaultAccum "uniform-free" exUniformFree (Value 42.0 `SCons` Value 1000_000 `SCons` SNil) ] |
