diff options
Diffstat (limited to 'bench/Main.hs')
-rw-r--r-- | bench/Main.hs | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/bench/Main.hs b/bench/Main.hs new file mode 100644 index 0000000..c62b0f2 --- /dev/null +++ b/bench/Main.hs @@ -0,0 +1,89 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE FlexibleInstances #-} + +{-# OPTIONS -Wno-orphans #-} +module Main where + +import Control.DeepSeq +import Criterion.Main +import Data.Coerce +import Data.Kind (Constraint) +import GHC.Exts (withDict) + +import AST +import Array +import CHAD.Top +import CHAD.Types +import Data +import Example +import Interpreter +import Interpreter.Rep +import Simplify + + +gradCHAD :: KnownEnv env => SList Value env -> Double -> Ex env (TScal TF64) -> (Double, Rep (Tup (D2E env))) +gradCHAD input ctg term = + interpretOpen False input $ + simplifyFix $ + ELet ext (EConst ext STF64 ctg) $ chad' knownEnv term + +instance KnownTy t => NFData (Value t) where + rnf = \(Value x) -> go (knownTy @t) x + where + go :: STy t' -> Rep t' -> () + go STNil () = () + go (STPair a b) (x, y) = go a x `seq` go b y + go (STEither a _) (Left x) = go a x + go (STEither _ b) (Right y) = go b y + go (STMaybe _) Nothing = () + go (STMaybe t) (Just x) = go t x + go (STArr (_ :: SNat n) (t :: STy t2)) arr = + withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr) + go (STScal t) x = case t of + STI32 -> rnf x + STI64 -> rnf x + STF32 -> rnf x + STF64 -> rnf x + STBool -> rnf x + go STAccum{} _ = error "Cannot rnf accumulators" + +type AllNFDataRep :: [Ty] -> Constraint +type family AllNFDataRep env where + AllNFDataRep '[] = () + AllNFDataRep (t : env) = (NFData (Rep t), AllNFDataRep env) + +instance (KnownEnv env, AllNFDataRep env) => NFData (SList Value env) where + rnf = go knownEnv + where + go :: SList STy env' -> SList Value env' -> () + go SNil SNil = () + go ((t :: STy t) `SCons` ts) (v `SCons` vs) = + withDict @(KnownTy t) t $ rnf v `seq` go ts vs + +makeNeuralInputs :: SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] +makeNeuralInputs = + let genArray sh = arrayGenerateLin sh (\i -> fromIntegral i :: Double) + genLayer nin nout = + (genArray (ShNil `ShCons` nout `ShCons` nin) + ,genArray (ShNil `ShCons` nout)) + in let + nin = 30 + n1 = 50 + n2 = 50 + input = Value (genArray (ShNil `ShCons` nin)) + lay1 = Value (genLayer nin n1) + lay2 = Value (genLayer n1 n2) + lay3 = Value (genArray (ShNil `ShCons` n2)) + in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil + +main :: IO () +main = defaultMain + [env (return makeNeuralInputs) $ \inputs -> + bench "neural" (nf (\(inp, ctg) -> gradCHAD inp ctg neural) (inputs, 1.0)) + ] |