diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-02-28 23:57:31 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-02-28 23:57:31 +0100 |
commit | 9eec3fb3ec727e61a34742be7672a4e281127576 (patch) | |
tree | cdfc2d9225077e082e18f1d1a00ea9e3ec2deca4 | |
parent | b3b7cebfac9d9c54a2e51152e60e04999a7683e3 (diff) |
test: Simplify and make it a bit faster
-rw-r--r-- | chad-fast.cabal | 2 | ||||
-rw-r--r-- | src/CHAD/Types/ToTan.hs | 42 | ||||
-rw-r--r-- | src/Interpreter.hs | 2 | ||||
-rw-r--r-- | src/Simplify.hs | 2 | ||||
-rw-r--r-- | test/Main.hs | 148 |
5 files changed, 114 insertions, 82 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 4ee1c19..7a1c641 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -26,6 +26,7 @@ library CHAD.EnvDescr CHAD.Top CHAD.Types + CHAD.Types.ToTan Compile Compile.Exec Data @@ -83,6 +84,7 @@ test-suite test hedgehog, tasty, tasty-hedgehog, + text, transformers, hs-source-dirs: test default-language: Haskell2010 diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs new file mode 100644 index 0000000..a75fdb8 --- /dev/null +++ b/src/CHAD/Types/ToTan.hs @@ -0,0 +1,42 @@ +{-# LANGUAGE GADTs #-} +module CHAD.Types.ToTan where + +import Data.Bifunctor (bimap) + +import Array +import AST.Types +import CHAD.Types +import Data +import ForwardAD +import Interpreter.Rep + + +toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) +toTanE SNil SNil SNil = SNil +toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = + Value (toTan t p x) `SCons` toTanE env primal inp + +toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) +toTan typ primal der = case typ of + STNil -> der + STPair t1 t2 -> case der of + Nothing -> bimap (zeroTan t1) (zeroTan t2) primal + Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal + STEither t1 t2 -> case der of + Nothing -> bimap (zeroTan t1) (zeroTan t2) primal + Just d -> case (primal, d) of + (Left p, Left d') -> Left (toTan t1 p d') + (Right p, Right d') -> Right (toTan t2 p d') + _ -> error "Primal and cotangent disagree on Either alternative" + STMaybe t -> liftA2 (toTan t) primal der + STArr _ t + | shapeSize (arrayShape der) == 0 -> + arrayMap (zeroTan t) primal + | arrayShape primal == arrayShape der -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) + | otherwise -> + error "Primal and cotangent disagree on array shape" + STScal sty -> case sty of + STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der + STAccum{} -> error "Accumulators not allowed in input program" diff --git a/src/Interpreter.hs b/src/Interpreter.hs index deb829b..dd558fe 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -558,7 +558,7 @@ tupRepIdx :: (forall m. f (S m) -> (f m, Int)) tupRepIdx _ SZ _ = () tupRepIdx uncons (SS n) tup = let (tup', i) = uncons tup - in (tupRepIdx uncons n tup', fromIntegral @Int @Int64 i) + in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i ixUncons :: Index (S n) -> (Index n, Int) ixUncons (IxCons idx i) = (idx, i) diff --git a/src/Simplify.hs b/src/Simplify.hs index 785e2bd..673b58c 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -7,7 +7,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -module Simplify where +module Simplify (simplifyN, simplifyFix) where import Data.Function (fix) import Data.Monoid (Any(..)) diff --git a/test/Main.hs b/test/Main.hs index de3d39e..9ab09c5 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -16,6 +16,7 @@ import Data.Bifunctor import Data.Int (Int64) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map +import qualified Data.Text as T import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range @@ -28,6 +29,7 @@ import AST.Pretty import AST.UnMonoid import CHAD.Top import CHAD.Types +import CHAD.Types.ToTan import qualified Example import qualified Example.GMM as Example import ForwardAD @@ -40,50 +42,24 @@ import Simplify data SimplIters = SimplIters Int | SimplFix deriving (Show) +simplifyWith :: SimplIters -> SList STy env -> Ex env t -> Ex env t +simplifyWith iters env | Dict <- envKnown env = + case iters of + SimplIters n -> simplifyN n + SimplFix -> simplifyFix + -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env))) -gradientByCHAD = \simplIters env term input -> - let dtermNonSimpl = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term - dterm | Dict <- envKnown env - = case simplIters of - SimplIters n -> simplifyN n dtermNonSimpl - SimplFix -> simplifyFix dtermNonSimpl +gradientByCHAD simplIters env term input = + let dterm = simplifyWith simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term (out, grad) = interpretOpen False input dterm in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env))) -gradientByCHAD' = \simplIters env term input -> second (second (toTanE env input)) $ gradientByCHAD simplIters env term input - where - toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) - toTanE SNil SNil SNil = SNil - toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = - Value (toTan t p x) `SCons` toTanE env primal inp - - toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) - toTan typ primal der = case typ of - STNil -> der - STPair t1 t2 -> case der of - Nothing -> bimap (zeroTan t1) (zeroTan t2) primal - Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal - STEither t1 t2 -> case der of - Nothing -> bimap (zeroTan t1) (zeroTan t2) primal - Just d -> case (primal, d) of - (Left p, Left d') -> Left (toTan t1 p d') - (Right p, Right d') -> Right (toTan t2 p d') - _ -> error "Primal and cotangent disagree on Either alternative" - STMaybe t -> liftA2 (toTan t) primal der - STArr _ t - | shapeSize (arrayShape der) == 0 -> - arrayMap (zeroTan t) primal - | arrayShape primal == arrayShape der -> - arrayGenerateLin (arrayShape primal) $ \i -> - toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) - | otherwise -> - error "Primal and cotangent disagree on array shape" - STScal sty -> case sty of - STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der - STAccum{} -> error "Accumulators not allowed in input program" +gradientByCHAD' simplIters env term input = + second (second (toTanE env input)) $ + gradientByCHAD simplIters env term input gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) gradientByForward env term input = drevByFwd env term input 1.0 @@ -188,40 +164,52 @@ genEnv SNil () = return SNil genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl -adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property -adTest = adTestCon (const True) +adTest :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> TestTree +adTest name = adTestCon name (const True) -adTestCon :: forall env. KnownEnv env => (SList Value env -> Bool) -> Ex env (TScal TF64) -> Property -adTestCon constr term = +adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env (TScal TF64) -> TestTree +adTestCon name constr term = let env = knownEnv - in adTestGen term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty)) + in adTestGen name term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty)) adTestTp :: forall env. KnownEnv env - => TemplateE env -> Ex env (TScal TF64) -> Property -adTestTp tmpl term = adTestGen term (evalStateT (genEnv knownEnv tmpl) mempty) + => TestName -> TemplateE env -> Ex env (TScal TF64) -> TestTree +adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) mempty) adTestGen :: forall env. KnownEnv env - => Ex env (TScal TF64) -> Gen (SList Value env) -> Property -adTestGen expr envGenerator = property $ do + => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree +adTestGen name expr envGenerator = testProperty name $ property $ do let env = knownEnv @env + + annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) + + let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr + dtermChadS = simplifyFix dtermChad0 + dtermChadS20 = simplifyN 20 dtermChad0 + + -- pack Text for less GC pressure (these values are retained for some reason) + diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChadS20)) + input <- forAllWith (showEnv env) envGenerator + + let convGrad :: Rep (Tup (D2E env)) -> SList Value (TanE env) + convGrad = toTanE env input . unTup vUnpair (d2e env) . Value + let outPrimal = interpretOpen False input expr + (outChad0, gradChad0) = second convGrad $ interpretOpen False input dtermChad0 + (outChadS, gradChadS) = second convGrad $ interpretOpen False input dtermChadS + scChad = envScalars env gradChad0 + scChadS = envScalars env gradChadS gradFwd = gradientByForward knownEnv expr input - (_ppdterm, (outChad, gradCHAD)) = gradientByCHAD' (SimplIters 0) knownEnv expr input - (ppdterm_S, (outChad_S, gradCHAD_S)) = gradientByCHAD' SimplFix knownEnv expr input - (ppdterm_S20, _) = gradientByCHAD' (SimplIters 20) knownEnv expr input scFwd = envScalars env gradFwd - scCHAD = envScalars env gradCHAD - scCHAD_S = envScalars env gradCHAD_S - annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) + -- annotate (ppExpr knownEnv expr) - -- annotate ppdterm - -- annotate ppdterm_S - diff ppdterm_S (==) ppdterm_S20 - diff outChad_S closeIsh outChad - diff outChad_S closeIsh outPrimal - diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scCHAD - diff scCHAD_S (\x y -> and (zipWith closeIsh x y)) scFwd + -- annotate (ppExpr env dtermChad0) + -- annotate (ppExpr env dtermChadS) + diff outChadS closeIsh outChad0 + diff outChadS closeIsh outPrimal + diff scChadS (\x y -> and (zipWith closeIsh x y)) scChad + diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd where envScalars :: SList STy env' -> SList Value (TanE env') -> [Double] envScalars SNil SNil = [] @@ -249,54 +237,54 @@ term_sparse = fromNamed $ lambda #inp $ body $ tests :: TestTree tests = testGroup "AD" - [testProperty "id" $ adTest $ fromNamed $ lambda #x $ body $ #x + [adTest "id" $ fromNamed $ lambda #x $ body $ #x - ,testProperty "idx0" $ adTest $ fromNamed $ lambda #x $ body $ idx0 #x + ,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x - ,testProperty "sum-vec" $ adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x) + ,adTest "sum-vec" $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x) - ,testProperty "sum-replicate" $ adTest $ fromNamed $ lambda #x $ body $ + ,adTest "sum-replicate" $ fromNamed $ lambda #x $ body $ idx0 $ sum1i $ replicate1i 10 #x - ,testProperty "pairs" $ adTest term_pairs + ,adTest "pairs" term_pairs - ,testProperty "build0 const" $ adTest $ fromNamed $ lambda @(TScal TF64) #x $ body $ + ,adTest "build0 const" $ fromNamed $ lambda @(TScal TF64) #x $ body $ idx0 $ build SZ nil $ #idx :-> const_ 0.0 - ,testProperty "build0" $ adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $ + ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $ idx0 $ build SZ (shape #x) $ #idx :-> #x ! #idx - ,testProperty "build1-sum" $ adTest term_build1_sum + ,adTest "build1-sum" term_build1_sum - ,testProperty "build2-sum" $ adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $ + ,adTest "build2-sum" $ fromNamed $ lambda @(TArr N2 _) #x $ body $ idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx - ,testProperty "maximum" $ adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ + ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ idx0 $ sum1i $ maximum1i #x - ,testProperty "minimum" $ adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ + ,adTestCon "minimum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ idx0 $ sum1i $ minimum1i #x - ,testProperty "unused" $ adTest $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ + ,adTest "unused" $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $ let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $ 42 - ,testProperty "sparse" $ adTestTp (C "" 5) term_sparse + ,adTestTp "sparse" (C "" 5) term_sparse - ,testProperty "neural" $ adTestGen Example.neural genNeural + ,adTestGen "neural" Example.neural genNeural - ,testProperty "neural-unMonoid" $ adTestGen (unMonoid (simplifyFix Example.neural)) genNeural + ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) genNeural - ,testProperty "logsumexp" $ adTestTp (C "" 1) $ + ,adTestTp "logsumexp" (C "" 1) $ fromNamed $ lambda @(TArr N1 _) #vec $ body $ let_ #m (maximum1i #vec) $ log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m - ,testProperty "mulmatvec" $ adTestTp ((C "" 0 :$ C "n" 0) :& C "n" 0) $ + ,adTestTp "mulmatvec" ((C "" 0 :$ C "n" 0) :& C "n" 0) $ fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $ idx0 $ sum1i $ let_ #hei (snd_ (fst_ (shape #mat))) $ @@ -305,13 +293,13 @@ tests = testGroup "AD" idx0 (sum1i (build1 #wid $ #j :-> #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) - ,testProperty "gmm-wrong" $ withShrinks 0 $ adTestGen (Example.gmmObjective True) genGMM + ,adTestGen "gmm-wrong" (Example.gmmObjective True) genGMM - ,testProperty "gmm-wrong-unMonoid" $ withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM + ,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM - ,testProperty "gmm" $ withShrinks 0 $ adTestGen (Example.gmmObjective False) genGMM + ,adTestGen "gmm" (Example.gmmObjective False) genGMM - ,testProperty "gmm-unMonoid" $ withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM + ,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM ] where genGMM = do |