summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-21 15:45:47 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-21 15:45:47 +0200
commit3a82a91be0f1b18f071cdb35526b2b2d0b8e093f (patch)
tree6a65376a17ef4051446952541bab57149c66fca7
parent181c02b60445310105be126fd6cc2ee9f5dc2c8a (diff)
Make index types useful for horde-ad by parameterising Int
-rw-r--r--src/Data/Array/Mixed.hs68
-rw-r--r--src/Data/Array/Nested.hs6
-rw-r--r--src/Data/Array/Nested/Internal.hs144
3 files changed, 112 insertions, 106 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 164f832..8d20583 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -45,15 +45,17 @@ lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerce Refl
-type IxX :: [Maybe Nat] -> Type
-data IxX sh where
- IZX :: IxX '[]
- (::@) :: Int -> IxX sh -> IxX (Just n : sh)
- (::?) :: Int -> IxX sh -> IxX (Nothing : sh)
-deriving instance Show (IxX sh)
-deriving instance Eq (IxX sh)
-infixr 5 ::@
-infixr 5 ::?
+type IxX :: Type -> [Maybe Nat] -> Type
+data IxX i sh where
+ IZX :: IxX i '[]
+ (::@) :: forall n sh i. i -> IxX i sh -> IxX i (Just n : sh)
+ (::?) :: forall sh i. i -> IxX i sh -> IxX i (Nothing : sh)
+deriving instance Show i => Show (IxX i sh)
+deriving instance Eq i => Eq (IxX i sh)
+infixr 3 ::@
+infixr 3 ::?
+
+type IIxX = IxX Int
-- | The part of a shape that is statically known.
type StaticShapeX :: [Maybe Nat] -> Type
@@ -62,8 +64,8 @@ data StaticShapeX sh where
(:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh)
(:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh)
deriving instance Show (StaticShapeX sh)
-infixr 5 :$@
-infixr 5 :$?
+infixr 3 :$@
+infixr 3 :$?
-- | Evidence for the static part of a shape.
type KnownShapeX :: [Maybe Nat] -> Constraint
@@ -84,22 +86,22 @@ type XArray :: [Maybe Nat] -> Type -> Type
newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a)
deriving (Show)
-zeroIxX :: StaticShapeX sh -> IxX sh
+zeroIxX :: StaticShapeX sh -> IIxX sh
zeroIxX SZX = IZX
zeroIxX (_ :$@ ssh) = 0 ::@ zeroIxX ssh
zeroIxX (_ :$? ssh) = 0 ::? zeroIxX ssh
-zeroIxX' :: IxX sh -> IxX sh
+zeroIxX' :: IIxX sh -> IIxX sh
zeroIxX' IZX = IZX
zeroIxX' (_ ::@ sh) = 0 ::@ zeroIxX' sh
zeroIxX' (_ ::? sh) = 0 ::? zeroIxX' sh
-ixAppend :: IxX sh -> IxX sh' -> IxX (sh ++ sh')
+ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh')
ixAppend IZX idx' = idx'
ixAppend (i ::@ idx) idx' = i ::@ ixAppend idx idx'
ixAppend (i ::? idx) idx' = i ::? ixAppend idx idx'
-ixDrop :: IxX (sh ++ sh') -> IxX sh -> IxX sh'
+ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh'
ixDrop sh IZX = sh
ixDrop (_ ::@ sh) (_ ::@ idx) = ixDrop sh idx
ixDrop (_ ::? sh) (_ ::? idx) = ixDrop sh idx
@@ -109,25 +111,25 @@ ssxAppend SZX sh' = sh'
ssxAppend (n :$@ sh) sh' = n :$@ ssxAppend sh sh'
ssxAppend (() :$? sh) sh' = () :$? ssxAppend sh sh'
-shapeSize :: IxX sh -> Int
+shapeSize :: IIxX sh -> Int
shapeSize IZX = 1
shapeSize (n ::@ sh) = n * shapeSize sh
shapeSize (n ::? sh) = n * shapeSize sh
-- | This may fail if @sh@ has @Nothing@s in it.
-ssxToShape' :: StaticShapeX sh -> Maybe (IxX sh)
+ssxToShape' :: StaticShapeX sh -> Maybe (IIxX sh)
ssxToShape' SZX = Just IZX
ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) ::@) <$> ssxToShape' sh
ssxToShape' (_ :$? _) = Nothing
-fromLinearIdx :: IxX sh -> Int -> IxX sh
+fromLinearIdx :: IIxX sh -> Int -> IIxX sh
fromLinearIdx = \sh i -> case go sh i of
(idx, 0) -> idx
_ -> error $ "fromLinearIdx: out of range (" ++ show i ++
" in array of shape " ++ show sh ++ ")"
where
-- returns (index in subarray, remaining index in enclosing array)
- go :: IxX sh -> Int -> (IxX sh, Int)
+ go :: IIxX sh -> Int -> (IIxX sh, Int)
go IZX i = (IZX, i)
go (n ::@ sh) i =
let (idx, i') = go sh i
@@ -138,11 +140,11 @@ fromLinearIdx = \sh i -> case go sh i of
(upi, locali) = i' `quotRem` n
in (locali ::? idx, upi)
-toLinearIdx :: IxX sh -> IxX sh -> Int
+toLinearIdx :: IIxX sh -> IIxX sh -> Int
toLinearIdx = \sh i -> fst (go sh i)
where
-- returns (index in subarray, size of subarray)
- go :: IxX sh -> IxX sh -> (Int, Int)
+ go :: IIxX sh -> IIxX sh -> (Int, Int)
go IZX IZX = (0, 1)
go (n ::@ sh) (i ::@ ix) =
let (lidx, sz) = go sh ix
@@ -151,15 +153,15 @@ toLinearIdx = \sh i -> fst (go sh i)
let (lidx, sz) = go sh ix
in (sz * i + lidx, n * sz)
-enumShape :: IxX sh -> [IxX sh]
+enumShape :: IIxX sh -> [IIxX sh]
enumShape = \sh -> go sh id []
where
- go :: IxX sh -> (IxX sh -> a) -> [a] -> [a]
+ go :: IIxX sh -> (IIxX sh -> a) -> [a] -> [a]
go IZX f = (f IZX :)
go (n ::@ sh) f = foldr (.) id [go sh (f . (i ::@)) | i <- [0 .. n-1]]
go (n ::? sh) f = foldr (.) id [go sh (f . (i ::?)) | i <- [0 .. n-1]]
-shapeLshape :: IxX sh -> S.ShapeL
+shapeLshape :: IIxX sh -> S.ShapeL
shapeLshape IZX = []
shapeLshape (n ::@ sh) = n : shapeLshape sh
shapeLshape (n ::? sh) = n : shapeLshape sh
@@ -182,7 +184,7 @@ lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2
-> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1))
lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
-lemKnownINatRank :: IxX sh -> Dict KnownINat (Rank sh)
+lemKnownINatRank :: IIxX sh -> Dict KnownINat (Rank sh)
lemKnownINatRank IZX = Dict
lemKnownINatRank (_ ::@ sh) | Dict <- lemKnownINatRank sh = Dict
lemKnownINatRank (_ ::? sh) | Dict <- lemKnownINatRank sh = Dict
@@ -206,16 +208,16 @@ lemAppKnownShapeX (() :$? ssh) ssh'
| Dict <- lemAppKnownShapeX ssh ssh'
= Dict
-shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh
+shape :: forall sh a. KnownShapeX sh => XArray sh a -> IIxX sh
shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)
where
- go :: StaticShapeX sh' -> [Int] -> IxX sh'
+ go :: StaticShapeX sh' -> [Int] -> IIxX sh'
go SZX [] = IZX
go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) ::@ go ssh l
go (() :$? ssh) (n : l) = n ::? go ssh l
go _ _ = error "Invalid shapeL"
-fromVector :: forall sh a. Storable a => IxX sh -> VS.Vector a -> XArray sh a
+fromVector :: forall sh a. Storable a => IIxX sh -> VS.Vector a -> XArray sh a
fromVector sh v
| Dict <- lemKnownINatRank sh
, Dict <- knownNatFromINat (Proxy @(Rank sh))
@@ -230,26 +232,26 @@ scalar = XArray . S.scalar
unScalar :: Storable a => XArray '[] a -> a
unScalar (XArray a) = S.unScalar a
-constant :: forall sh a. Storable a => IxX sh -> a -> XArray sh a
+constant :: forall sh a. Storable a => IIxX sh -> a -> XArray sh a
constant sh x
| Dict <- lemKnownINatRank sh
, Dict <- knownNatFromINat (Proxy @(Rank sh))
= XArray (S.constant (shapeLshape sh) x)
-generate :: Storable a => IxX sh -> (IxX sh -> a) -> XArray sh a
+generate :: Storable a => IIxX sh -> (IIxX sh -> a) -> XArray sh a
generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)
--- generateM :: (Monad m, Storable a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a)
+-- generateM :: (Monad m, Storable a) => IIxX sh -> (IIxX sh -> m a) -> m (XArray sh a)
-- generateM sh f | Dict <- lemKnownINatRank sh =
-- XArray . S.fromVector (shapeLshape sh)
-- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh)
-indexPartial :: Storable a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a
+indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a
indexPartial (XArray arr) IZX = XArray arr
indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (S.index arr i)) idx
indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (S.index arr i)) idx
-index :: forall sh a. Storable a => XArray sh a -> IxX sh -> a
+index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a
index xarr i
| Refl <- lemAppNil @sh
= let XArray arr' = indexPartial xarr i :: XArray '[] a
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index ada8751..148acf5 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -2,7 +2,7 @@
module Data.Array.Nested (
-- * Ranked arrays
Ranked,
- IxR(..),
+ IxR(..), IIxR,
rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
rtranspose, rappend, rscalar, rfromVector, runScalar,
rconstant, rfromList, rfromList1, rtoList, rtoList1,
@@ -12,7 +12,7 @@ module Data.Array.Nested (
-- * Shaped arrays
Shaped,
- IxS(..),
+ IxS(..), IIxS,
KnownShape(..), SShape(..),
sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
stranspose, sappend, sscalar, sfromVector, sunScalar,
@@ -23,7 +23,7 @@ module Data.Array.Nested (
-- * Mixed arrays
Mixed,
- IxX(..),
+ IxX(..), IIxX,
KnownShapeX(..), StaticShapeX(..),
mgenerate, mtranspose, mappend, mfromVector, munScalar,
mconstant, mfromList1, mtoList1, mslice,
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index fadf1a7..9a87389 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -47,7 +47,7 @@ import qualified Data.Vector.Storable.Mutable as VSM
import Foreign.Storable (Storable)
import GHC.TypeLits
-import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat)
+import Data.Array.Mixed (XArray, IxX(..), IIxX, KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat)
import qualified Data.Array.Mixed as X
import Data.INat
@@ -105,7 +105,7 @@ lemReplicatePlusApp _ _ _ = go (inatSing @n)
go SZ = Refl
go (SS n) | Refl <- go n = Refl
-ixAppSplit :: Proxy sh' -> StaticShapeX sh -> IxX (sh ++ sh') -> (IxX sh, IxX sh')
+ixAppSplit :: Proxy sh' -> StaticShapeX sh -> IIxX (sh ++ sh') -> (IIxX sh, IIxX sh')
ixAppSplit _ SZX idx = (IZX, idx)
ixAppSplit p (_ :$@ ssh) (i ::@ idx) = first (i ::@) (ixAppSplit p ssh idx)
ixAppSplit p (_ :$? ssh) (i ::? idx) = first (i ::?) (ixAppSplit p ssh idx)
@@ -168,7 +168,7 @@ newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MV
data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b)
-- etc.
-data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IxX sh2) !(MixedVecs s (sh1 ++ sh2) a)
+data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IIxX sh2) !(MixedVecs s (sh1 ++ sh2) a)
-- | Tree giving the shape of every array component.
@@ -179,9 +179,9 @@ type family ShapeTree a where
ShapeTree () = ()
ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
- ShapeTree (Mixed sh a) = (IxX sh, ShapeTree a)
- ShapeTree (Ranked n a) = (IxR n, ShapeTree a)
- ShapeTree (Shaped sh a) = (IxS sh, ShapeTree a)
+ ShapeTree (Mixed sh a) = (IIxX sh, ShapeTree a)
+ ShapeTree (Ranked n a) = (IIxR n, ShapeTree a)
+ ShapeTree (Shaped sh a) = (IIxS sh, ShapeTree a)
-- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or
@@ -190,9 +190,9 @@ type family ShapeTree a where
class Elt a where
-- ====== PUBLIC METHODS ====== --
- mshape :: KnownShapeX sh => Mixed sh a -> IxX sh
- mindex :: Mixed sh a -> IxX sh -> a
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a
+ mshape :: KnownShapeX sh => Mixed sh a -> IIxX sh
+ mindex :: Mixed sh a -> IIxX sh -> a
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
mscalar :: a -> Mixed '[] a
-- | All arrays in the list, even subarrays inside @a@, must have the same
@@ -226,7 +226,7 @@ class Elt a where
-- Remember I said that this module needed better management of exports?
-- | Create an empty array. The given shape must have size zero; this may or may not be checked.
- memptyArray :: IxX sh -> Mixed sh a
+ memptyArray :: IIxX sh -> Mixed sh a
mshapeTree :: a -> ShapeTree a
@@ -240,20 +240,20 @@ class Elt a where
-- | Create uninitialised vectors for this array type, given the shape of
-- this vector and an example for the contents.
- mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a)
+ mvecsUnsafeNew :: IIxX sh -> a -> ST s (MixedVecs s sh a)
mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
- mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s ()
+ mvecsWrite :: IIxX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
-- | Given the shape of this array, an index and a value, write the value at
-- that index in the vectors.
- mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+ mvecsWritePartial :: KnownShapeX sh' => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-- | Given the shape of this array, finalise the vectors into 'XArray's.
- mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+ mvecsFreeze :: IIxX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
-- Arrays of scalars are basically just arrays of scalars.
@@ -295,7 +295,7 @@ instance Storable a => Elt (Primitive a) where
-- TODO: this use of toVector is suboptimal
mvecsWritePartial
:: forall sh' sh s. KnownShapeX sh'
- => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do
let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' (X.shape arr)))
VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)
@@ -339,7 +339,7 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
-- TODO: this is quadratic in the nesting depth because it repeatedly
-- truncates the shape vector to one a little shorter. Fix with a
-- moverlongShape method, a prefix of which is mshape.
- mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh
+ mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IIxX sh
mshape (M_Nest arr)
| Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
= fst (ixAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr))
@@ -347,7 +347,7 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
mindex (M_Nest arr) i = mindexPartial arr i
mindexPartial :: forall sh1 sh2.
- Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
+ Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
mindexPartial (M_Nest arr) i
| Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
= M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
@@ -416,7 +416,7 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs
mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
- => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
+ => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
-> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
-> ST s ()
mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)
@@ -428,7 +428,7 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where
-- | Check whether a given shape corresponds on the statically-known components of the shape.
-checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool
+checkBounds :: IIxX sh' -> StaticShapeX sh' -> Bool
checkBounds IZX SZX = True
checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh'
checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
@@ -450,7 +450,7 @@ checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
-- the entire hierarchy (after distributing out tuples) must be a rectangular
-- array. The type of 'mgenerate' allows this requirement to be broken very
-- easily, hence the runtime check.
-mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IxX sh -> (IxX sh -> a) -> Mixed sh a
+mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IIxX sh -> (IIxX sh -> a) -> Mixed sh a
mgenerate sh f
-- TODO: Do we need this checkBounds check elsewhere as well?
| not (checkBounds sh (knownShapeX @sh)) =
@@ -488,7 +488,7 @@ mappend = mlift2 go
=> Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b
go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append
-mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IxX sh -> VS.Vector a -> Mixed sh (Primitive a)
+mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> VS.Vector a -> Mixed sh (Primitive a)
mfromVector sh v
| not (checkBounds sh (knownShapeX @sh)) =
error $ "mfromVector: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
@@ -504,7 +504,7 @@ mtoList1 = map munScalar . mtoList
munScalar :: Elt a => Mixed '[] a -> a
munScalar arr = mindex arr IZX
-mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IxX sh -> a -> Mixed sh (Primitive a)
+mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> a -> Mixed sh (Primitive a)
mconstantP sh x
| not (checkBounds sh (knownShapeX @sh)) =
error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
@@ -514,7 +514,7 @@ mconstantP sh x
-- | This 'Coercible' constraint holds for the scalar types for which 'Mixed'
-- is defined.
mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a))
- => IxX sh -> a -> Mixed sh a
+ => IIxX sh -> a -> Mixed sh a
mconstant sh x = coerce (mconstantP sh x)
mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a
@@ -594,7 +594,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a)
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
mindexPartial (M_Ranked arr) i
| Dict <- lemKnownReplicate (Proxy @n)
= coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
@@ -629,7 +629,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
= coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
mlift2 f arr1 arr2
- memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a)
+ memptyArray :: forall sh. IIxX sh -> Mixed sh (Ranked n a)
memptyArray i
| Dict <- lemKnownReplicate (Proxy @n)
= coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
@@ -656,7 +656,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
| Dict <- lemKnownReplicate (Proxy @n)
= MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
- mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWrite :: forall sh s. IIxX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
mvecsWrite sh idx (Ranked arr) vecs
| Dict <- lemKnownReplicate (Proxy @n)
= mvecsWrite sh idx arr
@@ -664,7 +664,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
vecs)
mvecsWritePartial :: forall sh sh' s. KnownShapeX sh'
- => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a)
+ => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
-> MixedVecs s (sh ++ sh') (Ranked n a)
-> ST s ()
mvecsWritePartial sh idx arr vecs
@@ -677,7 +677,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where
@(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
vecs)
- mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsFreeze :: forall sh s. IIxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
mvecsFreeze sh vecs
| Dict <- lemKnownReplicate (Proxy @n)
= coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
@@ -723,7 +723,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr
mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i)
- mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IxX sh1 -> Mixed sh2 (Shaped sh a)
+ mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
mindexPartial (M_Shaped arr) i
| Dict <- lemKnownMapJust (Proxy @sh)
= coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
@@ -758,7 +758,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
= coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
mlift2 f arr1 arr2
- memptyArray :: forall sh'. IxX sh' -> Mixed sh' (Shaped sh a)
+ memptyArray :: forall sh'. IIxX sh' -> Mixed sh' (Shaped sh a)
memptyArray i
| Dict <- lemKnownMapJust (Proxy @sh)
= coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
@@ -784,7 +784,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
| Dict <- lemKnownMapJust (Proxy @sh)
= MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
- mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWrite :: forall sh' s. IIxX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
mvecsWrite sh idx (Shaped arr) vecs
| Dict <- lemKnownMapJust (Proxy @sh)
= mvecsWrite sh idx arr
@@ -792,7 +792,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
vecs)
mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
- => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Shaped sh a)
+ => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
-> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
-> ST s ()
mvecsWritePartial sh idx arr vecs
@@ -805,7 +805,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
@(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
vecs)
- mvecsFreeze :: forall sh' s. IxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsFreeze :: forall sh' s. IIxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
mvecsFreeze sh vecs
| Dict <- lemKnownMapJust (Proxy @sh)
= coerce @(Mixed sh' (Mixed (MapJust sh) a))
@@ -850,46 +850,48 @@ deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int)
deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double)
-- | An index into a rank-typed array.
-type IxR :: INat -> Type
-data IxR n where
- IZR :: IxR Z
- (:::) :: Int -> IxR n -> IxR (S n)
-deriving instance Show (IxR n)
-deriving instance Eq (IxR n)
-infixr 5 :::
-
-zeroIxR :: SINat n -> IxR n
+type IxR :: Type -> INat -> Type
+data IxR i n where
+ IZR :: IxR i Z
+ (:::) :: forall n i. i -> IxR i n -> IxR i (S n)
+deriving instance Show i => Show (IxR i n)
+deriving instance Eq i => Eq (IxR i n)
+infixr 3 :::
+
+type IIxR = IxR Int
+
+zeroIxR :: SINat n -> IIxR n
zeroIxR SZ = IZR
zeroIxR (SS n) = 0 ::: zeroIxR n
-ixCvtXR :: IxX sh -> IxR (X.Rank sh)
+ixCvtXR :: IIxX sh -> IIxR (X.Rank sh)
ixCvtXR IZX = IZR
ixCvtXR (n ::@ idx) = n ::: ixCvtXR idx
ixCvtXR (n ::? idx) = n ::: ixCvtXR idx
-ixCvtRX :: IxR n -> IxX (Replicate n Nothing)
+ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
ixCvtRX IZR = IZX
ixCvtRX (n ::: idx) = n ::? ixCvtRX idx
-knownIxR :: IxR n -> Dict KnownINat n
+knownIxR :: IIxR n -> Dict KnownINat n
knownIxR IZR = Dict
knownIxR (_ ::: idx) | Dict <- knownIxR idx = Dict
-shapeSizeR :: IxR n -> Int
+shapeSizeR :: IIxR n -> Int
shapeSizeR IZR = 1
shapeSizeR (n ::: sh) = n * shapeSizeR sh
-rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IxR n
+rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IIxR n
rshape (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
, Refl <- lemRankReplicate (Proxy @n)
= ixCvtXR (mshape arr)
-rindex :: Elt a => Ranked n a -> IxR n -> a
+rindex :: Elt a => Ranked n a -> IIxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
-rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IxR n -> Ranked m a
+rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IIxR n -> Ranked m a
rindexPartial (Ranked arr) idx =
Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
(rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
@@ -897,7 +899,7 @@ rindexPartial (Ranked arr) idx =
-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
-rgenerate :: forall n a. Elt a => IxR n -> (IxR n -> a) -> Ranked n a
+rgenerate :: forall n a. Elt a => IIxR n -> (IIxR n -> a) -> Ranked n a
rgenerate sh f
| Dict <- knownIxR sh
, Dict <- lemKnownReplicate (Proxy @n)
@@ -935,7 +937,7 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
rscalar :: Elt a => a -> Ranked I0 a
rscalar x = Ranked (mscalar x)
-rfromVector :: forall n a. (KnownINat n, Storable a) => IxR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVector :: forall n a. (KnownINat n, Storable a) => IIxR n -> VS.Vector a -> Ranked n (Primitive a)
rfromVector sh v
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mfromVector (ixCvtRX sh) v)
@@ -957,13 +959,13 @@ rtoList1 = map runScalar . rtoList
runScalar :: Elt a => Ranked I0 a -> a
runScalar arr = rindex arr IZR
-rconstantP :: forall n a. (KnownINat n, Storable a) => IxR n -> a -> Ranked n (Primitive a)
+rconstantP :: forall n a. (KnownINat n, Storable a) => IIxR n -> a -> Ranked n (Primitive a)
rconstantP sh x
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mconstantP (ixCvtRX sh) x)
rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a))
- => IxR n -> a -> Ranked n a
+ => IIxR n -> a -> Ranked n a
rconstant sh x = coerce (rconstantP sh x)
rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a
@@ -1000,43 +1002,45 @@ deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped
-- (traditionally called \"@Fin@\"). Note that because the shape of a
-- shape-typed array is known statically, you can also retrieve the array shape
-- from a 'KnownShape' dictionary.
-type IxS :: [Nat] -> Type
-data IxS sh where
- IZS :: IxS '[]
- (::$) :: Int -> IxS sh -> IxS (n : sh)
-deriving instance Show (IxS n)
-deriving instance Eq (IxS n)
-infixr 5 ::$
-
-zeroIxS :: SShape sh -> IxS sh
+type IxS :: Type -> [Nat] -> Type
+data IxS i sh where
+ IZS :: IxS i '[]
+ (::$) :: forall n sh i. i -> IxS i sh -> IxS i (n : sh)
+deriving instance Show i => Show (IxS i n)
+deriving instance Eq i => Eq (IxS i n)
+infixr 3 ::$
+
+type IIxS = IxS Int
+
+zeroIxS :: SShape sh -> IIxS sh
zeroIxS ShNil = IZS
zeroIxS (ShCons _ sh) = 0 ::$ zeroIxS sh
-cvtSShapeIxS :: SShape sh -> IxS sh
+cvtSShapeIxS :: SShape sh -> IIxS sh
cvtSShapeIxS ShNil = IZS
cvtSShapeIxS (ShCons n sh) = fromIntegral (fromSNat n) ::$ cvtSShapeIxS sh
-ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh
+ixCvtXS :: SShape sh -> IIxX (MapJust sh) -> IIxS sh
ixCvtXS ShNil IZX = IZS
ixCvtXS (ShCons _ sh) (n ::@ idx) = n ::$ ixCvtXS sh idx
-ixCvtSX :: IxS sh -> IxX (MapJust sh)
+ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
ixCvtSX IZS = IZX
ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh
-shapeSizeS :: IxS sh -> Int
+shapeSizeS :: IIxS sh -> Int
shapeSizeS IZS = 1
shapeSizeS (n ::$ sh) = n * shapeSizeS sh
-- | This does not touch the passed array, all information comes from 'KnownShape'.
-sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh
+sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IIxS sh
sshape _ = cvtSShapeIxS (knownShape @sh)
-sindex :: Elt a => Shaped sh a -> IxS sh -> a
+sindex :: Elt a => Shaped sh a -> IIxS sh -> a
sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
-sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a
+sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
sindexPartial (Shaped arr) idx =
Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
(rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr)
@@ -1044,7 +1048,7 @@ sindexPartial (Shaped arr) idx =
-- | __WARNING__: All values returned from the function must have equal shape.
-- See the documentation of 'mgenerate' for more details.
-sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IxS sh -> a) -> Shaped sh a
+sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IIxS sh -> a) -> Shaped sh a
sgenerate f
| Dict <- lemKnownMapJust (Proxy @sh)
= Shaped (mgenerate (ixCvtSX (cvtSShapeIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh)))