1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleInstances #-}
{-# OPTIONS -Wno-orphans #-}
module Main where
import Control.DeepSeq
import Criterion.Main
import Data.Coerce
import Data.Int (Int64)
import Data.Kind (Constraint)
import GHC.Exts (withDict)
import AST
import Array
import qualified CHAD (defaultConfig)
import CHAD.Top
import CHAD.Types
import Data
import Example
import Example.GMM
import Example.Types
import Interpreter
import Interpreter.Rep
import Simplify
gradCHAD :: KnownEnv env => SList Value env -> Double -> Ex env (TScal TF64) -> (Double, Rep (Tup (D2E env)))
gradCHAD input ctg term =
interpretOpen False input $
simplifyFix $
ELet ext (EConst ext STF64 ctg) $ chad' CHAD.defaultConfig knownEnv term
instance KnownTy t => NFData (Value t) where
rnf = \(Value x) -> go (knownTy @t) x
where
go :: STy t' -> Rep t' -> ()
go STNil () = ()
go (STPair a b) (x, y) = go a x `seq` go b y
go (STEither a _) (Left x) = go a x
go (STEither _ b) (Right y) = go b y
go (STMaybe _) Nothing = ()
go (STMaybe t) (Just x) = go t x
go (STArr (_ :: SNat n) (t :: STy t2)) arr =
withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr)
go (STScal t) x = case t of
STI32 -> rnf x
STI64 -> rnf x
STF32 -> rnf x
STF64 -> rnf x
STBool -> rnf x
go STAccum{} _ = error "Cannot rnf accumulators"
type AllNFDataRep :: [Ty] -> Constraint
type family AllNFDataRep env where
AllNFDataRep '[] = ()
AllNFDataRep (t : env) = (NFData (Rep t), AllNFDataRep env)
instance (KnownEnv env, AllNFDataRep env) => NFData (SList Value env) where
rnf = go knownEnv
where
go :: SList STy env' -> SList Value env' -> ()
go SNil SNil = ()
go ((t :: STy t) `SCons` ts) (v `SCons` vs) =
withDict @(KnownTy t) t $ rnf v `seq` go ts vs
makeNeuralInputs :: SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)]
makeNeuralInputs =
let genArray sh = arrayGenerateLin sh (\i -> fromIntegral i :: Double)
genLayer nin nout =
(genArray (ShNil `ShCons` nout `ShCons` nin)
,genArray (ShNil `ShCons` nout))
in let
nin = 30
n1 = 50
n2 = 50
input = Value (genArray (ShNil `ShCons` nin))
lay1 = Value (genLayer nin n1)
lay2 = Value (genLayer n1 n2)
lay3 = Value (genArray (ShNil `ShCons` n2))
in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil
makeGMMInputs :: SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64]
makeGMMInputs =
let genArray sh = Value (arrayFromList sh (map fromIntegral [0 .. shapeSize sh - 1]))
-- these values are completely arbitrary lol
kN = 10
kD = 10
kK = 10
i2i64 = fromIntegral @Int @Int64
valpha = genArray (ShNil `ShCons` kK)
vM = genArray (ShNil `ShCons` kK `ShCons` kD)
vQ = genArray (ShNil `ShCons` kK `ShCons` kD)
vL = genArray (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
vX = genArray (ShNil `ShCons` kN `ShCons` kD)
vgamma = 0.42
vm = 2
k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
k2 = 0.5 * vgamma * vgamma
k3 = 0.42 -- don't feel like multigammaing today
in Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
Value vm `SCons` vX `SCons`
vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
SNil
main :: IO ()
main = defaultMain
[env (return makeNeuralInputs) $ \inputs ->
bench "neural" (nf (\(inp, ctg) -> gradCHAD inp ctg neural) (inputs, 1.0))
,env (return makeGMMInputs) $ \inputs ->
bench "gmm" (nf (\(inp, ctg) -> gradCHAD inp ctg (gmmObjective False)) (inputs, 1.0))
]
|