summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/Main.hs88
1 files changed, 56 insertions, 32 deletions
diff --git a/test/Main.hs b/test/Main.hs
index afbd79b..f5e4a3c 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -42,6 +42,18 @@ import Language
import Simplify
+data TypedValue t = TypedValue (STy t) (Rep t)
+instance Show (TypedValue t) where
+ showsPrec d (TypedValue t x) = showValue d t x
+
+data TypedEnv env = TypedEnv (SList STy env) (SList Value env)
+instance Show (TypedEnv env) where
+ show (TypedEnv env xs) = showEnv env xs
+
+unTypedEnv :: TypedEnv env -> SList Value env
+unTypedEnv (TypedEnv _ xs) = xs
+
+
data SimplIters = SimplIters Int | SimplFix
deriving (Show)
@@ -67,6 +79,7 @@ gradientByCHAD' simplIters env term input =
gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env)
gradientByForward art input = drevByFwd art input 1.0
+-- | Generate input tangents for this primal
extendDN :: STy t -> Rep t -> Gen (Rep (DN t))
extendDN STNil () = pure ()
extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y
@@ -82,6 +95,9 @@ extendDN (STScal sty) x = case sty of
STI64 -> pure x
STBool -> pure x
extendDN (STAccum _) _ = error "Accumulators not supported in input program"
+extendDN (STLEither _ _) Nothing = pure Nothing
+extendDN (STLEither a _) (Just (Left x)) = Just . Left <$> extendDN a x
+extendDN (STLEither _ b) (Just (Right y)) = Just . Right <$> extendDN b y
extendDNE :: SList STy env -> SList Value env -> Gen (SList Value (DNE env))
extendDNE SNil SNil = pure SNil
@@ -112,10 +128,19 @@ closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y)
closeIshT' h (STScal STF64) x y = closeIsh' h x y
closeIshT' _ (STScal STBool) x y = x == y
closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators"
+closeIshT' _ (STLEither _ _) Nothing Nothing = True
+closeIshT' h (STLEither a _) (Just (Left x)) (Just (Left x')) = closeIshT' h a x x'
+closeIshT' h (STLEither _ b) (Just (Right y)) (Just (Right y')) = closeIshT' h b y y'
+closeIshT' _ STLEither{} _ _ = False
closeIshT :: STy t -> Rep t -> Rep t -> Bool
closeIshT = closeIshT' 1e-5
+closeIshE :: SList STy t -> SList Value t -> SList Value t -> Bool
+closeIshE SNil SNil SNil = True
+closeIshE (t `SCons` env) (Value x `SCons` xs) (Value y `SCons` ys) =
+ closeIshT t x y && closeIshE env xs ys
+
data a :$ b = a :$ b deriving (Show) ; infixl :$
-- | The type index is just a marker that helps typed holes show what (type of)
@@ -218,6 +243,9 @@ genValue topty tpl = case topty of
STI64 -> genInt
STBool -> Gen.choice [return (Value False), return (Value True)]
STAccum{} -> error "Cannot generate inputs for accumulators"
+ STLEither a b -> Gen.frequency [(1, pure (Value Nothing))
+ ,(8, liftV (Just . Left) <$> genValue a (emptyTpl a))
+ ,(8, liftV (Just . Right) <$> genValue b (emptyTpl b))]
where
genInt :: (Integral (Rep t), Tpl t ~ TplConstr _q) => StateT (Map String Int) Gen (Value t)
genInt = do
@@ -237,10 +265,6 @@ genEnv SNil () = return SNil
genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil
genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl
-data TypedValue t = TypedValue (STy t) (Rep t)
-instance Show (TypedValue t) where
- showsPrec d (TypedValue t x) = showValue d t x
-
compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree
compileTest name (expr :: Ex env t) = compileTestTp name (emptyTemplateE (knownEnv @env)) expr
@@ -337,22 +361,22 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
unpackGrad = unTup vUnpair (d2e env) . Value
- let scFwd = tanEScalars env $ gradientByForward fwdartifactC input
+ let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input
let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0
(outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS
(outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0
(outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS
- scChad = tanEScalars env $ toTanE env input gradChad0
- scChadS = tanEScalars env $ toTanE env input gradChadS
- scSChad = tanEScalars env $ toTanE env input gradSChad0
- scSChadS = tanEScalars env $ toTanE env input gradSChadS
+ tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0
+ tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS
+ tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0
+ tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS
(outCompSChadS, gradCompSChadS) <- second unpackGrad <$> evalIO (dcompSChadS input)
- let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS
+ let tansCompSChadS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadS
- -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
- -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
+ -- annotate (showEnv (d2e env) gradChad0)
+ -- annotate (showEnv (d2e env) gradChadS)
-- annotate (ppExpr knownEnv expr)
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
@@ -362,13 +386,12 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
diff outSChad0 closeIsh outPrimal
diff outSChadS closeIsh outPrimal
diff outCompSChadS closeIsh outPrimal
- -- TODO: use closeIshT
- let closeIshList x y = and (zipWith closeIsh x y)
- diff scChad closeIshList scFwd
- diff scChadS closeIshList scFwd
- diff scSChad closeIshList scFwd
- diff scSChadS closeIshList scFwd
- diff scCompSChadS closeIshList scFwd
+ let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2)
+ diff tansChad closeIshE' tansFwd
+ diff tansChadS closeIshE' tansFwd
+ diff tansSChad closeIshE' tansFwd
+ diff tansSChadS closeIshE' tansFwd
+ diff tansCompSChadS closeIshE' tansFwd
withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
withCompiled env expr = withResource (compile env expr) (\_ -> pure ())
@@ -478,18 +501,25 @@ tests_Compile = testGroup "Compile"
nil
,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $
- with @(TPair R R) nothing $ #ac :->
- let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $
+ with @(TPair R R) (pair 0.0 0.0) $ #ac :->
+ let_ #_ (if_ #b (accum (SAPFst SAPHere) (pair nil nil) 3.0 #ac) nil) $
+ let_ #_ (accum SAPHere nil #x #ac) $
+ let_ #_ (accum (SAPSnd SAPHere) (pair nil nil) 4.0 #ac) $
+ nil
+
+ ,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $
+ with @(TMaybe (TPair R R)) nothing $ #ac :->
+ let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) (pair nil nil) 3.0 #ac) nil) $
let_ #_ (accum SAPHere nil #x #ac) $
- let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $
+ let_ #_ (accum (SAPJust (SAPSnd SAPHere)) (pair nil nil) 4.0 #ac) $
nil
- ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda #x $ body $
+ ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $
let_ #len (snd_ (shape #x)) $
- with @(TVec R) nothing $ #ac :->
- let_ #_ (if_ #b (accum (SAPArrIdx SAPHere (SS SZ)) (pair (pair (pair nil 2) (pair nil #len)) nil) 6.0 #ac)
+ with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :->
+ let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair (pair nil 2) (build1 #len (#_ :-> nil))) nil) 6.0 #ac)
nil) $
- let_ #_ (accum SAPHere nil (just #x) #ac) $
+ let_ #_ (accum SAPHere nil #x #ac) $
nil
]
@@ -567,8 +597,6 @@ tests_AD = testGroup "AD"
,adTestGen "neural" Example.neural gen_neural
- ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) gen_neural
-
,adTestTp "logsumexp" (C "" 1) $
fromNamed $ lambda @(TVec _) #vec $ body $
let_ #m (maximum1i #vec) $
@@ -578,11 +606,7 @@ tests_AD = testGroup "AD"
,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm
- ,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) gen_gmm
-
,adTestGen "gmm" (Example.gmmObjective False) gen_gmm
-
- ,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) gen_gmm
]
main :: IO ()