summaryrefslogtreecommitdiff
path: root/bench/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'bench/Main.hs')
-rw-r--r--bench/Main.hs89
1 files changed, 89 insertions, 0 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
new file mode 100644
index 0000000..c62b0f2
--- /dev/null
+++ b/bench/Main.hs
@@ -0,0 +1,89 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE FlexibleInstances #-}
+
+{-# OPTIONS -Wno-orphans #-}
+module Main where
+
+import Control.DeepSeq
+import Criterion.Main
+import Data.Coerce
+import Data.Kind (Constraint)
+import GHC.Exts (withDict)
+
+import AST
+import Array
+import CHAD.Top
+import CHAD.Types
+import Data
+import Example
+import Interpreter
+import Interpreter.Rep
+import Simplify
+
+
+gradCHAD :: KnownEnv env => SList Value env -> Double -> Ex env (TScal TF64) -> (Double, Rep (Tup (D2E env)))
+gradCHAD input ctg term =
+ interpretOpen False input $
+ simplifyFix $
+ ELet ext (EConst ext STF64 ctg) $ chad' knownEnv term
+
+instance KnownTy t => NFData (Value t) where
+ rnf = \(Value x) -> go (knownTy @t) x
+ where
+ go :: STy t' -> Rep t' -> ()
+ go STNil () = ()
+ go (STPair a b) (x, y) = go a x `seq` go b y
+ go (STEither a _) (Left x) = go a x
+ go (STEither _ b) (Right y) = go b y
+ go (STMaybe _) Nothing = ()
+ go (STMaybe t) (Just x) = go t x
+ go (STArr (_ :: SNat n) (t :: STy t2)) arr =
+ withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr)
+ go (STScal t) x = case t of
+ STI32 -> rnf x
+ STI64 -> rnf x
+ STF32 -> rnf x
+ STF64 -> rnf x
+ STBool -> rnf x
+ go STAccum{} _ = error "Cannot rnf accumulators"
+
+type AllNFDataRep :: [Ty] -> Constraint
+type family AllNFDataRep env where
+ AllNFDataRep '[] = ()
+ AllNFDataRep (t : env) = (NFData (Rep t), AllNFDataRep env)
+
+instance (KnownEnv env, AllNFDataRep env) => NFData (SList Value env) where
+ rnf = go knownEnv
+ where
+ go :: SList STy env' -> SList Value env' -> ()
+ go SNil SNil = ()
+ go ((t :: STy t) `SCons` ts) (v `SCons` vs) =
+ withDict @(KnownTy t) t $ rnf v `seq` go ts vs
+
+makeNeuralInputs :: SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)]
+makeNeuralInputs =
+ let genArray sh = arrayGenerateLin sh (\i -> fromIntegral i :: Double)
+ genLayer nin nout =
+ (genArray (ShNil `ShCons` nout `ShCons` nin)
+ ,genArray (ShNil `ShCons` nout))
+ in let
+ nin = 30
+ n1 = 50
+ n2 = 50
+ input = Value (genArray (ShNil `ShCons` nin))
+ lay1 = Value (genLayer nin n1)
+ lay2 = Value (genLayer n1 n2)
+ lay3 = Value (genArray (ShNil `ShCons` n2))
+ in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil
+
+main :: IO ()
+main = defaultMain
+ [env (return makeNeuralInputs) $ \inputs ->
+ bench "neural" (nf (\(inp, ctg) -> gradCHAD inp ctg neural) (inputs, 1.0))
+ ]