{-# LANGUAGE DataKinds #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE FlexibleInstances #-} {-# OPTIONS -Wno-orphans #-} module Main where import Control.DeepSeq import Criterion.Main import Data.Coerce import Data.Kind (Constraint) import GHC.Exts (withDict) import AST import Array import CHAD.Top import CHAD.Types import Data import Example import Interpreter import Interpreter.Rep import Simplify gradCHAD :: KnownEnv env => SList Value env -> Double -> Ex env (TScal TF64) -> (Double, Rep (Tup (D2E env))) gradCHAD input ctg term = interpretOpen False input $ simplifyFix $ ELet ext (EConst ext STF64 ctg) $ chad' knownEnv term instance KnownTy t => NFData (Value t) where rnf = \(Value x) -> go (knownTy @t) x where go :: STy t' -> Rep t' -> () go STNil () = () go (STPair a b) (x, y) = go a x `seq` go b y go (STEither a _) (Left x) = go a x go (STEither _ b) (Right y) = go b y go (STMaybe _) Nothing = () go (STMaybe t) (Just x) = go t x go (STArr (_ :: SNat n) (t :: STy t2)) arr = withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr) go (STScal t) x = case t of STI32 -> rnf x STI64 -> rnf x STF32 -> rnf x STF64 -> rnf x STBool -> rnf x go STAccum{} _ = error "Cannot rnf accumulators" type AllNFDataRep :: [Ty] -> Constraint type family AllNFDataRep env where AllNFDataRep '[] = () AllNFDataRep (t : env) = (NFData (Rep t), AllNFDataRep env) instance (KnownEnv env, AllNFDataRep env) => NFData (SList Value env) where rnf = go knownEnv where go :: SList STy env' -> SList Value env' -> () go SNil SNil = () go ((t :: STy t) `SCons` ts) (v `SCons` vs) = withDict @(KnownTy t) t $ rnf v `seq` go ts vs makeNeuralInputs :: SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] makeNeuralInputs = let genArray sh = arrayGenerateLin sh (\i -> fromIntegral i :: Double) genLayer nin nout = (genArray (ShNil `ShCons` nout `ShCons` nin) ,genArray (ShNil `ShCons` nout)) in let nin = 30 n1 = 50 n2 = 50 input = Value (genArray (ShNil `ShCons` nin)) lay1 = Value (genLayer nin n1) lay2 = Value (genLayer n1 n2) lay3 = Value (genArray (ShNil `ShCons` n2)) in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil main :: IO () main = defaultMain [env (return makeNeuralInputs) $ \inputs -> bench "neural" (nf (\(inp, ctg) -> gradCHAD inp ctg neural) (inputs, 1.0)) ]