diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-04-21 15:45:47 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-04-21 15:45:47 +0200 |
commit | 3a82a91be0f1b18f071cdb35526b2b2d0b8e093f (patch) | |
tree | 6a65376a17ef4051446952541bab57149c66fca7 /src | |
parent | 181c02b60445310105be126fd6cc2ee9f5dc2c8a (diff) |
Make index types useful for horde-ad by parameterising Int
Diffstat (limited to 'src')
-rw-r--r-- | src/Data/Array/Mixed.hs | 68 | ||||
-rw-r--r-- | src/Data/Array/Nested.hs | 6 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 144 |
3 files changed, 112 insertions, 106 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 164f832..8d20583 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -45,15 +45,17 @@ lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl -type IxX :: [Maybe Nat] -> Type -data IxX sh where - IZX :: IxX '[] - (::@) :: Int -> IxX sh -> IxX (Just n : sh) - (::?) :: Int -> IxX sh -> IxX (Nothing : sh) -deriving instance Show (IxX sh) -deriving instance Eq (IxX sh) -infixr 5 ::@ -infixr 5 ::? +type IxX :: Type -> [Maybe Nat] -> Type +data IxX i sh where + IZX :: IxX i '[] + (::@) :: forall n sh i. i -> IxX i sh -> IxX i (Just n : sh) + (::?) :: forall sh i. i -> IxX i sh -> IxX i (Nothing : sh) +deriving instance Show i => Show (IxX i sh) +deriving instance Eq i => Eq (IxX i sh) +infixr 3 ::@ +infixr 3 ::? + +type IIxX = IxX Int -- | The part of a shape that is statically known. type StaticShapeX :: [Maybe Nat] -> Type @@ -62,8 +64,8 @@ data StaticShapeX sh where (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) deriving instance Show (StaticShapeX sh) -infixr 5 :$@ -infixr 5 :$? +infixr 3 :$@ +infixr 3 :$? -- | Evidence for the static part of a shape. type KnownShapeX :: [Maybe Nat] -> Constraint @@ -84,22 +86,22 @@ type XArray :: [Maybe Nat] -> Type -> Type newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a) deriving (Show) -zeroIxX :: StaticShapeX sh -> IxX sh +zeroIxX :: StaticShapeX sh -> IIxX sh zeroIxX SZX = IZX zeroIxX (_ :$@ ssh) = 0 ::@ zeroIxX ssh zeroIxX (_ :$? ssh) = 0 ::? zeroIxX ssh -zeroIxX' :: IxX sh -> IxX sh +zeroIxX' :: IIxX sh -> IIxX sh zeroIxX' IZX = IZX zeroIxX' (_ ::@ sh) = 0 ::@ zeroIxX' sh zeroIxX' (_ ::? sh) = 0 ::? zeroIxX' sh -ixAppend :: IxX sh -> IxX sh' -> IxX (sh ++ sh') +ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh') ixAppend IZX idx' = idx' ixAppend (i ::@ idx) idx' = i ::@ ixAppend idx idx' ixAppend (i ::? idx) idx' = i ::? ixAppend idx idx' -ixDrop :: IxX (sh ++ sh') -> IxX sh -> IxX sh' +ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh' ixDrop sh IZX = sh ixDrop (_ ::@ sh) (_ ::@ idx) = ixDrop sh idx ixDrop (_ ::? sh) (_ ::? idx) = ixDrop sh idx @@ -109,25 +111,25 @@ ssxAppend SZX sh' = sh' ssxAppend (n :$@ sh) sh' = n :$@ ssxAppend sh sh' ssxAppend (() :$? sh) sh' = () :$? ssxAppend sh sh' -shapeSize :: IxX sh -> Int +shapeSize :: IIxX sh -> Int shapeSize IZX = 1 shapeSize (n ::@ sh) = n * shapeSize sh shapeSize (n ::? sh) = n * shapeSize sh -- | This may fail if @sh@ has @Nothing@s in it. -ssxToShape' :: StaticShapeX sh -> Maybe (IxX sh) +ssxToShape' :: StaticShapeX sh -> Maybe (IIxX sh) ssxToShape' SZX = Just IZX ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) ::@) <$> ssxToShape' sh ssxToShape' (_ :$? _) = Nothing -fromLinearIdx :: IxX sh -> Int -> IxX sh +fromLinearIdx :: IIxX sh -> Int -> IIxX sh fromLinearIdx = \sh i -> case go sh i of (idx, 0) -> idx _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ " in array of shape " ++ show sh ++ ")" where -- returns (index in subarray, remaining index in enclosing array) - go :: IxX sh -> Int -> (IxX sh, Int) + go :: IIxX sh -> Int -> (IIxX sh, Int) go IZX i = (IZX, i) go (n ::@ sh) i = let (idx, i') = go sh i @@ -138,11 +140,11 @@ fromLinearIdx = \sh i -> case go sh i of (upi, locali) = i' `quotRem` n in (locali ::? idx, upi) -toLinearIdx :: IxX sh -> IxX sh -> Int +toLinearIdx :: IIxX sh -> IIxX sh -> Int toLinearIdx = \sh i -> fst (go sh i) where -- returns (index in subarray, size of subarray) - go :: IxX sh -> IxX sh -> (Int, Int) + go :: IIxX sh -> IIxX sh -> (Int, Int) go IZX IZX = (0, 1) go (n ::@ sh) (i ::@ ix) = let (lidx, sz) = go sh ix @@ -151,15 +153,15 @@ toLinearIdx = \sh i -> fst (go sh i) let (lidx, sz) = go sh ix in (sz * i + lidx, n * sz) -enumShape :: IxX sh -> [IxX sh] +enumShape :: IIxX sh -> [IIxX sh] enumShape = \sh -> go sh id [] where - go :: IxX sh -> (IxX sh -> a) -> [a] -> [a] + go :: IIxX sh -> (IIxX sh -> a) -> [a] -> [a] go IZX f = (f IZX :) go (n ::@ sh) f = foldr (.) id [go sh (f . (i ::@)) | i <- [0 .. n-1]] go (n ::? sh) f = foldr (.) id [go sh (f . (i ::?)) | i <- [0 .. n-1]] -shapeLshape :: IxX sh -> S.ShapeL +shapeLshape :: IIxX sh -> S.ShapeL shapeLshape IZX = [] shapeLshape (n ::@ sh) = n : shapeLshape sh shapeLshape (n ::? sh) = n : shapeLshape sh @@ -182,7 +184,7 @@ lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2 -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1)) lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this -lemKnownINatRank :: IxX sh -> Dict KnownINat (Rank sh) +lemKnownINatRank :: IIxX sh -> Dict KnownINat (Rank sh) lemKnownINatRank IZX = Dict lemKnownINatRank (_ ::@ sh) | Dict <- lemKnownINatRank sh = Dict lemKnownINatRank (_ ::? sh) | Dict <- lemKnownINatRank sh = Dict @@ -206,16 +208,16 @@ lemAppKnownShapeX (() :$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict -shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh +shape :: forall sh a. KnownShapeX sh => XArray sh a -> IIxX sh shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) where - go :: StaticShapeX sh' -> [Int] -> IxX sh' + go :: StaticShapeX sh' -> [Int] -> IIxX sh' go SZX [] = IZX go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) ::@ go ssh l go (() :$? ssh) (n : l) = n ::? go ssh l go _ _ = error "Invalid shapeL" -fromVector :: forall sh a. Storable a => IxX sh -> VS.Vector a -> XArray sh a +fromVector :: forall sh a. Storable a => IIxX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownINatRank sh , Dict <- knownNatFromINat (Proxy @(Rank sh)) @@ -230,26 +232,26 @@ scalar = XArray . S.scalar unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a -constant :: forall sh a. Storable a => IxX sh -> a -> XArray sh a +constant :: forall sh a. Storable a => IIxX sh -> a -> XArray sh a constant sh x | Dict <- lemKnownINatRank sh , Dict <- knownNatFromINat (Proxy @(Rank sh)) = XArray (S.constant (shapeLshape sh) x) -generate :: Storable a => IxX sh -> (IxX sh -> a) -> XArray sh a +generate :: Storable a => IIxX sh -> (IIxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) --- generateM :: (Monad m, Storable a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a) +-- generateM :: (Monad m, Storable a) => IIxX sh -> (IIxX sh -> m a) -> m (XArray sh a) -- generateM sh f | Dict <- lemKnownINatRank sh = -- XArray . S.fromVector (shapeLshape sh) -- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) -indexPartial :: Storable a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a +indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a indexPartial (XArray arr) IZX = XArray arr indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (S.index arr i)) idx indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (S.index arr i)) idx -index :: forall sh a. Storable a => XArray sh a -> IxX sh -> a +index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a index xarr i | Refl <- lemAppNil @sh = let XArray arr' = indexPartial xarr i :: XArray '[] a diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index ada8751..148acf5 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -2,7 +2,7 @@ module Data.Array.Nested ( -- * Ranked arrays Ranked, - IxR(..), + IxR(..), IIxR, rshape, rindex, rindexPartial, rgenerate, rsumOuter1, rtranspose, rappend, rscalar, rfromVector, runScalar, rconstant, rfromList, rfromList1, rtoList, rtoList1, @@ -12,7 +12,7 @@ module Data.Array.Nested ( -- * Shaped arrays Shaped, - IxS(..), + IxS(..), IIxS, KnownShape(..), SShape(..), sshape, sindex, sindexPartial, sgenerate, ssumOuter1, stranspose, sappend, sscalar, sfromVector, sunScalar, @@ -23,7 +23,7 @@ module Data.Array.Nested ( -- * Mixed arrays Mixed, - IxX(..), + IxX(..), IIxX, KnownShapeX(..), StaticShapeX(..), mgenerate, mtranspose, mappend, mfromVector, munScalar, mconstant, mfromList1, mtoList1, mslice, diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index fadf1a7..9a87389 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -47,7 +47,7 @@ import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Storable (Storable) import GHC.TypeLits -import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat) +import Data.Array.Mixed (XArray, IxX(..), IIxX, KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat) import qualified Data.Array.Mixed as X import Data.INat @@ -105,7 +105,7 @@ lemReplicatePlusApp _ _ _ = go (inatSing @n) go SZ = Refl go (SS n) | Refl <- go n = Refl -ixAppSplit :: Proxy sh' -> StaticShapeX sh -> IxX (sh ++ sh') -> (IxX sh, IxX sh') +ixAppSplit :: Proxy sh' -> StaticShapeX sh -> IIxX (sh ++ sh') -> (IIxX sh, IIxX sh') ixAppSplit _ SZX idx = (IZX, idx) ixAppSplit p (_ :$@ ssh) (i ::@ idx) = first (i ::@) (ixAppSplit p ssh idx) ixAppSplit p (_ :$? ssh) (i ::? idx) = first (i ::?) (ixAppSplit p ssh idx) @@ -168,7 +168,7 @@ newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MV data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) -- etc. -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IxX sh2) !(MixedVecs s (sh1 ++ sh2) a) +data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IIxX sh2) !(MixedVecs s (sh1 ++ sh2) a) -- | Tree giving the shape of every array component. @@ -179,9 +179,9 @@ type family ShapeTree a where ShapeTree () = () ShapeTree (a, b) = (ShapeTree a, ShapeTree b) - ShapeTree (Mixed sh a) = (IxX sh, ShapeTree a) - ShapeTree (Ranked n a) = (IxR n, ShapeTree a) - ShapeTree (Shaped sh a) = (IxS sh, ShapeTree a) + ShapeTree (Mixed sh a) = (IIxX sh, ShapeTree a) + ShapeTree (Ranked n a) = (IIxR n, ShapeTree a) + ShapeTree (Shaped sh a) = (IIxS sh, ShapeTree a) -- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or @@ -190,9 +190,9 @@ type family ShapeTree a where class Elt a where -- ====== PUBLIC METHODS ====== -- - mshape :: KnownShapeX sh => Mixed sh a -> IxX sh - mindex :: Mixed sh a -> IxX sh -> a - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a + mshape :: KnownShapeX sh => Mixed sh a -> IIxX sh + mindex :: Mixed sh a -> IIxX sh -> a + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a mscalar :: a -> Mixed '[] a -- | All arrays in the list, even subarrays inside @a@, must have the same @@ -226,7 +226,7 @@ class Elt a where -- Remember I said that this module needed better management of exports? -- | Create an empty array. The given shape must have size zero; this may or may not be checked. - memptyArray :: IxX sh -> Mixed sh a + memptyArray :: IIxX sh -> Mixed sh a mshapeTree :: a -> ShapeTree a @@ -240,20 +240,20 @@ class Elt a where -- | Create uninitialised vectors for this array type, given the shape of -- this vector and an example for the contents. - mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a) + mvecsUnsafeNew :: IIxX sh -> a -> ST s (MixedVecs s sh a) mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. - mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s () + mvecsWrite :: IIxX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. - mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + mvecsWritePartial :: KnownShapeX sh' => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () -- | Given the shape of this array, finalise the vectors into 'XArray's. - mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + mvecsFreeze :: IIxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) -- Arrays of scalars are basically just arrays of scalars. @@ -295,7 +295,7 @@ instance Storable a => Elt (Primitive a) where -- TODO: this use of toVector is suboptimal mvecsWritePartial :: forall sh' sh s. KnownShapeX sh' - => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' (X.shape arr))) VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) @@ -339,7 +339,7 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where -- TODO: this is quadratic in the nesting depth because it repeatedly -- truncates the shape vector to one a little shorter. Fix with a -- moverlongShape method, a prefix of which is mshape. - mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh + mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IIxX sh mshape (M_Nest arr) | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = fst (ixAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr)) @@ -347,7 +347,7 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where mindex (M_Nest arr) i = mindexPartial arr i mindexPartial :: forall sh1 sh2. - Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) + Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest arr) i | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) @@ -416,7 +416,7 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) + => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -> ST s () mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) @@ -428,7 +428,7 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where -- | Check whether a given shape corresponds on the statically-known components of the shape. -checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool +checkBounds :: IIxX sh' -> StaticShapeX sh' -> Bool checkBounds IZX SZX = True checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh' checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' @@ -450,7 +450,7 @@ checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' -- the entire hierarchy (after distributing out tuples) must be a rectangular -- array. The type of 'mgenerate' allows this requirement to be broken very -- easily, hence the runtime check. -mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IxX sh -> (IxX sh -> a) -> Mixed sh a +mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IIxX sh -> (IIxX sh -> a) -> Mixed sh a mgenerate sh f -- TODO: Do we need this checkBounds check elsewhere as well? | not (checkBounds sh (knownShapeX @sh)) = @@ -488,7 +488,7 @@ mappend = mlift2 go => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append -mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IxX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> VS.Vector a -> Mixed sh (Primitive a) mfromVector sh v | not (checkBounds sh (knownShapeX @sh)) = error $ "mfromVector: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) @@ -504,7 +504,7 @@ mtoList1 = map munScalar . mtoList munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr IZX -mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IxX sh -> a -> Mixed sh (Primitive a) +mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> a -> Mixed sh (Primitive a) mconstantP sh x | not (checkBounds sh (knownShapeX @sh)) = error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) @@ -514,7 +514,7 @@ mconstantP sh x -- | This 'Coercible' constraint holds for the scalar types for which 'Mixed' -- is defined. mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a)) - => IxX sh -> a -> Mixed sh a + => IIxX sh -> a -> Mixed sh a mconstant sh x = coerce (mconstantP sh x) mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a @@ -594,7 +594,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a) + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) mindexPartial (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ @@ -629,7 +629,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where = coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ mlift2 f arr1 arr2 - memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a) + memptyArray :: forall sh. IIxX sh -> Mixed sh (Ranked n a) memptyArray i | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ @@ -656,7 +656,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where | Dict <- lemKnownReplicate (Proxy @n) = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) - mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWrite :: forall sh s. IIxX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () mvecsWrite sh idx (Ranked arr) vecs | Dict <- lemKnownReplicate (Proxy @n) = mvecsWrite sh idx arr @@ -664,7 +664,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where vecs) mvecsWritePartial :: forall sh sh' s. KnownShapeX sh' - => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a) + => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) -> MixedVecs s (sh ++ sh') (Ranked n a) -> ST s () mvecsWritePartial sh idx arr vecs @@ -677,7 +677,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) vecs) - mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsFreeze :: forall sh s. IIxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) mvecsFreeze sh vecs | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @@ -723,7 +723,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i) - mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IxX sh1 -> Mixed sh2 (Shaped sh a) + mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) mindexPartial (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ @@ -758,7 +758,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where = coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ mlift2 f arr1 arr2 - memptyArray :: forall sh'. IxX sh' -> Mixed sh' (Shaped sh a) + memptyArray :: forall sh'. IIxX sh' -> Mixed sh' (Shaped sh a) memptyArray i | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ @@ -784,7 +784,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) - mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWrite :: forall sh' s. IIxX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () mvecsWrite sh idx (Shaped arr) vecs | Dict <- lemKnownMapJust (Proxy @sh) = mvecsWrite sh idx arr @@ -792,7 +792,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where vecs) mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Shaped sh a) + => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) -> ST s () mvecsWritePartial sh idx arr vecs @@ -805,7 +805,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) vecs) - mvecsFreeze :: forall sh' s. IxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsFreeze :: forall sh' s. IIxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) mvecsFreeze sh vecs | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @@ -850,46 +850,48 @@ deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int) deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double) -- | An index into a rank-typed array. -type IxR :: INat -> Type -data IxR n where - IZR :: IxR Z - (:::) :: Int -> IxR n -> IxR (S n) -deriving instance Show (IxR n) -deriving instance Eq (IxR n) -infixr 5 ::: - -zeroIxR :: SINat n -> IxR n +type IxR :: Type -> INat -> Type +data IxR i n where + IZR :: IxR i Z + (:::) :: forall n i. i -> IxR i n -> IxR i (S n) +deriving instance Show i => Show (IxR i n) +deriving instance Eq i => Eq (IxR i n) +infixr 3 ::: + +type IIxR = IxR Int + +zeroIxR :: SINat n -> IIxR n zeroIxR SZ = IZR zeroIxR (SS n) = 0 ::: zeroIxR n -ixCvtXR :: IxX sh -> IxR (X.Rank sh) +ixCvtXR :: IIxX sh -> IIxR (X.Rank sh) ixCvtXR IZX = IZR ixCvtXR (n ::@ idx) = n ::: ixCvtXR idx ixCvtXR (n ::? idx) = n ::: ixCvtXR idx -ixCvtRX :: IxR n -> IxX (Replicate n Nothing) +ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) ixCvtRX IZR = IZX ixCvtRX (n ::: idx) = n ::? ixCvtRX idx -knownIxR :: IxR n -> Dict KnownINat n +knownIxR :: IIxR n -> Dict KnownINat n knownIxR IZR = Dict knownIxR (_ ::: idx) | Dict <- knownIxR idx = Dict -shapeSizeR :: IxR n -> Int +shapeSizeR :: IIxR n -> Int shapeSizeR IZR = 1 shapeSizeR (n ::: sh) = n * shapeSizeR sh -rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IxR n +rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IIxR n rshape (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) , Refl <- lemRankReplicate (Proxy @n) = ixCvtXR (mshape arr) -rindex :: Elt a => Ranked n a -> IxR n -> a +rindex :: Elt a => Ranked n a -> IIxR n -> a rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) -rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IxR n -> Ranked m a +rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IIxR n -> Ranked m a rindexPartial (Ranked arr) idx = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) @@ -897,7 +899,7 @@ rindexPartial (Ranked arr) idx = -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. -rgenerate :: forall n a. Elt a => IxR n -> (IxR n -> a) -> Ranked n a +rgenerate :: forall n a. Elt a => IIxR n -> (IIxR n -> a) -> Ranked n a rgenerate sh f | Dict <- knownIxR sh , Dict <- lemKnownReplicate (Proxy @n) @@ -935,7 +937,7 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend rscalar :: Elt a => a -> Ranked I0 a rscalar x = Ranked (mscalar x) -rfromVector :: forall n a. (KnownINat n, Storable a) => IxR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVector :: forall n a. (KnownINat n, Storable a) => IIxR n -> VS.Vector a -> Ranked n (Primitive a) rfromVector sh v | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mfromVector (ixCvtRX sh) v) @@ -957,13 +959,13 @@ rtoList1 = map runScalar . rtoList runScalar :: Elt a => Ranked I0 a -> a runScalar arr = rindex arr IZR -rconstantP :: forall n a. (KnownINat n, Storable a) => IxR n -> a -> Ranked n (Primitive a) +rconstantP :: forall n a. (KnownINat n, Storable a) => IIxR n -> a -> Ranked n (Primitive a) rconstantP sh x | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mconstantP (ixCvtRX sh) x) rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a)) - => IxR n -> a -> Ranked n a + => IIxR n -> a -> Ranked n a rconstant sh x = coerce (rconstantP sh x) rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a @@ -1000,43 +1002,45 @@ deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped -- (traditionally called \"@Fin@\"). Note that because the shape of a -- shape-typed array is known statically, you can also retrieve the array shape -- from a 'KnownShape' dictionary. -type IxS :: [Nat] -> Type -data IxS sh where - IZS :: IxS '[] - (::$) :: Int -> IxS sh -> IxS (n : sh) -deriving instance Show (IxS n) -deriving instance Eq (IxS n) -infixr 5 ::$ - -zeroIxS :: SShape sh -> IxS sh +type IxS :: Type -> [Nat] -> Type +data IxS i sh where + IZS :: IxS i '[] + (::$) :: forall n sh i. i -> IxS i sh -> IxS i (n : sh) +deriving instance Show i => Show (IxS i n) +deriving instance Eq i => Eq (IxS i n) +infixr 3 ::$ + +type IIxS = IxS Int + +zeroIxS :: SShape sh -> IIxS sh zeroIxS ShNil = IZS zeroIxS (ShCons _ sh) = 0 ::$ zeroIxS sh -cvtSShapeIxS :: SShape sh -> IxS sh +cvtSShapeIxS :: SShape sh -> IIxS sh cvtSShapeIxS ShNil = IZS cvtSShapeIxS (ShCons n sh) = fromIntegral (fromSNat n) ::$ cvtSShapeIxS sh -ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh +ixCvtXS :: SShape sh -> IIxX (MapJust sh) -> IIxS sh ixCvtXS ShNil IZX = IZS ixCvtXS (ShCons _ sh) (n ::@ idx) = n ::$ ixCvtXS sh idx -ixCvtSX :: IxS sh -> IxX (MapJust sh) +ixCvtSX :: IIxS sh -> IIxX (MapJust sh) ixCvtSX IZS = IZX ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh -shapeSizeS :: IxS sh -> Int +shapeSizeS :: IIxS sh -> Int shapeSizeS IZS = 1 shapeSizeS (n ::$ sh) = n * shapeSizeS sh -- | This does not touch the passed array, all information comes from 'KnownShape'. -sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh +sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IIxS sh sshape _ = cvtSShapeIxS (knownShape @sh) -sindex :: Elt a => Shaped sh a -> IxS sh -> a +sindex :: Elt a => Shaped sh a -> IIxS sh -> a sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) -sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a +sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a sindexPartial (Shaped arr) idx = Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) @@ -1044,7 +1048,7 @@ sindexPartial (Shaped arr) idx = -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. -sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IxS sh -> a) -> Shaped sh a +sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IIxS sh -> a) -> Shaped sh a sgenerate f | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mgenerate (ixCvtSX (cvtSShapeIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh))) |