{-# LANGUAGE DataKinds #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE FlexibleInstances #-} {-# OPTIONS -Wno-orphans #-} module Main where import Control.DeepSeq import Criterion.Main import Data.Coerce import Data.Int (Int64) import Data.Kind (Constraint) import GHC.Exts (withDict) import AST import AST.UnMonoid import Array import qualified CHAD (defaultConfig) import CHAD.Top import CHAD.Types import Compile import Data import Example import Example.GMM import Example.Types import Interpreter.Rep import Simplify gradCHAD :: KnownEnv env => CHADConfig -> Ex env (TScal TF64) -> IO (SList Value env -> IO (Double, Rep (Tup (D2E env)))) gradCHAD config term = compile knownEnv $ simplifyFix $ unMonoid $ simplifyFix $ ELet ext (EConst ext STF64 1.0) $ chad' config knownEnv term instance KnownTy t => NFData (Value t) where rnf = \(Value x) -> go (knownTy @t) x where go :: STy t' -> Rep t' -> () go STNil () = () go (STPair a b) (x, y) = go a x `seq` go b y go (STEither a _) (Left x) = go a x go (STEither _ b) (Right y) = go b y go (STMaybe _) Nothing = () go (STMaybe t) (Just x) = go t x go (STArr (_ :: SNat n) (t :: STy t2)) arr = withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr) go (STScal t) x = case t of STI32 -> rnf x STI64 -> rnf x STF32 -> rnf x STF64 -> rnf x STBool -> rnf x go STAccum{} _ = error "Cannot rnf accumulators" type AllNFDataRep :: [Ty] -> Constraint type family AllNFDataRep env where AllNFDataRep '[] = () AllNFDataRep (t : env) = (NFData (Rep t), AllNFDataRep env) instance (KnownEnv env, AllNFDataRep env) => NFData (SList Value env) where rnf = go knownEnv where go :: SList STy env' -> SList Value env' -> () go SNil SNil = () go ((t :: STy t) `SCons` ts) (v `SCons` vs) = withDict @(KnownTy t) t $ rnf v `seq` go ts vs makeNeuralInputs :: SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] makeNeuralInputs = let genArray sh = arrayGenerateLin sh (\i -> fromIntegral i :: Double) genLayer nin nout = (genArray (ShNil `ShCons` nout `ShCons` nin) ,genArray (ShNil `ShCons` nout)) in let nin = 30 n1 = 50 n2 = 50 input = Value (genArray (ShNil `ShCons` nin)) lay1 = Value (genLayer nin n1) lay2 = Value (genLayer n1 n2) lay3 = Value (genArray (ShNil `ShCons` n2)) in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil makeGMMInputs :: SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] makeGMMInputs = let genArray sh = Value (arrayFromList sh (map fromIntegral [0 .. shapeSize sh - 1])) -- these values are completely arbitrary lol kN = 10 kD = 10 kK = 10 i2i64 = fromIntegral @Int @Int64 valpha = genArray (ShNil `ShCons` kK) vM = genArray (ShNil `ShCons` kK `ShCons` kD) vQ = genArray (ShNil `ShCons` kK `ShCons` kD) vL = genArray (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) vX = genArray (ShNil `ShCons` kN `ShCons` kD) vgamma = 0.42 vm = 2 k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) k2 = 0.5 * vgamma * vgamma k3 = 0.42 -- don't feel like multigammaing today in Value k3 `SCons` Value k2 `SCons` Value k1 `SCons` Value vm `SCons` vX `SCons` vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` SNil accumConfig :: CHADConfig accumConfig = chcSetAccum CHAD.defaultConfig main :: IO () main = defaultMain [env (return makeNeuralInputs) $ \inputs -> bgroup "neural" [env (gradCHAD CHAD.defaultConfig neural) $ \fun -> bench "default" (nfAppIO fun inputs) ,env (gradCHAD accumConfig neural) $ \fun -> bench "accum" (nfAppIO fun inputs) ] ,env (return makeGMMInputs) $ \inputs -> bgroup "gmm" [env (gradCHAD CHAD.defaultConfig (gmmObjective False)) $ \fun -> bench "default" (nfAppIO fun inputs) ,env (gradCHAD accumConfig (gmmObjective False)) $ \fun -> bench "accum" (nfAppIO fun inputs) ] ]