summaryrefslogtreecommitdiff
path: root/bench/Main.hs
blob: c62b0f262de6f1885a5b02b76974de43aa93edd7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleInstances #-}

{-# OPTIONS -Wno-orphans #-}
module Main where

import Control.DeepSeq
import Criterion.Main
import Data.Coerce
import Data.Kind (Constraint)
import GHC.Exts (withDict)

import AST
import Array
import CHAD.Top
import CHAD.Types
import Data
import Example
import Interpreter
import Interpreter.Rep
import Simplify


gradCHAD :: KnownEnv env => SList Value env -> Double -> Ex env (TScal TF64) -> (Double, Rep (Tup (D2E env)))
gradCHAD input ctg term =
  interpretOpen False input $
    simplifyFix $
      ELet ext (EConst ext STF64 ctg) $ chad' knownEnv term

instance KnownTy t => NFData (Value t) where
  rnf = \(Value x) -> go (knownTy @t) x
    where
      go :: STy t' -> Rep t' -> ()
      go STNil () = ()
      go (STPair a b) (x, y) = go a x `seq` go b y
      go (STEither a _) (Left x) = go a x
      go (STEither _ b) (Right y) = go b y
      go (STMaybe _) Nothing = ()
      go (STMaybe t) (Just x) = go t x
      go (STArr (_ :: SNat n) (t :: STy t2)) arr =
        withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr)
      go (STScal t) x = case t of
        STI32 -> rnf x
        STI64 -> rnf x
        STF32 -> rnf x
        STF64 -> rnf x
        STBool -> rnf x
      go STAccum{} _ = error "Cannot rnf accumulators"

type AllNFDataRep :: [Ty] -> Constraint
type family AllNFDataRep env where
  AllNFDataRep '[] = ()
  AllNFDataRep (t : env) = (NFData (Rep t), AllNFDataRep env)

instance (KnownEnv env, AllNFDataRep env) => NFData (SList Value env) where
  rnf = go knownEnv
    where
      go :: SList STy env' -> SList Value env' -> ()
      go SNil SNil = ()
      go ((t :: STy t) `SCons` ts) (v `SCons` vs) =
        withDict @(KnownTy t) t $ rnf v `seq` go ts vs

makeNeuralInputs :: SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)]
makeNeuralInputs =
  let genArray sh = arrayGenerateLin sh (\i -> fromIntegral i :: Double)
      genLayer nin nout =
        (genArray (ShNil `ShCons` nout `ShCons` nin)
        ,genArray (ShNil `ShCons` nout))
  in let
      nin = 30
      n1 = 50
      n2 = 50
      input = Value (genArray (ShNil `ShCons` nin))
      lay1 = Value (genLayer nin n1)
      lay2 = Value (genLayer n1 n2)
      lay3 = Value (genArray (ShNil `ShCons` n2))
  in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil

main :: IO ()
main = defaultMain
  [env (return makeNeuralInputs) $ \inputs ->
    bench "neural" (nf (\(inp, ctg) -> gradCHAD inp ctg neural) (inputs, 1.0))
  ]