summaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs55
1 files changed, 40 insertions, 15 deletions
diff --git a/test/Main.hs b/test/Main.hs
index 5fa1d46..afbd79b 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
@@ -117,18 +118,22 @@ closeIshT = closeIshT' 1e-5
data a :$ b = a :$ b deriving (Show) ; infixl :$
--- An empty name means "no restrictions".
-data TplConstr = C String -- ^ name; @""@ means anonymous
- Int -- ^ minimum value to generate
+-- | The type index is just a marker that helps typed holes show what (type of)
+-- argument this template constraint belongs to.
+data TplConstr a = C String -- ^ name; @""@ means anonymous
+ Int -- ^ minimum value to generate
+ | NC -- ^ no constraints
type family DimNames n where
DimNames Z = ()
- DimNames (S Z) = TplConstr
- DimNames (S n) = DimNames n :$ TplConstr
+ DimNames (S Z) = TplConstr (S Z)
+ DimNames (S n) = DimNames n :$ TplConstr (S n)
type family Tpl t where
Tpl (TArr n t) = DimNames n
Tpl (TPair a b) = (Tpl a, Tpl b)
+ Tpl (TScal TI32) = TplConstr TI32
+ Tpl (TScal TI64) = TplConstr TI64
-- If you add equations here, don't forget to update genValue! It currently
-- just emptyTpl's things out.
Tpl _ = ()
@@ -142,13 +147,17 @@ type family TemplateE env where
emptyDimNames :: SNat n -> DimNames n
emptyDimNames SZ = ()
-emptyDimNames (SS SZ) = C "" 0
-emptyDimNames (SS n@SS{}) = emptyDimNames n :$ C "" 0
+emptyDimNames (SS SZ) = NC
+emptyDimNames (SS n@SS{}) = emptyDimNames n :$ NC
emptyTpl :: STy t -> Tpl t
emptyTpl (STArr n _) = emptyDimNames n
emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b)
-emptyTpl (STScal _) = ()
+emptyTpl (STScal STI32) = NC
+emptyTpl (STScal STI64) = NC
+emptyTpl (STScal STF32) = ()
+emptyTpl (STScal STF64) = ()
+emptyTpl (STScal STBool) = ()
emptyTpl _ = error "too lazy"
emptyTemplateE :: SList STy env -> TemplateE env
@@ -168,7 +177,8 @@ genShape = \n tpl -> do
genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name
genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name
- genNamedDim :: TplConstr -> StateT (Map String Int) Gen Int
+ genNamedDim :: TplConstr a -> StateT (Map String Int) Gen Int
+ genNamedDim NC = genDim 0
genNamedDim (C "" lo) = genDim lo
genNamedDim (C name lo) = gets (Map.lookup name) >>= \case
Nothing -> do
@@ -182,15 +192,17 @@ genShape = \n tpl -> do
shapeDiv :: Shape n -> DimNames n -> Int -> Shape n
shapeDiv ShNil _ _ = ShNil
- shapeDiv (ShNil `ShCons` n) (C _ lo) f = ShNil `ShCons` (max lo (n `div` f))
+ shapeDiv (ShNil `ShCons` n) ( C _ lo) f = ShNil `ShCons` (max lo (n `div` f))
shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` (max lo (n `div` f))
+ shapeDiv (ShNil `ShCons` n) NC f = ShNil `ShCons` (n `div` f)
+ shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ NC) f = shapeDiv sh tpl f `ShCons` (n `div` f)
genArray :: STy a -> Shape n -> Gen (Value (TArr n a))
genArray t sh =
Value <$> arrayGenerateLinM sh (\_ ->
unValue <$> evalStateT (genValue t (emptyTpl t)) mempty)
-genValue :: STy t -> Tpl t -> StateT (Map String Int) Gen (Value t)
+genValue :: forall t. STy t -> Tpl t -> StateT (Map String Int) Gen (Value t)
genValue topty tpl = case topty of
STNil -> return (Value ())
STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl)
@@ -202,10 +214,23 @@ genValue topty tpl = case topty of
STScal sty -> case sty of
STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
- STI32 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
- STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
+ STI32 -> genInt
+ STI64 -> genInt
STBool -> Gen.choice [return (Value False), return (Value True)]
STAccum{} -> error "Cannot generate inputs for accumulators"
+ where
+ genInt :: (Integral (Rep t), Tpl t ~ TplConstr _q) => StateT (Map String Int) Gen (Value t)
+ genInt = do
+ let gen lo = Gen.integral (Range.linearFrom 0 lo (max 10 (lo + 10)))
+ val <- case tpl of
+ NC -> gen (-10)
+ C name lo -> gets (Map.lookup name) >>= \case
+ Nothing -> do
+ val <- fromIntegral @Int @(Rep t) <$> gen lo
+ modify (Map.insert name (fromIntegral @(Rep t) @Int val))
+ return val
+ Just val -> return (fromIntegral @Int @(Rep t) val)
+ return (Value val)
genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env)
genEnv SNil () = return SNil
@@ -538,7 +563,7 @@ tests_AD = testGroup "AD"
42
,adTest "arr-rebind-simple" term_arr_rebind_simple
- ,adTest "arr-rebind" term_arr_rebind
+ ,adTestTp "arr-rebind" (NC :& C "" 0) term_arr_rebind
,adTestGen "neural" Example.neural gen_neural
@@ -549,7 +574,7 @@ tests_AD = testGroup "AD"
let_ #m (maximum1i #vec) $
log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m
- ,adTestTp "mulmatvec" ((C "" 0 :$ C "n" 0) :& C "n" 0) term_mulmatvec
+ ,adTestTp "mulmatvec" ((NC :$ C "n" 0) :& C "n" 0) term_mulmatvec
,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm