diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-05-13 22:47:42 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-13 22:47:42 +0200 | 
| commit | e4e23a33f77d250af1e9b6614cf249128ba1510a (patch) | |
| tree | 34bb40910003749becbaf8005a7b7ca62024fff2 /src/Data/Array | |
| parent | 7c9865354442326d55094087ad6a74b6e96341fb (diff) | |
Shape/index hygiene
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Mixed.hs | 238 | ||||
| -rw-r--r-- | src/Data/Array/Nested.hs | 2 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 239 | 
3 files changed, 252 insertions, 227 deletions
| diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 47027cb..94b7cdf 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -1,6 +1,10 @@  {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DerivingStrategies #-}  {-# LANGUAGE FlexibleInstances #-}  {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE PatternSynonyms #-}  {-# LANGUAGE ScopedTypeVariables #-} @@ -32,6 +36,9 @@ pattern GHC_SNat :: () => KnownNat n => SNat n  pattern GHC_SNat = SNat  {-# COMPLETE GHC_SNat #-} +fromSNat' :: SNat n -> Int +fromSNat' = fromIntegral . fromSNat +  -- | Type-level list append.  type family l1 ++ l2 where @@ -44,9 +51,6 @@ lemAppNil = unsafeCoerce Refl  lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)  lemAppAssoc _ _ _ = unsafeCoerce Refl --- TODO: ListX? But if so, why is StaticShapeX not defined as a newtype --- over IxX (so that we can make IxX and StaticShapeX a newtype over ListX)? -  type IxX :: [Maybe Nat] -> Type -> Type  data IxX sh i where    ZIX :: IxX '[] i @@ -55,31 +59,48 @@ data IxX sh i where  deriving instance Show i => Show (IxX sh i)  deriving instance Eq i => Eq (IxX sh i)  deriving instance Ord i => Ord (IxX sh i) +deriving instance Functor (IxX sh) +deriving instance Foldable (IxX sh)  infixr 3 :.@  infixr 3 :.?  type IIxX sh = IxX sh Int --- | The part of a shape that is statically known. -type StaticShapeX :: [Maybe Nat] -> Type -data StaticShapeX sh where -  ZSX :: StaticShapeX '[] -  (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) -  (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) -deriving instance Show (StaticShapeX sh) +type ShX :: [Maybe Nat] -> Type -> Type +data ShX sh i where +  ZSX :: ShX '[] i +  (:$@) :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i +  (:$?) :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i +deriving instance Show i => Show (ShX sh i) +deriving instance Eq i => Eq (ShX sh i) +deriving instance Ord i => Ord (ShX sh i) +deriving instance Functor (ShX sh) +deriving instance Foldable (ShX sh)  infixr 3 :$@  infixr 3 :$? +type IShX sh = ShX sh Int + +-- | The part of a shape that is statically known. +type StaticShX :: [Maybe Nat] -> Type +data StaticShX sh where +  ZKSX :: StaticShX '[] +  (:!$@) :: SNat n -> StaticShX sh -> StaticShX (Just n : sh) +  (:!$?) :: () -> StaticShX sh -> StaticShX (Nothing : sh) +deriving instance Show (StaticShX sh) +infixr 3 :!$@ +infixr 3 :!$? +  -- | Evidence for the static part of a shape.  type KnownShapeX :: [Maybe Nat] -> Constraint  class KnownShapeX sh where -  knownShapeX :: StaticShapeX sh +  knownShapeX :: StaticShX sh  instance KnownShapeX '[] where -  knownShapeX = ZSX +  knownShapeX = ZKSX  instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where -  knownShapeX = natSing :$@ knownShapeX +  knownShapeX = natSing :!$@ knownShapeX  instance KnownShapeX sh => KnownShapeX (Nothing : sh) where -  knownShapeX = () :$? knownShapeX +  knownShapeX = () :!$? knownShapeX  type family Rank sh where    Rank '[] = Z @@ -89,138 +110,149 @@ type XArray :: [Maybe Nat] -> Type -> Type  newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a)    deriving (Show) -zeroIxX :: StaticShapeX sh -> IIxX sh -zeroIxX ZSX = ZIX -zeroIxX (_ :$@ ssh) = 0 :.@ zeroIxX ssh -zeroIxX (_ :$? ssh) = 0 :.? zeroIxX ssh +zeroIxX :: StaticShX sh -> IIxX sh +zeroIxX ZKSX = ZIX +zeroIxX (_ :!$@ ssh) = 0 :.@ zeroIxX ssh +zeroIxX (_ :!$? ssh) = 0 :.? zeroIxX ssh + +zeroIxX' :: IShX sh -> IIxX sh +zeroIxX' ZSX = ZIX +zeroIxX' (_ :$@ sh) = 0 :.@ zeroIxX' sh +zeroIxX' (_ :$? sh) = 0 :.? zeroIxX' sh -zeroIxX' :: IIxX sh -> IIxX sh -zeroIxX' ZIX = ZIX -zeroIxX' (_ :.@ sh) = 0 :.@ zeroIxX' sh -zeroIxX' (_ :.? sh) = 0 :.? zeroIxX' sh +-- This is a weird operation, so it has a long name +completeShXzeros :: StaticShX sh -> IShX sh +completeShXzeros ZKSX = ZSX +completeShXzeros (n :!$@ ssh) = n :$@ completeShXzeros ssh +completeShXzeros (_ :!$? ssh) = 0 :$? completeShXzeros ssh  ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh')  ixAppend ZIX idx' = idx'  ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx'  ixAppend (i :.? idx) idx' = i :.? ixAppend idx idx' +shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh') +shAppend ZSX sh' = sh' +shAppend (n :$@ sh) sh' = n :$@ shAppend sh sh' +shAppend (n :$? sh) sh' = n :$? shAppend sh sh' +  ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh'  ixDrop sh ZIX = sh  ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx  ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx -ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh') -ssxAppend ZSX sh' = sh' -ssxAppend (n :$@ sh) sh' = n :$@ ssxAppend sh sh' -ssxAppend (() :$? sh) sh' = () :$? ssxAppend sh sh' +ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') +ssxAppend ZKSX sh' = sh' +ssxAppend (n :!$@ sh) sh' = n :!$@ ssxAppend sh sh' +ssxAppend (() :!$? sh) sh' = () :!$? ssxAppend sh sh' -shapeSize :: IIxX sh -> Int -shapeSize ZIX = 1 -shapeSize (n :.@ sh) = n * shapeSize sh -shapeSize (n :.? sh) = n * shapeSize sh +shapeSize :: IShX sh -> Int +shapeSize ZSX = 1 +shapeSize (n :$@ sh) = fromSNat' n * shapeSize sh +shapeSize (n :$? sh) = n * shapeSize sh  -- | This may fail if @sh@ has @Nothing@s in it. -ssxToShape' :: StaticShapeX sh -> Maybe (IIxX sh) -ssxToShape' ZSX = Just ZIX -ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) :.@) <$> ssxToShape' sh -ssxToShape' (_ :$? _) = Nothing +ssxToShape' :: StaticShX sh -> Maybe (IShX sh) +ssxToShape' ZKSX = Just ZSX +ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh +ssxToShape' (_ :!$? _) = Nothing -fromLinearIdx :: IIxX sh -> Int -> IIxX sh +fromLinearIdx :: IShX sh -> Int -> IIxX sh  fromLinearIdx = \sh i -> case go sh i of    (idx, 0) -> idx    _ -> error $ "fromLinearIdx: out of range (" ++ show i ++                 " in array of shape " ++ show sh ++ ")"    where      -- returns (index in subarray, remaining index in enclosing array) -    go :: IIxX sh -> Int -> (IIxX sh, Int) -    go ZIX i = (ZIX, i) -    go (n :.@ sh) i = +    go :: IShX sh -> Int -> (IIxX sh, Int) +    go ZSX i = (ZIX, i) +    go (n :$@ sh) i =        let (idx, i') = go sh i -          (upi, locali) = i' `quotRem` n +          (upi, locali) = i' `quotRem` fromSNat' n        in (locali :.@ idx, upi) -    go (n :.? sh) i = +    go (n :$? sh) i =        let (idx, i') = go sh i            (upi, locali) = i' `quotRem` n        in (locali :.? idx, upi) -toLinearIdx :: IIxX sh -> IIxX sh -> Int +toLinearIdx :: IShX sh -> IIxX sh -> Int  toLinearIdx = \sh i -> fst (go sh i)    where      -- returns (index in subarray, size of subarray) -    go :: IIxX sh -> IIxX sh -> (Int, Int) -    go ZIX ZIX = (0, 1) -    go (n :.@ sh) (i :.@ ix) = +    go :: IShX sh -> IIxX sh -> (Int, Int) +    go ZSX ZIX = (0, 1) +    go (n :$@ sh) (i :.@ ix) =        let (lidx, sz) = go sh ix -      in (sz * i + lidx, n * sz) -    go (n :.? sh) (i :.? ix) = +      in (sz * i + lidx, fromSNat' n * sz) +    go (n :$? sh) (i :.? ix) =        let (lidx, sz) = go sh ix        in (sz * i + lidx, n * sz) -enumShape :: IIxX sh -> [IIxX sh] +enumShape :: IShX sh -> [IIxX sh]  enumShape = \sh -> go sh id []    where -    go :: IIxX sh -> (IIxX sh -> a) -> [a] -> [a] -    go ZIX f = (f ZIX :) -    go (n :.@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. n-1]] -    go (n :.? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]] +    go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] +    go ZSX f = (f ZIX :) +    go (n :$@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. fromSNat' n - 1]] +    go (n :$? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]] -shapeLshape :: IIxX sh -> S.ShapeL -shapeLshape ZIX = [] -shapeLshape (n :.@ sh) = n : shapeLshape sh -shapeLshape (n :.? sh) = n : shapeLshape sh +shapeLshape :: IShX sh -> S.ShapeL +shapeLshape ZSX = [] +shapeLshape (n :$@ sh) = fromSNat' n : shapeLshape sh +shapeLshape (n :$? sh) = n : shapeLshape sh -ssxLength :: StaticShapeX sh -> Int -ssxLength ZSX = 0 -ssxLength (_ :$@ ssh) = 1 + ssxLength ssh -ssxLength (_ :$? ssh) = 1 + ssxLength ssh +ssxLength :: StaticShX sh -> Int +ssxLength ZKSX = 0 +ssxLength (_ :!$@ ssh) = 1 + ssxLength ssh +ssxLength (_ :!$? ssh) = 1 + ssxLength ssh -ssxIotaFrom :: Int -> StaticShapeX sh -> [Int] -ssxIotaFrom _ ZSX = [] -ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh -ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom :: Int -> StaticShX sh -> [Int] +ssxIotaFrom _ ZKSX = [] +ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh -lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2 +lemRankApp :: StaticShX sh1 -> StaticShX sh2             -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2)  lemRankApp _ _ = unsafeCoerce Refl  -- TODO improve this -lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2 +lemRankAppComm :: StaticShX sh1 -> StaticShX sh2                 -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1))  lemRankAppComm _ _ = unsafeCoerce Refl  -- TODO improve this -lemKnownINatRank :: IIxX sh -> Dict KnownINat (Rank sh) -lemKnownINatRank ZIX = Dict -lemKnownINatRank (_ :.@ sh) | Dict <- lemKnownINatRank sh = Dict -lemKnownINatRank (_ :.? sh) | Dict <- lemKnownINatRank sh = Dict +lemKnownINatRank :: IShX sh -> Dict KnownINat (Rank sh) +lemKnownINatRank ZSX = Dict +lemKnownINatRank (_ :$@ sh) | Dict <- lemKnownINatRank sh = Dict +lemKnownINatRank (_ :$? sh) | Dict <- lemKnownINatRank sh = Dict -lemKnownINatRankSSX :: StaticShapeX sh -> Dict KnownINat (Rank sh) -lemKnownINatRankSSX ZSX = Dict -lemKnownINatRankSSX (_ :$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict -lemKnownINatRankSSX (_ :$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict +lemKnownINatRankSSX :: StaticShX sh -> Dict KnownINat (Rank sh) +lemKnownINatRankSSX ZKSX = Dict +lemKnownINatRankSSX (_ :!$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict +lemKnownINatRankSSX (_ :!$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict -lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh -lemKnownShapeX ZSX = Dict -lemKnownShapeX (GHC_SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict -lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict +lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh +lemKnownShapeX ZKSX = Dict +lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict +lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict -lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2) -lemAppKnownShapeX ZSX ssh' = lemKnownShapeX ssh' -lemAppKnownShapeX (GHC_SNat :$@ ssh) ssh' +lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2) +lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh' +lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh'    | Dict <- lemAppKnownShapeX ssh ssh'    = Dict -lemAppKnownShapeX (() :$? ssh) ssh' +lemAppKnownShapeX (() :!$? ssh) ssh'    | Dict <- lemAppKnownShapeX ssh ssh'    = Dict -shape :: forall sh a. KnownShapeX sh => XArray sh a -> IIxX sh +shape :: forall sh a. KnownShapeX sh => XArray sh a -> IShX sh  shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)    where -    go :: StaticShapeX sh' -> [Int] -> IIxX sh' -    go ZSX [] = ZIX -    go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) :.@ go ssh l -    go (() :$? ssh) (n : l) = n :.? go ssh l +    go :: StaticShX sh' -> [Int] -> IShX sh' +    go ZKSX [] = ZSX +    go (n :!$@ ssh) (_ : l) = n :$@ go ssh l +    go (() :!$? ssh) (n : l) = n :$? go ssh l      go _ _ = error "Invalid shapeL" -fromVector :: forall sh a. Storable a => IIxX sh -> VS.Vector a -> XArray sh a +fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a  fromVector sh v    | Dict <- lemKnownINatRank sh    , Dict <- knownNatFromINat (Proxy @(Rank sh)) @@ -235,16 +267,16 @@ scalar = XArray . S.scalar  unScalar :: Storable a => XArray '[] a -> a  unScalar (XArray a) = S.unScalar a -constant :: forall sh a. Storable a => IIxX sh -> a -> XArray sh a +constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a  constant sh x    | Dict <- lemKnownINatRank sh    , Dict <- knownNatFromINat (Proxy @(Rank sh))    = XArray (S.constant (shapeLshape sh) x) -generate :: Storable a => IIxX sh -> (IIxX sh -> a) -> XArray sh a +generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a  generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) --- generateM :: (Monad m, Storable a) => IIxX sh -> (IIxX sh -> m a) -> m (XArray sh a) +-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a)  -- generateM sh f | Dict <- lemKnownINatRank sh =  --   XArray . S.fromVector (shapeLshape sh)  --     <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) @@ -274,7 +306,7 @@ append (XArray a) (XArray b)  rerank :: forall sh sh1 sh2 a b.            (Storable a, Storable b) -       => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 +       => StaticShX sh -> StaticShX sh1 -> StaticShX sh2         -> (XArray sh1 a -> XArray sh2 b)         -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b  rerank ssh ssh1 ssh2 f (XArray arr) @@ -294,14 +326,14 @@ rerank ssh ssh1 ssh2 f (XArray arr)  rerankTop :: forall sh sh1 sh2 a b.               (Storable a, Storable b) -          => StaticShapeX sh1 -> StaticShapeX sh2 -> StaticShapeX sh +          => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh            -> (XArray sh1 a -> XArray sh2 b)            -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b  rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh  rerank2 :: forall sh sh1 sh2 a b c.             (Storable a, Storable b, Storable c) -        => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 +        => StaticShX sh -> StaticShX sh1 -> StaticShX sh2          -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)          -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c  rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) @@ -327,7 +359,7 @@ transpose perm (XArray arr)    = XArray (S.transpose perm arr)  transpose2 :: forall sh1 sh2 a. -              StaticShapeX sh1 -> StaticShapeX sh2 +              StaticShX sh1 -> StaticShX sh2             -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a  transpose2 ssh1 ssh2 (XArray arr)    | Refl <- lemRankApp ssh1 ssh2 @@ -344,24 +376,24 @@ sumFull :: (Storable a, Num a) => XArray sh a -> a  sumFull (XArray arr) = S.sumA arr  sumInner :: forall sh sh' a. (Storable a, Num a) -         => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a +         => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a  sumInner ssh ssh'    | Refl <- lemAppNil @sh -  = rerank ssh ssh' ZSX (scalar . sumFull) +  = rerank ssh ssh' ZKSX (scalar . sumFull)  sumOuter :: forall sh sh' a. (Storable a, Num a) -         => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a +         => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a  sumOuter ssh ssh'    | Refl <- lemAppNil @sh    = sumInner ssh' ssh . transpose2 ssh ssh'  fromList :: forall n sh a. Storable a -         => StaticShapeX (n : sh) -> [XArray sh a] -> XArray (n : sh) a +         => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a  fromList ssh l    | Dict <- lemKnownINatRankSSX ssh    , Dict <- knownNatFromINat (Proxy @(Rank (n : sh)))    = case ssh of -      m@GHC_SNat :$@ _ | natVal m /= fromIntegral (length l) -> +      m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) ->          error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++                  "does not match the type (" ++ show (natVal m) ++ ")"        _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l))) @@ -369,5 +401,13 @@ fromList ssh l  toList :: Storable a => XArray (n : sh) a -> [XArray sh a]  toList (XArray arr) = coerce (ORB.toList (S.unravel arr)) +-- | Throws if the given shape is not, in fact, empty. +empty :: forall sh a. Storable a => IShX sh -> XArray sh a +empty sh +  | Dict <- lemKnownINatRank sh +  , Dict <- knownNatFromINat (Proxy @(Rank sh)) +  = XArray (S.constant (shapeLshape sh) +                       (error "Data.Array.Mixed.empty: shape was not empty")) +  slice :: [(Int, Int)] -> XArray sh a -> XArray sh a  slice ivs (XArray arr) = XArray (S.slice ivs arr) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 370e30a..145598e 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -28,7 +28,7 @@ module Data.Array.Nested (    -- * Mixed arrays    Mixed,    IxX(..), IIxX, -  KnownShapeX(..), StaticShapeX(..), +  KnownShapeX(..), StaticShX(..),    mgenerate, mtranspose, mappend, mfromVector, munScalar,    mconstant, mfromList1, mtoList1, mslice, diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 49ed7cb..2f1e79e 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -8,12 +8,12 @@  {-# LANGUAGE InstanceSigs #-}  {-# LANGUAGE PatternSynonyms #-}  {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE RoleAnnotations #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeAbstractions #-}  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-} @@ -23,7 +23,7 @@  {-|  TODO: -* We should be more consistent in whether functions take a 'StaticShapeX' +* We should be more consistent in whether functions take a 'StaticShX'    argument or a 'KnownShapeX' constraint.  * Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point @@ -51,7 +51,7 @@ import qualified Data.Vector.Storable.Mutable as VSM  import Foreign.Storable (Storable)  import GHC.TypeLits -import Data.Array.Mixed (XArray, IxX(..), IIxX, KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat) +import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat)  import qualified Data.Array.Mixed as X  import Data.INat @@ -100,9 +100,9 @@ type family MapJust l where  lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)  lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n))    where -    go :: SINat m -> StaticShapeX (Replicate m Nothing) -    go SZ = ZSX -    go (SS n) = () :$? go n +    go :: SINat m -> StaticShX (Replicate m Nothing) +    go SZ = ZKSX +    go (SS n) = () :!$? go n  lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n  lemRankReplicate _ = go (inatSing @n) @@ -119,10 +119,10 @@ lemReplicatePlusApp _ _ _ = go (inatSing @n)      go SZ = Refl      go (SS n) | Refl <- go n = Refl -ixAppSplit :: Proxy sh' -> StaticShapeX sh -> IIxX (sh ++ sh') -> (IIxX sh, IIxX sh') -ixAppSplit _ ZSX idx = (ZIX, idx) -ixAppSplit p (_ :$@ ssh) (i :.@ idx) = first (i :.@) (ixAppSplit p ssh idx) -ixAppSplit p (_ :$? ssh) (i :.? idx) = first (i :.?) (ixAppSplit p ssh idx) +shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') +shAppSplit _ ZKSX idx = (ZSX, idx) +shAppSplit p (_ :!$@ ssh) (i :$@ idx) = first (i :$@) (shAppSplit p ssh idx) +shAppSplit p (_ :!$? ssh) (i :$? idx) = first (i :$?) (shAppSplit p ssh idx)  -- | Wrapper type used as a tag to attach instances on. The instances on arrays @@ -184,7 +184,7 @@ newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ())  -- no content, MV  data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b)  -- etc. -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IIxX sh2) !(MixedVecs s (sh1 ++ sh2) a) +data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a)  -- | Tree giving the shape of every array component. @@ -196,9 +196,9 @@ type family ShapeTree a where    ShapeTree () = ()    ShapeTree (a, b) = (ShapeTree a, ShapeTree b) -  ShapeTree (Mixed sh a) = (IIxX sh, ShapeTree a) -  ShapeTree (Ranked n a) = (IIxR n, ShapeTree a) -  ShapeTree (Shaped sh a) = (IIxS sh, ShapeTree a) +  ShapeTree (Mixed sh a) = (IShX sh, ShapeTree a) +  ShapeTree (Ranked n a) = (IShR n, ShapeTree a) +  ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)  -- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or @@ -207,7 +207,7 @@ type family ShapeTree a where  class Elt a where    -- ====== PUBLIC METHODS ====== -- -  mshape :: KnownShapeX sh => Mixed sh a -> IIxX sh +  mshape :: KnownShapeX sh => Mixed sh a -> IShX sh    mindex :: Mixed sh a -> IIxX sh -> a    mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a    mscalar :: a -> Mixed '[] a @@ -240,15 +240,12 @@ class Elt a where           -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a    -- ====== PRIVATE METHODS ====== -- -  -- Remember I said that this module needed better management of exports?    -- | Create an empty array. The given shape must have size zero; this may or may not be checked. -  memptyArray :: IIxX sh -> Mixed sh a +  memptyArray :: IShX sh -> Mixed sh a    mshapeTree :: a -> ShapeTree a -  mshapeTreeZero :: Proxy a -> ShapeTree a -    mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool    mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool @@ -257,20 +254,20 @@ class Elt a where    -- | Create uninitialised vectors for this array type, given the shape of    -- this vector and an example for the contents. -  mvecsUnsafeNew :: IIxX sh -> a -> ST s (MixedVecs s sh a) +  mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)    mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)    -- | Given the shape of this array, an index and a value, write the value at    -- that index in the vectors. -  mvecsWrite :: IIxX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () +  mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()    -- | Given the shape of this array, an index and a value, write the value at    -- that index in the vectors. -  mvecsWritePartial :: KnownShapeX sh' => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () +  mvecsWritePartial :: KnownShapeX sh' => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()    -- | Given the shape of this array, finalise the vectors into 'XArray's. -  mvecsFreeze :: IIxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) +  mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)  -- Arrays of scalars are basically just arrays of scalars. @@ -299,9 +296,8 @@ instance Storable a => Elt (Primitive a) where      , Refl <- X.lemAppNil @sh3      = M_Primitive (f Proxy a b) -  memptyArray sh = M_Primitive (X.generate sh (error $ "memptyArray Int: shape was not empty (" ++ show sh ++ ")")) +  memptyArray sh = M_Primitive (X.empty sh)    mshapeTree _ = () -  mshapeTreeZero _ = ()    mshapeTreeEq _ () () = True    mshapeTreeEmpty _ () = False    mshowShapeTree _ () = "()" @@ -312,7 +308,7 @@ instance Storable a => Elt (Primitive a) where    -- TODO: this use of toVector is suboptimal    mvecsWritePartial      :: forall sh' sh s. KnownShapeX sh' -    => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () +    => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()    mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do      let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' (X.shape arr)))      VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) @@ -338,7 +334,6 @@ instance (Elt a, Elt b) => Elt (a, b) where    memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)    mshapeTree (x, y) = (mshapeTree x, mshapeTree y) -  mshapeTreeZero _ = (mshapeTreeZero (Proxy @a), mshapeTreeZero (Proxy @b))    mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'    mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2    mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" @@ -357,10 +352,10 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where    -- TODO: this is quadratic in the nesting depth because it repeatedly    -- truncates the shape vector to one a little shorter. Fix with a    -- moverlongShape method, a prefix of which is mshape. -  mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IIxX sh +  mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IShX sh    mshape (M_Nest arr)      | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') -    = fst (ixAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr)) +    = fst (shAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr))    mindex (M_Nest arr) i = mindexPartial arr i @@ -410,12 +405,10 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where          , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT))          = f (Proxy @(sh' ++ shT)) -  memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIxX (knownShapeX @sh')))) +  memptyArray sh = M_Nest (memptyArray (X.shAppend sh (X.completeShXzeros (knownShapeX @sh'))))    mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (knownShapeX @sh')))) -  mshapeTreeZero _ = (X.zeroIxX (knownShapeX @sh'), mshapeTreeZero (Proxy @a)) -    mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2    mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t @@ -424,32 +417,26 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where    mvecsUnsafeNew sh example      | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) -    | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) +    | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh (mshape example))                                                   (mindex example (X.zeroIxX (knownShapeX @sh')))      where        sh' = mshape example -  mvecsNewEmpty _ = MV_Nest (X.zeroIxX (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a) +  mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a) -  mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs +  mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs    mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 -                    => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) +                    => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)                      -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)                      -> ST s ()    mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)      | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))      , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') -    = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs - -  mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs +    = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.shAppend sh12 sh') idx arr vecs +  mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.shAppend sh sh') vecs --- | Check whether a given shape corresponds on the statically-known components of the shape. -checkBounds :: IIxX sh' -> StaticShapeX sh' -> Bool -checkBounds ZIX ZSX = True -checkBounds (n :.@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh' -checkBounds (_ :.? sh') (() :$? ssh') = checkBounds sh' ssh'  -- | Create an array given a size and a function that computes the element at a  -- given index. @@ -468,31 +455,25 @@ checkBounds (_ :.? sh') (() :$? ssh') = checkBounds sh' ssh'  -- the entire hierarchy (after distributing out tuples) must be a rectangular  -- array. The type of 'mgenerate' allows this requirement to be broken very  -- easily, hence the runtime check. -mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IIxX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f -  -- TODO: Do we need this checkBounds check elsewhere as well? -  | not (checkBounds sh (knownShapeX @sh)) = -      error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) -  -- If the shape is empty, there is no first element, so we should not try to -  -- generate it. -  | X.shapeSize sh == 0 = memptyArray sh -  | otherwise = -      let firstidx = X.zeroIxX' sh -          firstelem = f (X.zeroIxX' sh) -          shapetree = mshapeTree firstelem -      in if mshapeTreeEmpty (Proxy @a) shapetree -           then memptyArray sh -           else runST $ do -                  vecs <- mvecsUnsafeNew sh firstelem -                  mvecsWrite sh firstidx firstelem vecs -                  -- TODO: This is likely fine if @a@ is big, but if @a@ is a -                  -- scalar this array copying inefficient. Should improve this. -                  forM_ (tail (X.enumShape sh)) $ \idx -> do -                    let val = f idx -                    when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ -                      error "Data.Array.Nested mgenerate: generated values do not have equal shapes" -                    mvecsWrite sh idx val vecs -                  mvecsFreeze sh vecs +mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgenerate sh f = case X.enumShape sh of +  [] -> memptyArray sh +  firstidx : restidxs -> +    let firstelem = f (X.zeroIxX' sh) +        shapetree = mshapeTree firstelem +    in if mshapeTreeEmpty (Proxy @a) shapetree +         then memptyArray sh +         else runST $ do +                vecs <- mvecsUnsafeNew sh firstelem +                mvecsWrite sh firstidx firstelem vecs +                -- TODO: This is likely fine if @a@ is big, but if @a@ is a +                -- scalar this array copying inefficient. Should improve this. +                forM_ restidxs $ \idx -> do +                  let val = f idx +                  when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ +                    error "Data.Array.Nested mgenerate: generated values do not have equal shapes" +                  mvecsWrite sh idx val vecs +                mvecsFreeze sh vecs  mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a  mtranspose perm = @@ -506,12 +487,8 @@ mappend = mlift2 go             => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b          go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append -mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> VS.Vector a -> Mixed sh (Primitive a) -mfromVector sh v -  | not (checkBounds sh (knownShapeX @sh)) = -      error $ "mfromVector: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) -  | otherwise = -      M_Primitive (X.fromVector sh v) +mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVector sh v = M_Primitive (X.fromVector sh v)  mfromList1 :: (KnownShapeX '[n], Elt a) => NonEmpty a -> Mixed '[n] a  mfromList1 = mfromList . fmap mscalar @@ -522,17 +499,13 @@ mtoList1 = map munScalar . mtoList  munScalar :: Elt a => Mixed '[] a -> a  munScalar arr = mindex arr ZIX -mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> a -> Mixed sh (Primitive a) -mconstantP sh x -  | not (checkBounds sh (knownShapeX @sh)) = -      error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) -  | otherwise = -      M_Primitive (X.constant sh x) +mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> a -> Mixed sh (Primitive a) +mconstantP sh x = M_Primitive (X.constant sh x)  -- | This 'Coercible' constraint holds for the scalar types for which 'Mixed'  -- is defined.  mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a)) -          => IIxX sh -> a -> Mixed sh a +          => IShX sh -> a -> Mixed sh a  mconstant sh x = coerce (mconstantP sh x)  mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a @@ -648,7 +621,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where      = coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $          mlift2 f arr1 arr2 -  memptyArray :: forall sh. IIxX sh -> Mixed sh (Ranked n a) +  memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a)    memptyArray i      | Dict <- lemKnownReplicate (Proxy @n)      = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ @@ -657,9 +630,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where    mshapeTree (Ranked arr)      | Refl <- lemRankReplicate (Proxy @n)      , Dict <- lemKnownReplicate (Proxy @n) -    = first ixCvtXR (mshapeTree arr) - -  mshapeTreeZero _ = (zeroIxR (inatSing @n), mshapeTreeZero (Proxy @a)) +    = first shCvtXR (mshapeTree arr)    mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -675,7 +646,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where      | Dict <- lemKnownReplicate (Proxy @n)      = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) -  mvecsWrite :: forall sh s. IIxX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () +  mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()    mvecsWrite sh idx (Ranked arr) vecs      | Dict <- lemKnownReplicate (Proxy @n)      = mvecsWrite sh idx arr @@ -683,7 +654,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where             vecs)    mvecsWritePartial :: forall sh sh' s. KnownShapeX sh' -                    => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) +                    => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)                      -> MixedVecs s (sh ++ sh') (Ranked n a)                      -> ST s ()    mvecsWritePartial sh idx arr vecs @@ -696,7 +667,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where                  @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))             vecs) -  mvecsFreeze :: forall sh s. IIxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) +  mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))    mvecsFreeze sh vecs      | Dict <- lemKnownReplicate (Proxy @n)      = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @@ -712,6 +683,8 @@ data ShS sh where    ZSS :: ShS '[]    (:$$) :: forall n sh. SNat n -> ShS sh -> ShS (n : sh)  deriving instance Show (ShS sh) +deriving instance Eq (ShS sh) +deriving instance Ord (ShS sh)  infixr 3 :$$  -- | A statically-known shape of a shape-typed array. @@ -726,9 +699,9 @@ sshapeKnown (GHC_SNat :$$ sh) | Dict <- sshapeKnown sh = Dict  lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh)  lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh))    where -    go :: ShS sh' -> StaticShapeX (MapJust sh') -    go ZSS = ZSX -    go (n :$$ sh) = n :$@ go sh +    go :: ShS sh' -> StaticShX (MapJust sh') +    go ZSS = ZKSX +    go (n :$$ sh) = n :!$@ go sh  lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2                    -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 @@ -777,7 +750,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where      = coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $          mlift2 f arr1 arr2 -  memptyArray :: forall sh'. IIxX sh' -> Mixed sh' (Shaped sh a) +  memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)    memptyArray i      | Dict <- lemKnownMapJust (Proxy @sh)      = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ @@ -785,9 +758,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where    mshapeTree (Shaped arr)      | Dict <- lemKnownMapJust (Proxy @sh) -    = first (ixCvtXS (knownShape @sh)) (mshapeTree arr) - -  mshapeTreeZero _ = (zeroIxS (knownShape @sh), mshapeTreeZero (Proxy @a)) +    = first (shCvtXS (knownShape @sh)) (mshapeTree arr)    mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -803,7 +774,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where      | Dict <- lemKnownMapJust (Proxy @sh)      = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) -  mvecsWrite :: forall sh' s. IIxX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () +  mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()    mvecsWrite sh idx (Shaped arr) vecs      | Dict <- lemKnownMapJust (Proxy @sh)      = mvecsWrite sh idx arr @@ -811,7 +782,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where             vecs)    mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 -                    => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) +                    => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)                      -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)                      -> ST s ()    mvecsWritePartial sh idx arr vecs @@ -824,7 +795,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where                  @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))             vecs) -  mvecsFreeze :: forall sh' s. IIxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) +  mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))    mvecsFreeze sh vecs      | Dict <- lemKnownMapJust (Proxy @sh)      = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @@ -927,6 +898,8 @@ newtype ShR n i = ShR (ListR n i)    deriving (Show, Eq, Ord)    deriving newtype (Functor, Foldable) +type IShR n = ShR n Int +  pattern ZSR :: forall n i. () => n ~ Z => ShR n i  pattern ZSR = ShR ZR @@ -957,20 +930,29 @@ ixCvtXR ZIX = ZIR  ixCvtXR (n :.@ idx) = n :.: ixCvtXR idx  ixCvtXR (n :.? idx) = n :.: ixCvtXR idx +shCvtXR :: IShX sh -> IShR (X.Rank sh) +shCvtXR ZSX = ZSR +shCvtXR (n :$@ idx) = X.fromSNat' n :$: shCvtXR idx +shCvtXR (n :$? idx) = n :$: shCvtXR idx +  ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)  ixCvtRX ZIR = ZIX  ixCvtRX (n :.: idx) = n :.? ixCvtRX idx -shapeSizeR :: IIxR n -> Int -shapeSizeR ZIR = 1 -shapeSizeR (n :.: sh) = n * shapeSizeR sh +shCvtRX :: IShR n -> IShX (Replicate n Nothing) +shCvtRX ZSR = ZSX +shCvtRX (n :$: idx) = n :$? shCvtRX idx + +shapeSizeR :: IShR n -> Int +shapeSizeR ZSR = 1 +shapeSizeR (n :$: sh) = n * shapeSizeR sh -rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IIxR n +rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IShR n  rshape (Ranked arr)    | Dict <- lemKnownReplicate (Proxy @n)    , Refl <- lemRankReplicate (Proxy @n) -  = ixCvtXR (mshape arr) +  = shCvtXR (mshape arr)  rindex :: Elt a => Ranked n a -> IIxR n -> a  rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) @@ -983,12 +965,12 @@ rindexPartial (Ranked arr) idx =  -- | __WARNING__: All values returned from the function must have equal shape.  -- See the documentation of 'mgenerate' for more details. -rgenerate :: forall n a. Elt a => IIxR n -> (IIxR n -> a) -> Ranked n a +rgenerate :: forall n a. Elt a => IShR n -> (IIxR n -> a) -> Ranked n a  rgenerate sh f -  | Dict <- knownIxR sh +  | Dict <- knownShR sh    , Dict <- lemKnownReplicate (Proxy @n)    , Refl <- lemRankReplicate (Proxy @n) -  = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) +  = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))  -- | See the documentation of 'mlift'.  rlift :: forall n1 n2 a. (KnownINat n2, Elt a) @@ -1005,7 +987,7 @@ rsumOuter1 (Ranked arr)    | Dict <- lemKnownReplicate (Proxy @n)    = Ranked      . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a)) -    . X.sumOuter (() :$? ZSX) (knownShapeX @(Replicate n Nothing)) +    . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing))      . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a)      $ arr @@ -1021,10 +1003,10 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend  rscalar :: Elt a => a -> Ranked I0 a  rscalar x = Ranked (mscalar x) -rfromVector :: forall n a. (KnownINat n, Storable a) => IIxR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVector :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)  rfromVector sh v    | Dict <- lemKnownReplicate (Proxy @n) -  = Ranked (mfromVector (ixCvtRX sh) v) +  = Ranked (mfromVector (shCvtRX sh) v)  rfromList :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a  rfromList l @@ -1043,13 +1025,13 @@ rtoList1 = map runScalar . rtoList  runScalar :: Elt a => Ranked I0 a -> a  runScalar arr = rindex arr ZIR -rconstantP :: forall n a. (KnownINat n, Storable a) => IIxR n -> a -> Ranked n (Primitive a) +rconstantP :: forall n a. (KnownINat n, Storable a) => IShR n -> a -> Ranked n (Primitive a)  rconstantP sh x    | Dict <- lemKnownReplicate (Proxy @n) -  = Ranked (mconstantP (ixCvtRX sh) x) +  = Ranked (mconstantP (shCvtRX sh) x)  rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a)) -          => IIxR n -> a -> Ranked n a +          => IShR n -> a -> Ranked n a  rconstant sh x = coerce (rconstantP sh x)  rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a @@ -1141,27 +1123,30 @@ zeroIxS :: ShS sh -> IIxS sh  zeroIxS ZSS = ZIS  zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh --- TODO: this function should not exist perhaps -cvtShSIxS :: ShS sh -> IIxS sh -cvtShSIxS ZSS = ZIS -cvtShSIxS (n :$$ sh) = fromIntegral (fromSNat n) :.$ cvtShSIxS sh -  ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh  ixCvtXS ZSS ZIX = ZIS  ixCvtXS (_ :$$ sh) (n :.@ idx) = n :.$ ixCvtXS sh idx +shCvtXS :: ShS sh -> IShX (MapJust sh) -> ShS sh +shCvtXS ZSS ZSX = ZSS +shCvtXS (_ :$$ sh) (n :$@ idx) = n :$$ shCvtXS sh idx +  ixCvtSX :: IIxS sh -> IIxX (MapJust sh)  ixCvtSX ZIS = ZIX  ixCvtSX (n :.$ sh) = n :.@ ixCvtSX sh -shapeSizeS :: IIxS sh -> Int -shapeSizeS ZIS = 1 -shapeSizeS (n :.$ sh) = n * shapeSizeS sh +shCvtSX :: ShS sh -> IShX (MapJust sh) +shCvtSX ZSS = ZSX +shCvtSX (n :$$ sh) = n :$@ shCvtSX sh + +shapeSizeS :: ShS sh -> Int +shapeSizeS ZSS = 1 +shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh  -- | This does not touch the passed array, all information comes from 'KnownShape'. -sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IIxS sh -sshape _ = cvtShSIxS (knownShape @sh) +sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> ShS sh +sshape _ = knownShape @sh  sindex :: Elt a => Shaped sh a -> IIxS sh -> a  sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) @@ -1177,7 +1162,7 @@ sindexPartial (Shaped arr) idx =  sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IIxS sh -> a) -> Shaped sh a  sgenerate f    | Dict <- lemKnownMapJust (Proxy @sh) -  = Shaped (mgenerate (ixCvtSX (cvtShSIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh))) +  = Shaped (mgenerate (shCvtSX (knownShape @sh)) (f . ixCvtXS (knownShape @sh)))  -- | See the documentation of 'mlift'.  slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) @@ -1194,7 +1179,7 @@ ssumOuter1 (Shaped arr)    | Dict <- lemKnownMapJust (Proxy @sh)    = Shaped      . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a)) -    . X.sumOuter (natSing @n :$@ ZSX) (knownShapeX @(MapJust sh)) +    . X.sumOuter (natSing @n :!$@ ZKSX) (knownShapeX @(MapJust sh))      . coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a)      $ arr @@ -1213,7 +1198,7 @@ sscalar x = Shaped (mscalar x)  sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a)  sfromVector v    | Dict <- lemKnownMapJust (Proxy @sh) -  = Shaped (mfromVector (ixCvtSX (cvtShSIxS (knownShape @sh))) v) +  = Shaped (mfromVector (shCvtSX (knownShape @sh)) v)  sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a)            => NonEmpty (Shaped sh a) -> Shaped (n : sh) a @@ -1236,7 +1221,7 @@ sunScalar arr = sindex arr ZIS  sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a)  sconstantP x    | Dict <- lemKnownMapJust (Proxy @sh) -  = Shaped (mconstantP (ixCvtSX (cvtShSIxS (knownShape @sh))) x) +  = Shaped (mconstantP (shCvtSX (knownShape @sh)) x)  sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a))            => a -> Shaped sh a | 
