From f2cec69969a68e8feed3dceacef5186b1debdda5 Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Tue, 16 Dec 2025 09:51:51 +0100 Subject: Make ShR a newtype over ShX --- bench/Main.hs | 4 +- src/Data/Array/Nested/Convert.hs | 3 + src/Data/Array/Nested/Lemmas.hs | 14 +++ src/Data/Array/Nested/Mixed/Shape.hs | 21 ++-- src/Data/Array/Nested/Permutation.hs | 8 +- src/Data/Array/Nested/Ranked/Base.hs | 6 +- src/Data/Array/Nested/Ranked/Shape.hs | 176 ++++++++++++++++++++++++---------- src/Data/Array/Nested/Types.hs | 2 +- test/Gen.hs | 5 +- test/Tests/C.hs | 39 ++++---- 10 files changed, 182 insertions(+), 96 deletions(-) diff --git a/bench/Main.hs b/bench/Main.hs index 2058e77..8fe0fdc 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -9,7 +9,6 @@ import Control.Monad (when) import Data.Array.Internal qualified as OI import Data.Array.Internal.RankedG qualified as RG import Data.Array.Internal.RankedS qualified as RS -import Data.Foldable (toList) import Data.Vector.Storable qualified as VS import Numeric.LinearAlgebra qualified as LA import Test.Tasty.Bench @@ -19,6 +18,7 @@ import Data.Array.Nested import Data.Array.Nested.Mixed (Mixed(M_Primitive), mliftPrim, mliftPrim2, toPrimitive) import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked (liftRanked1, liftRanked2) +import Data.Array.Nested.Ranked.Shape import Data.Array.Strided.Arith.Internal qualified as Arith import Data.Array.XArray (XArray(..)) @@ -40,7 +40,7 @@ main_tests = defaultMain let showSh l = showListWith (\n -> let ln = round (logBase 10 (fromIntegral n :: Double)) :: Int in if n > 1 && n == 10 ^ ln then showString ("1e" ++ show ln) else shows n) l "" - in bench (name ++ " " ++ showSh (toList (rshape inp1)) ++ + in bench (name ++ " " ++ showSh (shrToList (rshape inp1)) ++ " str " ++ showSh (stridesOf inp1) ++ " " ++ showSh (stridesOf inp2)) $ nf (\(a,b) -> rsumAllPrim (rdot1Inner a b)) (inp1, inp2) diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 3706105..d4d1cea 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -57,6 +57,9 @@ import Data.Array.Nested.Types -- * To ranked +-- TODO: change all those unsafeCoerces into coerces by defining shaped +-- and ranekd index types as newtypes of the mixed index type +-- and similarly for the sized lists ixrFromIxS :: IxS sh i -> IxR (Rank sh) i ixrFromIxS = unsafeCoerce diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs index e089479..fa5611b 100644 --- a/src/Data/Array/Nested/Lemmas.hs +++ b/src/Data/Array/Nested/Lemmas.hs @@ -56,6 +56,20 @@ lemReplicatePlusApp sn _ _ = go sn -} lemReplicatePlusApp _ _ _ = unsafeCoerceRefl +lemReplicateEmpty :: proxy n -> Replicate n (Nothing @Nat) :~: '[] -> n :~: 0 +lemReplicateEmpty _ Refl = unsafeCoerceRefl + +-- TODO: make less ad-hoc and rename these three: +lemReplicateCons :: proxy sh -> proxy' n1 -> Nothing : sh :~: Replicate n1 Nothing -> n1 :~: Rank sh + 1 +lemReplicateCons _ _ Refl = unsafeCoerceRefl + +lemReplicateCons2 :: proxy sh -> proxy' n1 -> Nothing : sh :~: Replicate n1 Nothing -> sh :~: Replicate (Rank sh) Nothing +lemReplicateCons2 _ _ Refl = unsafeCoerceRefl + +lemReplicateSucc2 :: forall n1 n proxy. + proxy n1 -> n + 1 :~: n1 -> Nothing @Nat : Replicate n Nothing :~: Replicate n1 Nothing +lemReplicateSucc2 _ _ = unsafeCoerceRefl + lemDropLenApp :: Rank l1 <= Rank l2 => Proxy l1 -> Proxy l2 -> Proxy rest -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest) diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 7c79f8b..5ffd40c 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -549,11 +549,11 @@ shxFromList topssh topl = go topssh topl {-# INLINEABLE shxToList #-} shxToList :: IShX sh -> [Int] -shxToList list = build (\(cons :: i -> is -> is) (nil :: is) -> +shxToList sh0 = build (\(cons :: i -> is -> is) (nil :: is) -> let go :: IShX sh -> is go ZSX = nil go (smn :$% sh) = fromSMayNat' smn `cons` go sh - in go list) + in go sh0) -- If it ever matters for performance, this is unsafeCoercible. shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i @@ -578,6 +578,10 @@ shxHead (ShX list) = listhHead list shxTail :: ShX (n : sh) i -> ShX sh i shxTail (ShX list) = ShX (listhTail list) +shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSSX _ ZKX _ = ZSX +shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh + shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i shxDropSSX = coerce (listhDrop @i @()) @@ -594,10 +598,6 @@ shxInit = coerce (listhInit @i) shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh)) shxLast = coerce (listhLast @i) -shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i -shxTakeSSX _ ZKX _ = ZSX -shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh - {-# INLINE shxZipWith #-} shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n) -> ShX sh i -> ShX sh j -> ShX sh k @@ -690,14 +690,13 @@ ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') ssxEqType = testEquality ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') -ssxAppend ZKX sh' = sh' -ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' +ssxAppend = coerce (listhAppend @_ @()) ssxHead :: StaticShX (n : sh) -> SMayNat () n ssxHead (StaticShX list) = listhHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh -ssxTail (_ :!% ssh) = ssh +ssxTail (StaticShX list) = StaticShX (listhTail list) ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh ssxTakeIx _ (IxX ZX) _ = ZKX @@ -795,8 +794,8 @@ instance KnownShX sh => IsList (IxX sh i) where toList = Foldable.toList -- | Untyped: length and known dimensions are checked (at runtime). -instance KnownShX sh => IsList (ShX sh Int) where - type Item (ShX sh Int) = Int +instance KnownShX sh => IsList (IShX sh) where + type Item (IShX sh) = Int fromList = shxFromList (knownShX @sh) toList = shxToList diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index 2e0c1ca..c3d2075 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -214,8 +214,8 @@ ssxDropLen = coerce (listhDropLen @()) ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) ssxPermute = coerce (listhPermute @()) -ssxIndex :: SNat i -> StaticShX sh -> SMayNat () (Index i sh) -ssxIndex i = coerce (listhIndex @() i) +ssxIndex :: SNat k -> StaticShX sh -> SMayNat () (Index k sh) +ssxIndex k = coerce (listhIndex @() k) ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) ssxPermutePrefix = coerce (listhPermutePrefix @()) @@ -229,8 +229,8 @@ shxDropLen = coerce (listhDropLen @Int) shxPermute :: Perm is -> IShX sh -> IShX (Permute is sh) shxPermute = coerce (listhPermute @Int) -shxIndex :: SNat i -> IShX sh -> SMayNat Int (Index i sh) -shxIndex i = coerce (listhIndex @Int i) +shxIndex :: forall k sh i. SNat k -> ShX sh i -> SMayNat i (Index k sh) +shxIndex k = coerce (listhIndex @i k) shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) shxPermutePrefix = coerce (listhPermutePrefix @Int) diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index 97a5f6f..5c696f3 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -32,10 +32,6 @@ import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits -#ifndef OXAR_DEFAULT_SHOW_INSTANCES -import Data.Foldable (toList) -#endif - import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape @@ -65,7 +61,7 @@ deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) #ifndef OXAR_DEFAULT_SHOW_INSTANCES instance (Show a, Elt a) => Show (Ranked n a) where showsPrec d arr@(Ranked marr) = - let sh = show (toList (rshape arr)) + let sh = show (shrToList (rshape arr)) in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr #endif diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 36f49dc..6ce0f4f 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -41,7 +41,9 @@ import GHC.TypeLits import GHC.TypeNats qualified as TN import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Mixed.Shape.Internal +import Data.Array.Nested.Permutation import Data.Array.Nested.Types @@ -180,7 +182,12 @@ listrZipWith f (i ::: irest) (j ::: jrest) = listrZipWith _ _ _ = error "listrZipWith: impossible pattern needlessly required" -listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i +listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) +listrSplitAt SZ sh = (ZR, sh) +listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) +listrSplitAt SS{} ZR = error "m' + 1 <= 0" + +listrPermutePrefix :: forall i n. PermR -> ListR n i -> ListR n i listrPermutePrefix = \perm sh -> TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> case listrRank sh of { shlen@SNat -> @@ -192,11 +199,6 @@ listrPermutePrefix = \perm sh -> ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" } where - listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) - listrSplitAt SZ sh = (ZR, sh) - listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) - listrSplitAt SS{} ZR = error "m' + 1 <= 0" - applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i applyPermRFull _ ZR _ = ZR applyPermRFull sm@SNat (i ::: perm) l = @@ -282,7 +284,7 @@ ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 -ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i +ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i ixrPermutePrefix = coerce (listrPermutePrefix @i) -- | Given a multidimensional index, get the corresponding linear @@ -303,18 +305,34 @@ ixrToLinear = \sh i -> go sh i 0 type role ShR nominal representational type ShR :: Nat -> Type -> Type -newtype ShR n i = ShR (ListR n i) - deriving (Eq, Ord, NFData, Functor, Foldable) +newtype ShR n i = ShR (ShX (Replicate n Nothing) i) + deriving (Eq, Ord, NFData, Functor) pattern ZSR :: forall n i. () => n ~ 0 => ShR n i -pattern ZSR = ShR ZR +pattern ZSR <- ShR (matchZSR @n -> Just Refl) + where ZSR = ShR ZSX + +matchZSR :: forall n i. ShX (Replicate n Nothing) i -> Maybe (n :~: 0) +matchZSR ZSX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl +matchZSR _ = Nothing pattern (:$:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) - where i :$: ShR sh = ShR (i ::: sh) +pattern i :$: shl <- (shrUncons -> Just (UnconsShRRes shl i)) + where i :$: ShR shl | Refl <- lemReplicateSucc2 (Proxy @n1) Refl + = ShR (SUnknown i :$% shl) + +data UnconsShRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsShRRes (ShR n i) i +shrUncons :: forall n1 i. ShR n1 i -> Maybe (UnconsShRRes i n1) +shrUncons (ShR (SUnknown x :$% (sh' :: ShX sh' i))) + | Refl <- lemReplicateCons (Proxy @sh') (Proxy @n1) Refl + , Refl <- lemReplicateCons2 (Proxy @sh') (Proxy @n1) Refl + = Just (UnconsShRRes (ShR sh') x) +shrUncons (ShR _) = Nothing + infixr 3 :$: {-# COMPLETE ZSR, (:$:) #-} @@ -325,67 +343,125 @@ type IShR n = ShR n Int deriving instance Show i => Show (ShR n i) #else instance Show i => Show (ShR n i) where - showsPrec _ (ShR l) = listrShow shows l + showsPrec d (ShR l) = showsPrec d l #endif -- | This checks only whether the ranks are equal, not whether the actual -- values are. shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' +shrEqRank ZSR ZSR = Just Refl +shrEqRank (_ :$: sh) (_ :$: sh') + | Just Refl <- shrEqRank sh sh' + = Just Refl +shrEqRank _ _ = Nothing -- | This compares the shapes for value equality. shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' +shrEqual ZSR ZSR = Just Refl +shrEqual (i :$: sh) (i' :$: sh') + | Just Refl <- shrEqual sh sh' + , i == i' + = Just Refl +shrEqual _ _ = Nothing shrLength :: ShR sh i -> Int -shrLength (ShR l) = listrLength l +shrLength (ShR l) = shxLength l -- | This function can also be used to conjure up a 'KnownNat' dictionary; -- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern -- synonym yields 'KnownNat' evidence. -shrRank :: ShR n i -> SNat n -shrRank (ShR sh) = listrRank sh +shrRank :: forall n i. ShR n i -> SNat n +shrRank (ShR sh) | Refl <- lemRankReplicate (Proxy @n) = shxRank sh -- | The number of elements in an array described by this shape. shrSize :: IShR n -> Int -shrSize ZSR = 1 -shrSize (n :$: sh) = n * shrSize sh +shrSize (ShR sh) = shxSize sh -shrFromList :: forall n i. SNat n -> [i] -> ShR n i -shrFromList = coerce (listrFromList @_ @i) +shrFromList :: SNat n -> [Int] -> IShR n +shrFromList snat = coerce (shxFromList (ssxReplicate snat)) {-# INLINEABLE shrToList #-} -shrToList :: forall n i. ShR n i -> [i] -shrToList = coerce (listrToList @_ @i) - -shrHead :: ShR (n + 1) i -> i -shrHead (ShR list) = listrHead list - -shrTail :: ShR (n + 1) i -> ShR n i -shrTail (ShR list) = ShR (listrTail list) - -shrInit :: ShR (n + 1) i -> ShR n i -shrInit (ShR list) = ShR (listrInit list) - -shrLast :: ShR (n + 1) i -> i -shrLast (ShR list) = listrLast list +shrToList :: IShR n -> [Int] +shrToList = coerce shxToList + +shrHead :: forall n i. ShR (n + 1) i -> i +shrHead (ShR sh) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = case shxHead @Nothing @(Replicate n Nothing) sh of + SUnknown i -> i + +shrTail :: forall n i. ShR (n + 1) i -> ShR n i +shrTail + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = coerce (shxTail @_ @_ @i) + +shrInit :: forall n i. ShR (n + 1) i -> ShR n i +shrInit + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = -- TODO: change this and all other unsafeCoerceRefl to lemmas: + gcastWith (unsafeCoerceRefl + :: Init (Replicate (n + 1) (Nothing @Nat)) :~: Replicate n Nothing) $ + coerce (shxInit @_ @_ @i) + +shrLast :: forall n i. ShR (n + 1) i -> i +shrLast (ShR sh) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = case shxLast sh of + SUnknown i -> i + SKnown{} -> error "shrLast: impossible SKnown" -- | Performs a runtime check that the lengths are identical. shrCast :: SNat n' -> ShR n i -> ShR n' i -shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh) +shrCast SZ ZSR = ZSR +shrCast (SS n) (i :$: sh) = i :$: shrCast n sh +shrCast _ _ = error "shrCast: ranks don't match" shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i -shrAppend = coerce (listrAppend @_ @i) - -shrZip :: ShR n i -> ShR n j -> ShR n (i, j) -shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 +shrAppend = + -- lemReplicatePlusApp requires an SNat + gcastWith (unsafeCoerceRefl + :: Replicate n (Nothing @Nat) ++ Replicate m Nothing :~: Replicate (n + m) Nothing) $ + coerce (shxAppend @_ @_ @i) {-# INLINE shrZipWith #-} shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k -shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 - -shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i -shrPermutePrefix = coerce (listrPermutePrefix @i) +shrZipWith _ ZSR ZSR = ZSR +shrZipWith f (i :$: irest) (j :$: jrest) = + f i j :$: shrZipWith f irest jrest +shrZipWith _ _ _ = + error "shrZipWith: impossible pattern needlessly required" + +shrSplitAt :: m <= n' => SNat m -> ShR n' i -> (ShR m i, ShR (n' - m) i) +shrSplitAt SZ sh = (ZSR, sh) +shrSplitAt (SS m) (n :$: sh) = (\(pre, post) -> (n :$: pre, post)) (shrSplitAt m sh) +shrSplitAt SS{} ZSR = error "m' + 1 <= 0" + +shrIndex :: forall k sh i. SNat k -> ShR sh i -> i +shrIndex k (ShR sh) = case shxIndex @_ @_ @i k sh of + SUnknown i -> i + SKnown{} -> error "shrIndex: impossible SKnown" + +-- Copy-pasted from listrPermutePrefix, probably unavoidably. +shrPermutePrefix :: forall i n. PermR -> ShR n i -> ShR n i +shrPermutePrefix = \perm sh -> + TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> + case shrRank sh of { shlen@SNat -> + let sperm = shrFromList permlen perm in + case cmpNat permlen shlen of + LTI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + } + where + applyPermRFull :: SNat m -> ShR k Int -> ShR m i -> ShR k i + applyPermRFull _ ZSR _ = ZSR + applyPermRFull sm@SNat (i :$: perm) l = + TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> + case cmpNat (SNat @(idx + 1)) sm of + LTI -> shrIndex si l :$: applyPermRFull sm perm l + EQI -> shrIndex si l :$: applyPermRFull sm perm l + GTI -> error "shrPermutePrefix: Index in permutation out of range" shrEnum :: IShR sh -> [IIxR sh] shrEnum = shrEnum' @@ -417,17 +493,17 @@ instance KnownNat n => IsList (IxR n i) where toList = Foldable.toList -- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ShR n i) where - type Item (ShR n i) = i - fromList = ShR . IsList.fromList - toList = Foldable.toList +instance KnownNat n => IsList (IShR n) where + type Item (IShR n) = Int + fromList = shrFromList (SNat @n) + toList = shrToList -- * Internal helper functions listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i listrCastWithName _ SZ ZR = ZR -listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx +listrCastWithName name (SS n) (i ::: l) = i ::: listrCastWithName name n l listrCastWithName name _ _ = error $ name ++ ": ranks don't match" $(ixFromLinearStub "ixrFromLinear" [t| IShR |] [t| IxR |] [p| ZSR |] (\a b -> [p| $a :$: $b |]) [| ZIR |] [| (:.:) |] [| shrToList |]) diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs index a43ae0c..5b084e9 100644 --- a/src/Data/Array/Nested/Types.hs +++ b/src/Data/Array/Nested/Types.hs @@ -110,7 +110,7 @@ type family Replicate n a where Replicate n a = a : Replicate (n - 1) a lemReplicateSucc :: forall a n proxy. - proxy n -> (a : Replicate n a) :~: Replicate (n + 1) a + proxy n -> a : Replicate n a :~: Replicate (n + 1) a lemReplicateSucc _ = unsafeCoerceRefl type family MapJust l = r | r -> l where diff --git a/test/Gen.hs b/test/Gen.hs index 952e8db..789a59c 100644 --- a/test/Gen.hs +++ b/test/Gen.hs @@ -11,7 +11,6 @@ module Gen where import Data.ByteString qualified as BS -import Data.Foldable (toList) import Data.Type.Equality import Data.Type.Ord import Data.Vector.Storable qualified as VS @@ -46,7 +45,7 @@ genLowBiased (lo, hi) = do return (lo + x * x * x * (hi - lo)) shuffleShR :: IShR n -> Gen (IShR n) -shuffleShR = \sh -> go (length sh) (toList sh) sh +shuffleShR = \sh -> go (shrLength sh) (shrToList sh) sh where go :: Int -> [Int] -> IShR n -> Gen (IShR n) go _ _ ZSR = return ZSR @@ -78,7 +77,7 @@ genShRwithTarget targetMax sn = do dims <- genDims m (if dim == 0 then 0 else tgt `div` dim) return (dim :$: dims) dims <- genDims sn targetSize - let maxdim = maximum dims + let maxdim = maximum $ shrToList dims cap = binarySearch (`div` 2) 1 maxdim (\cap' -> shrSize (min cap' <$> dims) <= targetSize) shuffleShR (min cap <$> dims) diff --git a/test/Tests/C.hs b/test/Tests/C.hs index 0656107..8703957 100644 --- a/test/Tests/C.hs +++ b/test/Tests/C.hs @@ -15,7 +15,6 @@ module Tests.C where import Control.Monad import Data.Array.RankedS qualified as OR -import Data.Foldable (toList) import Data.Functor.Const import Data.Type.Equality import Foreign @@ -50,10 +49,10 @@ prop_sum_nonempty = property $ genRank $ \outrank@(SNat @n) -> do -- Test nonempty _results_. The first dimension of the input is allowed to be 0, because then OR.rerank doesn't fail yet. let inrank = SNat @(n + 1) sh <- forAll $ genShR inrank - -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) - guard (all (> 0) (shrTail sh)) -- only constrain the tail - arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList sh) <$> - genStorables (Range.singleton (product sh)) + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (shrSize sh)) + guard (all (> 0) (shrToList $ shrTail sh)) -- only constrain the tail + arr <- forAllT $ OR.fromVector @Double @(n + 1) (shrToList sh) <$> + genStorables (Range.singleton (shrSize sh)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) let rarr = rfromOrthotope inrank arr almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr) @@ -68,9 +67,9 @@ prop_sum_empty = property $ genRank $ \outrankm1@(SNat @nm1) -> do sht <- shuffleShR (0 :$: shtt) -- n n <- Gen.int (Range.linear 0 20) return (n :$: sht) -- n + 1 - guard (0 `elem` shrTail sh) - -- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh)) - let arr = OR.fromList @(n + 1) @Double (toList sh) [] + guard (0 `elem` (shrToList $ shrTail sh)) + -- traceM ("sh: " ++ show sh ++ " -> " ++ show (shrSize sh)) + let arr = OR.fromList @(n + 1) @Double (shrToList sh) [] let rarr = rfromOrthotope inrank arr OR.toList (rtoOrthotope (rsumOuter1Prim rarr)) === [] @@ -78,10 +77,10 @@ prop_sum_lasteq1 :: Property prop_sum_lasteq1 = property $ genRank $ \outrank@(SNat @n) -> do let inrank = SNat @(n + 1) outsh <- forAll $ genShR outrank - guard (all (> 0) outsh) + guard (all (> 0) $ shrToList outsh) let insh = shrAppend outsh (1 :$: ZSR) - arr <- forAllT $ OR.fromVector @Double @(n + 1) (toList insh) <$> - genStorables (Range.singleton (product insh)) + arr <- forAllT $ OR.fromVector @Double @(n + 1) (shrToList insh) <$> + genStorables (Range.singleton (shrSize insh)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) let rarr = rfromOrthotope inrank arr almostEq fineTol (rtoOrthotope (rsumOuter1Prim rarr)) (orSumOuter1 outrank arr) @@ -99,12 +98,12 @@ prop_sum_replicated doTranspose = property $ label (LabelName ("rankdiff " ++ show (fromSNat' inrank2 - fromSNat' inrank1))) label (LabelName ("size sh1 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh1) :: Double)) :: Int))) label (LabelName ("size sh3 10^" ++ show (floor (logBase 10 (fromIntegral (shrSize sh3) :: Double)) :: Int))) - guard (all (> 0) sh3) + guard (all (> 0) $ shrToList sh3) arr <- forAllT $ - OR.stretch (toList sh3) - . OR.reshape (toList sh2) - . OR.fromVector @Double @m (toList sh1) <$> - genStorables (Range.singleton (product sh1)) + OR.stretch (shrToList sh3) + . OR.reshape (shrToList sh2) + . OR.fromVector @Double @m (shrToList sh1) <$> + genStorables (Range.singleton (shrSize sh1)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) arrTrans <- if doTranspose then do perm <- forAll $ genPermR (fromSNat' inrank2) @@ -121,9 +120,9 @@ prop_negate_with :: forall f b. Show b prop_negate_with genRank' genB preproc = property $ genRank' $ \extra rank@(SNat @n) -> do sh <- forAll $ genShR rank - guard (all (> 0) sh) - arr <- forAllT $ OR.fromVector @Double @n (toList sh) <$> - genStorables (Range.singleton (product sh)) + guard (all (> 0) $ shrToList sh) + arr <- forAllT $ OR.fromVector @Double @n (shrToList sh) <$> + genStorables (Range.singleton (shrSize sh)) (\w -> fromIntegral w / fromIntegral (maxBound :: Word64)) bval <- forAll $ genB extra sh let arr' = preproc extra bval arr @@ -156,7 +155,7 @@ tests = testGroup "C" (\_ sh -> do let genPair n = do lo <- Gen.integral (Range.constant 0 (n-1)) len <- Gen.integral (Range.constant 0 (n-lo-1)) return (lo, len) - pairs <- mapM genPair (toList sh) + pairs <- mapM genPair (shrToList sh) return pairs) (\_ -> OR.slice) ] -- cgit v1.2.3-70-g09d2