diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 23:47:59 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-26 23:47:59 +0100 |
commit | 6e85d5b2aee0cf2c089538e74261f1d88d6b1b71 (patch) | |
tree | 1ae5699e3b420e7a61203dfb697ad17eba38f683 | |
parent | 16a836d078caefc3526031c084e2527cba0da3a8 (diff) |
test: Start of a list of compile tests
-rw-r--r-- | src/Array.hs | 3 | ||||
-rw-r--r-- | test/Main.hs | 66 |
2 files changed, 60 insertions, 9 deletions
diff --git a/src/Array.hs b/src/Array.hs index 059600f..707dce2 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -88,6 +88,9 @@ emptyArray n = Array (emptyShape n) V.empty arrayFromList :: Shape n -> [t] -> Array n t arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l) +arrayToList :: Array n t -> [t] +arrayToList (Array _ v) = V.toList v + arrayUnit :: t -> Array Z t arrayUnit x = Array ShNil (V.singleton x) diff --git a/test/Main.hs b/test/Main.hs index 19bd3e6..014ad43 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -93,6 +93,28 @@ closeIsh' h a b = closeIsh :: Double -> Double -> Bool closeIsh = closeIsh' 1e-5 +closeIshT' :: Double -> STy t -> Rep t -> Rep t -> Bool +closeIshT' _ STNil () () = True +closeIshT' h (STPair a b) (x, y) (x', y') = closeIshT' h a x x' && closeIshT' h b y y' +closeIshT' h (STEither a _) (Left x) (Left x') = closeIshT' h a x x' +closeIshT' h (STEither _ b) (Right x) (Right x') = closeIshT' h b x x' +closeIshT' _ STEither{} _ _ = False +closeIshT' _ (STMaybe _) Nothing Nothing = True +closeIshT' h (STMaybe a) (Just x) (Just x') = closeIshT' h a x x' +closeIshT' _ STMaybe{} _ _ = False +closeIshT' h (STArr _ a) arr1 arr2 = + arrayShape arr1 == arrayShape arr2 && + and (zipWith (closeIshT' h a) (arrayToList arr1) (arrayToList arr2)) +closeIshT' _ (STScal STI32) i j = i == j +closeIshT' _ (STScal STI64) i j = i == j +closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y) +closeIshT' h (STScal STF64) x y = closeIsh' h x y +closeIshT' _ (STScal STBool) x y = x == y +closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators" + +closeIshT :: STy t -> Rep t -> Rep t -> Bool +closeIshT = closeIshT' 1e-5 + data a :$ b = a :$ b deriving (Show) ; infixl :$ -- An empty name means "no restrictions". @@ -189,6 +211,22 @@ genEnv SNil () = return SNil genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl +data TypedValue t = TypedValue (STy t) (Rep t) +instance Show (TypedValue t) where + showsPrec d (TypedValue t x) = showValue d t x + +compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree +compileTest name expr = + let env = knownEnv + t = typeOf expr + in withCompiled env expr $ \fun -> + testProperty name $ property $ do + input <- forAllWith (showEnv env) (evalStateT (genEnv env (emptyTemplateE env)) mempty) + let resI = interpretOpen False input expr + resC <- liftIO $ fun input + let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y + diff (TypedValue t resI) cmp (TypedValue t resC) + adTest :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> TestTree adTest name = adTestCon name (const True) @@ -247,16 +285,16 @@ adTestGenChad :: forall env. SList STy env -> Gen (SList Value env) -> (SList Value env -> IO Double) -> TestTree adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = + let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr + dtermChadS = simplifyFix dtermChad0 + dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS + dtermSChadS = simplifyFix dtermSChad0 + in withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> - withCompiled env (simplifyFix (unMonoid (simplifyFix (ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS)))) $ \dcompSChadS -> + withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS -> testProperty "chad" $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) - let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr - dtermChadS = simplifyFix dtermChad0 - let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS - dtermSChadS = simplifyFix dtermSChad0 - -- pack Text for less GC pressure (these values are retained for some reason) diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0))) diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) @@ -286,11 +324,13 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) + annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS))) diff outChad0 closeIsh outPrimal diff outChadS closeIsh outPrimal diff outSChad0 closeIsh outPrimal diff outSChadS closeIsh outPrimal diff outCompSChadS closeIsh outPrimal + -- TODO: use closeIshT let closeIshList x y = and (zipWith closeIsh x y) diff scChad closeIshList scFwd diff scChadS closeIshList scFwd @@ -338,8 +378,14 @@ term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec idx0 (sum1i (build1 #wid $ #j :-> #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) -tests :: TestTree -tests = testGroup "AD" +tests_Compile :: TestTree +tests_Compile = testGroup "Compile" + [compileTest "accum f64" $ fromNamed $ lambda #x $ body $ + with @(TScal TF64) 0.0 $ #ac :-> + accum SAPHere nil #x #ac] + +tests_AD :: TestTree +tests_AD = testGroup "AD" [adTest "id" $ fromNamed $ lambda #x $ body $ #x ,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x @@ -456,4 +502,6 @@ tests = testGroup "AD" return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) main :: IO () -main = defaultMain tests +main = defaultMain $ testGroup "All" + [tests_Compile + ,tests_AD] |