summaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-29 20:37:06 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-29 20:38:05 +0200
commitd0eb9a1edfb4233d557d954f46685f25382234d8 (patch)
tree04eb5a746258fcaa2a3b98228c6eadb2b0178ba3 /test/Main.hs
parent4ad7eaba73d5fda8ff5028d1e53966f728d704d3 (diff)
Reorder TLEither to after TEither
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs20
1 files changed, 10 insertions, 10 deletions
diff --git a/test/Main.hs b/test/Main.hs
index f5e4a3c..1b83a2e 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -85,6 +85,9 @@ extendDN STNil () = pure ()
extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y
extendDN (STEither a _) (Left x) = Left <$> extendDN a x
extendDN (STEither _ b) (Right y) = Right <$> extendDN b y
+extendDN (STLEither _ _) Nothing = pure Nothing
+extendDN (STLEither a _) (Just (Left x)) = Just . Left <$> extendDN a x
+extendDN (STLEither _ b) (Just (Right y)) = Just . Right <$> extendDN b y
extendDN (STMaybe _) Nothing = pure Nothing
extendDN (STMaybe t) (Just x) = Just <$> extendDN t x
extendDN (STArr _ t) arr = traverse (extendDN t) arr
@@ -95,9 +98,6 @@ extendDN (STScal sty) x = case sty of
STI64 -> pure x
STBool -> pure x
extendDN (STAccum _) _ = error "Accumulators not supported in input program"
-extendDN (STLEither _ _) Nothing = pure Nothing
-extendDN (STLEither a _) (Just (Left x)) = Just . Left <$> extendDN a x
-extendDN (STLEither _ b) (Just (Right y)) = Just . Right <$> extendDN b y
extendDNE :: SList STy env -> SList Value env -> Gen (SList Value (DNE env))
extendDNE SNil SNil = pure SNil
@@ -116,6 +116,10 @@ closeIshT' h (STPair a b) (x, y) (x', y') = closeIshT' h a x x' && closeIshT' h
closeIshT' h (STEither a _) (Left x) (Left x') = closeIshT' h a x x'
closeIshT' h (STEither _ b) (Right x) (Right x') = closeIshT' h b x x'
closeIshT' _ STEither{} _ _ = False
+closeIshT' _ (STLEither _ _) Nothing Nothing = True
+closeIshT' h (STLEither a _) (Just (Left x)) (Just (Left x')) = closeIshT' h a x x'
+closeIshT' h (STLEither _ b) (Just (Right y)) (Just (Right y')) = closeIshT' h b y y'
+closeIshT' _ STLEither{} _ _ = False
closeIshT' _ (STMaybe _) Nothing Nothing = True
closeIshT' h (STMaybe a) (Just x) (Just x') = closeIshT' h a x x'
closeIshT' _ STMaybe{} _ _ = False
@@ -128,10 +132,6 @@ closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y)
closeIshT' h (STScal STF64) x y = closeIsh' h x y
closeIshT' _ (STScal STBool) x y = x == y
closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators"
-closeIshT' _ (STLEither _ _) Nothing Nothing = True
-closeIshT' h (STLEither a _) (Just (Left x)) (Just (Left x')) = closeIshT' h a x x'
-closeIshT' h (STLEither _ b) (Just (Right y)) (Just (Right y')) = closeIshT' h b y y'
-closeIshT' _ STLEither{} _ _ = False
closeIshT :: STy t -> Rep t -> Rep t -> Bool
closeIshT = closeIshT' 1e-5
@@ -233,6 +233,9 @@ genValue topty tpl = case topty of
STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl)
STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a)
,liftV Right <$> genValue b (emptyTpl b)]
+ STLEither a b -> Gen.frequency [(1, pure (Value Nothing))
+ ,(8, liftV (Just . Left) <$> genValue a (emptyTpl a))
+ ,(8, liftV (Just . Right) <$> genValue b (emptyTpl b))]
STMaybe t -> Gen.choice [return (Value Nothing)
,liftV Just <$> genValue t (emptyTpl t)]
STArr n t -> genShape n tpl >>= lift . genArray t
@@ -243,9 +246,6 @@ genValue topty tpl = case topty of
STI64 -> genInt
STBool -> Gen.choice [return (Value False), return (Value True)]
STAccum{} -> error "Cannot generate inputs for accumulators"
- STLEither a b -> Gen.frequency [(1, pure (Value Nothing))
- ,(8, liftV (Just . Left) <$> genValue a (emptyTpl a))
- ,(8, liftV (Just . Right) <$> genValue b (emptyTpl b))]
where
genInt :: (Integral (Rep t), Tpl t ~ TplConstr _q) => StateT (Map String Int) Gen (Value t)
genInt = do