From 06625c89089044b064bbc6cf36ea4e83199c19a4 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 17 May 2024 13:32:32 +0200 Subject: WIP rewrite to singletons only --- src/Data/Array/Mixed.hs | 78 +++++++------- src/Data/Array/Nested/Internal.hs | 216 +++++++++++++++++++++----------------- 2 files changed, 155 insertions(+), 139 deletions(-) (limited to 'src') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index ce18431..69c44ab 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -127,17 +127,6 @@ deriving instance Show (StaticShX sh) infixr 3 :!$@ infixr 3 :!$? --- | Evidence for the static part of a shape. -type KnownShapeX :: [Maybe Nat] -> Constraint -class KnownShapeX sh where - knownShapeX :: StaticShX sh -instance KnownShapeX '[] where - knownShapeX = ZKSX -instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where - knownShapeX = natSing :!$@ knownShapeX -instance KnownShapeX sh => KnownShapeX (Nothing : sh) where - knownShapeX = () :!$? knownShapeX - type family Rank sh where Rank '[] = 0 Rank (_ : sh) = 1 + Rank sh @@ -162,6 +151,7 @@ completeShXzeros ZKSX = ZSX completeShXzeros (n :!$@ ssh) = n :$@ completeShXzeros ssh completeShXzeros (_ :!$? ssh) = 0 :$? completeShXzeros ssh +-- TODO: generalise all these things to arbitrary @i@ ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh') ixAppend ZIX idx' = idx' ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx' @@ -177,6 +167,15 @@ ixDrop sh ZIX = sh ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx +shDropIx :: IShX (sh ++ sh') -> IIxX sh -> IShX sh' +shDropIx sh ZIX = sh +shDropIx (_ :$@ sh) (_ :.@ idx) = shDropIx sh idx +shDropIx (_ :$? sh) (_ :.? idx) = shDropIx sh idx + +shTail :: IShX (n : sh) -> IShX sh +shTail (_ :$@ sh) = sh +shTail (_ :$? sh) = sh + ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend ZKSX sh' = sh' ssxAppend (n :!$@ sh) sh' = n :!$@ ssxAppend sh sh' @@ -279,22 +278,22 @@ lemKnownNatRankSSX ZKSX = Dict lemKnownNatRankSSX (_ :!$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict lemKnownNatRankSSX (_ :!$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict -lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh -lemKnownShapeX ZKSX = Dict -lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict -lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict - -lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2) -lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh' -lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh' - | Dict <- lemAppKnownShapeX ssh ssh' - = Dict -lemAppKnownShapeX (() :!$? ssh) ssh' - | Dict <- lemAppKnownShapeX ssh ssh' - = Dict - -shape :: forall sh a. KnownShapeX sh => XArray sh a -> IShX sh -shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) +-- lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh +-- lemKnownShapeX ZKSX = Dict +-- lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict +-- lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict + +-- lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2) +-- lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh' +-- lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh' +-- | Dict <- lemAppKnownShapeX ssh ssh' +-- = Dict +-- lemAppKnownShapeX (() :!$? ssh) ssh' +-- | Dict <- lemAppKnownShapeX ssh ssh' +-- = Dict + +shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh +shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) where go :: StaticShX sh' -> [Int] -> IShX sh' go ZKSX [] = ZSX @@ -345,10 +344,10 @@ type family AddMaybe n m where AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) -append :: forall n m sh a. (KnownShapeX sh, Storable a) - => XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a -append (XArray a) (XArray b) - | Dict <- lemKnownNatRankSSX (knownShapeX @sh) +append :: forall n m sh a. Storable a + => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a +append ssh (XArray a) (XArray b) + | Dict <- lemKnownNatRankSSX ssh = XArray (S.append a b) rerank :: forall sh sh1 sh2 a b. @@ -429,10 +428,6 @@ foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m foldHList _ HNil = mempty foldHList f (x `HCons` l) = f x <> foldHList f l -class KnownNatList l where makeNatList :: HList SNat l -instance KnownNatList '[] where makeNatList = HNil -instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList - type family TakeLen ref l where TakeLen '[] l = '[] TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs @@ -485,15 +480,16 @@ ssxIndex p pT (SS (i :: SNat i')) (() :!$? (sh :: StaticShX sh')) rest ssxIndex _ _ _ ZKSX _ = error "Index into empty shape" -- | The list argument gives indices into the original dimension list. -transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh) - => HList SNat is +transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh) + => StaticShX sh + -> HList SNat is -> XArray sh a -> XArray (PermutePrefix is sh) a -transpose perm (XArray arr) - | Dict <- lemKnownNatRankSSX (knownShapeX @sh) - , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm (knownShapeX @sh))) (ssxDropLen perm (knownShapeX @sh)) +transpose ssh perm (XArray arr) + | Dict <- lemKnownNatRankSSX ssh + , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm - , Refl <- lemRankDropLen (knownShapeX @sh) perm + , Refl <- lemRankDropLen ssh perm = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int] in XArray (S.transpose perm' arr) diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 7bd6565..d2883a7 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -28,6 +28,9 @@ TODO: * We should be more consistent in whether functions take a 'StaticShX' argument or a 'KnownShapeX' constraint. +* Allow downtyping certain dimensions, and write conversions between Mixed, + Ranked and Shaped + * Mikolaj wants these: About your wishlist of operations: these are already there @@ -78,7 +81,7 @@ import Data.Bifunctor (first) import Data.Coerce (coerce, Coercible) import Data.Foldable (toList) import Data.Kind -import Data.List.NonEmpty (NonEmpty) +import Data.List.NonEmpty (NonEmpty(..)) import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Storable as VS @@ -86,7 +89,7 @@ import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Storable (Storable) import GHC.TypeLits -import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..), pattern SZ, pattern SS, Replicate) +import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..), pattern SZ, pattern SS, Replicate) import qualified Data.Array.Mixed as X @@ -143,12 +146,12 @@ knownNatSucc :: KnownNat n => Dict KnownNat (n + 1) knownNatSucc = Dict -lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) -lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) - where - go :: SNat m -> StaticShX (Replicate m Nothing) - go SZ = ZKSX - go (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n +-- lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) +-- lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) +-- where +-- go :: SNat m -> StaticShX (Replicate m Nothing) +-- go SZ = ZKSX +-- go (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n lemRankReplicate _ = go (natSing @n) @@ -160,12 +163,9 @@ lemRankReplicate _ = go (natSing @n) , Refl <- go n = Refl -lemRankMapJust :: forall sh. KnownShape sh => Proxy sh -> X.Rank (MapJust sh) :~: X.Rank sh -lemRankMapJust _ = go (knownShape @sh) - where - go :: forall sh'. ShS sh' -> X.Rank (MapJust sh') :~: X.Rank sh' - go ZSS = Refl - go (_ :$$ sh') | Refl <- go sh' = Refl +lemRankMapJust :: forall sh. ShS sh -> X.Rank (MapJust sh) :~: X.Rank sh +lemRankMapJust ZSS = Refl +lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a @@ -197,11 +197,11 @@ class PrimElt a where fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) - default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a - fromPrimitive = coerce + default fromPrimitive :: Coercible (Mixed' sh a) (Mixed' sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a + fromPrimitive (Mixed sh m) = Mixed sh (coerce m) - default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) - toPrimitive = coerce + default toPrimitive :: Coercible (Mixed' sh (Primitive a)) (Mixed' sh a) => Mixed sh a -> Mixed sh (Primitive a) + toPrimitive (Mixed sh m) = Mixed sh (coerce m) -- [PRIMITIVE ELEMENT TYPES LIST] instance PrimElt Int @@ -218,31 +218,37 @@ instance PrimElt () -- -- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type -- class. -type Mixed :: [Maybe Nat] -> Type -> Type -data family Mixed sh a +data Mixed sh a = Mixed (IShX sh) (Mixed' sh a) +deriving instance Show (Mixed' sh a) => Show (Mixed sh a) + +unMixed :: Mixed sh a -> Mixed' sh a +unMixed (Mixed _ arr) = arr + +type Mixed' :: [Maybe Nat] -> Type -> Type +data family Mixed' sh a -- NOTE: When opening up the Mixed abstraction, you might see dimension sizes -- that you're not supposed to see. In particular, you might see (nonempty) -- sizes of the elements of an empty array, which is information that should -- ostensibly not exist; the full array is still empty. -newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a) +newtype instance Mixed' sh (Primitive a) = M_Primitive (XArray sh a) deriving (Show) -- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance Mixed sh Int = M_Int (XArray sh Int) +newtype instance Mixed' sh Int = M_Int (XArray sh Int) deriving (Show) -newtype instance Mixed sh Double = M_Double (XArray sh Double) +newtype instance Mixed' sh Double = M_Double (XArray sh Double) deriving (Show) -newtype instance Mixed sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector) +newtype instance Mixed' sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector) deriving (Show) -- etc. -data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) -deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) +data instance Mixed' sh (a, b) = M_Tup2 !(Mixed' sh a) !(Mixed' sh b) +deriving instance (Show (Mixed' sh a), Show (Mixed' sh b)) => Show (Mixed' sh (a, b)) -- etc. -newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a) -deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) +newtype instance Mixed' sh1 (Mixed sh2 a) = M_Nest (Mixed' (sh1 ++ sh2) a) +deriving instance Show (Mixed' (sh1 ++ sh2) a) => Show (Mixed' sh1 (Mixed sh2 a)) -- | Internal helper data family mirroring 'Mixed' that consists of mutable @@ -273,7 +279,7 @@ type family ShapeTree a where ShapeTree () = () ShapeTree (a, b) = (ShapeTree a, ShapeTree b) - ShapeTree (Mixed sh a) = (IShX sh, ShapeTree a) + ShapeTree (Mixed' sh a) = (IShX sh, ShapeTree a) ShapeTree (Ranked n a) = (IShR n, ShapeTree a) ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) @@ -284,7 +290,7 @@ type family ShapeTree a where class Elt a where -- ====== PUBLIC METHODS ====== -- - mshape :: KnownShapeX sh => Mixed sh a -> IShX sh + mshape :: Mixed sh a -> IShX sh mindex :: Mixed sh a -> IIxX sh -> a mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a mscalar :: a -> Mixed '[] a @@ -298,7 +304,7 @@ class Elt a where -- -- If you want a single-dimensional array from your list, map 'mscalar' -- first. - mfromList1 :: forall n sh. KnownShapeX (n : sh) => NonEmpty (Mixed sh a) -> Mixed (n : sh) a + mfromList1 :: forall n sh. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a mtoList1 :: Mixed (n : sh) a -> [Mixed sh a] @@ -307,13 +313,15 @@ class Elt a where -- full 'XArray' and as such you can distinguish different empty arrays by -- the "shapes" of their elements. This information is meaningless, so you -- should not use it. - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a -- | See the documentation for 'mlift'. - mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a -- ====== PRIVATE METHODS ====== -- @@ -341,7 +349,7 @@ class Elt a where -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. - mvecsWritePartial :: KnownShapeX sh' => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) @@ -349,31 +357,37 @@ class Elt a where -- Arrays of scalars are basically just arrays of scalars. instance Storable a => Elt (Primitive a) where - mshape (M_Primitive a) = X.shape a - mindex (M_Primitive a) i = Primitive (X.index a i) - mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) - mscalar (Primitive x) = M_Primitive (X.scalar x) - mfromList1 l = M_Primitive (X.fromList1 knownShapeX (coerce (toList l))) - mtoList1 (M_Primitive arr) = coerce (X.toList1 arr) + mshape (Mixed sh _) = sh + mindex (Mixed _ (M_Primitive a)) i = Primitive (X.index a i) + mindexPartial (Mixed sh (M_Primitive a)) i = Mixed (X.shDropIx sh i) (M_Primitive (X.indexPartial a i)) + mscalar (Primitive x) = Mixed ZSX (M_Primitive (X.scalar x)) + mfromList1 sn l@(arr1 :| _) = + let sh = sn :$@ mshape arr1 + in Mixed sh (M_Primitive (X.fromList1 (X.staticShapeFrom sh) (map (coerce . unMixed) (toList l)))) + mtoList1 (Mixed sh (M_Primitive arr)) = map (Mixed (X.shTail sh) . coerce) (X.toList1 arr) mlift :: forall sh1 sh2. - (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) + StaticShX sh2 + -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) - mlift f (M_Primitive a) + mlift ssh2 f (Mixed _ (M_Primitive a)) | Refl <- X.lemAppNil @sh1 , Refl <- X.lemAppNil @sh2 - = M_Primitive (f Proxy a) + , let result = f ZKSX a + = Mixed (X.shape ssh2 result) (M_Primitive result) mlift2 :: forall sh1 sh2 sh3. - (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) + StaticShX sh3 + -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) - mlift2 f (M_Primitive a) (M_Primitive b) + mlift2 ssh3 f (Mixed _ (M_Primitive a)) (Mixed _ (M_Primitive b)) | Refl <- X.lemAppNil @sh1 , Refl <- X.lemAppNil @sh2 , Refl <- X.lemAppNil @sh3 - = M_Primitive (f Proxy a b) + , let result = f ZKSX a b + = Mixed (X.shape ssh3 result) (M_Primitive result) - memptyArray sh = M_Primitive (X.empty sh) + memptyArray sh = Mixed sh (M_Primitive (X.empty sh)) mshapeTree _ = () mshapeTreeEq _ () () = True mshapeTreeEmpty _ () = False @@ -384,15 +398,20 @@ instance Storable a => Elt (Primitive a) where -- TODO: this use of toVector is suboptimal mvecsWritePartial - :: forall sh' sh s. KnownShapeX sh' - => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do - let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' (X.shape arr))) - VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) + :: forall sh' sh s. + IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + mvecsWritePartial sh i (Mixed sh' (M_Primitive arr)) (MV_Primitive v) = do + let arrsh = X.shape (X.staticShapeFrom sh') arr + offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' arrsh)) + VS.copy (VSM.slice offset (X.shapeSize arrsh) v) (X.toVector arr) - mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VS.freeze v + mvecsFreeze sh (MV_Primitive v) = Mixed sh . M_Primitive . X.fromVector sh <$> VS.freeze v -- [PRIMITIVE ELEMENT TYPES LIST] + + + +TODO -- should rewrite methods of Elt class to take ' in their name, and work on Mixed' instead of Mixed (and take explicit StaticShapeX). Then wrap all of the public functions to work on Mixed. Then don't export the contents of Elt from Nested.hs, and export the wrappers instead. This also makes the haddocks more consistent. deriving via Primitive Int instance Elt Int deriving via Primitive Double instance Elt Double deriving via Primitive () instance Elt () @@ -425,11 +444,11 @@ instance (Elt a, Elt b) => Elt (a, b) where mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b -- Arrays of arrays are just arrays, but with more dimensions. -instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where +instance Elt a => Elt (Mixed sh' a) where -- TODO: this is quadratic in the nesting depth because it repeatedly -- truncates the shape vector to one a little shorter. Fix with a -- moverlongShape method, a prefix of which is mshape. - mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IShX sh + mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh mshape (M_Nest arr) | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = fst (shAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr)) @@ -444,37 +463,36 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where mscalar = M_Nest - mfromList1 :: forall n sh. KnownShapeX (n : sh) - => NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (n : sh) (Mixed sh' a) + mfromList1 :: forall n sh. SNat n -> NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Just n : sh) (Mixed sh' a) mfromList1 l | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @(n : sh)) (knownShapeX @sh')) = M_Nest (mfromList1 (coerce l)) mtoList1 (M_Nest arr) = coerce (mtoList1 arr) - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) + mlift :: forall sh1 sh2. + (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) mlift f (M_Nest arr) | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) = M_Nest (mlift f' arr) where - f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b + f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b f' _ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) = f (Proxy @(sh' ++ shT)) - mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) - => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) + mlift2 :: forall sh1 sh2 sh3. + (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) mlift2 f (M_Nest arr1) (M_Nest arr2) | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh3) (knownShapeX @sh')) = M_Nest (mlift2 f' arr1 arr2) where - f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b + f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b f' _ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) @@ -503,8 +521,8 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs - mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + mvecsWritePartial :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -> ST s () mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) @@ -532,7 +550,7 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where -- the entire hierarchy (after distributing out tuples) must be a rectangular -- array. The type of 'mgenerate' allows this requirement to be broken very -- easily, hence the runtime check. -mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgenerate :: forall sh a. Elt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a mgenerate sh f = case X.enumShape sh of [] -> memptyArray sh firstidx : restidxs -> @@ -552,24 +570,25 @@ mgenerate sh f = case X.enumShape sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs -mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a +mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a) + => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a mtranspose perm | Dict <- X.lemKnownShapeX (X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm (knownShapeX @sh))) (X.ssxDropLen perm (knownShapeX @sh))) = mlift $ \(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @(X.PermutePrefix is sh)) (knownShapeX @sh') (X.transpose perm) -mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a) +mappend :: forall n m sh a. Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a mappend = mlift2 go - where go :: forall sh' b. (KnownShapeX sh', Storable b) - => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b + where go :: forall sh' b. Storable b + => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append -mfromVectorP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) mfromVectorP sh v = M_Primitive (X.fromVector sh v) -mfromVector :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a) => IShX sh -> VS.Vector a -> Mixed sh a +mfromVector :: forall sh a. (Storable a, PrimElt a) => IShX sh -> VS.Vector a -> Mixed sh a mfromVector sh v = fromPrimitive (mfromVectorP sh v) mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a @@ -578,8 +597,8 @@ mtoVectorP (M_Primitive v) = X.toVector v mtoVector :: (Storable a, PrimElt a) => Mixed sh a -> VS.Vector a mtoVector arr = mtoVectorP (coerce toPrimitive arr) -mfromList :: (KnownShapeX '[n], Elt a) => NonEmpty a -> Mixed '[n] a -mfromList = mfromList1 . fmap mscalar +mfromList :: Elt a => SNat n -> NonEmpty a -> Mixed '[Just n] a +mfromList sn = mfromList1 sn . fmap mscalar mtoList :: Elt a => Mixed '[n] a -> [a] mtoList = map munScalar . mtoList1 @@ -587,24 +606,23 @@ mtoList = map munScalar . mtoList1 munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr ZIX -mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> a -> Mixed sh (Primitive a) +mconstantP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) mconstantP sh x = M_Primitive (X.constant sh x) -mconstant :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a) +mconstant :: forall sh a. (Storable a, PrimElt a) => IShX sh -> a -> Mixed sh a mconstant sh x = fromPrimitive (mconstantP sh x) -mslice :: (KnownShapeX sh, Elt a) => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a +mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a mslice i n = withKnownNat n $ mlift $ \_ -> X.slice i n -msliceU :: (KnownShapeX sh, Elt a) => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a +msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a msliceU i n = mlift $ \_ -> X.sliceU i n -mrev1 :: (KnownShapeX (n : sh), Elt a) => Mixed (n : sh) a -> Mixed (n : sh) a +mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a mrev1 = mlift $ \_ -> X.rev1 -mreshape :: forall sh sh' a. (KnownShapeX sh, KnownShapeX sh', Elt a) - => IShX sh' -> Mixed sh a -> Mixed sh' a +mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a mreshape sh' = mlift $ \(_ :: Proxy shIn) -> X.reshapePartial (knownShapeX @sh) (knownShapeX @shIn) sh' @@ -620,18 +638,18 @@ mfromXArrayPrimP = M_Primitive mfromXArrayPrim :: PrimElt a => XArray sh a -> Mixed sh a mfromXArrayPrim = fromPrimitive . mfromXArrayPrimP -mliftPrim :: (KnownShapeX sh, Storable a) +mliftPrim :: Storable a => (a -> a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) mliftPrim f (M_Primitive (X.XArray arr)) = M_Primitive (X.XArray (S.mapA f arr)) -mliftPrim2 :: (KnownShapeX sh, Storable a) +mliftPrim2 :: Storable a => (a -> a -> a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) mliftPrim2 f (M_Primitive (X.XArray arr1)) (M_Primitive (X.XArray arr2)) = M_Primitive (X.XArray (S.zipWithA f arr1 arr2)) -instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) where +instance (Storable a, Num a) => Num (Mixed sh (Primitive a)) where (+) = mliftPrim2 (+) (-) = mliftPrim2 (-) (*) = mliftPrim2 (*) @@ -645,8 +663,8 @@ instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) whe \Unknown components in shape, use explicit mconstant" -- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types) -deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int) -deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double) +deriving via Mixed sh (Primitive Int) instance Num (Mixed sh Int) +deriving via Mixed sh (Primitive Double) instance Num (Mixed sh Double) -- | A rank-typed array: the number of dimensions of the array (its /rank/) is @@ -676,15 +694,16 @@ newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) -- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) -deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a)) -newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a)) -deriving instance Show (Mixed sh (Mixed (MapJust sh' ) a)) => Show (Mixed sh (Shaped sh' a)) +newtype instance Mixed' sh (Ranked n a) = M_Ranked (Mixed' sh (Mixed (Replicate n Nothing) a)) +deriving instance Show (Mixed' sh (Mixed (Replicate n Nothing) a)) => Show (Mixed' sh (Ranked n a)) +newtype instance Mixed' sh (Shaped sh' a) = M_Shaped (Mixed' sh (Mixed (MapJust sh' ) a)) +deriving instance Show (Mixed' sh (Mixed (MapJust sh' ) a)) => Show (Mixed' sh (Shaped sh' a)) newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a)) +{- -- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; -- these instances allow them to also be used as elements of arrays, thus -- making them first-class in the API. @@ -782,7 +801,7 @@ instance (Elt a, KnownNat n) => Elt (Ranked n a) where (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) - +-} -- | The shape of a shape-typed array given as a list of 'SNat' values. data ShS sh where @@ -793,11 +812,7 @@ deriving instance Eq (ShS sh) deriving instance Ord (ShS sh) infixr 3 :$$ --- | A statically-known shape of a shape-typed array. -class KnownShape sh where knownShape :: ShS sh -instance KnownShape '[] where knownShape = ZSS -instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = natSing :$$ knownShape - +{- sshapeKnown :: ShS sh -> Dict KnownShape sh sshapeKnown ZSS = Dict sshapeKnown (GHC_SNat :$$ sh) | Dict <- sshapeKnown sh = Dict @@ -942,6 +957,7 @@ instance (KnownNat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where -- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types) deriving via Ranked n (Primitive Int) instance KnownNat n => Num (Ranked n Int) deriving via Ranked n (Primitive Double) instance KnownNat n => Num (Ranked n Double) +-} type role ListR nominal representational type ListR :: Nat -> Type -> Type @@ -1021,6 +1037,7 @@ unconsShR :: ShR n1 i -> Maybe (UnconsShRRes i n1) unconsShR (ShR (i ::: sh')) = Just (UnconsShRRes (ShR sh') i) unconsShR (ShR ZR) = Nothing +{- knownShR :: ShR n i -> Dict KnownNat n knownShR (ShR sh) = knownListR sh @@ -1215,6 +1232,7 @@ instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) whe -- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types) deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int) deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double) +-} type role ListS nominal representational type ListS :: [Nat] -> Type -> Type @@ -1272,6 +1290,7 @@ unconsShS :: ShS sh1 -> Maybe (UnconsShSRes sh1) unconsShS (i :$$ shl') = Just (UnconsShSRes shl' i) unconsShS ZSS = Nothing +{- zeroIxS :: ShS sh -> IIxS sh zeroIxS ZSS = ZIS zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh @@ -1465,3 +1484,4 @@ sfromXArrayPrimP = Shaped . mfromXArrayPrimP sfromXArrayPrim :: PrimElt a => XArray (MapJust sh) a -> Shaped sh a sfromXArrayPrim = Shaped . mfromXArrayPrim +-} -- cgit v1.2.3-70-g09d2