summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-26 23:47:59 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-26 23:47:59 +0100
commit6e85d5b2aee0cf2c089538e74261f1d88d6b1b71 (patch)
tree1ae5699e3b420e7a61203dfb697ad17eba38f683 /test
parent16a836d078caefc3526031c084e2527cba0da3a8 (diff)
test: Start of a list of compile tests
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs66
1 files changed, 57 insertions, 9 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 19bd3e6..014ad43 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -93,6 +93,28 @@ closeIsh' h a b =
closeIsh :: Double -> Double -> Bool
closeIsh = closeIsh' 1e-5
+closeIshT' :: Double -> STy t -> Rep t -> Rep t -> Bool
+closeIshT' _ STNil () () = True
+closeIshT' h (STPair a b) (x, y) (x', y') = closeIshT' h a x x' && closeIshT' h b y y'
+closeIshT' h (STEither a _) (Left x) (Left x') = closeIshT' h a x x'
+closeIshT' h (STEither _ b) (Right x) (Right x') = closeIshT' h b x x'
+closeIshT' _ STEither{} _ _ = False
+closeIshT' _ (STMaybe _) Nothing Nothing = True
+closeIshT' h (STMaybe a) (Just x) (Just x') = closeIshT' h a x x'
+closeIshT' _ STMaybe{} _ _ = False
+closeIshT' h (STArr _ a) arr1 arr2 =
+ arrayShape arr1 == arrayShape arr2 &&
+ and (zipWith (closeIshT' h a) (arrayToList arr1) (arrayToList arr2))
+closeIshT' _ (STScal STI32) i j = i == j
+closeIshT' _ (STScal STI64) i j = i == j
+closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y)
+closeIshT' h (STScal STF64) x y = closeIsh' h x y
+closeIshT' _ (STScal STBool) x y = x == y
+closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators"
+
+closeIshT :: STy t -> Rep t -> Rep t -> Bool
+closeIshT = closeIshT' 1e-5
+
data a :$ b = a :$ b deriving (Show) ; infixl :$
-- An empty name means "no restrictions".
@@ -189,6 +211,22 @@ genEnv SNil () = return SNil
genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil
genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl
+data TypedValue t = TypedValue (STy t) (Rep t)
+instance Show (TypedValue t) where
+ showsPrec d (TypedValue t x) = showValue d t x
+
+compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree
+compileTest name expr =
+ let env = knownEnv
+ t = typeOf expr
+ in withCompiled env expr $ \fun ->
+ testProperty name $ property $ do
+ input <- forAllWith (showEnv env) (evalStateT (genEnv env (emptyTemplateE env)) mempty)
+ let resI = interpretOpen False input expr
+ resC <- liftIO $ fun input
+ let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y
+ diff (TypedValue t resI) cmp (TypedValue t resC)
+
adTest :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> TestTree
adTest name = adTestCon name (const True)
@@ -247,16 +285,16 @@ adTestGenChad :: forall env. SList STy env -> Gen (SList Value env)
-> (SList Value env -> IO Double)
-> TestTree
adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
+ let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
+ dtermChadS = simplifyFix dtermChad0
+ dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS
+ dtermSChadS = simplifyFix dtermSChad0
+ in
withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
- withCompiled env (simplifyFix (unMonoid (simplifyFix (ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS)))) $ \dcompSChadS ->
+ withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS ->
testProperty "chad" $ property $ do
annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
- let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
- dtermChadS = simplifyFix dtermChad0
- let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS
- dtermSChadS = simplifyFix dtermSChad0
-
-- pack Text for less GC pressure (these values are retained for some reason)
diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
@@ -286,11 +324,13 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
-- annotate (ppExpr knownEnv expr)
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
+ annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS)))
diff outChad0 closeIsh outPrimal
diff outChadS closeIsh outPrimal
diff outSChad0 closeIsh outPrimal
diff outSChadS closeIsh outPrimal
diff outCompSChadS closeIsh outPrimal
+ -- TODO: use closeIshT
let closeIshList x y = and (zipWith closeIsh x y)
diff scChad closeIshList scFwd
diff scChadS closeIshList scFwd
@@ -338,8 +378,14 @@ term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec
idx0 (sum1i (build1 #wid $ #j :->
#mat ! pair (pair nil #i) #j * #vec ! pair nil #j))
-tests :: TestTree
-tests = testGroup "AD"
+tests_Compile :: TestTree
+tests_Compile = testGroup "Compile"
+ [compileTest "accum f64" $ fromNamed $ lambda #x $ body $
+ with @(TScal TF64) 0.0 $ #ac :->
+ accum SAPHere nil #x #ac]
+
+tests_AD :: TestTree
+tests_AD = testGroup "AD"
[adTest "id" $ fromNamed $ lambda #x $ body $ #x
,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x
@@ -456,4 +502,6 @@ tests = testGroup "AD"
return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)
main :: IO ()
-main = defaultMain tests
+main = defaultMain $ testGroup "All"
+ [tests_Compile
+ ,tests_AD]