{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS -Wno-orphans #-} module Main where import Control.DeepSeq import Criterion.Main import Data.Int (Int64) import Data.Kind (Constraint) import GHC.Exts (withDict) import AST import AST.Count import AST.UnMonoid import Array import qualified CHAD (defaultConfig) import CHAD.Top import CHAD.Types import Compile import Data import Example import Example.GMM import Example.Types import Interpreter.Rep import Simplify gradCHAD :: KnownEnv env => CHADConfig -> Ex env (TScal TF64) -> IO (SList Value env -> IO (Double, Rep (Tup (D2E env)))) gradCHAD config term = compile knownEnv $ simplifyFix $ pruneExpr knownEnv $ simplifyFix $ unMonoid $ simplifyFix $ ELet ext (EConst ext STF64 1.0) $ chad' config knownEnv $ simplifyFix term type AllNFDataRep :: [Ty] -> Constraint type family AllNFDataRep env where AllNFDataRep '[] = () AllNFDataRep (t : env) = (NFData (Rep t), AllNFDataRep env) instance (KnownEnv env, AllNFDataRep env) => NFData (SList Value env) where rnf = go knownEnv where go :: SList STy env' -> SList Value env' -> () go SNil SNil = () go ((t :: STy t) `SCons` ts) (v `SCons` vs) = withDict @(KnownTy t) t $ rnf v `seq` go ts vs makeNeuralInputs :: SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] makeNeuralInputs = let genArray sh = arrayGenerateLin sh (\i -> fromIntegral i :: Double) genLayer nin nout = (genArray (ShNil `ShCons` nout `ShCons` nin) ,genArray (ShNil `ShCons` nout)) in let nin = 30 n1 = 50 n2 = 50 input = Value (genArray (ShNil `ShCons` nin)) lay1 = Value (genLayer nin n1) lay2 = Value (genLayer n1 n2) lay3 = Value (genArray (ShNil `ShCons` n2)) in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil makeGMMInputs :: SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] makeGMMInputs = let genArray sh = Value (arrayFromList sh (map fromIntegral [0 .. shapeSize sh - 1])) -- these values are completely arbitrary lol kN = 10 kD = 10 kK = 10 i2i64 = fromIntegral @Int @Int64 valpha = genArray (ShNil `ShCons` kK) vM = genArray (ShNil `ShCons` kK `ShCons` kD) vQ = genArray (ShNil `ShCons` kK `ShCons` kD) vL = genArray (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) vX = genArray (ShNil `ShCons` kN `ShCons` kD) vgamma = 0.42 vm = 2 k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) k2 = 0.5 * vgamma * vgamma k3 = 0.42 -- don't feel like multigammaing today in Value k3 `SCons` Value k2 `SCons` Value k1 `SCons` Value vm `SCons` vX `SCons` vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` SNil accumConfig :: CHADConfig accumConfig = chcSetAccum CHAD.defaultConfig bgroupDefaultAccum :: (KnownEnv env, NFData (Rep (Tup (D2E env)))) => String -> Ex env R -> SList Value env -> Benchmark bgroupDefaultAccum name term inputs = bgroup name [env (gradCHAD CHAD.defaultConfig term) $ \fun -> bench "default" (nfAppIO fun inputs) ,env (gradCHAD accumConfig term) $ \fun -> bench "accum" (nfAppIO fun inputs) ] main :: IO () main = defaultMain [env (return makeNeuralInputs) $ bgroupDefaultAccum "neural" neural ,env (return makeGMMInputs) $ bgroupDefaultAccum "gmm" (gmmObjective False) ,bgroupDefaultAccum "uniform-free" exUniformFree (Value 42.0 `SCons` Value 1000_000 `SCons` SNil) ]