summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-02-28 23:57:31 +0100
committerTom Smeding <tom@tomsmeding.com>2025-02-28 23:57:31 +0100
commit9eec3fb3ec727e61a34742be7672a4e281127576 (patch)
treecdfc2d9225077e082e18f1d1a00ea9e3ec2deca4
parentb3b7cebfac9d9c54a2e51152e60e04999a7683e3 (diff)
test: Simplify and make it a bit faster
-rw-r--r--chad-fast.cabal2
-rw-r--r--src/CHAD/Types/ToTan.hs42
-rw-r--r--src/Interpreter.hs2
-rw-r--r--src/Simplify.hs2
-rw-r--r--test/Main.hs148
5 files changed, 114 insertions, 82 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 4ee1c19..7a1c641 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -26,6 +26,7 @@ library
CHAD.EnvDescr
CHAD.Top
CHAD.Types
+ CHAD.Types.ToTan
Compile
Compile.Exec
Data
@@ -83,6 +84,7 @@ test-suite test
hedgehog,
tasty,
tasty-hedgehog,
+ text,
transformers,
hs-source-dirs: test
default-language: Haskell2010
diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs
new file mode 100644
index 0000000..a75fdb8
--- /dev/null
+++ b/src/CHAD/Types/ToTan.hs
@@ -0,0 +1,42 @@
+{-# LANGUAGE GADTs #-}
+module CHAD.Types.ToTan where
+
+import Data.Bifunctor (bimap)
+
+import Array
+import AST.Types
+import CHAD.Types
+import Data
+import ForwardAD
+import Interpreter.Rep
+
+
+toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
+toTanE SNil SNil SNil = SNil
+toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
+ Value (toTan t p x) `SCons` toTanE env primal inp
+
+toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
+toTan typ primal der = case typ of
+ STNil -> der
+ STPair t1 t2 -> case der of
+ Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
+ Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
+ STEither t1 t2 -> case der of
+ Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
+ Just d -> case (primal, d) of
+ (Left p, Left d') -> Left (toTan t1 p d')
+ (Right p, Right d') -> Right (toTan t2 p d')
+ _ -> error "Primal and cotangent disagree on Either alternative"
+ STMaybe t -> liftA2 (toTan t) primal der
+ STArr _ t
+ | shapeSize (arrayShape der) == 0 ->
+ arrayMap (zeroTan t) primal
+ | arrayShape primal == arrayShape der ->
+ arrayGenerateLin (arrayShape primal) $ \i ->
+ toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
+ | otherwise ->
+ error "Primal and cotangent disagree on array shape"
+ STScal sty -> case sty of
+ STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
+ STAccum{} -> error "Accumulators not allowed in input program"
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index deb829b..dd558fe 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -558,7 +558,7 @@ tupRepIdx :: (forall m. f (S m) -> (f m, Int))
tupRepIdx _ SZ _ = ()
tupRepIdx uncons (SS n) tup =
let (tup', i) = uncons tup
- in (tupRepIdx uncons n tup', fromIntegral @Int @Int64 i)
+ in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i
ixUncons :: Index (S n) -> (Index n, Int)
ixUncons (IxCons idx i) = (idx, i)
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 785e2bd..673b58c 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -7,7 +7,7 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-module Simplify where
+module Simplify (simplifyN, simplifyFix) where
import Data.Function (fix)
import Data.Monoid (Any(..))
diff --git a/test/Main.hs b/test/Main.hs
index de3d39e..9ab09c5 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -16,6 +16,7 @@ import Data.Bifunctor
import Data.Int (Int64)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
+import qualified Data.Text as T
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
@@ -28,6 +29,7 @@ import AST.Pretty
import AST.UnMonoid
import CHAD.Top
import CHAD.Types
+import CHAD.Types.ToTan
import qualified Example
import qualified Example.GMM as Example
import ForwardAD
@@ -40,50 +42,24 @@ import Simplify
data SimplIters = SimplIters Int | SimplFix
deriving (Show)
+simplifyWith :: SimplIters -> SList STy env -> Ex env t -> Ex env t
+simplifyWith iters env | Dict <- envKnown env =
+ case iters of
+ SimplIters n -> simplifyN n
+ SimplFix -> simplifyFix
+
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
-gradientByCHAD = \simplIters env term input ->
- let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
- dterm | Dict <- envKnown env
- = case simplIters of
- SimplIters n -> simplifyN n dtermNonSimpl
- SimplFix -> simplifyFix dtermNonSimpl
+gradientByCHAD simplIters env term input =
+ let dterm = simplifyWith simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
(out, grad) = interpretOpen False input dterm
in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
-- In addition to the gradient, also returns the pretty-printed differentiated term.
gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env)))
-gradientByCHAD' = \simplIters env term input -> second (second (toTanE env input)) $ gradientByCHAD simplIters env term input
- where
- toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
- toTanE SNil SNil SNil = SNil
- toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
- Value (toTan t p x) `SCons` toTanE env primal inp
-
- toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
- toTan typ primal der = case typ of
- STNil -> der
- STPair t1 t2 -> case der of
- Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
- Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
- STEither t1 t2 -> case der of
- Nothing -> bimap (zeroTan t1) (zeroTan t2) primal
- Just d -> case (primal, d) of
- (Left p, Left d') -> Left (toTan t1 p d')
- (Right p, Right d') -> Right (toTan t2 p d')
- _ -> error "Primal and cotangent disagree on Either alternative"
- STMaybe t -> liftA2 (toTan t) primal der
- STArr _ t
- | shapeSize (arrayShape der) == 0 ->
- arrayMap (zeroTan t) primal
- | arrayShape primal == arrayShape der ->
- arrayGenerateLin (arrayShape primal) $ \i ->
- toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
- | otherwise ->
- error "Primal and cotangent disagree on array shape"
- STScal sty -> case sty of
- STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
- STAccum{} -> error "Accumulators not allowed in input program"
+gradientByCHAD' simplIters env term input =
+ second (second (toTanE env input)) $
+ gradientByCHAD simplIters env term input
gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
gradientByForward env term input = drevByFwd env term input 1.0
@@ -188,40 +164,52 @@ genEnv SNil () = return SNil
genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil
genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl
-adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property
-adTest = adTestCon (const True)
+adTest :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> TestTree
+adTest name = adTestCon name (const True)
-adTestCon :: forall env. KnownEnv env => (SList Value env -> Bool) -> Ex env (TScal TF64) -> Property
-adTestCon constr term =
+adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env (TScal TF64) -> TestTree
+adTestCon name constr term =
let env = knownEnv
- in adTestGen term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty))
+ in adTestGen name term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty))
adTestTp :: forall env. KnownEnv env
- => TemplateE env -> Ex env (TScal TF64) -> Property
-adTestTp tmpl term = adTestGen term (evalStateT (genEnv knownEnv tmpl) mempty)
+ => TestName -> TemplateE env -> Ex env (TScal TF64) -> TestTree
+adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) mempty)
adTestGen :: forall env. KnownEnv env
- => Ex env (TScal TF64) -> Gen (SList Value env) -> Property
-adTestGen expr envGenerator = property $ do
+ => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree
+adTestGen name expr envGenerator = testProperty name $ property $ do
let env = knownEnv @env
+
+ annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
+
+ let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
+ dtermChadS = simplifyFix dtermChad0
+ dtermChadS20 = simplifyN 20 dtermChad0
+
+ -- pack Text for less GC pressure (these values are retained for some reason)
+ diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChadS20))
+
input <- forAllWith (showEnv env) envGenerator
+
+ let convGrad :: Rep (Tup (D2E env)) -> SList Value (TanE env)
+ convGrad = toTanE env input . unTup vUnpair (d2e env) . Value
+
let outPrimal = interpretOpen False input expr
+ (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0
+ (outChadS, gradChadS) = second convGrad $ interpretOpen False input dtermChadS
+ scChad = envScalars env gradChad0
+ scChadS = envScalars env gradChadS
gradFwd = gradientByForward knownEnv expr input
- (_ppdterm, (outChad, gradCHAD)) = gradientByCHAD' (SimplIters 0) knownEnv expr input
- (ppdterm_S, (outChad_S, gradCHAD_S)) = gradientByCHAD' SimplFix knownEnv expr input
- (ppdterm_S20, _) = gradientByCHAD' (SimplIters 20) knownEnv expr input
scFwd = envScalars env gradFwd
- scCHAD = envScalars env gradCHAD
- scCHAD_S = envScalars env gradCHAD_S
- annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
+
-- annotate (ppExpr knownEnv expr)
- -- annotate ppdterm
- -- annotate ppdterm_S
- diff ppdterm_S (==) ppdterm_S20
- diff outChad_S closeIsh outChad
- diff outChad_S closeIsh outPrimal
- diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scCHAD
- diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scFwd
+ -- annotate (ppExpr env dtermChad0)
+ -- annotate (ppExpr env dtermChadS)
+ diff outChadS closeIsh outChad0
+ diff outChadS closeIsh outPrimal
+ diff scChadS (\x y -> and (zipWith closeIsh x y)) scChad
+ diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd
where
envScalars :: SList STy env' -> SList Value (TanE env') -> [Double]
envScalars SNil SNil = []
@@ -249,54 +237,54 @@ term_sparse = fromNamed $ lambda #inp $ body $
tests :: TestTree
tests = testGroup "AD"
- [testProperty "id" $ adTest $ fromNamed $ lambda #x $ body $ #x
+ [adTest "id" $ fromNamed $ lambda #x $ body $ #x
- ,testProperty "idx0" $ adTest $ fromNamed $ lambda #x $ body $ idx0 #x
+ ,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x
- ,testProperty "sum-vec" $ adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x)
+ ,adTest "sum-vec" $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x)
- ,testProperty "sum-replicate" $ adTest $ fromNamed $ lambda #x $ body $
+ ,adTest "sum-replicate" $ fromNamed $ lambda #x $ body $
idx0 $ sum1i $ replicate1i 10 #x
- ,testProperty "pairs" $ adTest term_pairs
+ ,adTest "pairs" term_pairs
- ,testProperty "build0 const" $ adTest $ fromNamed $ lambda @(TScal TF64) #x $ body $
+ ,adTest "build0 const" $ fromNamed $ lambda @(TScal TF64) #x $ body $
idx0 $ build SZ nil $ #idx :-> const_ 0.0
- ,testProperty "build0" $ adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $
+ ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $
idx0 $
build SZ (shape #x) $ #idx :-> #x ! #idx
- ,testProperty "build1-sum" $ adTest term_build1_sum
+ ,adTest "build1-sum" term_build1_sum
- ,testProperty "build2-sum" $ adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $
+ ,adTest "build2-sum" $ fromNamed $ lambda @(TArr N2 _) #x $ body $
idx0 $ sum1i . sum1i $
build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx
- ,testProperty "maximum" $ adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
+ ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
idx0 $ sum1i $ maximum1i #x
- ,testProperty "minimum" $ adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
+ ,adTestCon "minimum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
idx0 $ sum1i $ minimum1i #x
- ,testProperty "unused" $ adTest $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $
+ ,adTest "unused" $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $
let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
42
- ,testProperty "sparse" $ adTestTp (C "" 5) term_sparse
+ ,adTestTp "sparse" (C "" 5) term_sparse
- ,testProperty "neural" $ adTestGen Example.neural genNeural
+ ,adTestGen "neural" Example.neural genNeural
- ,testProperty "neural-unMonoid" $ adTestGen (unMonoid (simplifyFix Example.neural)) genNeural
+ ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) genNeural
- ,testProperty "logsumexp" $ adTestTp (C "" 1) $
+ ,adTestTp "logsumexp" (C "" 1) $
fromNamed $ lambda @(TArr N1 _) #vec $ body $
let_ #m (maximum1i #vec) $
log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m
- ,testProperty "mulmatvec" $ adTestTp ((C "" 0 :$ C "n" 0) :& C "n" 0) $
+ ,adTestTp "mulmatvec" ((C "" 0 :$ C "n" 0) :& C "n" 0) $
fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $
idx0 $ sum1i $
let_ #hei (snd_ (fst_ (shape #mat))) $
@@ -305,13 +293,13 @@ tests = testGroup "AD"
idx0 (sum1i (build1 #wid $ #j :->
#mat ! pair (pair nil #i) #j * #vec ! pair nil #j))
- ,testProperty "gmm-wrong" $ withShrinks 0 $ adTestGen (Example.gmmObjective True) genGMM
+ ,adTestGen "gmm-wrong" (Example.gmmObjective True) genGMM
- ,testProperty "gmm-wrong-unMonoid" $ withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM
+ ,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM
- ,testProperty "gmm" $ withShrinks 0 $ adTestGen (Example.gmmObjective False) genGMM
+ ,adTestGen "gmm" (Example.gmmObjective False) genGMM
- ,testProperty "gmm-unMonoid" $ withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM
+ ,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM
]
where
genGMM = do