summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-13 22:47:42 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-13 22:47:42 +0200
commite4e23a33f77d250af1e9b6614cf249128ba1510a (patch)
tree34bb40910003749becbaf8005a7b7ca62024fff2
parent7c9865354442326d55094087ad6a74b6e96341fb (diff)
Shape/index hygiene
-rw-r--r--cabal.project5
-rw-r--r--src/Data/Array/Mixed.hs246
-rw-r--r--src/Data/Array/Nested.hs2
-rw-r--r--src/Data/Array/Nested/Internal.hs239
-rw-r--r--src/Data/INat.hs1
-rw-r--r--test/Main.hs6
6 files changed, 264 insertions, 235 deletions
diff --git a/cabal.project b/cabal.project
index a13761a..697d3bd 100644
--- a/cabal.project
+++ b/cabal.project
@@ -1,2 +1,5 @@
packages: .
-with-compiler: ghc-9.6.4
+with-compiler: ghc-9.8.2
+
+allow-newer:
+ orthotope:deepseq
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 47027cb..94b7cdf 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -1,6 +1,10 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
@@ -32,6 +36,9 @@ pattern GHC_SNat :: () => KnownNat n => SNat n
pattern GHC_SNat = SNat
{-# COMPLETE GHC_SNat #-}
+fromSNat' :: SNat n -> Int
+fromSNat' = fromIntegral . fromSNat
+
-- | Type-level list append.
type family l1 ++ l2 where
@@ -44,9 +51,6 @@ lemAppNil = unsafeCoerce Refl
lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerce Refl
--- TODO: ListX? But if so, why is StaticShapeX not defined as a newtype
--- over IxX (so that we can make IxX and StaticShapeX a newtype over ListX)?
-
type IxX :: [Maybe Nat] -> Type -> Type
data IxX sh i where
ZIX :: IxX '[] i
@@ -55,31 +59,48 @@ data IxX sh i where
deriving instance Show i => Show (IxX sh i)
deriving instance Eq i => Eq (IxX sh i)
deriving instance Ord i => Ord (IxX sh i)
+deriving instance Functor (IxX sh)
+deriving instance Foldable (IxX sh)
infixr 3 :.@
infixr 3 :.?
type IIxX sh = IxX sh Int
--- | The part of a shape that is statically known.
-type StaticShapeX :: [Maybe Nat] -> Type
-data StaticShapeX sh where
- ZSX :: StaticShapeX '[]
- (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh)
- (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh)
-deriving instance Show (StaticShapeX sh)
+type ShX :: [Maybe Nat] -> Type -> Type
+data ShX sh i where
+ ZSX :: ShX '[] i
+ (:$@) :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i
+ (:$?) :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i
+deriving instance Show i => Show (ShX sh i)
+deriving instance Eq i => Eq (ShX sh i)
+deriving instance Ord i => Ord (ShX sh i)
+deriving instance Functor (ShX sh)
+deriving instance Foldable (ShX sh)
infixr 3 :$@
infixr 3 :$?
+type IShX sh = ShX sh Int
+
+-- | The part of a shape that is statically known.
+type StaticShX :: [Maybe Nat] -> Type
+data StaticShX sh where
+ ZKSX :: StaticShX '[]
+ (:!$@) :: SNat n -> StaticShX sh -> StaticShX (Just n : sh)
+ (:!$?) :: () -> StaticShX sh -> StaticShX (Nothing : sh)
+deriving instance Show (StaticShX sh)
+infixr 3 :!$@
+infixr 3 :!$?
+
-- | Evidence for the static part of a shape.
type KnownShapeX :: [Maybe Nat] -> Constraint
class KnownShapeX sh where
- knownShapeX :: StaticShapeX sh
+ knownShapeX :: StaticShX sh
instance KnownShapeX '[] where
- knownShapeX = ZSX
+ knownShapeX = ZKSX
instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
- knownShapeX = natSing :$@ knownShapeX
+ knownShapeX = natSing :!$@ knownShapeX
instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
- knownShapeX = () :$? knownShapeX
+ knownShapeX = () :!$? knownShapeX
type family Rank sh where
Rank '[] = Z
@@ -89,138 +110,149 @@ type XArray :: [Maybe Nat] -> Type -> Type
newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a)
deriving (Show)
-zeroIxX :: StaticShapeX sh -> IIxX sh
-zeroIxX ZSX = ZIX
-zeroIxX (_ :$@ ssh) = 0 :.@ zeroIxX ssh
-zeroIxX (_ :$? ssh) = 0 :.? zeroIxX ssh
+zeroIxX :: StaticShX sh -> IIxX sh
+zeroIxX ZKSX = ZIX
+zeroIxX (_ :!$@ ssh) = 0 :.@ zeroIxX ssh
+zeroIxX (_ :!$? ssh) = 0 :.? zeroIxX ssh
+
+zeroIxX' :: IShX sh -> IIxX sh
+zeroIxX' ZSX = ZIX
+zeroIxX' (_ :$@ sh) = 0 :.@ zeroIxX' sh
+zeroIxX' (_ :$? sh) = 0 :.? zeroIxX' sh
-zeroIxX' :: IIxX sh -> IIxX sh
-zeroIxX' ZIX = ZIX
-zeroIxX' (_ :.@ sh) = 0 :.@ zeroIxX' sh
-zeroIxX' (_ :.? sh) = 0 :.? zeroIxX' sh
+-- This is a weird operation, so it has a long name
+completeShXzeros :: StaticShX sh -> IShX sh
+completeShXzeros ZKSX = ZSX
+completeShXzeros (n :!$@ ssh) = n :$@ completeShXzeros ssh
+completeShXzeros (_ :!$? ssh) = 0 :$? completeShXzeros ssh
ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh')
ixAppend ZIX idx' = idx'
ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx'
ixAppend (i :.? idx) idx' = i :.? ixAppend idx idx'
+shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh')
+shAppend ZSX sh' = sh'
+shAppend (n :$@ sh) sh' = n :$@ shAppend sh sh'
+shAppend (n :$? sh) sh' = n :$? shAppend sh sh'
+
ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh'
ixDrop sh ZIX = sh
ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx
ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx
-ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh')
-ssxAppend ZSX sh' = sh'
-ssxAppend (n :$@ sh) sh' = n :$@ ssxAppend sh sh'
-ssxAppend (() :$? sh) sh' = () :$? ssxAppend sh sh'
+ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
+ssxAppend ZKSX sh' = sh'
+ssxAppend (n :!$@ sh) sh' = n :!$@ ssxAppend sh sh'
+ssxAppend (() :!$? sh) sh' = () :!$? ssxAppend sh sh'
-shapeSize :: IIxX sh -> Int
-shapeSize ZIX = 1
-shapeSize (n :.@ sh) = n * shapeSize sh
-shapeSize (n :.? sh) = n * shapeSize sh
+shapeSize :: IShX sh -> Int
+shapeSize ZSX = 1
+shapeSize (n :$@ sh) = fromSNat' n * shapeSize sh
+shapeSize (n :$? sh) = n * shapeSize sh
-- | This may fail if @sh@ has @Nothing@s in it.
-ssxToShape' :: StaticShapeX sh -> Maybe (IIxX sh)
-ssxToShape' ZSX = Just ZIX
-ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) :.@) <$> ssxToShape' sh
-ssxToShape' (_ :$? _) = Nothing
+ssxToShape' :: StaticShX sh -> Maybe (IShX sh)
+ssxToShape' ZKSX = Just ZSX
+ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh
+ssxToShape' (_ :!$? _) = Nothing
-fromLinearIdx :: IIxX sh -> Int -> IIxX sh
+fromLinearIdx :: IShX sh -> Int -> IIxX sh
fromLinearIdx = \sh i -> case go sh i of
(idx, 0) -> idx
_ -> error $ "fromLinearIdx: out of range (" ++ show i ++
" in array of shape " ++ show sh ++ ")"
where
-- returns (index in subarray, remaining index in enclosing array)
- go :: IIxX sh -> Int -> (IIxX sh, Int)
- go ZIX i = (ZIX, i)
- go (n :.@ sh) i =
+ go :: IShX sh -> Int -> (IIxX sh, Int)
+ go ZSX i = (ZIX, i)
+ go (n :$@ sh) i =
let (idx, i') = go sh i
- (upi, locali) = i' `quotRem` n
+ (upi, locali) = i' `quotRem` fromSNat' n
in (locali :.@ idx, upi)
- go (n :.? sh) i =
+ go (n :$? sh) i =
let (idx, i') = go sh i
(upi, locali) = i' `quotRem` n
in (locali :.? idx, upi)
-toLinearIdx :: IIxX sh -> IIxX sh -> Int
+toLinearIdx :: IShX sh -> IIxX sh -> Int
toLinearIdx = \sh i -> fst (go sh i)
where
-- returns (index in subarray, size of subarray)
- go :: IIxX sh -> IIxX sh -> (Int, Int)
- go ZIX ZIX = (0, 1)
- go (n :.@ sh) (i :.@ ix) =
+ go :: IShX sh -> IIxX sh -> (Int, Int)
+ go ZSX ZIX = (0, 1)
+ go (n :$@ sh) (i :.@ ix) =
let (lidx, sz) = go sh ix
- in (sz * i + lidx, n * sz)
- go (n :.? sh) (i :.? ix) =
+ in (sz * i + lidx, fromSNat' n * sz)
+ go (n :$? sh) (i :.? ix) =
let (lidx, sz) = go sh ix
in (sz * i + lidx, n * sz)
-enumShape :: IIxX sh -> [IIxX sh]
+enumShape :: IShX sh -> [IIxX sh]
enumShape = \sh -> go sh id []
where
- go :: IIxX sh -> (IIxX sh -> a) -> [a] -> [a]
- go ZIX f = (f ZIX :)
- go (n :.@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. n-1]]
- go (n :.? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]]
-
-shapeLshape :: IIxX sh -> S.ShapeL
-shapeLshape ZIX = []
-shapeLshape (n :.@ sh) = n : shapeLshape sh
-shapeLshape (n :.? sh) = n : shapeLshape sh
-
-ssxLength :: StaticShapeX sh -> Int
-ssxLength ZSX = 0
-ssxLength (_ :$@ ssh) = 1 + ssxLength ssh
-ssxLength (_ :$? ssh) = 1 + ssxLength ssh
-
-ssxIotaFrom :: Int -> StaticShapeX sh -> [Int]
-ssxIotaFrom _ ZSX = []
-ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh
-ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh
-
-lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2
+ go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a]
+ go ZSX f = (f ZIX :)
+ go (n :$@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. fromSNat' n - 1]]
+ go (n :$? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]]
+
+shapeLshape :: IShX sh -> S.ShapeL
+shapeLshape ZSX = []
+shapeLshape (n :$@ sh) = fromSNat' n : shapeLshape sh
+shapeLshape (n :$? sh) = n : shapeLshape sh
+
+ssxLength :: StaticShX sh -> Int
+ssxLength ZKSX = 0
+ssxLength (_ :!$@ ssh) = 1 + ssxLength ssh
+ssxLength (_ :!$? ssh) = 1 + ssxLength ssh
+
+ssxIotaFrom :: Int -> StaticShX sh -> [Int]
+ssxIotaFrom _ ZKSX = []
+ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh
+ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh
+
+lemRankApp :: StaticShX sh1 -> StaticShX sh2
-> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2)
lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this
-lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2
+lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
-> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1))
lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
-lemKnownINatRank :: IIxX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRank ZIX = Dict
-lemKnownINatRank (_ :.@ sh) | Dict <- lemKnownINatRank sh = Dict
-lemKnownINatRank (_ :.? sh) | Dict <- lemKnownINatRank sh = Dict
+lemKnownINatRank :: IShX sh -> Dict KnownINat (Rank sh)
+lemKnownINatRank ZSX = Dict
+lemKnownINatRank (_ :$@ sh) | Dict <- lemKnownINatRank sh = Dict
+lemKnownINatRank (_ :$? sh) | Dict <- lemKnownINatRank sh = Dict
-lemKnownINatRankSSX :: StaticShapeX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRankSSX ZSX = Dict
-lemKnownINatRankSSX (_ :$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
-lemKnownINatRankSSX (_ :$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
+lemKnownINatRankSSX :: StaticShX sh -> Dict KnownINat (Rank sh)
+lemKnownINatRankSSX ZKSX = Dict
+lemKnownINatRankSSX (_ :!$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
+lemKnownINatRankSSX (_ :!$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
-lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh
-lemKnownShapeX ZSX = Dict
-lemKnownShapeX (GHC_SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict
-lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict
+lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh
+lemKnownShapeX ZKSX = Dict
+lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict
+lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict
-lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
-lemAppKnownShapeX ZSX ssh' = lemKnownShapeX ssh'
-lemAppKnownShapeX (GHC_SNat :$@ ssh) ssh'
+lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
+lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh'
+lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh'
| Dict <- lemAppKnownShapeX ssh ssh'
= Dict
-lemAppKnownShapeX (() :$? ssh) ssh'
+lemAppKnownShapeX (() :!$? ssh) ssh'
| Dict <- lemAppKnownShapeX ssh ssh'
= Dict
-shape :: forall sh a. KnownShapeX sh => XArray sh a -> IIxX sh
+shape :: forall sh a. KnownShapeX sh => XArray sh a -> IShX sh
shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)
where
- go :: StaticShapeX sh' -> [Int] -> IIxX sh'
- go ZSX [] = ZIX
- go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) :.@ go ssh l
- go (() :$? ssh) (n : l) = n :.? go ssh l
+ go :: StaticShX sh' -> [Int] -> IShX sh'
+ go ZKSX [] = ZSX
+ go (n :!$@ ssh) (_ : l) = n :$@ go ssh l
+ go (() :!$? ssh) (n : l) = n :$? go ssh l
go _ _ = error "Invalid shapeL"
-fromVector :: forall sh a. Storable a => IIxX sh -> VS.Vector a -> XArray sh a
+fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
fromVector sh v
| Dict <- lemKnownINatRank sh
, Dict <- knownNatFromINat (Proxy @(Rank sh))
@@ -235,16 +267,16 @@ scalar = XArray . S.scalar
unScalar :: Storable a => XArray '[] a -> a
unScalar (XArray a) = S.unScalar a
-constant :: forall sh a. Storable a => IIxX sh -> a -> XArray sh a
+constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
constant sh x
| Dict <- lemKnownINatRank sh
, Dict <- knownNatFromINat (Proxy @(Rank sh))
= XArray (S.constant (shapeLshape sh) x)
-generate :: Storable a => IIxX sh -> (IIxX sh -> a) -> XArray sh a
+generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a
generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)
--- generateM :: (Monad m, Storable a) => IIxX sh -> (IIxX sh -> m a) -> m (XArray sh a)
+-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a)
-- generateM sh f | Dict <- lemKnownINatRank sh =
-- XArray . S.fromVector (shapeLshape sh)
-- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh)
@@ -274,7 +306,7 @@ append (XArray a) (XArray b)
rerank :: forall sh sh1 sh2 a b.
(Storable a, Storable b)
- => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
+ => StaticShX sh -> StaticShX sh1 -> StaticShX sh2
-> (XArray sh1 a -> XArray sh2 b)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
rerank ssh ssh1 ssh2 f (XArray arr)
@@ -294,14 +326,14 @@ rerank ssh ssh1 ssh2 f (XArray arr)
rerankTop :: forall sh sh1 sh2 a b.
(Storable a, Storable b)
- => StaticShapeX sh1 -> StaticShapeX sh2 -> StaticShapeX sh
+ => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh
-> (XArray sh1 a -> XArray sh2 b)
-> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b
rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh
rerank2 :: forall sh sh1 sh2 a b c.
(Storable a, Storable b, Storable c)
- => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
+ => StaticShX sh -> StaticShX sh1 -> StaticShX sh2
-> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
@@ -327,7 +359,7 @@ transpose perm (XArray arr)
= XArray (S.transpose perm arr)
transpose2 :: forall sh1 sh2 a.
- StaticShapeX sh1 -> StaticShapeX sh2
+ StaticShX sh1 -> StaticShX sh2
-> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a
transpose2 ssh1 ssh2 (XArray arr)
| Refl <- lemRankApp ssh1 ssh2
@@ -344,24 +376,24 @@ sumFull :: (Storable a, Num a) => XArray sh a -> a
sumFull (XArray arr) = S.sumA arr
sumInner :: forall sh sh' a. (Storable a, Num a)
- => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a
+ => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a
sumInner ssh ssh'
| Refl <- lemAppNil @sh
- = rerank ssh ssh' ZSX (scalar . sumFull)
+ = rerank ssh ssh' ZKSX (scalar . sumFull)
sumOuter :: forall sh sh' a. (Storable a, Num a)
- => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a
+ => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a
sumOuter ssh ssh'
| Refl <- lemAppNil @sh
= sumInner ssh' ssh . transpose2 ssh ssh'
fromList :: forall n sh a. Storable a
- => StaticShapeX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
+ => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
fromList ssh l
| Dict <- lemKnownINatRankSSX ssh
, Dict <- knownNatFromINat (Proxy @(Rank (n : sh)))
= case ssh of
- m@GHC_SNat :$@ _ | natVal m /= fromIntegral (length l) ->
+ m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) ->
error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++
"does not match the type (" ++ show (natVal m) ++ ")"
_ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l)))
@@ -369,5 +401,13 @@ fromList ssh l
toList :: Storable a => XArray (n : sh) a -> [XArray sh a]
toList (XArray arr) = coerce (ORB.toList (S.unravel arr))
+-- | Throws if the given shape is not, in fact, empty.
+empty :: forall sh a. Storable a => IShX sh -> XArray sh a
+empty sh
+ | Dict <- lemKnownINatRank sh
+ , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ = XArray (S.constant (shapeLshape sh)
+ (error "Data.Array.Mixed.empty: shape was not empty"))
+
slice :: [(Int, Int)] -> XArray sh a -> XArray sh a
slice ivs (XArray arr) = XArray (S.slice ivs arr)
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 370e30a..145598e 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -28,7 +28,7 @@ module Data.Array.Nested (
-- * Mixed arrays
Mixed,
IxX(..), IIxX,
- KnownShapeX(..), StaticShapeX(..),
+ KnownShapeX(..), StaticShX(..),
mgenerate, mtranspose, mappend, mfromVector, munScalar,
mconstant, mfromList1, mtoList1, mslice,
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 49ed7cb..2f1e79e 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -8,12 +8,12 @@
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
@@ -23,7 +23,7 @@
{-|
TODO:
-* We should be more consistent in whether functions take a 'StaticShapeX'
+* We should be more consistent in whether functions take a 'StaticShX'
argument or a 'KnownShapeX' constraint.
* Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point
@@ -51,7 +51,7 @@ import qualified Data.Vector.Storable.Mutable as VSM
import Foreign.Storable (Storable)
import GHC.TypeLits
-import Data.Array.Mixed (XArray, IxX(..), IIxX, KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat)
+import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat)
import qualified Data.Array.Mixed as X
import Data.INat
@@ -100,9 +100,9 @@ type family MapJust l where
lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n))
where
- go :: SINat m -> StaticShapeX (Replicate m Nothing)
- go SZ = ZSX
- go (SS n) = () :$? go n
+ go :: SINat m -> StaticShX (Replicate m Nothing)
+ go SZ = ZKSX
+ go (SS n) = () :!$? go n
lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
lemRankReplicate _ = go (inatSing @n)
@@ -119,10 +119,10 @@ lemReplicatePlusApp _ _ _ = go (inatSing @n)
go SZ = Refl
go (SS n) | Refl <- go n = Refl
-ixAppSplit :: Proxy sh' -> StaticShapeX sh -> IIxX (sh ++ sh') -> (IIxX sh, IIxX sh')
-ixAppSplit _ ZSX idx = (ZIX, idx)
-ixAppSplit p (_ :$@ ssh) (i :.@ idx) = first (i :.@) (ixAppSplit p ssh idx)
-ixAppSplit p (_ :$? ssh) (i :.? idx) = first (i :.?) (ixAppSplit p ssh idx)
+shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh')
+shAppSplit _ ZKSX idx = (ZSX, idx)
+shAppSplit p (_ :!$@ ssh) (i :$@ idx) = first (i :$@) (shAppSplit p ssh idx)
+shAppSplit p (_ :!$? ssh) (i :$? idx) = first (i :$?) (shAppSplit p ssh idx)
-- | Wrapper type used as a tag to attach instances on. The instances on arrays
@@ -184,7 +184,7 @@ newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MV
data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b)
-- etc.
-data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IIxX sh2) !(MixedVecs s (sh1 ++ sh2) a)
+data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a)
-- | Tree giving the shape of every array component.
@@ -196,9 +196,9 @@ type family ShapeTree a where
ShapeTree () = ()
ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
- ShapeTree (Mixed sh a) = (IIxX sh, ShapeTree a)
- ShapeTree (Ranked n a) = (IIxR n, ShapeTree a)
- ShapeTree (Shaped sh a) = (IIxS sh, ShapeTree a)
+ ShapeTree (Mixed sh a) = (IShX sh, ShapeTree a)
+ ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
+ ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
-- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or
@@ -207,7 +207,7 @@ type family ShapeTree a where
class Elt a where
-- ====== PUBLIC METHODS ====== --
- mshape :: KnownShapeX sh => Mixed sh a -> IIxX sh
+ mshape :: KnownShapeX sh => Mixed sh a -> IShX sh
mindex :: Mixed sh a -> IIxX sh -> a
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
mscalar :: a -> Mixed '[] a
@@ -240,15 +240,12 @@ class Elt a where
-> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
-- ====== PRIVATE METHODS ====== --
- -- Remember I said that this module needed better management of exports?
-- | Create an empty array. The given shape must have size zero; this may or may not be checked.
- memptyArray :: IIxX sh -> Mixed sh a
+ memptyArray :: IShX sh -> Mixed sh a
mshapeTree :: a -> ShapeTree a
- mshapeTreeZero :: Proxy a -> ShapeTree a
-
mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool
@@ -257,20 +254,20 @@ class Elt a where
-- | Create uninitialised vectors for this array type, given the shape of
-- this vector and an example for the contents.
- mvecsUnsafeNew :: IIxX sh -> a -> ST s (MixedVecs s sh a)
+ mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
- mvecsWrite :: IIxX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+ mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
- mvecsWritePartial :: KnownShapeX sh' => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+ mvecsWritePartial :: KnownShapeX sh' => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-- | Given the shape of this array, finalise the vectors into 'XArray's.
- mvecsFreeze :: IIxX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+ mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
-- Arrays of scalars are basically just arrays of scalars.
@@ -299,9 +296,8 @@ instance Storable a => Elt (Primitive a) where
, Refl <- X.lemAppNil @sh3
= M_Primitive (f Proxy a b)
- memptyArray sh = M_Primitive (X.generate sh (error $ "memptyArray Int: shape was not empty (" ++ show sh ++ ")"))
+ memptyArray sh = M_Primitive (X.empty sh)
mshapeTree _ = ()
- mshapeTreeZero _ = ()
mshapeTreeEq _ () () = True
mshapeTreeEmpty _ () = False
mshowShapeTree _ () = "()"
@@ -312,7 +308,7 @@ instance Storable a => Elt (Primitive a) where
-- TODO: this use of toVector is suboptimal
mvecsWritePartial
:: forall sh' sh s. KnownShapeX sh'
- => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do
let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' (X.shape arr)))
VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)
@@ -338,7 +334,6 @@ instance (Elt a, Elt b) => Elt (a, b) where
memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
- mshapeTreeZero _ = (mshapeTreeZero (Proxy @a), mshapeTreeZero (Proxy @b))
mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
@@ -357,10 +352,10 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
-- TODO: this is quadratic in the nesting depth because it repeatedly
-- truncates the shape vector to one a little shorter. Fix with a
-- moverlongShape method, a prefix of which is mshape.
- mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IIxX sh
+ mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IShX sh
mshape (M_Nest arr)
| Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
- = fst (ixAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr))
+ = fst (shAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr))
mindex (M_Nest arr) i = mindexPartial arr i
@@ -410,12 +405,10 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
, Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT))
= f (Proxy @(sh' ++ shT))
- memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIxX (knownShapeX @sh'))))
+ memptyArray sh = M_Nest (memptyArray (X.shAppend sh (X.completeShXzeros (knownShapeX @sh'))))
mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (knownShapeX @sh'))))
- mshapeTreeZero _ = (X.zeroIxX (knownShapeX @sh'), mshapeTreeZero (Proxy @a))
-
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
@@ -424,32 +417,26 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
mvecsUnsafeNew sh example
| X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
- | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example))
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh (mshape example))
(mindex example (X.zeroIxX (knownShapeX @sh')))
where
sh' = mshape example
- mvecsNewEmpty _ = MV_Nest (X.zeroIxX (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a)
+ mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a)
- mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs
+ mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs
mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
- => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
+ => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
-> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
-> ST s ()
mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)
| Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
, Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs
-
- mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs
+ = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.shAppend sh12 sh') idx arr vecs
+ mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.shAppend sh sh') vecs
--- | Check whether a given shape corresponds on the statically-known components of the shape.
-checkBounds :: IIxX sh' -> StaticShapeX sh' -> Bool
-checkBounds ZIX ZSX = True
-checkBounds (n :.@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh'
-checkBounds (_ :.? sh') (() :$? ssh') = checkBounds sh' ssh'
-- | Create an array given a size and a function that computes the element at a
-- given index.
@@ -468,31 +455,25 @@ checkBounds (_ :.? sh') (() :$? ssh') = checkBounds sh' ssh'
-- the entire hierarchy (after distributing out tuples) must be a rectangular
-- array. The type of 'mgenerate' allows this requirement to be broken very
-- easily, hence the runtime check.
-mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IIxX sh -> (IIxX sh -> a) -> Mixed sh a
-mgenerate sh f
- -- TODO: Do we need this checkBounds check elsewhere as well?
- | not (checkBounds sh (knownShapeX @sh)) =
- error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
- -- If the shape is empty, there is no first element, so we should not try to
- -- generate it.
- | X.shapeSize sh == 0 = memptyArray sh
- | otherwise =
- let firstidx = X.zeroIxX' sh
- firstelem = f (X.zeroIxX' sh)
- shapetree = mshapeTree firstelem
- in if mshapeTreeEmpty (Proxy @a) shapetree
- then memptyArray sh
- else runST $ do
- vecs <- mvecsUnsafeNew sh firstelem
- mvecsWrite sh firstidx firstelem vecs
- -- TODO: This is likely fine if @a@ is big, but if @a@ is a
- -- scalar this array copying inefficient. Should improve this.
- forM_ (tail (X.enumShape sh)) $ \idx -> do
- let val = f idx
- when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
- error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
- mvecsWrite sh idx val vecs
- mvecsFreeze sh vecs
+mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IShX sh -> (IIxX sh -> a) -> Mixed sh a
+mgenerate sh f = case X.enumShape sh of
+ [] -> memptyArray sh
+ firstidx : restidxs ->
+ let firstelem = f (X.zeroIxX' sh)
+ shapetree = mshapeTree firstelem
+ in if mshapeTreeEmpty (Proxy @a) shapetree
+ then memptyArray sh
+ else runST $ do
+ vecs <- mvecsUnsafeNew sh firstelem
+ mvecsWrite sh firstidx firstelem vecs
+ -- TODO: This is likely fine if @a@ is big, but if @a@ is a
+ -- scalar this array copying inefficient. Should improve this.
+ forM_ restidxs $ \idx -> do
+ let val = f idx
+ when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
+ error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
+ mvecsWrite sh idx val vecs
+ mvecsFreeze sh vecs
mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a
mtranspose perm =
@@ -506,12 +487,8 @@ mappend = mlift2 go
=> Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b
go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append
-mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> VS.Vector a -> Mixed sh (Primitive a)
-mfromVector sh v
- | not (checkBounds sh (knownShapeX @sh)) =
- error $ "mfromVector: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
- | otherwise =
- M_Primitive (X.fromVector sh v)
+mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
+mfromVector sh v = M_Primitive (X.fromVector sh v)
mfromList1 :: (KnownShapeX '[n], Elt a) => NonEmpty a -> Mixed '[n] a
mfromList1 = mfromList . fmap mscalar
@@ -522,17 +499,13 @@ mtoList1 = map munScalar . mtoList
munScalar :: Elt a => Mixed '[] a -> a
munScalar arr = mindex arr ZIX
-mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> a -> Mixed sh (Primitive a)
-mconstantP sh x
- | not (checkBounds sh (knownShapeX @sh)) =
- error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
- | otherwise =
- M_Primitive (X.constant sh x)
+mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> a -> Mixed sh (Primitive a)
+mconstantP sh x = M_Primitive (X.constant sh x)
-- | This 'Coercible' constraint holds for the scalar types for which 'Mixed'
-- is defined.
mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a))
- => IIxX sh -> a -> Mixed sh a
+ => IShX sh -> a -> Mixed sh a
mconstant sh x = coerce (mconstantP sh x)
mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a
@@ -648,7 +621,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
= coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
mlift2 f arr1 arr2
- memptyArray :: forall sh. IIxX sh -> Mixed sh (Ranked n a)
+ memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a)
memptyArray i
| Dict <- lemKnownReplicate (Proxy @n)
= coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
@@ -657,9 +630,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
mshapeTree (Ranked arr)
| Refl <- lemRankReplicate (Proxy @n)
, Dict <- lemKnownReplicate (Proxy @n)
- = first ixCvtXR (mshapeTree arr)
-
- mshapeTreeZero _ = (zeroIxR (inatSing @n), mshapeTreeZero (Proxy @a))
+ = first shCvtXR (mshapeTree arr)
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -675,7 +646,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
| Dict <- lemKnownReplicate (Proxy @n)
= MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
- mvecsWrite :: forall sh s. IIxX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
mvecsWrite sh idx (Ranked arr) vecs
| Dict <- lemKnownReplicate (Proxy @n)
= mvecsWrite sh idx arr
@@ -683,7 +654,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
vecs)
mvecsWritePartial :: forall sh sh' s. KnownShapeX sh'
- => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
+ => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
-> MixedVecs s (sh ++ sh') (Ranked n a)
-> ST s ()
mvecsWritePartial sh idx arr vecs
@@ -696,7 +667,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
@(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
vecs)
- mvecsFreeze :: forall sh s. IIxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
mvecsFreeze sh vecs
| Dict <- lemKnownReplicate (Proxy @n)
= coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
@@ -712,6 +683,8 @@ data ShS sh where
ZSS :: ShS '[]
(:$$) :: forall n sh. SNat n -> ShS sh -> ShS (n : sh)
deriving instance Show (ShS sh)
+deriving instance Eq (ShS sh)
+deriving instance Ord (ShS sh)
infixr 3 :$$
-- | A statically-known shape of a shape-typed array.
@@ -726,9 +699,9 @@ sshapeKnown (GHC_SNat :$$ sh) | Dict <- sshapeKnown sh = Dict
lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh)
lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh))
where
- go :: ShS sh' -> StaticShapeX (MapJust sh')
- go ZSS = ZSX
- go (n :$$ sh) = n :$@ go sh
+ go :: ShS sh' -> StaticShX (MapJust sh')
+ go ZSS = ZKSX
+ go (n :$$ sh) = n :!$@ go sh
lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2
-> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
@@ -777,7 +750,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
= coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
mlift2 f arr1 arr2
- memptyArray :: forall sh'. IIxX sh' -> Mixed sh' (Shaped sh a)
+ memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
memptyArray i
| Dict <- lemKnownMapJust (Proxy @sh)
= coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
@@ -785,9 +758,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
mshapeTree (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
- = first (ixCvtXS (knownShape @sh)) (mshapeTree arr)
-
- mshapeTreeZero _ = (zeroIxS (knownShape @sh), mshapeTreeZero (Proxy @a))
+ = first (shCvtXS (knownShape @sh)) (mshapeTree arr)
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -803,7 +774,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
| Dict <- lemKnownMapJust (Proxy @sh)
= MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
- mvecsWrite :: forall sh' s. IIxX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
mvecsWrite sh idx (Shaped arr) vecs
| Dict <- lemKnownMapJust (Proxy @sh)
= mvecsWrite sh idx arr
@@ -811,7 +782,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
vecs)
mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
- => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
-> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
-> ST s ()
mvecsWritePartial sh idx arr vecs
@@ -824,7 +795,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
@(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
vecs)
- mvecsFreeze :: forall sh' s. IIxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
mvecsFreeze sh vecs
| Dict <- lemKnownMapJust (Proxy @sh)
= coerce @(Mixed sh' (Mixed (MapJust sh) a))
@@ -927,6 +898,8 @@ newtype ShR n i = ShR (ListR n i)
deriving (Show, Eq, Ord)
deriving newtype (Functor, Foldable)
+type IShR n = ShR n Int
+
pattern ZSR :: forall n i. () => n ~ Z => ShR n i
pattern ZSR = ShR ZR
@@ -957,20 +930,29 @@ ixCvtXR ZIX = ZIR
ixCvtXR (n :.@ idx) = n :.: ixCvtXR idx
ixCvtXR (n :.? idx) = n :.: ixCvtXR idx
+shCvtXR :: IShX sh -> IShR (X.Rank sh)
+shCvtXR ZSX = ZSR
+shCvtXR (n :$@ idx) = X.fromSNat' n :$: shCvtXR idx
+shCvtXR (n :$? idx) = n :$: shCvtXR idx
+
ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
ixCvtRX ZIR = ZIX
ixCvtRX (n :.: idx) = n :.? ixCvtRX idx
-shapeSizeR :: IIxR n -> Int
-shapeSizeR ZIR = 1
-shapeSizeR (n :.: sh) = n * shapeSizeR sh
+shCvtRX :: IShR n -> IShX (Replicate n Nothing)
+shCvtRX ZSR = ZSX
+shCvtRX (n :$: idx) = n :$? shCvtRX idx
+
+shapeSizeR :: IShR n -> Int
+shapeSizeR ZSR = 1
+shapeSizeR (n :$: sh) = n * shapeSizeR sh
-rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IIxR n
+rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IShR n
rshape (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
, Refl <- lemRankReplicate (Proxy @n)
- = ixCvtXR (mshape arr)
+ = shCvtXR (mshape arr)
rindex :: Elt a => Ranked n a -> IIxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
@@ -983,12 +965,12 @@ rindexPartial (Ranked arr) idx =
-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
-rgenerate :: forall n a. Elt a => IIxR n -> (IIxR n -> a) -> Ranked n a
+rgenerate :: forall n a. Elt a => IShR n -> (IIxR n -> a) -> Ranked n a
rgenerate sh f
- | Dict <- knownIxR sh
+ | Dict <- knownShR sh
, Dict <- lemKnownReplicate (Proxy @n)
, Refl <- lemRankReplicate (Proxy @n)
- = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR))
+ = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
-- | See the documentation of 'mlift'.
rlift :: forall n1 n2 a. (KnownINat n2, Elt a)
@@ -1005,7 +987,7 @@ rsumOuter1 (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked
. coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a))
- . X.sumOuter (() :$? ZSX) (knownShapeX @(Replicate n Nothing))
+ . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing))
. coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a)
$ arr
@@ -1021,10 +1003,10 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
rscalar :: Elt a => a -> Ranked I0 a
rscalar x = Ranked (mscalar x)
-rfromVector :: forall n a. (KnownINat n, Storable a) => IIxR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVector :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)
rfromVector sh v
| Dict <- lemKnownReplicate (Proxy @n)
- = Ranked (mfromVector (ixCvtRX sh) v)
+ = Ranked (mfromVector (shCvtRX sh) v)
rfromList :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a
rfromList l
@@ -1043,13 +1025,13 @@ rtoList1 = map runScalar . rtoList
runScalar :: Elt a => Ranked I0 a -> a
runScalar arr = rindex arr ZIR
-rconstantP :: forall n a. (KnownINat n, Storable a) => IIxR n -> a -> Ranked n (Primitive a)
+rconstantP :: forall n a. (KnownINat n, Storable a) => IShR n -> a -> Ranked n (Primitive a)
rconstantP sh x
| Dict <- lemKnownReplicate (Proxy @n)
- = Ranked (mconstantP (ixCvtRX sh) x)
+ = Ranked (mconstantP (shCvtRX sh) x)
rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a))
- => IIxR n -> a -> Ranked n a
+ => IShR n -> a -> Ranked n a
rconstant sh x = coerce (rconstantP sh x)
rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a
@@ -1141,27 +1123,30 @@ zeroIxS :: ShS sh -> IIxS sh
zeroIxS ZSS = ZIS
zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh
--- TODO: this function should not exist perhaps
-cvtShSIxS :: ShS sh -> IIxS sh
-cvtShSIxS ZSS = ZIS
-cvtShSIxS (n :$$ sh) = fromIntegral (fromSNat n) :.$ cvtShSIxS sh
-
ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh
ixCvtXS ZSS ZIX = ZIS
ixCvtXS (_ :$$ sh) (n :.@ idx) = n :.$ ixCvtXS sh idx
+shCvtXS :: ShS sh -> IShX (MapJust sh) -> ShS sh
+shCvtXS ZSS ZSX = ZSS
+shCvtXS (_ :$$ sh) (n :$@ idx) = n :$$ shCvtXS sh idx
+
ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
ixCvtSX ZIS = ZIX
ixCvtSX (n :.$ sh) = n :.@ ixCvtSX sh
-shapeSizeS :: IIxS sh -> Int
-shapeSizeS ZIS = 1
-shapeSizeS (n :.$ sh) = n * shapeSizeS sh
+shCvtSX :: ShS sh -> IShX (MapJust sh)
+shCvtSX ZSS = ZSX
+shCvtSX (n :$$ sh) = n :$@ shCvtSX sh
+
+shapeSizeS :: ShS sh -> Int
+shapeSizeS ZSS = 1
+shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh
-- | This does not touch the passed array, all information comes from 'KnownShape'.
-sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IIxS sh
-sshape _ = cvtShSIxS (knownShape @sh)
+sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> ShS sh
+sshape _ = knownShape @sh
sindex :: Elt a => Shaped sh a -> IIxS sh -> a
sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
@@ -1177,7 +1162,7 @@ sindexPartial (Shaped arr) idx =
sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IIxS sh -> a) -> Shaped sh a
sgenerate f
| Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mgenerate (ixCvtSX (cvtShSIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh)))
+ = Shaped (mgenerate (shCvtSX (knownShape @sh)) (f . ixCvtXS (knownShape @sh)))
-- | See the documentation of 'mlift'.
slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)
@@ -1194,7 +1179,7 @@ ssumOuter1 (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped
. coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a))
- . X.sumOuter (natSing @n :$@ ZSX) (knownShapeX @(MapJust sh))
+ . X.sumOuter (natSing @n :!$@ ZKSX) (knownShapeX @(MapJust sh))
. coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a)
$ arr
@@ -1213,7 +1198,7 @@ sscalar x = Shaped (mscalar x)
sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a)
sfromVector v
| Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mfromVector (ixCvtSX (cvtShSIxS (knownShape @sh))) v)
+ = Shaped (mfromVector (shCvtSX (knownShape @sh)) v)
sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a)
=> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
@@ -1236,7 +1221,7 @@ sunScalar arr = sindex arr ZIS
sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a)
sconstantP x
| Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mconstantP (ixCvtSX (cvtShSIxS (knownShape @sh))) x)
+ = Shaped (mconstantP (shCvtSX (knownShape @sh)) x)
sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a))
=> a -> Shaped sh a
diff --git a/src/Data/INat.hs b/src/Data/INat.hs
index 2d65c53..af8f18b 100644
--- a/src/Data/INat.hs
+++ b/src/Data/INat.hs
@@ -4,6 +4,7 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
diff --git a/test/Main.hs b/test/Main.hs
index 0a07531..2363813 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -7,7 +7,7 @@ import Data.Array.Nested
arr :: Ranked I2 (Shaped [2, 3] (Double, Int))
-arr = rgenerate (3 :.: 4 :.: ZIR) $ \(i :.: j :.: ZIR) ->
+arr = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) ->
sgenerate @[2, 3] $ \(k :.$ l :.$ ZIS) ->
let s = 24*i + 6*j + 3*k + l
in (fromIntegral s, s)
@@ -16,8 +16,8 @@ foo :: (Double, Int)
foo = arr `rindex` (2 :.: 1 :.: ZIR) `sindex` (1 :.$ 1 :.$ ZIS)
bad :: Ranked I2 (Ranked I1 Double)
-bad = rgenerate (3 :.: 4 :.: ZIR) $ \(i :.: j :.: ZIR) ->
- rgenerate (i :.: ZIR) $ \(k :.: ZIR) ->
+bad = rgenerate (3 :$: 4 :$: ZSR) $ \(i :.: j :.: ZIR) ->
+ rgenerate (i :$: ZSR) $ \(k :.: ZIR) ->
let s = 24*i + 6*j + 3*k
in fromIntegral s