summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs187
1 files changed, 139 insertions, 48 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 7dbafab..20b4ef0 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -41,6 +41,9 @@ import Language
import Simplify
+type R = TScal TF64
+
+
data SimplIters = SimplIters Int | SimplFix
deriving (Show)
@@ -51,19 +54,19 @@ simplifyIters iters env | Dict <- envKnown env =
SimplFix -> simplifyFix
-- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
+gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD simplIters env term input =
let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
(out, grad) = interpretOpen False input dterm
in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
-- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env)))
+gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env)))
gradientByCHAD' simplIters env term input =
second (second (toTanE env input)) $
gradientByCHAD simplIters env term input
-gradientByForward :: FwdADArtifact env (TScal TF64) -> SList Value env -> SList Value (TanE env)
+gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env)
gradientByForward art input = drevByFwd art input 1.0
extendDN :: STy t -> Rep t -> Gen (Rep (DN t))
@@ -93,6 +96,28 @@ closeIsh' h a b =
closeIsh :: Double -> Double -> Bool
closeIsh = closeIsh' 1e-5
+closeIshT' :: Double -> STy t -> Rep t -> Rep t -> Bool
+closeIshT' _ STNil () () = True
+closeIshT' h (STPair a b) (x, y) (x', y') = closeIshT' h a x x' && closeIshT' h b y y'
+closeIshT' h (STEither a _) (Left x) (Left x') = closeIshT' h a x x'
+closeIshT' h (STEither _ b) (Right x) (Right x') = closeIshT' h b x x'
+closeIshT' _ STEither{} _ _ = False
+closeIshT' _ (STMaybe _) Nothing Nothing = True
+closeIshT' h (STMaybe a) (Just x) (Just x') = closeIshT' h a x x'
+closeIshT' _ STMaybe{} _ _ = False
+closeIshT' h (STArr _ a) arr1 arr2 =
+ arrayShape arr1 == arrayShape arr2 &&
+ and (zipWith (closeIshT' h a) (arrayToList arr1) (arrayToList arr2))
+closeIshT' _ (STScal STI32) i j = i == j
+closeIshT' _ (STScal STI64) i j = i == j
+closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y)
+closeIshT' h (STScal STF64) x y = closeIsh' h x y
+closeIshT' _ (STScal STBool) x y = x == y
+closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators"
+
+closeIshT :: STy t -> Rep t -> Rep t -> Bool
+closeIshT = closeIshT' 1e-5
+
data a :$ b = a :$ b deriving (Show) ; infixl :$
-- An empty name means "no restrictions".
@@ -139,7 +164,7 @@ genShape = \n tpl -> do
sh <- genShapeNaive n tpl
let sz = shapeSize sh
factor = sz `div` 100 + 1
- return (shapeDiv sh factor)
+ return (shapeDiv sh tpl factor)
where
genShapeNaive :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n)
genShapeNaive SZ () = return ShNil
@@ -156,11 +181,12 @@ genShape = \n tpl -> do
Just dim -> return dim
genDim :: Int -> StateT (Map String Int) Gen Int
- genDim lo = Gen.integral (Range.linear lo 10)
+ genDim lo = Gen.integral (Range.linear lo (lo+10))
- shapeDiv :: Shape n -> Int -> Shape n
- shapeDiv ShNil _ = ShNil
- shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)
+ shapeDiv :: Shape n -> DimNames n -> Int -> Shape n
+ shapeDiv ShNil _ _ = ShNil
+ shapeDiv (ShNil `ShCons` n) (C _ lo) f = ShNil `ShCons` (max lo (n `div` f))
+ shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` (max lo (n `div` f))
genArray :: STy a -> Shape n -> Gen (Value (TArr n a))
genArray t sh =
@@ -189,20 +215,42 @@ genEnv SNil () = return SNil
genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil
genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl
-adTest :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> TestTree
+data TypedValue t = TypedValue (STy t) (Rep t)
+instance Show (TypedValue t) where
+ showsPrec d (TypedValue t x) = showValue d t x
+
+compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree
+compileTest name (expr :: Ex env t) = compileTestTp name (emptyTemplateE (knownEnv @env)) expr
+
+compileTestTp :: KnownEnv env => TestName -> TemplateE env -> Ex env t -> TestTree
+compileTestTp name tmpl expr = compileTestGen name expr (evalStateT (genEnv knownEnv tmpl) mempty)
+
+compileTestGen :: KnownEnv env => TestName -> Ex env t -> Gen (SList Value env) -> TestTree
+compileTestGen name expr envGenerator =
+ let env = knownEnv
+ t = typeOf expr
+ in withCompiled env expr $ \fun ->
+ testProperty name $ property $ do
+ input <- forAllWith (showEnv env) envGenerator
+ let resI = interpretOpen False input expr
+ resC <- liftIO $ fun input
+ let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y
+ diff (TypedValue t resI) cmp (TypedValue t resC)
+
+adTest :: forall env. KnownEnv env => TestName -> Ex env R -> TestTree
adTest name = adTestCon name (const True)
-adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env (TScal TF64) -> TestTree
+adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env R -> TestTree
adTestCon name constr term =
let env = knownEnv
in adTestGen name term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty))
adTestTp :: forall env. KnownEnv env
- => TestName -> TemplateE env -> Ex env (TScal TF64) -> TestTree
+ => TestName -> TemplateE env -> Ex env R -> TestTree
adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) mempty)
adTestGen :: forall env. KnownEnv env
- => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree
+ => TestName -> Ex env R -> Gen (SList Value env) -> TestTree
adTestGen name expr envGenerator =
let env = knownEnv @env
exprS = simplifyFix expr
@@ -210,11 +258,11 @@ adTestGen name expr envGenerator =
withCompiled env (simplifyFix expr) $ \primalSfun ->
testGroupCollapse name
[adTestGenPrimal env envGenerator expr exprS primalfun primalSfun
- ,adTestGenFwd env envGenerator expr exprS
+ ,adTestGenFwd env envGenerator exprS
,adTestGenChad env envGenerator expr exprS primalSfun]
adTestGenPrimal :: SList STy env -> Gen (SList Value env)
- -> Ex env (TScal TF64) -> Ex env (TScal TF64)
+ -> Ex env R -> Ex env R
-> (SList Value env -> IO Double) -> (SList Value env -> IO Double)
-> TestTree
adTestGenPrimal env envGenerator expr exprS primalfun primalSfun =
@@ -230,32 +278,33 @@ adTestGenPrimal env envGenerator expr exprS primalfun primalSfun =
diff outPrimalSI (closeIsh' 1e-8) outPrimalSC
adTestGenFwd :: SList STy env -> Gen (SList Value env)
- -> Ex env (TScal TF64) -> Ex env (TScal TF64)
+ -> Ex env R
-> TestTree
-adTestGenFwd env envGenerator expr exprS =
+adTestGenFwd env envGenerator exprS =
withCompiled (dne env) (dfwdDN exprS) $ \dnfun ->
testProperty "compile fwdAD" $ property $ do
input <- forAllWith (showEnv env) envGenerator
dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input
- let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN expr)
+ let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN exprS)
(outDNC1, outDNC2) <- liftIO $ dnfun dinput
diff outDNI1 (closeIsh' 1e-8) outDNC1
diff outDNI2 (closeIsh' 1e-8) outDNC2
adTestGenChad :: forall env. SList STy env -> Gen (SList Value env)
- -> Ex env (TScal TF64) -> Ex env (TScal TF64)
+ -> Ex env R -> Ex env R
-> (SList Value env -> IO Double)
-> TestTree
adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
+ let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
+ dtermChadS = simplifyFix dtermChad0
+ dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS
+ dtermSChadS = simplifyFix dtermSChad0
+ in
withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
+ withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS ->
testProperty "chad" $ property $ do
annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
- let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
- dtermChadS = simplifyFix dtermChad0
- let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS
- dtermSChadS = simplifyFix dtermSChad0
-
-- pack Text for less GC pressure (these values are retained for some reason)
diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
@@ -268,44 +317,52 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
let scFwd = tanEScalars env $ gradientByForward fwdartifactC input
- let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
- (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
+ let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
+ (outChadS , gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
(outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0
(outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS
- scChad = tanEScalars env $ toTanE env input gradChad0
- scChadS = tanEScalars env $ toTanE env input gradChadS
- scSChad = tanEScalars env $ toTanE env input gradSChad0
+ scChad = tanEScalars env $ toTanE env input gradChad0
+ scChadS = tanEScalars env $ toTanE env input gradChadS
+ scSChad = tanEScalars env $ toTanE env input gradSChad0
scSChadS = tanEScalars env $ toTanE env input gradSChadS
+ (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> liftIO (dcompSChadS input)
+ let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS
+
-- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
-- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
-- annotate (ppExpr knownEnv expr)
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
- diff outChad0 closeIsh outPrimal
- diff outChadS closeIsh outPrimal
- diff outSChad0 closeIsh outPrimal
- diff outSChadS closeIsh outPrimal
- diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd
- diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd
- diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd
- diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd
+ annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS)))
+ diff outChad0 closeIsh outPrimal
+ diff outChadS closeIsh outPrimal
+ diff outSChad0 closeIsh outPrimal
+ diff outSChadS closeIsh outPrimal
+ diff outCompSChadS closeIsh outPrimal
+ -- TODO: use closeIshT
+ let closeIshList x y = and (zipWith closeIsh x y)
+ diff scChad closeIshList scFwd
+ diff scChadS closeIshList scFwd
+ diff scSChad closeIshList scFwd
+ diff scSChadS closeIshList scFwd
+ diff scCompSChadS closeIshList scFwd
withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
withCompiled env expr = withResource (compile env expr) (\_ -> pure ())
-term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
+term_build1_sum :: Ex '[TArr N1 R] R
term_build1_sum = fromNamed $ lambda #x $ body $
idx0 $ sum1i $
build (SS SZ) (shape #x) $ #idx :-> #x ! #idx
-term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64)
+term_pairs :: Ex [R, R] R
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
let_ #p (pair #x #y) $
let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
fst_ #q * #x + snd_ #q * fst_ #p
-term_sparse :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
+term_sparse :: Ex '[TArr N1 R] R
term_sparse = fromNamed $ lambda #inp $ body $
let_ #n (snd_ (shape #inp)) $
let_ #arr (build1 #n (#i :-> #inp ! pair nil #i)) $
@@ -314,7 +371,7 @@ term_sparse = fromNamed $ lambda #inp $ body $
let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $
idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c)
-term_regression_simpl1 :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
+term_regression_simpl1 :: Ex '[TArr N1 R] R
term_regression_simpl1 = fromNamed $ lambda #q $ body $
idx0 $ sum1i $ build (SS SZ) (shape #q) $ #idx :->
let_ #j (snd_ #idx) $
@@ -322,7 +379,7 @@ term_regression_simpl1 = fromNamed $ lambda #q $ body $
(#q ! pair nil 0)
(if_ (#j .== #j) 1.0 2.0)
-term_mulmatvec :: Ex [TArr N1 (TScal TF64), TArr N2 (TScal TF64)] (TScal TF64)
+term_mulmatvec :: Ex [TArr N1 R, TArr N2 R] R
term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $
idx0 $ sum1i $
let_ #hei (snd_ (fst_ (shape #mat))) $
@@ -331,8 +388,31 @@ term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec
idx0 (sum1i (build1 #wid $ #j :->
#mat ! pair (pair nil #i) #j * #vec ! pair nil #j))
-tests :: TestTree
-tests = testGroup "AD"
+tests_Compile :: TestTree
+tests_Compile = testGroup "Compile"
+ [compileTest "accum f64" $ fromNamed $ lambda #b $ lambda #x $ body $
+ with @R 0.0 $ #ac :->
+ if_ #b (accum SAPHere nil #x #ac)
+ nil
+
+ ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $
+ with @(TPair R R) nothing $ #ac :->
+ let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $
+ let_ #_ (accum SAPHere nil #x #ac) $
+ let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $
+ nil
+
+ ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda #x $ body $
+ let_ #len (snd_ (shape #x)) $
+ with @(TArr N1 R) nothing $ #ac :->
+ let_ #_ (if_ #b (accum (SAPArrIdx SAPHere (SS SZ)) (pair (pair (pair nil 2) (pair nil #len)) nil) 6.0 #ac)
+ nil) $
+ let_ #_ (accum SAPHere nil (just #x) #ac) $
+ nil
+ ]
+
+tests_AD :: TestTree
+tests_AD = testGroup "AD"
[adTest "id" $ fromNamed $ lambda #x $ body $ #x
,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x
@@ -361,7 +441,7 @@ tests = testGroup "AD"
,adTest "pairs" term_pairs
- ,adTest "build0 const" $ fromNamed $ lambda @(TScal TF64) #x $ body $
+ ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $
idx0 $ build SZ nil $ #idx :-> const_ 0.0
,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $
@@ -375,14 +455,14 @@ tests = testGroup "AD"
build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx
,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
- fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
+ fromNamed $ lambda @(TArr N2 R) #x $ body $
idx0 $ sum1i $ maximum1i #x
,adTestCon "minimum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
- fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
+ fromNamed $ lambda @(TArr N2 R) #x $ body $
idx0 $ sum1i $ minimum1i #x
- ,adTest "unused" $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $
+ ,adTest "unused" $ fromNamed $ lambda @(TArr N1 R) #x $ body $
let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
42
@@ -391,6 +471,15 @@ tests = testGroup "AD"
-- Regression test for a simplifier bug (89b78d4)
,adTestTp "regression-simpl1" (C "" 1) term_regression_simpl1
+ -- Regression test for refcounts when indexing in nested arrays
+ ,adTestTp "regression-idx1" (C "" 1 :$ C "" 1) $
+ fromNamed $ lambda @(TArr N2 R) #L $ body $
+ if_ (const_ @TI64 1 Language..> 0)
+ (idx0 $ sum1i (build1 1 $ #_ :->
+ idx0 (sum1i (build1 1 $ #_ :->
+ #L ! pair (pair nil 0) 0 * #L ! pair (pair nil 0) 0))))
+ 42
+
,adTestGen "neural" Example.neural genNeural
,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) genNeural
@@ -449,4 +538,6 @@ tests = testGroup "AD"
return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)
main :: IO ()
-main = defaultMain tests
+main = defaultMain $ testGroup "All"
+ [tests_Compile
+ ,tests_AD]