{-# LANGUAGE DataKinds #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE FlexibleInstances #-} {-# OPTIONS -Wno-orphans #-} module Main where import Control.DeepSeq import Criterion.Main import Data.Coerce import Data.Int (Int64) import Data.Kind (Constraint) import GHC.Exts (withDict) import AST import Array import CHAD.Top import CHAD.Types import Data import Example import Example.GMM import Example.Types import Interpreter import Interpreter.Rep import Simplify gradCHAD :: KnownEnv env => SList Value env -> Double -> Ex env (TScal TF64) -> (Double, Rep (Tup (D2E env))) gradCHAD input ctg term = interpretOpen False input $ simplifyFix $ ELet ext (EConst ext STF64 ctg) $ chad' knownEnv term instance KnownTy t => NFData (Value t) where rnf = \(Value x) -> go (knownTy @t) x where go :: STy t' -> Rep t' -> () go STNil () = () go (STPair a b) (x, y) = go a x `seq` go b y go (STEither a _) (Left x) = go a x go (STEither _ b) (Right y) = go b y go (STMaybe _) Nothing = () go (STMaybe t) (Just x) = go t x go (STArr (_ :: SNat n) (t :: STy t2)) arr = withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr) go (STScal t) x = case t of STI32 -> rnf x STI64 -> rnf x STF32 -> rnf x STF64 -> rnf x STBool -> rnf x go STAccum{} _ = error "Cannot rnf accumulators" type AllNFDataRep :: [Ty] -> Constraint type family AllNFDataRep env where AllNFDataRep '[] = () AllNFDataRep (t : env) = (NFData (Rep t), AllNFDataRep env) instance (KnownEnv env, AllNFDataRep env) => NFData (SList Value env) where rnf = go knownEnv where go :: SList STy env' -> SList Value env' -> () go SNil SNil = () go ((t :: STy t) `SCons` ts) (v `SCons` vs) = withDict @(KnownTy t) t $ rnf v `seq` go ts vs makeNeuralInputs :: SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] makeNeuralInputs = let genArray sh = arrayGenerateLin sh (\i -> fromIntegral i :: Double) genLayer nin nout = (genArray (ShNil `ShCons` nout `ShCons` nin) ,genArray (ShNil `ShCons` nout)) in let nin = 30 n1 = 50 n2 = 50 input = Value (genArray (ShNil `ShCons` nin)) lay1 = Value (genLayer nin n1) lay2 = Value (genLayer n1 n2) lay3 = Value (genArray (ShNil `ShCons` n2)) in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil makeGMMInputs :: SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] makeGMMInputs = let genArray sh = Value (arrayFromList sh (map fromIntegral [0 .. shapeSize sh - 1])) -- these values are completely arbitrary lol kN = 10 kD = 10 kK = 10 i2i64 = fromIntegral @Int @Int64 valpha = genArray (ShNil `ShCons` kK) vM = genArray (ShNil `ShCons` kK `ShCons` kD) vQ = genArray (ShNil `ShCons` kK `ShCons` kD) vL = genArray (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) vX = genArray (ShNil `ShCons` kN `ShCons` kD) vgamma = 0.42 vm = 2 k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) k2 = 0.5 * vgamma * vgamma k3 = 0.42 -- don't feel like multigammaing today in Value k3 `SCons` Value k2 `SCons` Value k1 `SCons` Value vm `SCons` vX `SCons` vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` SNil main :: IO () main = defaultMain [env (return makeNeuralInputs) $ \inputs -> bench "neural" (nf (\(inp, ctg) -> gradCHAD inp ctg neural) (inputs, 1.0)) ,env (return makeGMMInputs) $ \inputs -> bench "gmm" (nf (\(inp, ctg) -> gradCHAD inp ctg (gmmObjective False)) (inputs, 1.0)) ]