aboutsummaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs483
1 files changed, 349 insertions, 134 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 7dbafab..0a57cbf 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
@@ -11,7 +12,6 @@
module Main where
import Control.Monad.Trans.Class (lift)
-import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.State
import Data.Bifunctor
import Data.Int (Int64)
@@ -24,7 +24,7 @@ import qualified Hedgehog.Range as Range
import Test.Framework
import Array
-import AST
+import AST hiding ((.>))
import AST.Pretty
import AST.UnMonoid
import CHAD.Top
@@ -33,6 +33,7 @@ import CHAD.Types.ToTan
import Compile
import qualified Example
import qualified Example.GMM as Example
+import Example.Types
import ForwardAD
import ForwardAD.DualNumbers
import Interpreter
@@ -41,6 +42,18 @@ import Language
import Simplify
+data TypedValue t = TypedValue (STy t) (Rep t)
+instance Show (TypedValue t) where
+ showsPrec d (TypedValue t x) = showValue d t x
+
+data TypedEnv env = TypedEnv (SList STy env) (SList Value env)
+instance Show (TypedEnv env) where
+ show (TypedEnv env xs) = showEnv env xs
+
+unTypedEnv :: TypedEnv env -> SList Value env
+unTypedEnv (TypedEnv _ xs) = xs
+
+
data SimplIters = SimplIters Int | SimplFix
deriving (Show)
@@ -51,26 +64,30 @@ simplifyIters iters env | Dict <- envKnown env =
SimplFix -> simplifyFix
-- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (D2E env)))
+gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env)))
gradientByCHAD simplIters env term input =
let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term
- (out, grad) = interpretOpen False input dterm
+ (out, grad) = interpretOpen False env input dterm
in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad)))
-- In addition to the gradient, also returns the pretty-printed differentiated term.
-gradientByCHAD' :: SimplIters -> SList STy env -> Ex env (TScal TF64) -> SList Value env -> (String, (Double, SList Value (TanE env)))
+gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env)))
gradientByCHAD' simplIters env term input =
second (second (toTanE env input)) $
gradientByCHAD simplIters env term input
-gradientByForward :: FwdADArtifact env (TScal TF64) -> SList Value env -> SList Value (TanE env)
+gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env)
gradientByForward art input = drevByFwd art input 1.0
+-- | Generate input tangents for this primal
extendDN :: STy t -> Rep t -> Gen (Rep (DN t))
extendDN STNil () = pure ()
extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y
extendDN (STEither a _) (Left x) = Left <$> extendDN a x
extendDN (STEither _ b) (Right y) = Right <$> extendDN b y
+extendDN (STLEither _ _) Nothing = pure Nothing
+extendDN (STLEither a _) (Just (Left x)) = Just . Left <$> extendDN a x
+extendDN (STLEither _ b) (Just (Right y)) = Just . Right <$> extendDN b y
extendDN (STMaybe _) Nothing = pure Nothing
extendDN (STMaybe t) (Just x) = Just <$> extendDN t x
extendDN (STArr _ t) arr = traverse (extendDN t) arr
@@ -93,20 +110,55 @@ closeIsh' h a b =
closeIsh :: Double -> Double -> Bool
closeIsh = closeIsh' 1e-5
+closeIshT' :: Double -> STy t -> Rep t -> Rep t -> Bool
+closeIshT' _ STNil () () = True
+closeIshT' h (STPair a b) (x, y) (x', y') = closeIshT' h a x x' && closeIshT' h b y y'
+closeIshT' h (STEither a _) (Left x) (Left x') = closeIshT' h a x x'
+closeIshT' h (STEither _ b) (Right x) (Right x') = closeIshT' h b x x'
+closeIshT' _ STEither{} _ _ = False
+closeIshT' _ (STLEither _ _) Nothing Nothing = True
+closeIshT' h (STLEither a _) (Just (Left x)) (Just (Left x')) = closeIshT' h a x x'
+closeIshT' h (STLEither _ b) (Just (Right y)) (Just (Right y')) = closeIshT' h b y y'
+closeIshT' _ STLEither{} _ _ = False
+closeIshT' _ (STMaybe _) Nothing Nothing = True
+closeIshT' h (STMaybe a) (Just x) (Just x') = closeIshT' h a x x'
+closeIshT' _ STMaybe{} _ _ = False
+closeIshT' h (STArr _ a) arr1 arr2 =
+ arrayShape arr1 == arrayShape arr2 &&
+ and (zipWith (closeIshT' h a) (arrayToList arr1) (arrayToList arr2))
+closeIshT' _ (STScal STI32) i j = i == j
+closeIshT' _ (STScal STI64) i j = i == j
+closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y)
+closeIshT' h (STScal STF64) x y = closeIsh' h x y
+closeIshT' _ (STScal STBool) x y = x == y
+closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators"
+
+closeIshT :: STy t -> Rep t -> Rep t -> Bool
+closeIshT = closeIshT' 1e-5
+
+closeIshE :: SList STy t -> SList Value t -> SList Value t -> Bool
+closeIshE SNil SNil SNil = True
+closeIshE (t `SCons` env) (Value x `SCons` xs) (Value y `SCons` ys) =
+ closeIshT t x y && closeIshE env xs ys
+
data a :$ b = a :$ b deriving (Show) ; infixl :$
--- An empty name means "no restrictions".
-data TplConstr = C String -- ^ name; @""@ means anonymous
- Int -- ^ minimum value to generate
+-- | The type index is just a marker that helps typed holes show what (type of)
+-- argument this template constraint belongs to.
+data TplConstr a = C String -- ^ name; @""@ means anonymous
+ Int -- ^ minimum value to generate
+ | NC -- ^ no constraints
type family DimNames n where
DimNames Z = ()
- DimNames (S Z) = TplConstr
- DimNames (S n) = DimNames n :$ TplConstr
+ DimNames (S Z) = TplConstr (S Z)
+ DimNames (S n) = DimNames n :$ TplConstr (S n)
type family Tpl t where
Tpl (TArr n t) = DimNames n
Tpl (TPair a b) = (Tpl a, Tpl b)
+ Tpl (TScal TI32) = TplConstr TI32
+ Tpl (TScal TI64) = TplConstr TI64
-- If you add equations here, don't forget to update genValue! It currently
-- just emptyTpl's things out.
Tpl _ = ()
@@ -120,13 +172,17 @@ type family TemplateE env where
emptyDimNames :: SNat n -> DimNames n
emptyDimNames SZ = ()
-emptyDimNames (SS SZ) = C "" 0
-emptyDimNames (SS n@SS{}) = emptyDimNames n :$ C "" 0
+emptyDimNames (SS SZ) = NC
+emptyDimNames (SS n@SS{}) = emptyDimNames n :$ NC
emptyTpl :: STy t -> Tpl t
emptyTpl (STArr n _) = emptyDimNames n
emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b)
-emptyTpl (STScal _) = ()
+emptyTpl (STScal STI32) = NC
+emptyTpl (STScal STI64) = NC
+emptyTpl (STScal STF32) = ()
+emptyTpl (STScal STF64) = ()
+emptyTpl (STScal STBool) = ()
emptyTpl _ = error "too lazy"
emptyTemplateE :: SList STy env -> TemplateE env
@@ -139,14 +195,15 @@ genShape = \n tpl -> do
sh <- genShapeNaive n tpl
let sz = shapeSize sh
factor = sz `div` 100 + 1
- return (shapeDiv sh factor)
+ return (shapeDiv sh tpl factor)
where
genShapeNaive :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n)
genShapeNaive SZ () = return ShNil
genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name
genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name
- genNamedDim :: TplConstr -> StateT (Map String Int) Gen Int
+ genNamedDim :: TplConstr a -> StateT (Map String Int) Gen Int
+ genNamedDim NC = genDim 0
genNamedDim (C "" lo) = genDim lo
genNamedDim (C name lo) = gets (Map.lookup name) >>= \case
Nothing -> do
@@ -156,53 +213,90 @@ genShape = \n tpl -> do
Just dim -> return dim
genDim :: Int -> StateT (Map String Int) Gen Int
- genDim lo = Gen.integral (Range.linear lo 10)
+ genDim lo = Gen.integral (Range.linear lo (lo+10))
- shapeDiv :: Shape n -> Int -> Shape n
- shapeDiv ShNil _ = ShNil
- shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)
+ shapeDiv :: Shape n -> DimNames n -> Int -> Shape n
+ shapeDiv ShNil _ _ = ShNil
+ shapeDiv (ShNil `ShCons` n) ( C _ lo) f = ShNil `ShCons` (max lo (n `div` f))
+ shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` (max lo (n `div` f))
+ shapeDiv (ShNil `ShCons` n) NC f = ShNil `ShCons` (n `div` f)
+ shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ NC) f = shapeDiv sh tpl f `ShCons` (n `div` f)
genArray :: STy a -> Shape n -> Gen (Value (TArr n a))
genArray t sh =
Value <$> arrayGenerateLinM sh (\_ ->
unValue <$> evalStateT (genValue t (emptyTpl t)) mempty)
-genValue :: STy t -> Tpl t -> StateT (Map String Int) Gen (Value t)
+genValue :: forall t. STy t -> Tpl t -> StateT (Map String Int) Gen (Value t)
genValue topty tpl = case topty of
STNil -> return (Value ())
STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl)
STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a)
,liftV Right <$> genValue b (emptyTpl b)]
+ STLEither a b -> Gen.frequency [(1, pure (Value Nothing))
+ ,(8, liftV (Just . Left) <$> genValue a (emptyTpl a))
+ ,(8, liftV (Just . Right) <$> genValue b (emptyTpl b))]
STMaybe t -> Gen.choice [return (Value Nothing)
,liftV Just <$> genValue t (emptyTpl t)]
STArr n t -> genShape n tpl >>= lift . genArray t
STScal sty -> case sty of
STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
- STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
- STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
+ STI32 -> genInt
+ STI64 -> genInt
STBool -> Gen.choice [return (Value False), return (Value True)]
STAccum{} -> error "Cannot generate inputs for accumulators"
+ where
+ genInt :: (Integral (Rep t), Tpl t ~ TplConstr _q) => StateT (Map String Int) Gen (Value t)
+ genInt = do
+ let gen lo = Gen.integral (Range.linearFrom 0 lo (max 10 (lo + 10)))
+ val <- case tpl of
+ NC -> gen (-10)
+ C name lo -> gets (Map.lookup name) >>= \case
+ Nothing -> do
+ val <- fromIntegral @Int @(Rep t) <$> gen lo
+ modify (Map.insert name (fromIntegral @(Rep t) @Int val))
+ return val
+ Just val -> return (fromIntegral @Int @(Rep t) val)
+ return (Value val)
genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env)
genEnv SNil () = return SNil
genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil
genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl
-adTest :: forall env. KnownEnv env => TestName -> Ex env (TScal TF64) -> TestTree
+compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree
+compileTest name (expr :: Ex env t) = compileTestTp name (emptyTemplateE (knownEnv @env)) expr
+
+compileTestTp :: KnownEnv env => TestName -> TemplateE env -> Ex env t -> TestTree
+compileTestTp name tmpl expr = compileTestGen name expr (evalStateT (genEnv knownEnv tmpl) mempty)
+
+compileTestGen :: KnownEnv env => TestName -> Ex env t -> Gen (SList Value env) -> TestTree
+compileTestGen name expr envGenerator =
+ let env = knownEnv
+ t = typeOf expr
+ in withCompiled env expr $ \fun ->
+ testProperty name $ property $ do
+ input <- forAllWith (showEnv env) envGenerator
+ let resI = interpretOpen False env input expr
+ resC <- evalIO $ fun input
+ let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y
+ diff (TypedValue t resI) cmp (TypedValue t resC)
+
+adTest :: forall env. KnownEnv env => TestName -> Ex env R -> TestTree
adTest name = adTestCon name (const True)
-adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env (TScal TF64) -> TestTree
+adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env R -> TestTree
adTestCon name constr term =
let env = knownEnv
in adTestGen name term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty))
adTestTp :: forall env. KnownEnv env
- => TestName -> TemplateE env -> Ex env (TScal TF64) -> TestTree
+ => TestName -> TemplateE env -> Ex env R -> TestTree
adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) mempty)
adTestGen :: forall env. KnownEnv env
- => TestName -> Ex env (TScal TF64) -> Gen (SList Value env) -> TestTree
+ => TestName -> Ex env R -> Gen (SList Value env) -> TestTree
adTestGen name expr envGenerator =
let env = knownEnv @env
exprS = simplifyFix expr
@@ -210,102 +304,182 @@ adTestGen name expr envGenerator =
withCompiled env (simplifyFix expr) $ \primalSfun ->
testGroupCollapse name
[adTestGenPrimal env envGenerator expr exprS primalfun primalSfun
- ,adTestGenFwd env envGenerator expr exprS
- ,adTestGenChad env envGenerator expr exprS primalSfun]
+ ,adTestGenFwd env envGenerator exprS
+ ,testGroup "chad"
+ [adTestGenChad "default" defaultConfig env envGenerator expr exprS primalSfun
+ ,adTestGenChad "accum" (chcSetAccum defaultConfig) env envGenerator expr exprS primalSfun]]
adTestGenPrimal :: SList STy env -> Gen (SList Value env)
- -> Ex env (TScal TF64) -> Ex env (TScal TF64)
+ -> Ex env R -> Ex env R
-> (SList Value env -> IO Double) -> (SList Value env -> IO Double)
-> TestTree
adTestGenPrimal env envGenerator expr exprS primalfun primalSfun =
testProperty "compile primal" $ property $ do
input <- forAllWith (showEnv env) envGenerator
- let outPrimalI = interpretOpen False input expr
- outPrimalC <- liftIO $ primalfun input
+ let outPrimalI = interpretOpen False env input expr
+ outPrimalC <- evalIO $ primalfun input
diff outPrimalI (closeIsh' 1e-8) outPrimalC
- let outPrimalSI = interpretOpen False input exprS
- outPrimalSC <- liftIO $ primalSfun input
+ let outPrimalSI = interpretOpen False env input exprS
+ outPrimalSC <- evalIO $ primalSfun input
diff outPrimalSI (closeIsh' 1e-8) outPrimalSC
adTestGenFwd :: SList STy env -> Gen (SList Value env)
- -> Ex env (TScal TF64) -> Ex env (TScal TF64)
+ -> Ex env R
-> TestTree
-adTestGenFwd env envGenerator expr exprS =
+adTestGenFwd env envGenerator exprS =
withCompiled (dne env) (dfwdDN exprS) $ \dnfun ->
testProperty "compile fwdAD" $ property $ do
input <- forAllWith (showEnv env) envGenerator
dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input
- let (outDNI1, outDNI2) = interpretOpen False dinput (dfwdDN expr)
- (outDNC1, outDNC2) <- liftIO $ dnfun dinput
+ let (outDNI1, outDNI2) = interpretOpen False (dne env) dinput (dfwdDN exprS)
+ (outDNC1, outDNC2) <- evalIO $ dnfun dinput
diff outDNI1 (closeIsh' 1e-8) outDNC1
diff outDNI2 (closeIsh' 1e-8) outDNC2
-adTestGenChad :: forall env. SList STy env -> Gen (SList Value env)
- -> Ex env (TScal TF64) -> Ex env (TScal TF64)
+adTestGenChad :: forall env. String -> CHADConfig -> SList STy env -> Gen (SList Value env)
+ -> Ex env R -> Ex env R
-> (SList Value env -> IO Double)
-> TestTree
-adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
+adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- envKnown env =
+ let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env expr
+ dtermChadS = simplifyFix dtermChad0
+ dtermChadSUS = simplifyFix $ unMonoid dtermChadS
+ dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' config env exprS
+ dtermSChadS = simplifyFix dtermSChad0
+ dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS
+ in
withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
- testProperty "chad" $ property $ do
+ withCompiled env dtermSChadSUS $ \dcompSChadSUS ->
+ testProperty testname $ property $ do
annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
- let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr
- dtermChadS = simplifyFix dtermChad0
- let dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS
- dtermSChadS = simplifyFix dtermSChad0
-
- -- pack Text for less GC pressure (these values are retained for some reason)
+ -- check simplifier convergence; pack Text for less GC pressure (these values are retained for some reason)
diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0)))
+ diff (T.pack (ppExpr env dtermChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermChad0)))
diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0)))
+ diff (T.pack (ppExpr env dtermSChadSUS)) (==) (T.pack (ppExpr env (simplifyN 20 $ unMonoid $ simplifyN 20 dtermSChad0)))
input <- forAllWith (showEnv env) envGenerator
- outPrimal <- liftIO $ primalSfun input
+ outPrimal <- evalIO $ primalSfun input
let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
unpackGrad = unTup vUnpair (d2e env) . Value
- let scFwd = tanEScalars env $ gradientByForward fwdartifactC input
-
- let (outChad0, gradChad0) = second unpackGrad $ interpretOpen False input dtermChad0
- (outChadS, gradChadS) = second unpackGrad $ interpretOpen False input dtermChadS
- (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False input dtermSChad0
- (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False input dtermSChadS
- scChad = tanEScalars env $ toTanE env input gradChad0
- scChadS = tanEScalars env $ toTanE env input gradChadS
- scSChad = tanEScalars env $ toTanE env input gradSChad0
- scSChadS = tanEScalars env $ toTanE env input gradSChadS
-
- -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
- -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
+ let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input
+
+ let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0
+ (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS
+ (outChadSUS , gradChadSUS) = second unpackGrad $ interpretOpen False env input dtermChadSUS
+ (outSChad0 , gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0
+ (outSChadS , gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS
+ (outSChadSUS, gradSChadSUS) = second unpackGrad $ interpretOpen False env input dtermSChadSUS
+ tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0
+ tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS
+ tansChadSUS = TypedEnv (tanenv env) $ toTanE env input gradChadSUS
+ tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0
+ tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS
+ tansSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradSChadSUS
+
+ (outCompSChadSUS, gradCompSChadSUS) <- second unpackGrad <$> evalIO (dcompSChadSUS input)
+ let tansCompSChadSUS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadSUS
+
+ -- annotate (showEnv (d2e env) gradChad0)
+ -- annotate (showEnv (d2e env) gradChadS)
-- annotate (ppExpr knownEnv expr)
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
- diff outChad0 closeIsh outPrimal
- diff outChadS closeIsh outPrimal
- diff outSChad0 closeIsh outPrimal
- diff outSChadS closeIsh outPrimal
- diff scChad (\x y -> and (zipWith closeIsh x y)) scFwd
- diff scChadS (\x y -> and (zipWith closeIsh x y)) scFwd
- diff scSChad (\x y -> and (zipWith closeIsh x y)) scFwd
- diff scSChadS (\x y -> and (zipWith closeIsh x y)) scFwd
+ annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS)))
+ diff outChad0 closeIsh outPrimal
+ diff outChadS closeIsh outPrimal
+ diff outChadSUS closeIsh outPrimal
+ diff outSChad0 closeIsh outPrimal
+ diff outSChadS closeIsh outPrimal
+ diff outSChadSUS closeIsh outPrimal
+ diff outCompSChadSUS closeIsh outPrimal
+ let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2)
+ diff tansChad closeIshE' tansFwd
+ diff tansChadS closeIshE' tansFwd
+ diff tansChadSUS closeIshE' tansFwd
+ diff tansSChad closeIshE' tansFwd
+ diff tansSChadS closeIshE' tansFwd
+ diff tansSChadSUS closeIshE' tansFwd
+ diff tansCompSChadSUS closeIshE' tansFwd
withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
withCompiled env expr = withResource (compile env expr) (\_ -> pure ())
-term_build1_sum :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
+gen_gmm :: Gen (SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64])
+gen_gmm = do
+ -- The input ranges here are completely arbitrary.
+ let tR = STScal STF64
+ kN <- Gen.integral (Range.linear 1 8)
+ kD <- Gen.integral (Range.linear 1 8)
+ kK <- Gen.integral (Range.linear 1 8)
+ let i2i64 = fromIntegral @Int @Int64
+ valpha <- genArray tR (ShNil `ShCons` kK)
+ vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
+ vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
+ vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
+ vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD)
+ vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
+ vm <- Gen.integral (Range.linear 0 5)
+ let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
+ k2 = 0.5 * vgamma * vgamma
+ k3 = 0.42 -- don't feel like multigammaing today
+ return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
+ Value vm `SCons` vX `SCons`
+ vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
+ Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
+ SNil)
+
+gen_neural :: Gen (SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)])
+gen_neural = do
+ let tR = STScal STF64
+ let genLayer nin nout =
+ liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
+ <*> genArray tR (ShNil `ShCons` nout)
+ nin <- Gen.integral (Range.linear 1 10)
+ n1 <- Gen.integral (Range.linear 1 10)
+ n2 <- Gen.integral (Range.linear 1 10)
+ input <- genArray tR (ShNil `ShCons` nin)
+ lay1 <- genLayer nin n1
+ lay2 <- genLayer n1 n2
+ lay3 <- genArray tR (ShNil `ShCons` n2)
+ return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)
+
+term_build0 :: Ex '[TArr N0 R] R
+term_build0 = fromNamed $ lambda @(TArr N0 _) #x $ body $
+ idx0 $
+ build SZ (shape #x) $ #idx :-> #x ! #idx
+
+term_build1_sum :: Ex '[TVec R] R
term_build1_sum = fromNamed $ lambda #x $ body $
idx0 $ sum1i $
build (SS SZ) (shape #x) $ #idx :-> #x ! #idx
-term_pairs :: Ex [TScal TF64, TScal TF64] (TScal TF64)
+term_build1_idx :: Ex '[TVec R] R
+term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $
+ let_ #n (snd_ (shape #x)) $
+ idx0 $ sum1i $
+ build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i))
+
+term_idx_coprod :: Ex '[TVec (TEither R R)] R
+term_idx_coprod = fromNamed $ lambda @(TVec (TEither R R)) #x $ body $
+ let_ #n (snd_ (shape #x)) $
+ idx0 $ sum1i $ build1 #n $ #i :->
+ case_ (#x ! pair nil #i)
+ (#a :-> #a * 2)
+ (#b :-> #b * 3)
+
+term_pairs :: Ex [R, R] R
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
let_ #p (pair #x #y) $
let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
fst_ #q * #x + snd_ #q * fst_ #p
-term_sparse :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
+term_sparse :: Ex '[TVec R] R
term_sparse = fromNamed $ lambda #inp $ body $
let_ #n (snd_ (shape #inp)) $
let_ #arr (build1 #n (#i :-> #inp ! pair nil #i)) $
@@ -314,7 +488,7 @@ term_sparse = fromNamed $ lambda #inp $ body $
let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $
idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c)
-term_regression_simpl1 :: Ex '[TArr N1 (TScal TF64)] (TScal TF64)
+term_regression_simpl1 :: Ex '[TVec R] R
term_regression_simpl1 = fromNamed $ lambda #q $ body $
idx0 $ sum1i $ build (SS SZ) (shape #q) $ #idx :->
let_ #j (snd_ #idx) $
@@ -322,8 +496,8 @@ term_regression_simpl1 = fromNamed $ lambda #q $ body $
(#q ! pair nil 0)
(if_ (#j .== #j) 1.0 2.0)
-term_mulmatvec :: Ex [TArr N1 (TScal TF64), TArr N2 (TScal TF64)] (TScal TF64)
-term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec $ body $
+term_mulmatvec :: Ex [TVec R, TMat R] R
+term_mulmatvec = fromNamed $ lambda #mat $ lambda #vec $ body $
idx0 $ sum1i $
let_ #hei (snd_ (fst_ (shape #mat))) $
let_ #wid (snd_ (shape #mat)) $
@@ -331,8 +505,59 @@ term_mulmatvec = fromNamed $ lambda @(TArr N2 _) #mat $ lambda @(TArr N1 _) #vec
idx0 (sum1i (build1 #wid $ #j :->
#mat ! pair (pair nil #i) #j * #vec ! pair nil #j))
-tests :: TestTree
-tests = testGroup "AD"
+term_arr_rebind :: Ex '[I64, TVec R] R
+term_arr_rebind = fromNamed $ lambda #a $ lambda #k $ body $
+ let_ #n (if_ (#k .< length_ #a) #k (length_ #a)) $
+ let_ #b (build1 #n (#i :-> #a ! pair nil #i)) $
+ let_ #p (if_ (#n `mod_` 2 .== 1)
+ (pair #a #b)
+ (pair (map_ (#x :-> #x + 1) #a) #b)) $
+ if_ (#n `mod_` 3 .== 1)
+ (idx0 (sum1i (snd_ #p)))
+ (let_ #b' (snd_ #p) $
+ idx0 (sum1i #b') * idx0 (sum1i (map_ (#x :-> 2 * #x) #b')))
+
+-- This simplifies away to a pointless test, but is helpful for debugging what
+-- term_arr_rebind is supposed to test in a REPL
+term_arr_rebind_simple :: Ex '[TVec R] R
+term_arr_rebind_simple = fromNamed $ lambda #a $ body $
+ let_ #b (build1 (length_ #a) (#i :-> 5 * (#a ! pair nil #i))) $
+ let_ #c #b $
+ let_ #d #c $
+ idx0 (sum1i #d)
+
+tests_Compile :: TestTree
+tests_Compile = testGroup "Compile"
+ [compileTest "accum f64" $ fromNamed $ lambda #b $ lambda #x $ body $
+ with @R 0.0 $ #ac :->
+ if_ #b (accum SAPHere nil #x #ac)
+ nil
+
+ ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $
+ with @(TPair R R) (pair 0.0 0.0) $ #ac :->
+ let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $
+ let_ #_ (accum SAPHere nil #x #ac) $
+ let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $
+ nil
+
+ ,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $
+ with @(TMaybe (TPair R R)) (just (pair 0 0)) $ #ac :->
+ let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) nil 3.0 #ac) nil) $
+ let_ #_ (accum SAPHere nil #x #ac) $
+ let_ #_ (accum (SAPJust (SAPSnd SAPHere)) nil 4.0 #ac) $
+ nil
+
+ ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $
+ let_ #len (snd_ (shape #x)) $
+ with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :->
+ let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair nil 2) nil) 6.0 #ac)
+ nil) $
+ let_ #_ (accum SAPHere nil #x #ac) $
+ nil
+ ]
+
+tests_AD :: TestTree
+tests_AD = testGroup "AD"
[adTest "id" $ fromNamed $ lambda #x $ body $ #x
,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x
@@ -361,28 +586,47 @@ tests = testGroup "AD"
,adTest "pairs" term_pairs
- ,adTest "build0 const" $ fromNamed $ lambda @(TScal TF64) #x $ body $
+ ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $
idx0 $ build SZ nil $ #idx :-> const_ 0.0
- ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $
- idx0 $
- build SZ (shape #x) $ #idx :-> #x ! #idx
+ ,adTest "build0" term_build0
,adTest "build1-sum" term_build1_sum
- ,adTest "build2-sum" $ fromNamed $ lambda @(TArr N2 _) #x $ body $
+ ,adTest "build2-sum" $ fromNamed $ lambda @(TMat _) #x $ body $
idx0 $ sum1i . sum1i $
build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx
+ ,adTest "build1-idx" term_build1_idx
+
+ ,adTest "idx-pair" $ fromNamed $ lambda @(TVec (TPair R R)) #x $ body $
+ let_ #n (snd_ (shape #x)) $
+ idx0 $ sum1i $ build1 #n $ #i :->
+ let_ #p (#x ! pair nil #i) $
+ 3 * fst_ #p + 2 * snd_ #p
+
+ ,adTest "idx-coprod" $ term_idx_coprod
+
+ ,adTest "idx-coprod-pair" $ fromNamed $ lambda @(TVec R) #arr $ body $
+ let_ #n (snd_ (shape #arr)) $
+ let_ #b (build1 #n (#i :-> let_ #x (#arr ! pair nil #i) $
+ if_ (#x .>= 1) (pair (inl (pair #x (7 * #x))) (2 * #x))
+ (pair (inr (3 * #x)) (exp #x)))) $
+ idx0 $ sum1i $ build1 #n $ #i :->
+ let_ #p (#b ! pair nil #i) $
+ case_ (fst_ #p)
+ (#a :-> fst_ #a * 2 + snd_ #a * snd_ #p)
+ (#b :-> #b * 4)
+
,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
- fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
+ fromNamed $ lambda @(TMat R) #x $ body $
idx0 $ sum1i $ maximum1i #x
,adTestCon "minimum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
- fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
+ fromNamed $ lambda @(TMat R) #x $ body $
idx0 $ sum1i $ minimum1i #x
- ,adTest "unused" $ fromNamed $ lambda @(TArr N1 (TScal TF64)) #x $ body $
+ ,adTest "unused" $ fromNamed $ lambda @(TVec R) #x $ body $
let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
42
@@ -391,62 +635,33 @@ tests = testGroup "AD"
-- Regression test for a simplifier bug (89b78d4)
,adTestTp "regression-simpl1" (C "" 1) term_regression_simpl1
- ,adTestGen "neural" Example.neural genNeural
+ -- Regression test for refcounts when indexing in nested arrays
+ ,adTestTp "regression-idx1" (C "" 1 :$ C "" 1) $
+ fromNamed $ lambda @(TMat R) #L $ body $
+ if_ (const_ @TI64 1 .> 0)
+ (idx0 $ sum1i (build1 1 $ #_ :->
+ idx0 (sum1i (build1 1 $ #_ :->
+ #L ! pair (pair nil 0) 0 * #L ! pair (pair nil 0) 0))))
+ 42
+
+ ,adTest "arr-rebind-simple" term_arr_rebind_simple
+ ,adTestTp "arr-rebind" (NC :& C "" 0) term_arr_rebind
- ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) genNeural
+ ,adTestGen "neural" Example.neural gen_neural
,adTestTp "logsumexp" (C "" 1) $
- fromNamed $ lambda @(TArr N1 _) #vec $ body $
+ fromNamed $ lambda @(TVec _) #vec $ body $
let_ #m (maximum1i #vec) $
log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m
- ,adTestTp "mulmatvec" ((C "" 0 :$ C "n" 0) :& C "n" 0) term_mulmatvec
-
- ,adTestGen "gmm-wrong" (Example.gmmObjective True) genGMM
+ ,adTestTp "mulmatvec" ((NC :$ C "n" 0) :& C "n" 0) term_mulmatvec
- ,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM
+ ,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm
- ,adTestGen "gmm" (Example.gmmObjective False) genGMM
-
- ,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM
+ ,adTestGen "gmm" (Example.gmmObjective False) gen_gmm
]
- where
- genGMM = do
- -- The input ranges here are completely arbitrary.
- let tR = STScal STF64
- kN <- Gen.integral (Range.linear 1 8)
- kD <- Gen.integral (Range.linear 1 8)
- kK <- Gen.integral (Range.linear 1 8)
- let i2i64 = fromIntegral @Int @Int64
- valpha <- genArray tR (ShNil `ShCons` kK)
- vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
- vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD)
- vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2))
- vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD)
- vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
- vm <- Gen.integral (Range.linear 0 5)
- let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi)
- k2 = 0.5 * vgamma * vgamma
- k3 = 0.42 -- don't feel like multigammaing today
- return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons`
- Value vm `SCons` vX `SCons`
- vL `SCons` vQ `SCons` vM `SCons` valpha `SCons`
- Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
- SNil)
-
- genNeural = do
- let tR = STScal STF64
- let genLayer nin nout =
- liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
- <*> genArray tR (ShNil `ShCons` nout)
- nin <- Gen.integral (Range.linear 1 10)
- n1 <- Gen.integral (Range.linear 1 10)
- n2 <- Gen.integral (Range.linear 1 10)
- input <- genArray tR (ShNil `ShCons` nin)
- lay1 <- genLayer nin n1
- lay2 <- genLayer n1 n2
- lay3 <- genArray tR (ShNil `ShCons` n2)
- return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)
main :: IO ()
-main = defaultMain tests
+main = defaultMain $ testGroup "All"
+ [tests_Compile
+ ,tests_AD]