blob: c62b0f262de6f1885a5b02b76974de43aa93edd7 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleInstances #-}
{-# OPTIONS -Wno-orphans #-}
module Main where
import Control.DeepSeq
import Criterion.Main
import Data.Coerce
import Data.Kind (Constraint)
import GHC.Exts (withDict)
import AST
import Array
import CHAD.Top
import CHAD.Types
import Data
import Example
import Interpreter
import Interpreter.Rep
import Simplify
gradCHAD :: KnownEnv env => SList Value env -> Double -> Ex env (TScal TF64) -> (Double, Rep (Tup (D2E env)))
gradCHAD input ctg term =
interpretOpen False input $
simplifyFix $
ELet ext (EConst ext STF64 ctg) $ chad' knownEnv term
instance KnownTy t => NFData (Value t) where
rnf = \(Value x) -> go (knownTy @t) x
where
go :: STy t' -> Rep t' -> ()
go STNil () = ()
go (STPair a b) (x, y) = go a x `seq` go b y
go (STEither a _) (Left x) = go a x
go (STEither _ b) (Right y) = go b y
go (STMaybe _) Nothing = ()
go (STMaybe t) (Just x) = go t x
go (STArr (_ :: SNat n) (t :: STy t2)) arr =
withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr)
go (STScal t) x = case t of
STI32 -> rnf x
STI64 -> rnf x
STF32 -> rnf x
STF64 -> rnf x
STBool -> rnf x
go STAccum{} _ = error "Cannot rnf accumulators"
type AllNFDataRep :: [Ty] -> Constraint
type family AllNFDataRep env where
AllNFDataRep '[] = ()
AllNFDataRep (t : env) = (NFData (Rep t), AllNFDataRep env)
instance (KnownEnv env, AllNFDataRep env) => NFData (SList Value env) where
rnf = go knownEnv
where
go :: SList STy env' -> SList Value env' -> ()
go SNil SNil = ()
go ((t :: STy t) `SCons` ts) (v `SCons` vs) =
withDict @(KnownTy t) t $ rnf v `seq` go ts vs
makeNeuralInputs :: SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)]
makeNeuralInputs =
let genArray sh = arrayGenerateLin sh (\i -> fromIntegral i :: Double)
genLayer nin nout =
(genArray (ShNil `ShCons` nout `ShCons` nin)
,genArray (ShNil `ShCons` nout))
in let
nin = 30
n1 = 50
n2 = 50
input = Value (genArray (ShNil `ShCons` nin))
lay1 = Value (genLayer nin n1)
lay2 = Value (genLayer n1 n2)
lay3 = Value (genArray (ShNil `ShCons` n2))
in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil
main :: IO ()
main = defaultMain
[env (return makeNeuralInputs) $ \inputs ->
bench "neural" (nf (\(inp, ctg) -> gradCHAD inp ctg neural) (inputs, 1.0))
]
|