aboutsummaryrefslogtreecommitdiff
path: root/bench/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'bench/Main.hs')
-rw-r--r--bench/Main.hs54
1 files changed, 19 insertions, 35 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index 5d2cb5a..358ba31 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -12,50 +12,30 @@ module Main where
import Control.DeepSeq
import Criterion.Main
-import Data.Coerce
import Data.Int (Int64)
import Data.Kind (Constraint)
import GHC.Exts (withDict)
import AST
+import AST.UnMonoid
import Array
import qualified CHAD (defaultConfig)
import CHAD.Top
import CHAD.Types
+import Compile
import Data
import Example
import Example.GMM
import Example.Types
-import Interpreter
import Interpreter.Rep
import Simplify
-gradCHAD :: KnownEnv env => CHADConfig -> SList Value env -> Double -> Ex env (TScal TF64) -> (Double, Rep (Tup (D2E env)))
-gradCHAD config input ctg term =
- interpretOpen False input $
- simplifyFix $
- ELet ext (EConst ext STF64 ctg) $ chad' config knownEnv term
-
-instance KnownTy t => NFData (Value t) where
- rnf = \(Value x) -> go (knownTy @t) x
- where
- go :: STy t' -> Rep t' -> ()
- go STNil () = ()
- go (STPair a b) (x, y) = go a x `seq` go b y
- go (STEither a _) (Left x) = go a x
- go (STEither _ b) (Right y) = go b y
- go (STMaybe _) Nothing = ()
- go (STMaybe t) (Just x) = go t x
- go (STArr (_ :: SNat n) (t :: STy t2)) arr =
- withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr)
- go (STScal t) x = case t of
- STI32 -> rnf x
- STI64 -> rnf x
- STF32 -> rnf x
- STF64 -> rnf x
- STBool -> rnf x
- go STAccum{} _ = error "Cannot rnf accumulators"
+gradCHAD :: KnownEnv env => CHADConfig -> Ex env (TScal TF64) -> IO (SList Value env -> IO (Double, Rep (Tup (D2E env))))
+gradCHAD config term =
+ compile knownEnv $
+ simplifyFix $ unMonoid $ simplifyFix $
+ ELet ext (EConst ext STF64 1.0) $ chad' config knownEnv term
type AllNFDataRep :: [Ty] -> Constraint
type family AllNFDataRep env where
@@ -115,12 +95,16 @@ accumConfig = chcSetAccum CHAD.defaultConfig
main :: IO ()
main = defaultMain
- [env (return makeNeuralInputs) $ \inputs ->
- bench "neural" (nf (\(inp, ctg) -> gradCHAD CHAD.defaultConfig inp ctg neural) (inputs, 1.0))
- ,env (return makeNeuralInputs) $ \inputs ->
- bench "neural-accum" (nf (\(inp, ctg) -> gradCHAD accumConfig inp ctg neural) (inputs, 1.0))
- ,env (return makeGMMInputs) $ \inputs ->
- bench "gmm" (nf (\(inp, ctg) -> gradCHAD CHAD.defaultConfig inp ctg (gmmObjective False)) (inputs, 1.0))
- ,env (return makeGMMInputs) $ \inputs ->
- bench "gmm-accum" (nf (\(inp, ctg) -> gradCHAD accumConfig inp ctg (gmmObjective False)) (inputs, 1.0))
+ [env (return makeNeuralInputs) $ \inputs -> bgroup "neural"
+ [env (gradCHAD CHAD.defaultConfig neural) $ \fun ->
+ bench "default" (nfAppIO fun inputs)
+ ,env (gradCHAD accumConfig neural) $ \fun ->
+ bench "accum" (nfAppIO fun inputs)
+ ]
+ ,env (return makeGMMInputs) $ \inputs -> bgroup "gmm"
+ [env (gradCHAD CHAD.defaultConfig (gmmObjective False)) $ \fun ->
+ bench "default" (nfAppIO fun inputs)
+ ,env (gradCHAD accumConfig (gmmObjective False)) $ \fun ->
+ bench "accum" (nfAppIO fun inputs)
+ ]
]