From 6e85d5b2aee0cf2c089538e74261f1d88d6b1b71 Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Wed, 26 Mar 2025 23:47:59 +0100
Subject: test: Start of a list of compile tests

---
 test/Main.hs | 66 +++++++++++++++++++++++++++++++++++++++++++++++++++---------
 1 file changed, 57 insertions(+), 9 deletions(-)

(limited to 'test/Main.hs')

diff --git a/test/Main.hs b/test/Main.hs
index 19bd3e6..014ad43 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -93,6 +93,28 @@ closeIsh' h a b =
 closeIsh :: Double -> Double -> Bool
 closeIsh = closeIsh' 1e-5
 
+closeIshT' :: Double -> STy t -> Rep t -> Rep t -> Bool
+closeIshT' _ STNil () () = True
+closeIshT' h (STPair a b) (x, y) (x', y') = closeIshT' h a x x' && closeIshT' h b y y'
+closeIshT' h (STEither a _) (Left x) (Left x') = closeIshT' h a x x'
+closeIshT' h (STEither _ b) (Right x) (Right x') = closeIshT' h b x x'
+closeIshT' _ STEither{} _ _ = False
+closeIshT' _ (STMaybe _) Nothing Nothing = True
+closeIshT' h (STMaybe a) (Just x) (Just x') = closeIshT' h a x x'
+closeIshT' _ STMaybe{} _ _ = False
+closeIshT' h (STArr _ a) arr1 arr2 =
+  arrayShape arr1 == arrayShape arr2 &&
+    and (zipWith (closeIshT' h a) (arrayToList arr1) (arrayToList arr2))
+closeIshT' _ (STScal STI32) i j = i == j
+closeIshT' _ (STScal STI64) i j = i == j
+closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y)
+closeIshT' h (STScal STF64) x y = closeIsh' h x y
+closeIshT' _ (STScal STBool) x y = x == y
+closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators"
+
+closeIshT :: STy t -> Rep t -> Rep t -> Bool
+closeIshT = closeIshT' 1e-5
+
 data a :$ b = a :$ b deriving (Show) ; infixl :$
 
 -- An empty name means "no restrictions".
@@ -189,6 +211,22 @@ genEnv SNil () = return SNil
 genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil
 genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl
 
+data TypedValue t = TypedValue (STy t) (Rep t)
+instance Show (TypedValue t) where
+  showsPrec d (TypedValue t x) = showValue d t x
+
+compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree
+compileTest name expr =
+  let env = knownEnv
+      t = typeOf expr
+  in withCompiled env expr $ \fun ->
+     testProperty name $ property $ do
+       input <- forAllWith (showEnv env) (evalStateT (genEnv env (emptyTemplateE env)) mempty)
+       let resI = interpretOpen False input expr
+       resC <- liftIO $ fun input
+       let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y
+       diff (TypedValue t resI) cmp (TypedValue t resC)
+
 adTest :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> TestTree
 adTest name = adTestCon name (const True)
 
@@ -247,16 +285,16 @@ adTestGenChad :: forall env. SList STy env -> Gen (SList Value env)
               -> (SList Value env -> IO Double)
               -> TestTree
 adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
+  let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
+      dtermChadS = simplifyFix dtermChad0
+      dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS
+      dtermSChadS = simplifyFix dtermSChad0
+  in
   withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
-  withCompiled env (simplifyFix (unMonoid (simplifyFix (ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS)))) $ \dcompSChadS ->
+  withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS ->
     testProperty "chad" $ property $ do
       annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
 
-      let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
-          dtermChadS = simplifyFix dtermChad0
-      let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS
-          dtermSChadS = simplifyFix dtermSChad0
-
       -- pack Text for less GC pressure (these values are retained for some reason)
       diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
       diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
@@ -286,11 +324,13 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
       -- annotate (ppExpr knownEnv expr)
       -- annotate (ppExpr env dtermChad0)
       -- annotate (ppExpr env dtermChadS)
+      annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS)))
       diff outChad0      closeIsh outPrimal
       diff outChadS      closeIsh outPrimal
       diff outSChad0     closeIsh outPrimal
       diff outSChadS     closeIsh outPrimal
       diff outCompSChadS closeIsh outPrimal
+      -- TODO: use closeIshT
       let closeIshList x y = and (zipWith closeIsh x y)
       diff scChad       closeIshList scFwd
       diff scChadS      closeIshList scFwd
@@ -338,8 +378,14 @@ term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec
         idx0 (sum1i (build1 #wid $ #j :->
                        #mat ! pair (pair nil #i) #j * #vec ! pair nil #j))
 
-tests :: TestTree
-tests = testGroup "AD"
+tests_Compile :: TestTree
+tests_Compile = testGroup "Compile"
+  [compileTest "accum f64" $ fromNamed $ lambda #x $ body $
+      with @(TScal TF64) 0.0 $ #ac :->
+        accum SAPHere nil #x #ac]
+
+tests_AD :: TestTree
+tests_AD = testGroup "AD"
   [adTest "id" $ fromNamed $ lambda #x $ body $ #x
 
   ,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x
@@ -456,4 +502,6 @@ tests = testGroup "AD"
       return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)
 
 main :: IO ()
-main = defaultMain tests
+main = defaultMain $ testGroup "All"
+  [tests_Compile
+  ,tests_AD]
-- 
cgit v1.2.3-70-g09d2