aboutsummaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs209
1 files changed, 133 insertions, 76 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 0a57cbf..05597cc 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
@@ -11,35 +12,38 @@
{-# LANGUAGE UndecidableInstances #-}
module Main where
+import Control.Monad (when)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Int (Int64)
import Data.Map.Strict (Map)
-import qualified Data.Map.Strict as Map
-import qualified Data.Text as T
+import Data.Map.Strict qualified as Map
+import Data.Text qualified as T
import Hedgehog
-import qualified Hedgehog.Gen as Gen
-import qualified Hedgehog.Range as Range
+import Hedgehog.Gen qualified as Gen
+import Hedgehog.Range qualified as Range
import Test.Framework
-import Array
-import AST hiding ((.>))
-import AST.Pretty
-import AST.UnMonoid
-import CHAD.Top
-import CHAD.Types
-import CHAD.Types.ToTan
-import Compile
-import qualified Example
-import qualified Example.GMM as Example
-import Example.Types
-import ForwardAD
-import ForwardAD.DualNumbers
-import Interpreter
-import Interpreter.Rep
-import Language
-import Simplify
+import CHAD.Array
+import CHAD.AST hiding ((.>))
+import CHAD.AST.Count (pruneExpr)
+import CHAD.AST.Pretty
+import CHAD.AST.UnMonoid
+import CHAD.Compile
+import CHAD.Data
+import CHAD.Drev.Top
+import CHAD.Drev.Types
+import CHAD.Drev.Types.ToTan
+import CHAD.Example qualified as Example
+import CHAD.Example.GMM qualified as Example
+import CHAD.Example.Types
+import CHAD.ForwardAD
+import CHAD.ForwardAD.DualNumbers
+import CHAD.Interpreter
+import CHAD.Interpreter.Rep
+import CHAD.Language
+import CHAD.Simplify
data TypedValue t = TypedValue (STy t) (Rep t)
@@ -63,18 +67,18 @@ simplifyIters iters env | Dict <- envKnown env =
SimplIters n -> simplifyN n
SimplFix -> simplifyFix
--- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env)))
-gradientByCHAD simplIters env term input =
- let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
- (out, grad) = interpretOpen False env input dterm
- in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
+-- -- In addition to the gradient, also returns the pretty-printed differentiated term.
+-- gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env)))
+-- gradientByCHAD simplIters env term input =
+-- let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
+-- (out, grad) = interpretOpen False env input dterm
+-- in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
--- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env)))
-gradientByCHAD' simplIters env term input =
- second (second (toTanE env input)) $
- gradientByCHAD simplIters env term input
+-- -- In addition to the gradient, also returns the pretty-printed differentiated term.
+-- gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env)))
+-- gradientByCHAD' simplIters env term input =
+-- second (second (toTanE env input)) $
+-- gradientByCHAD simplIters env term input
gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env)
gradientByForward art input = drevByFwd art input 1.0
@@ -92,8 +96,8 @@ extendDN (STMaybe _) Nothing = pure Nothing
extendDN (STMaybe t) (Just x) = Just <$> extendDN t x
extendDN (STArr _ t) arr = traverse (extendDN t) arr
extendDN (STScal sty) x = case sty of
- STF32 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d)
- STF64 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d)
+ STF32 -> Gen.realFloat (Range.constant (-1) 1) >>= \d -> pure (x, d)
+ STF64 -> Gen.realFloat (Range.constant (-1) 1) >>= \d -> pure (x, d)
STI32 -> pure x
STI64 -> pure x
STBool -> pure x
@@ -217,8 +221,8 @@ genShape = \n tpl -> do
shapeDiv :: Shape n -> DimNames n -> Int -> Shape n
shapeDiv ShNil _ _ = ShNil
- shapeDiv (ShNil `ShCons` n) ( C _ lo) f = ShNil `ShCons` (max lo (n `div` f))
- shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` (max lo (n `div` f))
+ shapeDiv (ShNil `ShCons` n) ( C _ lo) f = ShNil `ShCons` max lo (n `div` f)
+ shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` max lo (n `div` f)
shapeDiv (ShNil `ShCons` n) NC f = ShNil `ShCons` (n `div` f)
shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ NC) f = shapeDiv sh tpl f `ShCons` (n `div` f)
@@ -240,8 +244,8 @@ genValue topty tpl = case topty of
,liftV Just <$> genValue t (emptyTpl t)]
STArr n t -> genShape n tpl >>= lift . genArray t
STScal sty -> case sty of
- STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
- STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
+ STF32 -> Value <$> Gen.realFloat (Range.constant (-10) 10)
+ STF64 -> Value <$> Gen.realFloat (Range.constant (-10) 10)
STI32 -> genInt
STI64 -> genInt
STBool -> Gen.choice [return (Value False), return (Value True)]
@@ -302,7 +306,7 @@ adTestGen name expr envGenerator =
exprS = simplifyFix expr
in withCompiled env expr $ \primalfun ->
withCompiled env (simplifyFix expr) $ \primalSfun ->
- testGroupCollapse name
+ groupSetCollapse $ testGroup name
[adTestGenPrimal env envGenerator expr exprS primalfun primalSfun
,adTestGenFwd env envGenerator exprS
,testGroup "chad"
@@ -349,17 +353,23 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e
dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS
dtermSChadS = simplifyFix dtermSChad0
dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS
+ dtermSChadSUSP = simplifyFix $ pruneExpr env dtermSChadSUS
in
- withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
- withCompiled env dtermSChadSUS $ \dcompSChadSUS ->
+ withResource' (do (fun, output) <- makeFwdADArtifactCompile env exprS
+ when (not (null output)) $
+ outputWarningText $ "Forward AD compile GCC output: <<<\n" ++ output ++ ">>>"
+ return fun) $ \fwdartifactC ->
+ withCompiled env dtermSChadSUSP $ \dcompSChadSUSP ->
testProperty testname $ property $ do
annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
-- check simplifier convergence; pack Text for less GC pressure (these values are retained for some reason)
- diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
- diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermChad0)))
- diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
- diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermSChad0)))
+ let dtermChad20 = simplifyN 20 dtermChad0
+ diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env dtermChad20))
+ diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermChad20)))
+ let dtermSChad20 = simplifyN 20 dtermSChad0
+ diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env dtermSChad20))
+ diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid dtermSChad20)))
input <- forAllWith (showEnv env) envGenerator
outPrimal <- evalIO $ primalSfun input
@@ -369,46 +379,54 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e
let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input
- let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0
- (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS
- (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS
- (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0
- (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS
- (outSChadSUS, gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS
- tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0
- tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS
- tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS
- tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0
- tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS
- tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS
+ let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0
+ (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS
+ (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS
+ (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0
+ (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS
+ (outSChadSUS , gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS
+ (outSChadSUSP, gradSChadSUSP) = second unpackGrad $ interpretOpen False env input dtermSChadSUSP
+ tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0
+ tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS
+ tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS
+ tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0
+ tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS
+ tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS
+ tansSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradSChadSUSP
- (outCompSChadSUS, gradCompSChadSUS) <- second unpackGrad <$> evalIO (dcompSChadSUS input)
- let tansCompSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUS
+ (outCompSChadSUSP, gradCompSChadSUSP) <- second unpackGrad <$> evalIO (dcompSChadSUSP input)
+ let tansCompSChadSUSP = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUSP
-- annotate (showEnv (d2e env) gradChad0)
-- annotate (showEnv (d2e env) gradChadS)
-- annotate (ppExpr knownEnv expr)
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
- annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS)))
- diff outChad0 closeIsh outPrimal
- diff outChadS closeIsh outPrimal
- diff outChadSUS closeIsh outPrimal
- diff outSChad0 closeIsh outPrimal
- diff outSChadS closeIsh outPrimal
- diff outSChadSUS closeIsh outPrimal
- diff outCompSChadSUS closeIsh outPrimal
+ annotate (ppExpr env dtermSChadSUSP)
+ diff outChad0 closeIsh outPrimal
+ diff outChadS closeIsh outPrimal
+ diff outChadSUS closeIsh outPrimal
+ diff outSChad0 closeIsh outPrimal
+ diff outSChadS closeIsh outPrimal
+ diff outSChadSUS closeIsh outPrimal
+ diff outSChadSUSP closeIsh outPrimal
+ diff outCompSChadSUSP closeIsh outPrimal
let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2)
- diff tansChad closeIshE' tansFwd
- diff tansChadS closeIshE' tansFwd
- diff tansChadSUS closeIshE' tansFwd
- diff tansSChad closeIshE' tansFwd
- diff tansSChadS closeIshE' tansFwd
- diff tansSChadSUS closeIshE' tansFwd
- diff tansCompSChadSUS closeIshE' tansFwd
+ diff tansChad closeIshE' tansFwd
+ diff tansChadS closeIshE' tansFwd
+ diff tansChadSUS closeIshE' tansFwd
+ diff tansSChad closeIshE' tansFwd
+ diff tansSChadS closeIshE' tansFwd
+ diff tansSChadSUS closeIshE' tansFwd
+ diff tansSChadSUSP closeIshE' tansFwd
+ diff tansCompSChadSUSP closeIshE' tansFwd
withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
-withCompiled env expr = withResource (compile env expr) (\_ -> pure ())
+withCompiled env expr = withResource' $ do
+ (fun, output) <- compile env expr
+ when (not (null output)) $
+ outputWarningText $ "Kernel compilation GCC output: <<<\n" ++ output ++ ">>>"
+ return fun
gen_gmm :: Gen (SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64])
gen_gmm = do
@@ -423,7 +441,7 @@ gen_gmm = do
vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD)
- vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
+ vgamma <- Gen.realFloat (Range.constant (-10) 10)
vm <- Gen.integral (Range.linear 0 5)
let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
k2 = 0.5 * vgamma * vgamma
@@ -554,6 +572,13 @@ tests_Compile = testGroup "Compile"
nil) $
let_ #_ (accum SAPHere nil #x #ac) $
nil
+
+ ,compileTest "foldd1" $ fromNamed $ lambda @(TVec R) #a $ body $
+ fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a
+
+ ,compileTest "fold-manual" $ fromNamed $ lambda @(TVec R) #a $ lambda #d $ body $
+ let_ #pr (fold1iD1 (#x :-> #y :-> pair (#x * #y) (pair #x #y)) 1 #a) $
+ fold1iD2 (#tape :-> #ctg :-> pair (snd_ #tape * #ctg) (fst_ #tape * #ctg)) (snd_ #pr) #d
]
tests_AD :: TestTree
@@ -659,6 +684,38 @@ tests_AD = testGroup "AD"
,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm
,adTestGen "gmm" (Example.gmmObjective False) gen_gmm
+
+ ,adTestTp "uniform-free" (C "" 0 :& ()) Example.exUniformFree
+
+ ,adTest "reshape1" $ fromNamed $ lambda @(TMat R) #a $ body $
+ let_ #sh (shape #a) $
+ let_ #n (snd_ #sh * snd_ (fst_ #sh)) $
+ idx0 $ sum1i $ reshape (SS SZ) (pair nil #n) #a
+
+ ,adTestTp "reshape2" (C "" 1 :$ NC) $ fromNamed $ lambda @(TMat R) #a $ body $
+ let_ #sh (shape #a) $
+ let_ #innern (snd_ #sh) $
+ let_ #n (#innern * snd_ (fst_ #sh)) $
+ let_ #flata (reshape (SS SZ) (pair nil #n) #a) $
+ -- ensure the input array to EReshape is shared
+ idx0 $ sum1i $
+ build1 #n (#i :-> #flata ! pair nil #i + #a ! pair (pair nil 0) (#i `mod_` #innern))
+
+ ,adTest "fold-sum" $ fromNamed $ lambda @(TArr N1 R) #a $ body $
+ idx0 $ fold1i (#x :-> #y :-> #x + #y) 0 #a
+
+ ,adTest "fold-prod" $ fromNamed $ lambda @(TArr N1 R) #a $ body $
+ idx0 $ fold1i (#x :-> #y :-> #x * #y) 1 #a
+
+ ,adTest "fold-freevar" $ fromNamed $ lambda @(TArr N1 R) #a $ body $
+ let_ #v 2 $
+ idx0 $ fold1i (#x :-> #y :-> #x * #y + #v) 1 #a
+
+ ,adTestTp "fold-freearr" (C "" 1) $ fromNamed $ lambda @(TArr N1 R) #a $ body $
+ idx0 $ fold1i (#x :-> #y :-> #x * #y + #a ! pair nil 0) 1 #a
+
+ ,adTest "map" $ fromNamed $ lambda @(TArr N1 R) #a $ body $
+ idx0 $ sum1i $ map_ (#x :-> 2 * #x) #a
]
main :: IO ()