aboutsummaryrefslogtreecommitdiff
path: root/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data')
-rw-r--r--src/Data/Array/Nested/Convert.hs3
-rw-r--r--src/Data/Array/Nested/Lemmas.hs14
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs21
-rw-r--r--src/Data/Array/Nested/Permutation.hs8
-rw-r--r--src/Data/Array/Nested/Ranked/Base.hs6
-rw-r--r--src/Data/Array/Nested/Ranked/Shape.hs166
-rw-r--r--src/Data/Array/Nested/Types.hs2
7 files changed, 154 insertions, 66 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
index 3706105..d4d1cea 100644
--- a/src/Data/Array/Nested/Convert.hs
+++ b/src/Data/Array/Nested/Convert.hs
@@ -57,6 +57,9 @@ import Data.Array.Nested.Types
-- * To ranked
+-- TODO: change all those unsafeCoerces into coerces by defining shaped
+-- and ranekd index types as newtypes of the mixed index type
+-- and similarly for the sized lists
ixrFromIxS :: IxS sh i -> IxR (Rank sh) i
ixrFromIxS = unsafeCoerce
diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs
index e089479..fa5611b 100644
--- a/src/Data/Array/Nested/Lemmas.hs
+++ b/src/Data/Array/Nested/Lemmas.hs
@@ -56,6 +56,20 @@ lemReplicatePlusApp sn _ _ = go sn
-}
lemReplicatePlusApp _ _ _ = unsafeCoerceRefl
+lemReplicateEmpty :: proxy n -> Replicate n (Nothing @Nat) :~: '[] -> n :~: 0
+lemReplicateEmpty _ Refl = unsafeCoerceRefl
+
+-- TODO: make less ad-hoc and rename these three:
+lemReplicateCons :: proxy sh -> proxy' n1 -> Nothing : sh :~: Replicate n1 Nothing -> n1 :~: Rank sh + 1
+lemReplicateCons _ _ Refl = unsafeCoerceRefl
+
+lemReplicateCons2 :: proxy sh -> proxy' n1 -> Nothing : sh :~: Replicate n1 Nothing -> sh :~: Replicate (Rank sh) Nothing
+lemReplicateCons2 _ _ Refl = unsafeCoerceRefl
+
+lemReplicateSucc2 :: forall n1 n proxy.
+ proxy n1 -> n + 1 :~: n1 -> Nothing @Nat : Replicate n Nothing :~: Replicate n1 Nothing
+lemReplicateSucc2 _ _ = unsafeCoerceRefl
+
lemDropLenApp :: Rank l1 <= Rank l2
=> Proxy l1 -> Proxy l2 -> Proxy rest
-> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index 7c79f8b..5ffd40c 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -549,11 +549,11 @@ shxFromList topssh topl = go topssh topl
{-# INLINEABLE shxToList #-}
shxToList :: IShX sh -> [Int]
-shxToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
+shxToList sh0 = build (\(cons :: i -> is -> is) (nil :: is) ->
let go :: IShX sh -> is
go ZSX = nil
go (smn :$% sh) = fromSMayNat' smn `cons` go sh
- in go list)
+ in go sh0)
-- If it ever matters for performance, this is unsafeCoercible.
shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i
@@ -578,6 +578,10 @@ shxHead (ShX list) = listhHead list
shxTail :: ShX (n : sh) i -> ShX sh i
shxTail (ShX list) = ShX (listhTail list)
+shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i
+shxTakeSSX _ ZKX _ = ZSX
+shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh
+
shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i
shxDropSSX = coerce (listhDrop @i @())
@@ -594,10 +598,6 @@ shxInit = coerce (listhInit @i)
shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh))
shxLast = coerce (listhLast @i)
-shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i
-shxTakeSSX _ ZKX _ = ZSX
-shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh
-
{-# INLINE shxZipWith #-}
shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n)
-> ShX sh i -> ShX sh j -> ShX sh k
@@ -690,14 +690,13 @@ ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
ssxEqType = testEquality
ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
-ssxAppend ZKX sh' = sh'
-ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
+ssxAppend = coerce (listhAppend @_ @())
ssxHead :: StaticShX (n : sh) -> SMayNat () n
ssxHead (StaticShX list) = listhHead list
ssxTail :: StaticShX (n : sh) -> StaticShX sh
-ssxTail (_ :!% ssh) = ssh
+ssxTail (StaticShX list) = StaticShX (listhTail list)
ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh
ssxTakeIx _ (IxX ZX) _ = ZKX
@@ -795,8 +794,8 @@ instance KnownShX sh => IsList (IxX sh i) where
toList = Foldable.toList
-- | Untyped: length and known dimensions are checked (at runtime).
-instance KnownShX sh => IsList (ShX sh Int) where
- type Item (ShX sh Int) = Int
+instance KnownShX sh => IsList (IShX sh) where
+ type Item (IShX sh) = Int
fromList = shxFromList (knownShX @sh)
toList = shxToList
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index 2e0c1ca..c3d2075 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -214,8 +214,8 @@ ssxDropLen = coerce (listhDropLen @())
ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
ssxPermute = coerce (listhPermute @())
-ssxIndex :: SNat i -> StaticShX sh -> SMayNat () (Index i sh)
-ssxIndex i = coerce (listhIndex @() i)
+ssxIndex :: SNat k -> StaticShX sh -> SMayNat () (Index k sh)
+ssxIndex k = coerce (listhIndex @() k)
ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
ssxPermutePrefix = coerce (listhPermutePrefix @())
@@ -229,8 +229,8 @@ shxDropLen = coerce (listhDropLen @Int)
shxPermute :: Perm is -> IShX sh -> IShX (Permute is sh)
shxPermute = coerce (listhPermute @Int)
-shxIndex :: SNat i -> IShX sh -> SMayNat Int (Index i sh)
-shxIndex i = coerce (listhIndex @Int i)
+shxIndex :: forall k sh i. SNat k -> ShX sh i -> SMayNat i (Index k sh)
+shxIndex k = coerce (listhIndex @i k)
shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
shxPermutePrefix = coerce (listhPermutePrefix @Int)
diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs
index 97a5f6f..5c696f3 100644
--- a/src/Data/Array/Nested/Ranked/Base.hs
+++ b/src/Data/Array/Nested/Ranked/Base.hs
@@ -32,10 +32,6 @@ import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp)
import GHC.Generics (Generic)
import GHC.TypeLits
-#ifndef OXAR_DEFAULT_SHOW_INSTANCES
-import Data.Foldable (toList)
-#endif
-
import Data.Array.Nested.Lemmas
import Data.Array.Nested.Mixed
import Data.Array.Nested.Mixed.Shape
@@ -65,7 +61,7 @@ deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a)
#ifndef OXAR_DEFAULT_SHOW_INSTANCES
instance (Show a, Elt a) => Show (Ranked n a) where
showsPrec d arr@(Ranked marr) =
- let sh = show (toList (rshape arr))
+ let sh = show (shrToList (rshape arr))
in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr
#endif
diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs
index 36f49dc..6ce0f4f 100644
--- a/src/Data/Array/Nested/Ranked/Shape.hs
+++ b/src/Data/Array/Nested/Ranked/Shape.hs
@@ -41,7 +41,9 @@ import GHC.TypeLits
import GHC.TypeNats qualified as TN
import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Mixed.Shape.Internal
+import Data.Array.Nested.Permutation
import Data.Array.Nested.Types
@@ -180,7 +182,12 @@ listrZipWith f (i ::: irest) (j ::: jrest) =
listrZipWith _ _ _ =
error "listrZipWith: impossible pattern needlessly required"
-listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
+listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
+listrSplitAt SZ sh = (ZR, sh)
+listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
+listrSplitAt SS{} ZR = error "m' + 1 <= 0"
+
+listrPermutePrefix :: forall i n. PermR -> ListR n i -> ListR n i
listrPermutePrefix = \perm sh ->
TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
case listrRank sh of { shlen@SNat ->
@@ -192,11 +199,6 @@ listrPermutePrefix = \perm sh ->
++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
}
where
- listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
- listrSplitAt SZ sh = (ZR, sh)
- listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
- listrSplitAt SS{} ZR = error "m' + 1 <= 0"
-
applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
applyPermRFull _ ZR _ = ZR
applyPermRFull sm@SNat (i ::: perm) l =
@@ -282,7 +284,7 @@ ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2
ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k
ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2
-ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
+ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i
ixrPermutePrefix = coerce (listrPermutePrefix @i)
-- | Given a multidimensional index, get the corresponding linear
@@ -303,18 +305,34 @@ ixrToLinear = \sh i -> go sh i 0
type role ShR nominal representational
type ShR :: Nat -> Type -> Type
-newtype ShR n i = ShR (ListR n i)
- deriving (Eq, Ord, NFData, Functor, Foldable)
+newtype ShR n i = ShR (ShX (Replicate n Nothing) i)
+ deriving (Eq, Ord, NFData, Functor)
pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
-pattern ZSR = ShR ZR
+pattern ZSR <- ShR (matchZSR @n -> Just Refl)
+ where ZSR = ShR ZSX
+
+matchZSR :: forall n i. ShX (Replicate n Nothing) i -> Maybe (n :~: 0)
+matchZSR ZSX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl
+matchZSR _ = Nothing
pattern (:$:)
:: forall {n1} {i}.
forall n. (n + 1 ~ n1)
=> i -> ShR n i -> ShR n1 i
-pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
- where i :$: ShR sh = ShR (i ::: sh)
+pattern i :$: shl <- (shrUncons -> Just (UnconsShRRes shl i))
+ where i :$: ShR shl | Refl <- lemReplicateSucc2 (Proxy @n1) Refl
+ = ShR (SUnknown i :$% shl)
+
+data UnconsShRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsShRRes (ShR n i) i
+shrUncons :: forall n1 i. ShR n1 i -> Maybe (UnconsShRRes i n1)
+shrUncons (ShR (SUnknown x :$% (sh' :: ShX sh' i)))
+ | Refl <- lemReplicateCons (Proxy @sh') (Proxy @n1) Refl
+ , Refl <- lemReplicateCons2 (Proxy @sh') (Proxy @n1) Refl
+ = Just (UnconsShRRes (ShR sh') x)
+shrUncons (ShR _) = Nothing
+
infixr 3 :$:
{-# COMPLETE ZSR, (:$:) #-}
@@ -325,67 +343,125 @@ type IShR n = ShR n Int
deriving instance Show i => Show (ShR n i)
#else
instance Show i => Show (ShR n i) where
- showsPrec _ (ShR l) = listrShow shows l
+ showsPrec d (ShR l) = showsPrec d l
#endif
-- | This checks only whether the ranks are equal, not whether the actual
-- values are.
shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'
+shrEqRank ZSR ZSR = Just Refl
+shrEqRank (_ :$: sh) (_ :$: sh')
+ | Just Refl <- shrEqRank sh sh'
+ = Just Refl
+shrEqRank _ _ = Nothing
-- | This compares the shapes for value equality.
shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
-shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'
+shrEqual ZSR ZSR = Just Refl
+shrEqual (i :$: sh) (i' :$: sh')
+ | Just Refl <- shrEqual sh sh'
+ , i == i'
+ = Just Refl
+shrEqual _ _ = Nothing
shrLength :: ShR sh i -> Int
-shrLength (ShR l) = listrLength l
+shrLength (ShR l) = shxLength l
-- | This function can also be used to conjure up a 'KnownNat' dictionary;
-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern
-- synonym yields 'KnownNat' evidence.
-shrRank :: ShR n i -> SNat n
-shrRank (ShR sh) = listrRank sh
+shrRank :: forall n i. ShR n i -> SNat n
+shrRank (ShR sh) | Refl <- lemRankReplicate (Proxy @n) = shxRank sh
-- | The number of elements in an array described by this shape.
shrSize :: IShR n -> Int
-shrSize ZSR = 1
-shrSize (n :$: sh) = n * shrSize sh
+shrSize (ShR sh) = shxSize sh
-shrFromList :: forall n i. SNat n -> [i] -> ShR n i
-shrFromList = coerce (listrFromList @_ @i)
+shrFromList :: SNat n -> [Int] -> IShR n
+shrFromList snat = coerce (shxFromList (ssxReplicate snat))
{-# INLINEABLE shrToList #-}
-shrToList :: forall n i. ShR n i -> [i]
-shrToList = coerce (listrToList @_ @i)
+shrToList :: IShR n -> [Int]
+shrToList = coerce shxToList
-shrHead :: ShR (n + 1) i -> i
-shrHead (ShR list) = listrHead list
+shrHead :: forall n i. ShR (n + 1) i -> i
+shrHead (ShR sh)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = case shxHead @Nothing @(Replicate n Nothing) sh of
+ SUnknown i -> i
-shrTail :: ShR (n + 1) i -> ShR n i
-shrTail (ShR list) = ShR (listrTail list)
+shrTail :: forall n i. ShR (n + 1) i -> ShR n i
+shrTail
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = coerce (shxTail @_ @_ @i)
-shrInit :: ShR (n + 1) i -> ShR n i
-shrInit (ShR list) = ShR (listrInit list)
+shrInit :: forall n i. ShR (n + 1) i -> ShR n i
+shrInit
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = -- TODO: change this and all other unsafeCoerceRefl to lemmas:
+ gcastWith (unsafeCoerceRefl
+ :: Init (Replicate (n + 1) (Nothing @Nat)) :~: Replicate n Nothing) $
+ coerce (shxInit @_ @_ @i)
-shrLast :: ShR (n + 1) i -> i
-shrLast (ShR list) = listrLast list
+shrLast :: forall n i. ShR (n + 1) i -> i
+shrLast (ShR sh)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n)
+ = case shxLast sh of
+ SUnknown i -> i
+ SKnown{} -> error "shrLast: impossible SKnown"
-- | Performs a runtime check that the lengths are identical.
shrCast :: SNat n' -> ShR n i -> ShR n' i
-shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh)
+shrCast SZ ZSR = ZSR
+shrCast (SS n) (i :$: sh) = i :$: shrCast n sh
+shrCast _ _ = error "shrCast: ranks don't match"
shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i
-shrAppend = coerce (listrAppend @_ @i)
-
-shrZip :: ShR n i -> ShR n j -> ShR n (i, j)
-shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2
+shrAppend =
+ -- lemReplicatePlusApp requires an SNat
+ gcastWith (unsafeCoerceRefl
+ :: Replicate n (Nothing @Nat) ++ Replicate m Nothing :~: Replicate (n + m) Nothing) $
+ coerce (shxAppend @_ @_ @i)
{-# INLINE shrZipWith #-}
shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k
-shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2
+shrZipWith _ ZSR ZSR = ZSR
+shrZipWith f (i :$: irest) (j :$: jrest) =
+ f i j :$: shrZipWith f irest jrest
+shrZipWith _ _ _ =
+ error "shrZipWith: impossible pattern needlessly required"
+
+shrSplitAt :: m <= n' => SNat m -> ShR n' i -> (ShR m i, ShR (n' - m) i)
+shrSplitAt SZ sh = (ZSR, sh)
+shrSplitAt (SS m) (n :$: sh) = (\(pre, post) -> (n :$: pre, post)) (shrSplitAt m sh)
+shrSplitAt SS{} ZSR = error "m' + 1 <= 0"
-shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
-shrPermutePrefix = coerce (listrPermutePrefix @i)
+shrIndex :: forall k sh i. SNat k -> ShR sh i -> i
+shrIndex k (ShR sh) = case shxIndex @_ @_ @i k sh of
+ SUnknown i -> i
+ SKnown{} -> error "shrIndex: impossible SKnown"
+
+-- Copy-pasted from listrPermutePrefix, probably unavoidably.
+shrPermutePrefix :: forall i n. PermR -> ShR n i -> ShR n i
+shrPermutePrefix = \perm sh ->
+ TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat ->
+ case shrRank sh of { shlen@SNat ->
+ let sperm = shrFromList permlen perm in
+ case cmpNat permlen shlen of
+ LTI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ }
+ where
+ applyPermRFull :: SNat m -> ShR k Int -> ShR m i -> ShR k i
+ applyPermRFull _ ZSR _ = ZSR
+ applyPermRFull sm@SNat (i :$: perm) l =
+ TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
+ case cmpNat (SNat @(idx + 1)) sm of
+ LTI -> shrIndex si l :$: applyPermRFull sm perm l
+ EQI -> shrIndex si l :$: applyPermRFull sm perm l
+ GTI -> error "shrPermutePrefix: Index in permutation out of range"
shrEnum :: IShR sh -> [IIxR sh]
shrEnum = shrEnum'
@@ -417,17 +493,17 @@ instance KnownNat n => IsList (IxR n i) where
toList = Foldable.toList
-- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ShR n i) where
- type Item (ShR n i) = i
- fromList = ShR . IsList.fromList
- toList = Foldable.toList
+instance KnownNat n => IsList (IShR n) where
+ type Item (IShR n) = Int
+ fromList = shrFromList (SNat @n)
+ toList = shrToList
-- * Internal helper functions
listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i
listrCastWithName _ SZ ZR = ZR
-listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx
+listrCastWithName name (SS n) (i ::: l) = i ::: listrCastWithName name n l
listrCastWithName name _ _ = error $ name ++ ": ranks don't match"
$(ixFromLinearStub "ixrFromLinear" [t| IShR |] [t| IxR |] [p| ZSR |] (\a b -> [p| $a :$: $b |]) [| ZIR |] [| (:.:) |] [| shrToList |])
diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs
index a43ae0c..5b084e9 100644
--- a/src/Data/Array/Nested/Types.hs
+++ b/src/Data/Array/Nested/Types.hs
@@ -110,7 +110,7 @@ type family Replicate n a where
Replicate n a = a : Replicate (n - 1) a
lemReplicateSucc :: forall a n proxy.
- proxy n -> (a : Replicate n a) :~: Replicate (n + 1) a
+ proxy n -> a : Replicate n a :~: Replicate (n + 1) a
lemReplicateSucc _ = unsafeCoerceRefl
type family MapJust l = r | r -> l where