diff options
Diffstat (limited to 'bench')
-rw-r--r-- | bench/Main.hs | 54 |
1 files changed, 19 insertions, 35 deletions
diff --git a/bench/Main.hs b/bench/Main.hs index 5d2cb5a..358ba31 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -12,50 +12,30 @@ module Main where import Control.DeepSeq import Criterion.Main -import Data.Coerce import Data.Int (Int64) import Data.Kind (Constraint) import GHC.Exts (withDict) import AST +import AST.UnMonoid import Array import qualified CHAD (defaultConfig) import CHAD.Top import CHAD.Types +import Compile import Data import Example import Example.GMM import Example.Types -import Interpreter import Interpreter.Rep import Simplify -gradCHAD :: KnownEnv env => CHADConfig -> SList Value env -> Double -> Ex env (TScal TF64) -> (Double, Rep (Tup (D2E env))) -gradCHAD config input ctg term = - interpretOpen False input $ - simplifyFix $ - ELet ext (EConst ext STF64 ctg) $ chad' config knownEnv term - -instance KnownTy t => NFData (Value t) where - rnf = \(Value x) -> go (knownTy @t) x - where - go :: STy t' -> Rep t' -> () - go STNil () = () - go (STPair a b) (x, y) = go a x `seq` go b y - go (STEither a _) (Left x) = go a x - go (STEither _ b) (Right y) = go b y - go (STMaybe _) Nothing = () - go (STMaybe t) (Just x) = go t x - go (STArr (_ :: SNat n) (t :: STy t2)) arr = - withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr) - go (STScal t) x = case t of - STI32 -> rnf x - STI64 -> rnf x - STF32 -> rnf x - STF64 -> rnf x - STBool -> rnf x - go STAccum{} _ = error "Cannot rnf accumulators" +gradCHAD :: KnownEnv env => CHADConfig -> Ex env (TScal TF64) -> IO (SList Value env -> IO (Double, Rep (Tup (D2E env)))) +gradCHAD config term = + compile knownEnv $ + simplifyFix $ unMonoid $ simplifyFix $ + ELet ext (EConst ext STF64 1.0) $ chad' config knownEnv term type AllNFDataRep :: [Ty] -> Constraint type family AllNFDataRep env where @@ -115,12 +95,16 @@ accumConfig = chcSetAccum CHAD.defaultConfig main :: IO () main = defaultMain - [env (return makeNeuralInputs) $ \inputs -> - bench "neural" (nf (\(inp, ctg) -> gradCHAD CHAD.defaultConfig inp ctg neural) (inputs, 1.0)) - ,env (return makeNeuralInputs) $ \inputs -> - bench "neural-accum" (nf (\(inp, ctg) -> gradCHAD accumConfig inp ctg neural) (inputs, 1.0)) - ,env (return makeGMMInputs) $ \inputs -> - bench "gmm" (nf (\(inp, ctg) -> gradCHAD CHAD.defaultConfig inp ctg (gmmObjective False)) (inputs, 1.0)) - ,env (return makeGMMInputs) $ \inputs -> - bench "gmm-accum" (nf (\(inp, ctg) -> gradCHAD accumConfig inp ctg (gmmObjective False)) (inputs, 1.0)) + [env (return makeNeuralInputs) $ \inputs -> bgroup "neural" + [env (gradCHAD CHAD.defaultConfig neural) $ \fun -> + bench "default" (nfAppIO fun inputs) + ,env (gradCHAD accumConfig neural) $ \fun -> + bench "accum" (nfAppIO fun inputs) + ] + ,env (return makeGMMInputs) $ \inputs -> bgroup "gmm" + [env (gradCHAD CHAD.defaultConfig (gmmObjective False)) $ \fun -> + bench "default" (nfAppIO fun inputs) + ,env (gradCHAD accumConfig (gmmObjective False)) $ \fun -> + bench "accum" (nfAppIO fun inputs) + ] ] |