summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-07 23:11:36 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-07 23:11:36 +0100
commit92ddb2263ae495c229badcc209c76a1252bd2752 (patch)
treed69059d755a04121db23406050a643bf33c5b764
parent401e74939fe2a717852acc4b7a452b222d82274a (diff)
Benchmark
-rw-r--r--bench/Main.hs89
-rw-r--r--chad-fast.cabal13
-rw-r--r--src/Array.hs14
-rw-r--r--src/CHAD.hs12
-rw-r--r--src/CHAD/Top.hs53
-rw-r--r--src/Example.hs16
-rw-r--r--test/Main.hs93
7 files changed, 184 insertions, 106 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
new file mode 100644
index 0000000..c62b0f2
--- /dev/null
+++ b/bench/Main.hs
@@ -0,0 +1,89 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE FlexibleInstances #-}
+
+{-# OPTIONS -Wno-orphans #-}
+module Main where
+
+import Control.DeepSeq
+import Criterion.Main
+import Data.Coerce
+import Data.Kind (Constraint)
+import GHC.Exts (withDict)
+
+import AST
+import Array
+import CHAD.Top
+import CHAD.Types
+import Data
+import Example
+import Interpreter
+import Interpreter.Rep
+import Simplify
+
+
+gradCHAD :: KnownEnv env => SList Value env -> Double -> Ex env (TScal TF64) -> (Double, Rep (Tup (D2E env)))
+gradCHAD input ctg term =
+ interpretOpen False input $
+ simplifyFix $
+ ELet ext (EConst ext STF64 ctg) $ chad' knownEnv term
+
+instance KnownTy t => NFData (Value t) where
+ rnf = \(Value x) -> go (knownTy @t) x
+ where
+ go :: STy t' -> Rep t' -> ()
+ go STNil () = ()
+ go (STPair a b) (x, y) = go a x `seq` go b y
+ go (STEither a _) (Left x) = go a x
+ go (STEither _ b) (Right y) = go b y
+ go (STMaybe _) Nothing = ()
+ go (STMaybe t) (Just x) = go t x
+ go (STArr (_ :: SNat n) (t :: STy t2)) arr =
+ withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr)
+ go (STScal t) x = case t of
+ STI32 -> rnf x
+ STI64 -> rnf x
+ STF32 -> rnf x
+ STF64 -> rnf x
+ STBool -> rnf x
+ go STAccum{} _ = error "Cannot rnf accumulators"
+
+type AllNFDataRep :: [Ty] -> Constraint
+type family AllNFDataRep env where
+ AllNFDataRep '[] = ()
+ AllNFDataRep (t : env) = (NFData (Rep t), AllNFDataRep env)
+
+instance (KnownEnv env, AllNFDataRep env) => NFData (SList Value env) where
+ rnf = go knownEnv
+ where
+ go :: SList STy env' -> SList Value env' -> ()
+ go SNil SNil = ()
+ go ((t :: STy t) `SCons` ts) (v `SCons` vs) =
+ withDict @(KnownTy t) t $ rnf v `seq` go ts vs
+
+makeNeuralInputs :: SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)]
+makeNeuralInputs =
+ let genArray sh = arrayGenerateLin sh (\i -> fromIntegral i :: Double)
+ genLayer nin nout =
+ (genArray (ShNil `ShCons` nout `ShCons` nin)
+ ,genArray (ShNil `ShCons` nout))
+ in let
+ nin = 30
+ n1 = 50
+ n2 = 50
+ input = Value (genArray (ShNil `ShCons` nin))
+ lay1 = Value (genLayer nin n1)
+ lay2 = Value (genLayer n1 n2)
+ lay3 = Value (genArray (ShNil `ShCons` n2))
+ in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil
+
+main :: IO ()
+main = defaultMain
+ [env (return makeNeuralInputs) $ \inputs ->
+ bench "neural" (nf (\(inp, ctg) -> gradCHAD inp ctg neural) (inputs, 1.0))
+ ]
diff --git a/chad-fast.cabal b/chad-fast.cabal
index ae8ddf4..8ff3a21 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -20,6 +20,7 @@ library
AST.Weaken
AST.Weaken.Auto
CHAD
+ CHAD.Top
CHAD.Types
-- Compile
Data
@@ -39,6 +40,7 @@ library
build-depends:
base >= 4.19 && < 4.21,
containers,
+ deepseq,
-- template-haskell,
process,
transformers,
@@ -64,3 +66,14 @@ test-suite test
hedgehog,
default-language: Haskell2010
ghc-options: -Wall -threaded
+
+benchmark bench
+ type: exitcode-stdio-1.0
+ main-is: bench/Main.hs
+ build-depends:
+ chad-fast,
+ base,
+ criterion,
+ deepseq,
+ default-language: Haskell2010
+ ghc-options: -Wall -threaded
diff --git a/src/Array.hs b/src/Array.hs
index 8507544..ef9bb8d 100644
--- a/src/Array.hs
+++ b/src/Array.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
@@ -6,10 +7,12 @@
{-# LANGUAGE TupleSections #-}
module Array where
+import Control.DeepSeq
import Control.Monad.Trans.State.Strict
import Data.Foldable (traverse_)
import Data.Vector (Vector)
import qualified Data.Vector as V
+import GHC.Generics (Generic)
import Data
@@ -20,12 +23,20 @@ data Shape n where
deriving instance Show (Shape n)
deriving instance Eq (Shape n)
+instance NFData (Shape n) where
+ rnf ShNil = ()
+ rnf (sh `ShCons` n) = rnf n `seq` rnf sh
+
data Index n where
IxNil :: Index Z
IxCons :: Index n -> Int -> Index (S n)
deriving instance Show (Index n)
deriving instance Eq (Index n)
+instance NFData (Index n) where
+ rnf IxNil = ()
+ rnf (sh `IxCons` n) = rnf n `seq` rnf sh
+
shapeSize :: Shape n -> Int
shapeSize ShNil = 1
shapeSize (ShCons sh n) = shapeSize sh * n
@@ -51,7 +62,8 @@ enumShape sh = map (fromLinearIndex sh) [0 .. shapeSize sh - 1]
-- | TODO: this Vector is a boxed vector, which is horrendously inefficient.
data Array (n :: Nat) t = Array (Shape n) (Vector t)
- deriving (Show, Functor, Foldable, Traversable)
+ deriving (Show, Functor, Foldable, Traversable, Generic)
+instance NFData t => NFData (Array n t)
arrayShape :: Array n t -> Shape n
arrayShape (Array sh _) = sh
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 6b0627d..ffbdcac 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -689,22 +689,20 @@ retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ ta
freezeRet :: Descr env sto
-> Ret env sto t
- -> Ex (D1E env) (D2 t) -- the incoming cotangent value
- -> Ex (Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
-freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) d =
- let (e0', wInsertD2Ac) = weakenBindings weakenExpr (wSinks (d2ace (select SAccum descr))) e0
+ -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
+freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
+ let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0
e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2
in letBinds e0' $
EPair ext
(weakenExpr wInsertD2Ac e1)
- (ELet ext (weakenExpr (sinkWithBindings e0 .> wSinks (d2ace (select SAccum descr))) d) $
- ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t))
&. #tape (subList (bindingsBinds e0) subtape)
&. #shbinds (bindingsBinds e0)
&. #d2ace (d2ace (select SAccum descr))
&. #tl (sD1eEnv descr))
(#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)
- (#d :++: #shbinds :++: #d2ace :++: #tl))
+ (#shbinds :++: #d :++: #d2ace :++: #tl))
e2') $
expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
new file mode 100644
index 0000000..9df5412
--- /dev/null
+++ b/src/CHAD/Top.hs
@@ -0,0 +1,53 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Top where
+
+import AST
+import CHAD
+import CHAD.Types
+import Data
+
+
+type family MergeEnv env where
+ MergeEnv '[] = '[]
+ MergeEnv (t : ts) = "merge" : MergeEnv ts
+
+mergeDescr :: SList STy env -> Descr env (MergeEnv env)
+mergeDescr SNil = DTop
+mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, SMerge)
+
+mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[]
+mergeEnvNoAccum SNil = Refl
+mergeEnvNoAccum (_ `SCons` env) | Refl <- mergeEnvNoAccum env = Refl
+
+mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env
+mergeEnvOnlyMerge SNil = Refl
+mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl
+
+d1Identity :: STy t -> D1 t :~: t
+d1Identity = \case
+ STNil -> Refl
+ STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STMaybe t | Refl <- d1Identity t -> Refl
+ STArr _ t | Refl <- d1Identity t -> Refl
+ STScal _ -> Refl
+ STAccum{} -> error "Accumulators not allowed in input program"
+
+d1eIdentity :: SList STy env -> D1E env :~: env
+d1eIdentity SNil = Refl
+d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
+
+chad :: SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
+chad env term
+ | Refl <- mergeEnvNoAccum env
+ , Refl <- mergeEnvOnlyMerge env
+ = freezeRet (mergeDescr env) (drev (mergeDescr env) term)
+
+chad' :: SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))
+chad' env term
+ | Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term)
+ = chad env term
diff --git a/src/Example.hs b/src/Example.hs
index d0405af..1775bb9 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -11,6 +11,7 @@ import Array
import AST
import AST.Pretty
import CHAD
+import CHAD.Top
import Data
import ForwardAD
import Interpreter
@@ -23,16 +24,6 @@ import Example.Format
-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0)
-type family MergeEnv env where
- MergeEnv '[] = '[]
- MergeEnv (t : ts) = "merge" : MergeEnv ts
-
-mergeDescr :: KnownEnv env => Descr env (MergeEnv env)
-mergeDescr = go knownEnv
- where go :: SList STy env -> Descr env (MergeEnv env)
- go SNil = DTop
- go (t `SCons` env) = go env `DPush` (t, SMerge)
-
bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c
bin op a b = EOp ext op (EPair ext a b)
@@ -195,9 +186,8 @@ neuralGo =
argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil)
revderiv =
simplifyN 20 $
- freezeRet mergeDescr
- (drev mergeDescr neural)
- (EConst ext STF64 1.0)
+ ELet ext (EConst ext STF64 1.0) $
+ chad knownEnv neural
(primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen False argument revderiv
(Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0
in trace (formatter (ppExpr knownEnv revderiv)) $
diff --git a/test/Main.hs b/test/Main.hs
index e7dda69..d3e55b3 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,19 +1,15 @@
{-# LANGUAGE DataKinds #-}
--- {-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
--- {-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Main where
import Data.Bifunctor
--- import qualified Data.Dependent.Map as DMap
--- import Data.Dependent.Map (DMap)
import Data.List (intercalate)
import Hedgehog
import qualified Hedgehog.Gen as Gen
@@ -23,7 +19,7 @@ import Hedgehog.Main
import Array
import AST
import AST.Pretty
-import CHAD
+import CHAD.Top
import CHAD.Types
import Data
import qualified Example
@@ -34,63 +30,19 @@ import Language
import Simplify
-type family MapMerge env where
- MapMerge '[] = '[]
- MapMerge (t : ts) = "merge" : MapMerge ts
-
-mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[]
-mapMergeNoAccum SNil = Refl
-mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl
-
-mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env
-mapMergeOnlyMerge SNil = Refl
-mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl
-
-primalEnv :: SList STy env' -> SList STy (D1E env')
-primalEnv SNil = SNil
-primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env
-
data SimplIters = SimplIters Int | SimplFix
deriving (Show)
-diffCHAD :: SimplIters -> SList STy env -> Ex env (TScal TF64)
- -> Ex (D1E env) (TPair (TScal TF64) (Tup (D2E env)))
-diffCHAD = \simplIters env term ->
- case (mapMergeNoAccum env, mapMergeOnlyMerge env, envKnown (primalEnv env)) of
- (Refl, Refl, Dict) ->
- let descr = makeMergeDescr env
- simpl = case simplIters of
- SimplIters n -> simplifyN n
- SimplFix -> simplifyFix
- in simpl $ freezeRet descr (drev descr term) (EConst ext STF64 1.0)
- where
- makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
- makeMergeDescr SNil = DTop
- makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge)
-
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD = \simplIters env term input ->
- case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
- (Refl, Refl) ->
- let dterm = diffCHAD simplIters env term
- input1 = toPrimalE env input
- (out, grad) = interpretOpen False input1 dterm
- in (ppExpr (primalEnv env) dterm, (out, unTup vUnpair (d2e env) (Value grad)))
- where
- toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env')
- toPrimalE SNil SNil = SNil
- toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp
-
- toPrimal :: STy t -> Rep t -> Rep (D1 t)
- toPrimal = \case
- STNil -> id
- STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
- STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
- STMaybe t -> fmap (toPrimal t)
- STArr _ t -> fmap (toPrimal t)
- STScal _ -> id
- STAccum{} -> error "Accumulators not allowed in input program"
+ let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' env term
+ dterm | Dict <- envKnown env
+ = case simplIters of
+ SimplIters n -> simplifyN n dtermNonSimpl
+ SimplFix -> simplifyFix dtermNonSimpl
+ (out, grad) = interpretOpen False input dterm
+ in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env)))
@@ -172,29 +124,6 @@ genEnv :: SList STy env -> Gen (SList Value env)
genEnv SNil = return SNil
genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env
--- data TemplateVar n = TemplateVar (SNat n) String
--- deriving (Show)
-
--- data Template t where
--- TpShape :: TemplateVar n -> STy t -> Template (TArr n t)
--- TpAny :: STy t -> Template t
--- TpPair :: Template a -> Template b -> Template (TPair a b)
--- deriving instance Show (Template t)
-
--- data ShapeConstraint n = ShapeAtLeast (Shape n)
--- deriving (Show)
-
--- genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t)
--- genTemplate = _
-
--- genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env)
--- genEnvTemplateExact shapes env = _
-
--- genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env)
--- genEnvTemplate constrs env = do
--- shapes <- DMap.traverseWithKey _ constrs
--- genEnvTemplateExact shapes env
-
showEnv :: SList STy env -> SList Value env -> String
showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
where
@@ -205,11 +134,6 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property
adTest = flip adTestGen (genEnv (knownEnv @env))
--- adTestTp :: forall env. KnownEnv env
--- => DMap TemplateVar ShapeConstraint -> SList Template env
--- -> Ex env (TScal TF64) -> Property
--- adTestTp envConstrs envTp = adTestGen (genEnvTemplate envConstrs envTp)
-
adTestGen :: forall env. KnownEnv env
=> Ex env (TScal TF64) -> Gen (SList Value env) -> Property
adTestGen expr envGenerator = property $ do
@@ -268,7 +192,6 @@ tests = checkSequential $ Group "AD"
idx0 $
build SZ (shape #x) $ #idx :-> #x ! #idx)
- -- :hindentstr ppExpr knownEnv $ diffCHAD 20 knownEnv term_build1_sum
,("build1-sum", adTest term_build1_sum)
,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $