diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-07 23:11:36 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-07 23:11:36 +0100 |
commit | 92ddb2263ae495c229badcc209c76a1252bd2752 (patch) | |
tree | d69059d755a04121db23406050a643bf33c5b764 | |
parent | 401e74939fe2a717852acc4b7a452b222d82274a (diff) |
Benchmark
-rw-r--r-- | bench/Main.hs | 89 | ||||
-rw-r--r-- | chad-fast.cabal | 13 | ||||
-rw-r--r-- | src/Array.hs | 14 | ||||
-rw-r--r-- | src/CHAD.hs | 12 | ||||
-rw-r--r-- | src/CHAD/Top.hs | 53 | ||||
-rw-r--r-- | src/Example.hs | 16 | ||||
-rw-r--r-- | test/Main.hs | 93 |
7 files changed, 184 insertions, 106 deletions
diff --git a/bench/Main.hs b/bench/Main.hs new file mode 100644 index 0000000..c62b0f2 --- /dev/null +++ b/bench/Main.hs @@ -0,0 +1,89 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE FlexibleInstances #-} + +{-# OPTIONS -Wno-orphans #-} +module Main where + +import Control.DeepSeq +import Criterion.Main +import Data.Coerce +import Data.Kind (Constraint) +import GHC.Exts (withDict) + +import AST +import Array +import CHAD.Top +import CHAD.Types +import Data +import Example +import Interpreter +import Interpreter.Rep +import Simplify + + +gradCHAD :: KnownEnv env => SList Value env -> Double -> Ex env (TScal TF64) -> (Double, Rep (Tup (D2E env))) +gradCHAD input ctg term = + interpretOpen False input $ + simplifyFix $ + ELet ext (EConst ext STF64 ctg) $ chad' knownEnv term + +instance KnownTy t => NFData (Value t) where + rnf = \(Value x) -> go (knownTy @t) x + where + go :: STy t' -> Rep t' -> () + go STNil () = () + go (STPair a b) (x, y) = go a x `seq` go b y + go (STEither a _) (Left x) = go a x + go (STEither _ b) (Right y) = go b y + go (STMaybe _) Nothing = () + go (STMaybe t) (Just x) = go t x + go (STArr (_ :: SNat n) (t :: STy t2)) arr = + withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr) + go (STScal t) x = case t of + STI32 -> rnf x + STI64 -> rnf x + STF32 -> rnf x + STF64 -> rnf x + STBool -> rnf x + go STAccum{} _ = error "Cannot rnf accumulators" + +type AllNFDataRep :: [Ty] -> Constraint +type family AllNFDataRep env where + AllNFDataRep '[] = () + AllNFDataRep (t : env) = (NFData (Rep t), AllNFDataRep env) + +instance (KnownEnv env, AllNFDataRep env) => NFData (SList Value env) where + rnf = go knownEnv + where + go :: SList STy env' -> SList Value env' -> () + go SNil SNil = () + go ((t :: STy t) `SCons` ts) (v `SCons` vs) = + withDict @(KnownTy t) t $ rnf v `seq` go ts vs + +makeNeuralInputs :: SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] +makeNeuralInputs = + let genArray sh = arrayGenerateLin sh (\i -> fromIntegral i :: Double) + genLayer nin nout = + (genArray (ShNil `ShCons` nout `ShCons` nin) + ,genArray (ShNil `ShCons` nout)) + in let + nin = 30 + n1 = 50 + n2 = 50 + input = Value (genArray (ShNil `ShCons` nin)) + lay1 = Value (genLayer nin n1) + lay2 = Value (genLayer n1 n2) + lay3 = Value (genArray (ShNil `ShCons` n2)) + in input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil + +main :: IO () +main = defaultMain + [env (return makeNeuralInputs) $ \inputs -> + bench "neural" (nf (\(inp, ctg) -> gradCHAD inp ctg neural) (inputs, 1.0)) + ] diff --git a/chad-fast.cabal b/chad-fast.cabal index ae8ddf4..8ff3a21 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -20,6 +20,7 @@ library AST.Weaken AST.Weaken.Auto CHAD + CHAD.Top CHAD.Types -- Compile Data @@ -39,6 +40,7 @@ library build-depends: base >= 4.19 && < 4.21, containers, + deepseq, -- template-haskell, process, transformers, @@ -64,3 +66,14 @@ test-suite test hedgehog, default-language: Haskell2010 ghc-options: -Wall -threaded + +benchmark bench + type: exitcode-stdio-1.0 + main-is: bench/Main.hs + build-depends: + chad-fast, + base, + criterion, + deepseq, + default-language: Haskell2010 + ghc-options: -Wall -threaded diff --git a/src/Array.hs b/src/Array.hs index 8507544..ef9bb8d 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} @@ -6,10 +7,12 @@ {-# LANGUAGE TupleSections #-} module Array where +import Control.DeepSeq import Control.Monad.Trans.State.Strict import Data.Foldable (traverse_) import Data.Vector (Vector) import qualified Data.Vector as V +import GHC.Generics (Generic) import Data @@ -20,12 +23,20 @@ data Shape n where deriving instance Show (Shape n) deriving instance Eq (Shape n) +instance NFData (Shape n) where + rnf ShNil = () + rnf (sh `ShCons` n) = rnf n `seq` rnf sh + data Index n where IxNil :: Index Z IxCons :: Index n -> Int -> Index (S n) deriving instance Show (Index n) deriving instance Eq (Index n) +instance NFData (Index n) where + rnf IxNil = () + rnf (sh `IxCons` n) = rnf n `seq` rnf sh + shapeSize :: Shape n -> Int shapeSize ShNil = 1 shapeSize (ShCons sh n) = shapeSize sh * n @@ -51,7 +62,8 @@ enumShape sh = map (fromLinearIndex sh) [0 .. shapeSize sh - 1] -- | TODO: this Vector is a boxed vector, which is horrendously inefficient. data Array (n :: Nat) t = Array (Shape n) (Vector t) - deriving (Show, Functor, Foldable, Traversable) + deriving (Show, Functor, Foldable, Traversable, Generic) +instance NFData t => NFData (Array n t) arrayShape :: Array n t -> Shape n arrayShape (Array sh _) = sh diff --git a/src/CHAD.hs b/src/CHAD.hs index 6b0627d..ffbdcac 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -689,22 +689,20 @@ retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ ta freezeRet :: Descr env sto -> Ret env sto t - -> Ex (D1E env) (D2 t) -- the incoming cotangent value - -> Ex (Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) -freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) d = - let (e0', wInsertD2Ac) = weakenBindings weakenExpr (wSinks (d2ace (select SAccum descr))) e0 + -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) +freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = + let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2 in letBinds e0' $ EPair ext (weakenExpr wInsertD2Ac e1) - (ELet ext (weakenExpr (sinkWithBindings e0 .> wSinks (d2ace (select SAccum descr))) d) $ - ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) + (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #tape (subList (bindingsBinds e0) subtape) &. #shbinds (bindingsBinds e0) &. #d2ace (d2ace (select SAccum descr)) &. #tl (sD1eEnv descr)) (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) - (#d :++: #shbinds :++: #d2ace :++: #tl)) + (#shbinds :++: #d :++: #d2ace :++: #tl)) e2') $ expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs new file mode 100644 index 0000000..9df5412 --- /dev/null +++ b/src/CHAD/Top.hs @@ -0,0 +1,53 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Top where + +import AST +import CHAD +import CHAD.Types +import Data + + +type family MergeEnv env where + MergeEnv '[] = '[] + MergeEnv (t : ts) = "merge" : MergeEnv ts + +mergeDescr :: SList STy env -> Descr env (MergeEnv env) +mergeDescr SNil = DTop +mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, SMerge) + +mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[] +mergeEnvNoAccum SNil = Refl +mergeEnvNoAccum (_ `SCons` env) | Refl <- mergeEnvNoAccum env = Refl + +mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env +mergeEnvOnlyMerge SNil = Refl +mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl + +d1Identity :: STy t -> D1 t :~: t +d1Identity = \case + STNil -> Refl + STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STMaybe t | Refl <- d1Identity t -> Refl + STArr _ t | Refl <- d1Identity t -> Refl + STScal _ -> Refl + STAccum{} -> error "Accumulators not allowed in input program" + +d1eIdentity :: SList STy env -> D1E env :~: env +d1eIdentity SNil = Refl +d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl + +chad :: SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) +chad env term + | Refl <- mergeEnvNoAccum env + , Refl <- mergeEnvOnlyMerge env + = freezeRet (mergeDescr env) (drev (mergeDescr env) term) + +chad' :: SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +chad' env term + | Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term) + = chad env term diff --git a/src/Example.hs b/src/Example.hs index d0405af..1775bb9 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -11,6 +11,7 @@ import Array import AST import AST.Pretty import CHAD +import CHAD.Top import Data import ForwardAD import Interpreter @@ -23,16 +24,6 @@ import Example.Format -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) -type family MergeEnv env where - MergeEnv '[] = '[] - MergeEnv (t : ts) = "merge" : MergeEnv ts - -mergeDescr :: KnownEnv env => Descr env (MergeEnv env) -mergeDescr = go knownEnv - where go :: SList STy env -> Descr env (MergeEnv env) - go SNil = DTop - go (t `SCons` env) = go env `DPush` (t, SMerge) - bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c bin op a b = EOp ext op (EPair ext a b) @@ -195,9 +186,8 @@ neuralGo = argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) revderiv = simplifyN 20 $ - freezeRet mergeDescr - (drev mergeDescr neural) - (EConst ext STF64 1.0) + ELet ext (EConst ext STF64 1.0) $ + chad knownEnv neural (primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen False argument revderiv (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0 in trace (formatter (ppExpr knownEnv revderiv)) $ diff --git a/test/Main.hs b/test/Main.hs index e7dda69..d3e55b3 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,19 +1,15 @@ {-# LANGUAGE DataKinds #-} --- {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} --- {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Main where import Data.Bifunctor --- import qualified Data.Dependent.Map as DMap --- import Data.Dependent.Map (DMap) import Data.List (intercalate) import Hedgehog import qualified Hedgehog.Gen as Gen @@ -23,7 +19,7 @@ import Hedgehog.Main import Array import AST import AST.Pretty -import CHAD +import CHAD.Top import CHAD.Types import Data import qualified Example @@ -34,63 +30,19 @@ import Language import Simplify -type family MapMerge env where - MapMerge '[] = '[] - MapMerge (t : ts) = "merge" : MapMerge ts - -mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[] -mapMergeNoAccum SNil = Refl -mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl - -mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env -mapMergeOnlyMerge SNil = Refl -mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl - -primalEnv :: SList STy env' -> SList STy (D1E env') -primalEnv SNil = SNil -primalEnv (t `SCons` env) = d1 t `SCons` primalEnv env - data SimplIters = SimplIters Int | SimplFix deriving (Show) -diffCHAD :: SimplIters -> SList STy env -> Ex env (TScal TF64) - -> Ex (D1E env) (TPair (TScal TF64) (Tup (D2E env))) -diffCHAD = \simplIters env term -> - case (mapMergeNoAccum env, mapMergeOnlyMerge env, envKnown (primalEnv env)) of - (Refl, Refl, Dict) -> - let descr = makeMergeDescr env - simpl = case simplIters of - SimplIters n -> simplifyN n - SimplFix -> simplifyFix - in simpl $ freezeRet descr (drev descr term) (EConst ext STF64 1.0) - where - makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env') - makeMergeDescr SNil = DTop - makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge) - -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env))) gradientByCHAD = \simplIters env term input -> - case (mapMergeNoAccum env, mapMergeOnlyMerge env) of - (Refl, Refl) -> - let dterm = diffCHAD simplIters env term - input1 = toPrimalE env input - (out, grad) = interpretOpen False input1 dterm - in (ppExpr (primalEnv env) dterm, (out, unTup vUnpair (d2e env) (Value grad))) - where - toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env') - toPrimalE SNil SNil = SNil - toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp - - toPrimal :: STy t -> Rep t -> Rep (D1 t) - toPrimal = \case - STNil -> id - STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2) - STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2) - STMaybe t -> fmap (toPrimal t) - STArr _ t -> fmap (toPrimal t) - STScal _ -> id - STAccum{} -> error "Accumulators not allowed in input program" + let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' env term + dterm | Dict <- envKnown env + = case simplIters of + SimplIters n -> simplifyN n dtermNonSimpl + SimplFix -> simplifyFix dtermNonSimpl + (out, grad) = interpretOpen False input dterm + in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env))) @@ -172,29 +124,6 @@ genEnv :: SList STy env -> Gen (SList Value env) genEnv SNil = return SNil genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env --- data TemplateVar n = TemplateVar (SNat n) String --- deriving (Show) - --- data Template t where --- TpShape :: TemplateVar n -> STy t -> Template (TArr n t) --- TpAny :: STy t -> Template t --- TpPair :: Template a -> Template b -> Template (TPair a b) --- deriving instance Show (Template t) - --- data ShapeConstraint n = ShapeAtLeast (Shape n) --- deriving (Show) - --- genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t) --- genTemplate = _ - --- genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env) --- genEnvTemplateExact shapes env = _ - --- genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env) --- genEnvTemplate constrs env = do --- shapes <- DMap.traverseWithKey _ constrs --- genEnvTemplateExact shapes env - showEnv :: SList STy env -> SList Value env -> String showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" where @@ -205,11 +134,6 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property adTest = flip adTestGen (genEnv (knownEnv @env)) --- adTestTp :: forall env. KnownEnv env --- => DMap TemplateVar ShapeConstraint -> SList Template env --- -> Ex env (TScal TF64) -> Property --- adTestTp envConstrs envTp = adTestGen (genEnvTemplate envConstrs envTp) - adTestGen :: forall env. KnownEnv env => Ex env (TScal TF64) -> Gen (SList Value env) -> Property adTestGen expr envGenerator = property $ do @@ -268,7 +192,6 @@ tests = checkSequential $ Group "AD" idx0 $ build SZ (shape #x) $ #idx :-> #x ! #idx) - -- :hindentstr ppExpr knownEnv $ diffCHAD 20 knownEnv term_build1_sum ,("build1-sum", adTest term_build1_sum) ,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $ |