diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 72 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Lemmas.hs | 14 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 170 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 584 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape/Internal.hs | 59 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 104 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 18 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 61 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 243 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 24 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 52 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 421 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Types.hs | 6 | ||||
| -rw-r--r-- | src/Data/Array/Strided/Orthotope.hs | 5 | ||||
| -rw-r--r-- | src/Data/Array/XArray.hs | 86 |
15 files changed, 1129 insertions, 790 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs index 8c88d23..408bf8a 100644 --- a/src/Data/Array/Nested/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -15,10 +15,10 @@ module Data.Array.Nested.Convert ( -- * Shape\/index\/list casting functions -- ** To ranked - ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2, + ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShXAnyShape, shrFromShX, listrCast, ixrCast, shrCast, -- ** To shaped - ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX, + ixsFromIxR, ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX, ixsCast, -- ** To mixed ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS, @@ -38,9 +38,11 @@ module Data.Array.Nested.Convert ( ) where import Control.Category +import Data.Coerce (coerce) import Data.Proxy import Data.Type.Equality import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed @@ -55,48 +57,39 @@ import Data.Array.Nested.Types -- * To ranked +-- TODO: change all those unsafeCoerces into coerces by defining shaped +-- and ranekd index types as newtypes of the mixed index type +-- and similarly for the sized lists or, preferably, by defining +-- all as newtypes over [], exploiting fusion and getting free toList. ixrFromIxS :: IxS sh i -> IxR (Rank sh) i -ixrFromIxS ZIS = ZIR -ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix +ixrFromIxS = unsafeCoerce -ixrFromIxX :: IxX sh i -> IxR (Rank sh) i -ixrFromIxX ZIX = ZIR -ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx +-- ixrFromIxX re-exported shrFromShS :: ShS sh -> IShR (Rank sh) shrFromShS ZSS = ZSR shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh --- shrFromShX re-exported --- shrFromShX2 re-exported +shrFromShXAnyShape :: IShX sh -> IShR (Rank sh) +shrFromShXAnyShape ZSX = ZSR +shrFromShXAnyShape (n :$% idx) = fromSMayNat' n :$: shrFromShXAnyShape idx + +shrFromShX :: IShX (Replicate n Nothing) -> IShR n +shrFromShX = coerce + -- listrCast re-exported -- ixrCast re-exported -- shrCast re-exported -- * To shaped --- TODO: these take a ShS because there are KnownNats inside IxS. - -ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i -ixsFromIxR ZSS ZIR = ZIS -ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx +ixsFromIxR :: IxR (Rank sh) i -> IxS sh i +ixsFromIxR = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled --- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the --- following, but more efficient: --- --- > ixsFromIxR' sh idx = ixsFromIxR sh (ixrCast (shsRank sh) idx) -ixsFromIxR' :: ShS sh -> IxR n i -> IxS sh i -ixsFromIxR' ZSS ZIR = ZIS -ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx -ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank" - --- TODO: this takes a ShS because there are KnownNats inside IxS. -ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i -ixsFromIxX ZSS ZIX = ZIS -ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx +-- ixsFromIxX re-exported -- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to --- the following, but more efficient: +-- the following, but less verbose: -- -- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx) ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i @@ -113,7 +106,8 @@ withShsFromShR (n :$: sh) k = Just sn@SNat -> k (sn :$$ sh') Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")" --- shsFromShX re-exported +shsFromShX :: IShX (MapJust sh) -> ShS sh +shsFromShX = coerce -- | Produce an existential 'ShS' from an 'IShX'. If you already know that -- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead. @@ -128,6 +122,7 @@ withShsFromShX (SUnknown n :$% sh) k = Just sn@SNat -> k (sn :$$ sh') Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")" +-- If it ever matters for performance, this is unsafeCoercible. shsFromSSX :: StaticShX (MapJust sh) -> ShS sh shsFromSSX = shsFromShX Prelude.. shxFromSSX @@ -135,25 +130,14 @@ shsFromSSX = shsFromShX Prelude.. shxFromSSX -- * To mixed -ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i -ixxFromIxR ZIR = ZIX -ixxFromIxR (n :.: (idx :: IxR m i)) = - castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m))) - (n :.% ixxFromIxR idx) - -ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i -ixxFromIxS ZIS = ZIX -ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh +-- ixxFromIxR re-exported +-- ixxFromIxS re-exported shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i -shxFromShR ZSR = ZSX -shxFromShR (n :$: (idx :: ShR m i)) = - castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) (Proxy @m))) - (SUnknown n :$% shxFromShR idx) +shxFromShR = coerce shxFromShS :: ShS sh -> IShX (MapJust sh) -shxFromShS ZSS = ZSX -shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh +shxFromShS = coerce -- ixxCast re-exported -- shxCast re-exported diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs index e089479..fa5611b 100644 --- a/src/Data/Array/Nested/Lemmas.hs +++ b/src/Data/Array/Nested/Lemmas.hs @@ -56,6 +56,20 @@ lemReplicatePlusApp sn _ _ = go sn -} lemReplicatePlusApp _ _ _ = unsafeCoerceRefl +lemReplicateEmpty :: proxy n -> Replicate n (Nothing @Nat) :~: '[] -> n :~: 0 +lemReplicateEmpty _ Refl = unsafeCoerceRefl + +-- TODO: make less ad-hoc and rename these three: +lemReplicateCons :: proxy sh -> proxy' n1 -> Nothing : sh :~: Replicate n1 Nothing -> n1 :~: Rank sh + 1 +lemReplicateCons _ _ Refl = unsafeCoerceRefl + +lemReplicateCons2 :: proxy sh -> proxy' n1 -> Nothing : sh :~: Replicate n1 Nothing -> sh :~: Replicate (Rank sh) Nothing +lemReplicateCons2 _ _ Refl = unsafeCoerceRefl + +lemReplicateSucc2 :: forall n1 n proxy. + proxy n1 -> n + 1 :~: n1 -> Nothing @Nat : Replicate n Nothing :~: Replicate n1 Nothing +lemReplicateSucc2 _ _ = unsafeCoerceRefl + lemDropLenApp :: Rank l1 <= Rank l2 => Proxy l1 -> Proxy l2 -> Proxy rest -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest) diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 182943d..39f00fa 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -23,12 +23,14 @@ module Data.Array.Nested.Mixed where import Prelude hiding (mconcat) import Control.DeepSeq (NFData(..)) -import Control.Monad (forM_, when) +import Control.Monad (foldM_, forM_, when) import Control.Monad.ST +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as ORG +import Data.Array.Internal.RankedS qualified as ORS import Data.Array.RankedS qualified as S import Data.Bifunctor (bimap) import Data.Coerce -import Data.Foldable (toList) import Data.Int import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty(..)) @@ -39,6 +41,7 @@ import Data.Vector.Storable qualified as VS import Data.Vector.Storable.Mutable qualified as VSM import Foreign.C.Types (CInt) import Foreign.Storable (Storable) +import Foreign.Storable qualified as Storable import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits @@ -237,11 +240,13 @@ instance Elt a => NFData (Mixed sh a) where rnf = mrnf +{-# INLINE mliftNumElt1 #-} mliftNumElt1 :: (PrimElt a, PrimElt b) => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b) -> Mixed sh a -> Mixed sh b mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) +{-# INLINE mliftNumElt2 #-} mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c) => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c) -> Mixed sh a -> Mixed sh b -> Mixed sh c @@ -310,7 +315,7 @@ class Elt a where -- | See 'mfromListOuter'. If the list does not have the given length, a -- runtime error is thrown. 'mfromListPrimSN' is faster if applicable. - mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a + mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] @@ -355,6 +360,9 @@ class Elt a where -- | Tree giving the shape of every array component. type ShapeTree a + -- | Produces an internal representation of a tree of shapes of (potentially) + -- nested arrays. If the argument is an array, it requires that the array + -- is not empty (it's guaranteed to crash early otherwise). mshapeTree :: a -> ShapeTree a mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool @@ -367,17 +375,21 @@ class Elt a where -- this mixed array. marrayStrides :: Mixed sh a -> Bag [Int] - -- | Given the shape of this array, an index and a value, write the value at + -- | Given a linear index and a value, write the value at -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + mvecsWriteLinear :: Int -> a -> MixedVecs s sh a -> ST s () - -- | Given the shape of this array, an index and a value, write the value at + -- | Given a linear index and a value, write the value at -- that index in the vectors. - mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + mvecsWritePartialLinear :: Proxy sh -> Int -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + -- | 'mvecsFreeze' but without copying the mutable vectors before freezing + -- them. If the mutable vectors are mutated after this function, referential + -- transparency is broken. + mvecsUnsafeFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) -- | Element types for which we have evidence of the (static part of the) shape -- in a type class constraint. Compare the instance contexts of the instances @@ -393,11 +405,18 @@ class Elt a => KnownElt a where -- this vector and an example for the contents. mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) + -- | Create initialised vectors for this array type, given the shape of + -- this vector and the chosen element. + mvecsReplicate :: IShX sh -> a -> ST s (MixedVecs s sh a) + mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) -- Arrays of scalars are basically just arrays of scalars. instance Storable a => Elt (Primitive a) where + -- Somehow, INLINE here can increase allocation with GHC 9.14.1. + -- Maybe that happens in void instances such as @Primitive ()@. + {-# INLINEABLE mshape #-} mshape (M_Primitive sh _) = sh {-# INLINEABLE mindex #-} mindex (M_Primitive _ a) i = Primitive (X.index a i) @@ -405,10 +424,11 @@ instance Storable a => Elt (Primitive a) where mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) mfromListOuterSN sn l@(arr1 :| _) = - let sh = SKnown sn :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l))) + let sh = mshape arr1 + in M_Primitive (SKnown sn :$% sh) (X.fromListOuterSN sn sh ((\(M_Primitive _ a) -> a) <$> l)) mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) @@ -419,6 +439,7 @@ instance Storable a => Elt (Primitive a) where , let result = f ZKX a = M_Primitive (X.shape ssh2 result) result + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) @@ -430,6 +451,7 @@ instance Storable a => Elt (Primitive a) where , let result = f ZKX a b = M_Primitive (X.shape ssh3 result) result + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) @@ -464,18 +486,22 @@ instance Storable a => Elt (Primitive a) where mshapeTreeIsEmpty _ () = False mshowShapeTree _ () = "()" marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x + mvecsWriteLinear i (Primitive x) (MV_Primitive v) = VSM.write v i x - -- TODO: this use of toVector is suboptimal - mvecsWritePartial + -- TODO: this use of toVectorListT is suboptimal + mvecsWritePartialLinear :: forall sh' sh s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do + Proxy sh -> Int -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + mvecsWritePartialLinear _ i (M_Primitive sh' arr@(XArray (ORS.A (ORG.A sht t)))) (MV_Primitive v) = do let arrsh = X.shape (ssxFromShX sh') arr - offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) - VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) + offset = i * shxSize arrsh + f off el = do + VS.copy (VSM.slice off (VS.length el) v) el + return $! off + VS.length el + foldM_ f offset (OI.toVectorListT sht t) mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v + mvecsUnsafeFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.unsafeFreeze v -- [PRIMITIVE ELEMENT TYPES LIST] deriving via Primitive Bool instance Elt Bool @@ -492,6 +518,7 @@ deriving via Primitive () instance Elt () instance Storable a => KnownElt (Primitive a) where memptyArrayUnsafe sh = M_Primitive sh (X.empty sh) mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) + mvecsReplicate sh (Primitive a) = MV_Primitive <$> VSM.replicate (shxSize sh) a mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 -- [PRIMITIVE ELEMENT TYPES LIST] @@ -508,16 +535,22 @@ deriving via Primitive () instance KnownElt () -- Arrays of pairs are pairs of arrays. instance (Elt a, Elt b) => Elt (a, b) where + {-# INLINEABLE mshape #-} mshape (M_Tup2 a _) = mshape a + {-# INLINEABLE mindex #-} mindex (M_Tup2 a b) i = (mindex a i, mindex b i) + {-# INLINEABLE mindexPartial #-} mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) mfromListOuterSN sn l = M_Tup2 (mfromListOuterSN sn ((\(M_Tup2 x _) -> x) <$> l)) (mfromListOuterSN sn ((\(M_Tup2 _ y) -> y) <$> l)) mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) + {-# INLINE mlift #-} mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) + {-# INLINE mlift2 #-} mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) + {-# INLINE mliftL #-} mliftL ssh2 f = let unzipT2l [] = ([], []) unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) @@ -542,17 +575,19 @@ instance (Elt a, Elt b) => Elt (a, b) where mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b + mvecsWriteLinear i (x, y) (MV_Tup2 a b) = do + mvecsWriteLinear i x a + mvecsWriteLinear i y b + mvecsWritePartialLinear proxy i (M_Tup2 x y) (MV_Tup2 a b) = do + mvecsWritePartialLinear proxy i x a + mvecsWritePartialLinear proxy i y b mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b + mvecsUnsafeFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsUnsafeFreeze sh a <*> mvecsUnsafeFreeze sh b instance (KnownElt a, KnownElt b) => KnownElt (a, b) where memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh) mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsReplicate sh (x, y) = MV_Tup2 <$> mvecsReplicate sh x <*> mvecsReplicate sh y mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) -- Arrays of arrays are just arrays, but with more dimensions. @@ -560,13 +595,16 @@ instance Elt a => Elt (Mixed sh' a) where -- TODO: this is quadratic in the nesting depth because it repeatedly -- truncates the shape vector to one a little shorter. Fix with a -- moverlongShape method, a prefix of which is mshape. + {-# INLINEABLE mshape #-} mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh mshape (M_Nest sh arr) - = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr)) + = shxTakeSh (Proxy @sh') sh (mshape arr) + {-# INLINEABLE mindex #-} mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a mindex (M_Nest _ arr) = mindexPartial arr + {-# INLINEABLE mindexPartial #-} mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest sh arr) i @@ -581,16 +619,17 @@ instance Elt a => Elt (Mixed sh' a) where mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) mlift ssh2 f (M_Nest sh1 arr) = let result = mlift (ssxAppend ssh2 ssh') f' arr - (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) + sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape result) in M_Nest sh2 result where - ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr))) + ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr)) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b f' sshT @@ -598,16 +637,17 @@ instance Elt a => Elt (Mixed sh' a) where , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) = f (ssxAppend ssh' sshT) + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 - (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) + sh3 = shxTakeSSX (Proxy @sh') ssh3 (mshape result) in M_Nest sh3 result where - ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) + ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1)) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b f' sshT @@ -616,16 +656,17 @@ instance Elt a => Elt (Mixed sh' a) where , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) = f (ssxAppend ssh' sshT) + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b)) -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a)) mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) = let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l) - (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result)) + sh2 = shxTakeSSX (Proxy @sh') ssh2 (mshape (NE.head result)) in fmap (M_Nest sh2) result where - ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) + ssh' = ssxFromShX (shxDropSh @sh1 @sh' sh1 (mshape arr1)) f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b) f' sshT @@ -658,12 +699,13 @@ instance Elt a => Elt (Mixed sh' a) where mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) mconcat l@(M_Nest sh1 _ :| _) = let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) - in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result + in M_Nest (shxTakeSh (Proxy @sh') sh1 (mshape result)) result mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) + -- This requires that @arr@ is not empty. mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr))))) @@ -676,17 +718,20 @@ instance Elt a => Elt (Mixed sh' a) where marrayStrides (M_Nest _ arr) = marrayStrides arr - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs + mvecsWriteLinear :: forall s sh. Int -> Mixed sh' a -> MixedVecs s sh (Mixed sh' a) -> ST s () + mvecsWriteLinear idx val (MV_Nest _ vecs) = mvecsWritePartialLinear (Proxy @sh) idx val vecs - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) + mvecsWritePartialLinear + :: forall sh1 sh2 s. + Proxy sh1 -> Int -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () + mvecsWritePartialLinear proxy idx (M_Nest _ arr) (MV_Nest _ vecs) | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs + = mvecsWritePartialLinear proxy idx arr vecs mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs + mvecsUnsafeFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsUnsafeFreeze (shxAppend sh sh') vecs instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh')))) @@ -697,9 +742,28 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where where sh' = mshape example + mvecsReplicate sh example = do + vecs <- mvecsUnsafeNew sh example + forM_ [0 .. shxSize sh - 1] $ \idx -> mvecsWriteLinear idx example vecs + -- this is a slow case, but the alternative, mvecsUnsafeNew with manual + -- writing in a loop, leads to every case being as slow + return vecs + mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWrite #-} +mvecsWrite :: Elt a => IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () +mvecsWrite sh idx val vecs = mvecsWriteLinear (ixxToLinear sh idx) val vecs + +-- | Given the shape of this array, an index and a value, write the value at +-- that index in the vectors. +{-# INLINE mvecsWritePartial #-} +mvecsWritePartial :: forall sh sh' s a. Elt a => IShX sh -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () +mvecsWritePartial sh idx val vecs = mvecsWritePartialLinear (Proxy @sh) (ixxToLinear sh idx) val vecs + -- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) @@ -746,7 +810,7 @@ mgenerate sh f = case shxEnum sh of when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ error "Data.Array.Nested mgenerate: generated values do not have equal shapes" mvecsWrite sh idx val vecs - mvecsFreeze sh vecs + mvecsUnsafeFreeze sh vecs -- | An optimized special case of 'mgenerate', where the function results -- are of a primitive type and so there's not need to check that all shapes @@ -759,19 +823,23 @@ mgeneratePrim sh f = let g i = f (ixxFromLinear sh i) in mfromVector sh $ VS.generate (shxSize sh) g +{-# INLINEABLE msumOuter1PrimP #-} msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1PrimP (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr) +{-# INLINEABLE msumOuter1Prim #-} msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a) => Mixed (n : sh) a -> Mixed sh a msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive +{-# INLINEABLE msumAllPrimP #-} msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr +{-# INLINEABLE msumAllPrim #-} msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a msumAllPrim arr = msumAllPrimP (toPrimitive arr) @@ -782,7 +850,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 sn :$% sh = mshape arr1 sm :$% _ = mshape arr2 ssh = ssxFromShX sh - snm :: SMayNat () SNat (AddMaybe n m) + snm :: SMayNat () (AddMaybe n m) snm = case (sn, sm) of (SUnknown{}, _) -> SUnknown () (SKnown{}, SUnknown{}) -> SUnknown () @@ -792,15 +860,19 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b f ssh' = X.append (ssxAppend ssh ssh') +{-# INLINEABLE mfromVectorP #-} mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) +{-# INLINEABLE mfromVector #-} mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a mfromVector sh v = fromPrimitive (mfromVectorP sh v) +{-# INLINEABLE mtoVectorP #-} mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a mtoVectorP (M_Primitive _ v) = X.toVector v +{-# INLINEABLE mtoVector #-} mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a mtoVector arr = mtoVectorP (toPrimitive arr) @@ -856,7 +928,7 @@ mfromListLinear sh l = mreshape sh (mfromList1N (shxSize sh) l) mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a mfromList1Prim l = let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l + xarr = X.fromList1 l in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr mfromList1PrimN :: PrimElt a => Int -> [a] -> Mixed '[Nothing] a @@ -865,11 +937,15 @@ mfromList1PrimN n l = Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromList1PrimSN sn l) Nothing -> error $ "mfromList1PrimN: length negative (" ++ show n ++ ")" -mfromList1PrimSN :: PrimElt a => SNat n -> [a] -> Mixed '[Just n] a +mfromList1PrimSN :: forall n a. PrimElt a => SNat n -> [a] -> Mixed '[Just n] a mfromList1PrimSN sn l = - let ssh = SKnown sn :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + let ssh = SKnown sn :$% ZSX + in fromPrimitive $ M_Primitive ssh + $ if Storable.sizeOf (undefined :: a) > 0 + then X.fromList1SN sn l + else case l of -- don't force the list if all elements are the same + a0 : _ -> X.replicateScal ssh a0 + [] -> X.fromList1SN sn l mfromListPrimLinear :: forall sh a. PrimElt a => IShX sh -> [a] -> Mixed sh a mfromListPrimLinear sh l = @@ -886,7 +962,7 @@ munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr ZIX mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a) -mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr +mnest ssh arr = M_Nest (shxTakeSSX (Proxy @sh') ssh (mshape arr)) arr munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a munNest (M_Nest _ arr) = arr @@ -999,6 +1075,7 @@ mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr)) +{-# INLINEABLE mdot1Inner #-} mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b)) @@ -1014,6 +1091,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. -- Prefer 'mdot1Inner' if applicable. +{-# INLINEABLE mdot #-} mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a mdot a b = munScalar $ @@ -1032,11 +1110,13 @@ mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP +{-# INLINE mliftPrim #-} mliftPrim :: (PrimElt a, PrimElt b) => (a -> b) -> Mixed sh a -> Mixed sh b mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) +{-# INLINE mliftPrim2 #-} mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) => (a -> b -> c) -> Mixed sh a -> Mixed sh b -> Mixed sh c diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index c999853..abcf3f8 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -1,9 +1,10 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} @@ -31,13 +32,11 @@ import Control.DeepSeq (NFData(..)) import Data.Bifunctor (first) import Data.Coerce import Data.Foldable qualified as Foldable -import Data.Functor.Const -import Data.Functor.Product import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) +import Data.Proxy import Data.Type.Equality -import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build) -import GHC.Generics (Generic) +import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits @@ -45,7 +44,6 @@ import GHC.TypeLits import GHC.TypeLits.Orphans () #endif -import Data.Array.Nested.Mixed.Shape.Internal import Data.Array.Nested.Types @@ -56,129 +54,107 @@ type family Rank sh where Rank (_ : sh) = Rank sh + 1 --- * Mixed lists +-- * Mixed lists to be used IxX and shaped and ranked lists and indexes type role ListX nominal representational -type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type -data ListX sh f where - ZX :: ListX '[] f - (::%) :: f n -> ListX sh f -> ListX (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) +type ListX :: [Maybe Nat] -> Type -> Type +data ListX sh i where + ZX :: ListX '[] i + (::%) :: forall n sh {i}. i -> ListX sh i -> ListX (n : sh) i +deriving instance Eq i => Eq (ListX sh i) +deriving instance Ord i => Ord (ListX sh i) infixr 3 ::% #ifdef OXAR_DEFAULT_SHOW_INSTANCES -deriving instance (forall n. Show (f n)) => Show (ListX sh f) +deriving instance Show i => Show (ListX sh i) #else -instance (forall n. Show (f n)) => Show (ListX sh f) where +instance Show i => Show (ListX sh i) where showsPrec _ = listxShow shows #endif -instance (forall n. NFData (f n)) => NFData (ListX sh f) where +instance NFData i => NFData (ListX sh i) where rnf ZX = () rnf (x ::% l) = rnf x `seq` rnf l -data UnconsListXRes f sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) +data UnconsListXRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh i) i listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) listxUncons ZX = Nothing --- | This checks only whether the types are equal; if the elements of the list --- are not singletons, their values may still differ. This corresponds to --- 'testEquality', except on the penultimate type parameter. -listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') -listxEqType ZX ZX = Just Refl -listxEqType (n ::% sh) (m ::% sh') - | Just Refl <- testEquality n m - , Just Refl <- listxEqType sh sh' - = Just Refl -listxEqType _ _ = Nothing - --- | This checks whether the two lists actually contain equal values. This is --- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ --- in the @some@ package (except on the penultimate type parameter). -listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') -listxEqual ZX ZX = Just Refl -listxEqual (n ::% sh) (m ::% sh') - | Just Refl <- testEquality n m - , n == m - , Just Refl <- listxEqual sh sh' - = Just Refl -listxEqual _ _ = Nothing - -{-# INLINE listxFmap #-} -listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g -listxFmap _ ZX = ZX -listxFmap f (x ::% xs) = f x ::% listxFmap f xs +instance Functor (ListX l) where + {-# INLINE fmap #-} + fmap _ ZX = ZX + fmap f (x ::% xs) = f x ::% fmap f xs -{-# INLINE listxFoldMap #-} -listxFoldMap :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m -listxFoldMap _ ZX = mempty -listxFoldMap f (x ::% xs) = f x <> listxFoldMap f xs +instance Foldable (ListX l) where + {-# INLINE foldMap #-} + foldMap _ ZX = mempty + foldMap f (x ::% xs) = f x <> foldMap f xs + {-# INLINE foldr #-} + foldr _ z ZX = z + foldr f z (x ::% xs) = f x (foldr f z xs) + toList = listxToList + null ZX = False + null _ = True -listxLength :: ListX sh f -> Int -listxLength = getSum . listxFoldMap (\_ -> Sum 1) +listxLength :: ListX sh i -> Int +listxLength = length -listxRank :: ListX sh f -> SNat (Rank sh) +listxRank :: ListX sh i -> SNat (Rank sh) listxRank ZX = SNat listxRank (_ ::% l) | SNat <- listxRank l = SNat {-# INLINE listxShow #-} -listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS +listxShow :: forall sh i. (i -> ShowS) -> ListX sh i -> ShowS listxShow f l = showString "[" . go "" l . showString "]" where - go :: String -> ListX sh' f -> ShowS + go :: String -> ListX sh' i -> ShowS go _ ZX = id go prefix (x ::% xs) = showString prefix . f x . go "," xs -listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i) +listxFromList :: StaticShX sh -> [i] -> ListX sh i listxFromList topssh topl = go topssh topl where - go :: StaticShX sh' -> [i] -> ListX sh' (Const i) + go :: StaticShX sh' -> [i] -> ListX sh' i go ZKX [] = ZX - go (_ :!% sh) (i : is) = Const i ::% go sh is + go (_ :!% sh) (i : is) = i ::% go sh is go _ _ = error $ "listxFromList: Mismatched list length (type says " ++ show (ssxLength topssh) ++ ", list has length " ++ show (length topl) ++ ")" {-# INLINEABLE listxToList #-} -listxToList :: ListX sh (Const i) -> [i] +listxToList :: ListX sh i -> [i] listxToList list = build (\(cons :: i -> is -> is) (nil :: is) -> - let go :: ListX sh (Const i) -> is + let go :: ListX sh i -> is go ZX = nil - go (Const i ::% is) = i `cons` go is + go (i ::% is) = i `cons` go is in go list) -listxHead :: ListX (mn ': sh) f -> f mn +listxHead :: ListX (mn ': sh) i -> i listxHead (i ::% _) = i listxTail :: ListX (n : sh) i -> ListX sh i listxTail (_ ::% sh) = sh -listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f +listxAppend :: ListX sh i -> ListX sh' i -> ListX (sh ++ sh') i listxAppend ZX idx' = idx' listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' -listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f +listxDrop :: forall i j sh sh'. ListX sh j -> ListX (sh ++ sh') i -> ListX sh' i listxDrop ZX long = long listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long' -listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f +listxInit :: forall i n sh. ListX (n : sh) i -> ListX (Init (n : sh)) i listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh listxInit (_ ::% ZX) = ZX -listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh)) +listxLast :: forall i n sh. ListX (n : sh) i -> i listxLast (_ ::% sh@(_ ::% _)) = listxLast sh listxLast (x ::% ZX) = x -listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g) -listxZip ZX ZX = ZX -listxZip (i ::% irest) (j ::% jrest) = Pair i j ::% listxZip irest jrest - {-# INLINE listxZipWith #-} -listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g - -> ListX sh h +listxZipWith :: (i -> j -> k) -> ListX sh i -> ListX sh j -> ListX sh k listxZipWith _ ZX ZX = ZX listxZipWith f (i ::% is) (j ::% js) = f i j ::% listxZipWith f is js @@ -188,8 +164,8 @@ listxZipWith f (i ::% is) (j ::% js) = f i j ::% listxZipWith f is js -- | An index into a mixed-typed array. type role IxX nominal representational type IxX :: [Maybe Nat] -> Type -> Type -newtype IxX sh i = IxX (ListX sh (Const i)) - deriving (Eq, Ord, Generic) +newtype IxX sh i = IxX (ListX sh i) + deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i pattern ZIX = IxX ZX @@ -198,8 +174,8 @@ pattern (:.%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => i -> IxX sh i -> IxX sh1 i -pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) - where i :.% IxX shl = IxX (Const i ::% shl) +pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) i)) + where i :.% IxX shl = IxX (i ::% shl) infixr 3 :.% {-# COMPLETE ZIX, (:.%) #-} @@ -212,25 +188,9 @@ type IIxX sh = IxX sh Int deriving instance Show i => Show (IxX sh i) #else instance Show i => Show (IxX sh i) where - showsPrec _ (IxX l) = listxShow (shows . getConst) l + showsPrec _ (IxX l) = listxShow shows l #endif -instance Functor (IxX sh) where - {-# INLINE fmap #-} - fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) - -instance Foldable (IxX sh) where - {-# INLINE foldMap #-} - foldMap f (IxX l) = listxFoldMap (f . getConst) l - {-# INLINE foldr #-} - foldr _ z ZIX = z - foldr f z (x :.% xs) = f x (foldr f z xs) - toList = ixxToList - null ZIX = False - null _ = True - -instance NFData i => NFData (IxX sh i) - ixxLength :: IxX sh i -> Int ixxLength (IxX l) = listxLength l @@ -245,30 +205,30 @@ ixxZero' :: IShX sh -> IIxX sh ixxZero' ZSX = ZIX ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh +{-# INLINEABLE ixxFromList #-} ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i ixxFromList = coerce (listxFromList @_ @i) -{-# INLINEABLE ixxToList #-} -ixxToList :: forall sh i. IxX sh i -> [i] -ixxToList = coerce (listxToList @_ @i) +ixxToList :: IxX sh i -> [i] +ixxToList = Foldable.toList ixxHead :: IxX (n : sh) i -> i -ixxHead (IxX list) = getConst (listxHead list) +ixxHead (IxX list) = listxHead list ixxTail :: IxX (n : sh) i -> IxX sh i ixxTail (IxX list) = IxX (listxTail list) ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i -ixxAppend = coerce (listxAppend @_ @(Const i)) +ixxAppend = coerce (listxAppend @_ @i) ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i -ixxDrop = coerce (listxDrop @(Const i) @(Const i)) +ixxDrop = coerce (listxDrop @i @i) ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i -ixxInit = coerce (listxInit @(Const i)) +ixxInit = coerce (listxInit @i) ixxLast :: forall n sh i. IxX (n : sh) i -> i -ixxLast = coerce (listxLast @(Const i)) +ixxLast = coerce (listxLast @i) ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i ixxCast ZKX ZIX = ZIX @@ -284,43 +244,96 @@ ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k ixxZipWith _ ZIX ZIX = ZIX ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js -ixxToLinear :: IShX sh -> IIxX sh -> Int -ixxToLinear = \sh i -> fst (go sh i) +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixxToLinear #-} +ixxToLinear :: Num i => IShX sh -> IxX sh i -> i +ixxToLinear = \sh i -> go sh i 0 where - -- returns (index in subarray, size of subarray) - go :: IShX sh -> IIxX sh -> (Int, Int) - go ZSX ZIX = (0, 1) - go (n :$% sh) (i :.% ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, fromSMayNat' n * sz) + go :: Num i => IShX sh -> IxX sh i -> i -> i + go ZSX ZIX !a = a + go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i) +{-# INLINEABLE ixxFromLinear #-} +ixxFromLinear :: Num i => IShX sh -> Int -> IxX sh i +ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared when it's called many times + let suffixes = drop 1 (scanr (*) 1 (shxToList sh)) + in fromLin0 sh suffixes + where + -- Unfold first iteration of fromLin to do the range check. + -- Don't inline this function at first to allow GHC to inline the outer + -- function and realise that 'suffixes' is shared. But then later inline it + -- anyway, to avoid the function call. Removing the pragma makes GHC + -- somehow unable to recognise that 'suffixes' can be shared in a loop. + {-# NOINLINE [0] fromLin0 #-} + fromLin0 :: Num i => IShX sh -> [Int] -> Int -> IxX sh i + fromLin0 sh suffixes i = + if i < 0 then outrange sh i else + case (sh, suffixes) of + (ZSX, _) | i > 0 -> outrange sh i + | otherwise -> ZIX + ((fromSMayNat' -> n) :$% sh', suff : suffs) -> + let (q, r) = i `quotRem` suff + in if q >= n then outrange sh i else + fromIntegral q :.% fromLin sh' suffs r + _ -> error "impossible" --- * Mixed shapes + fromLin :: Num i => IShX sh -> [Int] -> Int -> IxX sh i + fromLin ZSX _ !_ = ZIX + fromLin (_ :$% sh') (suff : suffs) i = + let (q, r) = i `quotRem` suff -- suff == shrSize sh' + in fromIntegral q :.% fromLin sh' suffs r + fromLin _ _ _ = error "impossible" -data SMayNat i f n where - SUnknown :: i -> SMayNat i f Nothing - SKnown :: f n -> SMayNat i f (Just n) -deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) -deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) -deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) + {-# NOINLINE outrange #-} + outrange :: IShX sh -> Int -> a + outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" -instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where +shxEnum :: IShX sh -> [IIxX sh] +shxEnum = shxEnum' + +{-# INLINABLE shxEnum' #-} -- ensure this can be specialised at use site +shxEnum' :: Num i => IShX sh -> [IxX sh i] +shxEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shxSize sh - 1]] + where + suffixes = drop 1 (scanr (*) 1 (shxToList sh)) + + fromLin :: Num i => IShX sh -> [Int] -> Int# -> IxX sh i + fromLin ZSX _ _ = ZIX + fromLin (_ :$% sh') (I# suff# : suffs) i# = + let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh' + in fromIntegral (I# q#) :.% fromLin sh' suffs r# + fromLin _ _ _ = error "impossible" + + +-- * Mixed shape-like lists to be used for ShX and StaticShX + +data SMayNat i n where + SUnknown :: i -> SMayNat i Nothing + SKnown :: SNat n -> SMayNat i (Just n) +deriving instance Show i => Show (SMayNat i n) +deriving instance Eq i => Eq (SMayNat i n) +deriving instance Ord i => Ord (SMayNat i n) + +instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where rnf (SUnknown i) = rnf i rnf (SKnown x) = rnf x -instance TestEquality f => TestEquality (SMayNat i f) where +instance TestEquality (SMayNat i) where testEquality SUnknown{} SUnknown{} = Just Refl testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl testEquality _ _ = Nothing {-# INLINE fromSMayNat #-} fromSMayNat :: (n ~ Nothing => i -> r) - -> (forall m. n ~ Just m => f m -> r) - -> SMayNat i f n -> r + -> (forall m. n ~ Just m => SNat m -> r) + -> SMayNat i n -> r fromSMayNat f _ (SUnknown i) = f i fromSMayNat _ g (SKnown s) = g s -fromSMayNat' :: SMayNat Int SNat n -> Int +{-# INLINE fromSMayNat' #-} +fromSMayNat' :: SMayNat Int n -> Int fromSMayNat' = fromSMayNat id fromSNat' type family AddMaybe n m where @@ -328,27 +341,162 @@ type family AddMaybe n m where AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) -smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m) smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) --- | This is a newtype over 'ListX'. +type role ListH nominal representational +type ListH :: [Maybe Nat] -> Type -> Type +data ListH sh i where + ZH :: ListH '[] i + ConsUnknown :: forall sh i. i -> ListH sh i -> ListH (Nothing : sh) i +-- TODO: bring this UNPACK back when GHC no longer crashes: +-- ConsKnown :: forall n sh i. {-# UNPACK #-} SNat n -> ListH sh i -> ListH (Just n : sh) i + ConsKnown :: forall n sh i. SNat n -> ListH sh i -> ListH (Just n : sh) i +deriving instance Ord i => Ord (ListH sh i) + +-- A manually defined instance and this INLINEABLE is needed to specialize +-- mdot1Inner (otherwise GHC warns specialization breaks down here). +instance Eq i => Eq (ListH sh i) where + {-# INLINEABLE (==) #-} + ZH == ZH = True + ConsUnknown i1 sh1 == ConsUnknown i2 sh2 = i1 == i2 && sh1 == sh2 + ConsKnown _ sh1 == ConsKnown _ sh2 = sh1 == sh2 + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ListH sh i) +#else +instance Show i => Show (ListH sh i) where + showsPrec _ = listhShow shows +#endif + +instance NFData i => NFData (ListH sh i) where + rnf ZH = () + rnf (x `ConsUnknown` l) = rnf x `seq` rnf l + rnf (SNat `ConsKnown` l) = rnf l + +data UnconsListHRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListHRes (ListH sh i) (SMayNat i n) +listhUncons :: ListH sh1 i -> Maybe (UnconsListHRes i sh1) +listhUncons (i `ConsUnknown` shl') = Just (UnconsListHRes shl' (SUnknown i)) +listhUncons (i `ConsKnown` shl') = Just (UnconsListHRes shl' (SKnown i)) +listhUncons ZH = Nothing + +-- | This checks only whether the types are equal; if the elements of the list +-- are not singletons, their values may still differ. This corresponds to +-- 'testEquality', except on the penultimate type parameter. +listhEqType :: ListH sh i -> ListH sh' i -> Maybe (sh :~: sh') +listhEqType ZH ZH = Just Refl +listhEqType (_ `ConsUnknown` sh) (_ `ConsUnknown` sh') + | Just Refl <- listhEqType sh sh' + = Just Refl +listhEqType (n `ConsKnown` sh) (m `ConsKnown` sh') + | Just Refl <- testEquality n m + , Just Refl <- listhEqType sh sh' + = Just Refl +listhEqType _ _ = Nothing + +-- | This checks whether the two lists actually contain equal values. This is +-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ +-- in the @some@ package (except on the penultimate type parameter). +listhEqual :: Eq i => ListH sh i -> ListH sh' i -> Maybe (sh :~: sh') +listhEqual ZH ZH = Just Refl +listhEqual (n `ConsUnknown` sh) (m `ConsUnknown` sh') + | n == m + , Just Refl <- listhEqual sh sh' + = Just Refl +listhEqual (n `ConsKnown` sh) (m `ConsKnown` sh') + | Just Refl <- testEquality n m + , Just Refl <- listhEqual sh sh' + = Just Refl +listhEqual _ _ = Nothing + +{-# INLINE listhFmap #-} +listhFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ListH sh i -> ListH sh j +listhFmap _ ZH = ZH +listhFmap f (x `ConsUnknown` xs) = case f (SUnknown x) of + SUnknown y -> y `ConsUnknown` listhFmap f xs +listhFmap f (x `ConsKnown` xs) = case f (SKnown x) of + SKnown y -> y `ConsKnown` listhFmap f xs + +{-# INLINE listhFoldMap #-} +listhFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ListH sh i -> m +listhFoldMap _ ZH = mempty +listhFoldMap f (x `ConsUnknown` xs) = f (SUnknown x) <> listhFoldMap f xs +listhFoldMap f (x `ConsKnown` xs) = f (SKnown x) <> listhFoldMap f xs + +listhLength :: ListH sh i -> Int +listhLength = getSum . listhFoldMap (\_ -> Sum 1) + +listhRank :: ListH sh i -> SNat (Rank sh) +listhRank ZH = SNat +listhRank (_ `ConsUnknown` l) | SNat <- listhRank l = SNat +listhRank (_ `ConsKnown` l) | SNat <- listhRank l = SNat + +{-# INLINE listhShow #-} +listhShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ListH sh i -> ShowS +listhShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListH sh' i -> ShowS + go _ ZH = id + go prefix (x `ConsUnknown` xs) = showString prefix . f (SUnknown x) . go "," xs + go prefix (x `ConsKnown` xs) = showString prefix . f (SKnown x) . go "," xs + +listhHead :: ListH (mn ': sh) i -> SMayNat i mn +listhHead (i `ConsUnknown` _) = SUnknown i +listhHead (i `ConsKnown` _) = SKnown i + +listhTail :: ListH (n : sh) i -> ListH sh i +listhTail (_ `ConsUnknown` sh) = sh +listhTail (_ `ConsKnown` sh) = sh + +listhAppend :: ListH sh i -> ListH sh' i -> ListH (sh ++ sh') i +listhAppend ZH idx' = idx' +listhAppend (i `ConsUnknown` idx) idx' = i `ConsUnknown` listhAppend idx idx' +listhAppend (i `ConsKnown` idx) idx' = i `ConsKnown` listhAppend idx idx' + +listhDrop :: forall i j sh sh'. ListH sh j -> ListH (sh ++ sh') i -> ListH sh' i +listhDrop ZH long = long +listhDrop (_ `ConsUnknown` short) long = case long of + _ `ConsUnknown` long' -> listhDrop short long' +listhDrop (_ `ConsKnown` short) long = case long of + _ `ConsKnown` long' -> listhDrop short long' + +listhInit :: forall i n sh. ListH (n : sh) i -> ListH (Init (n : sh)) i +listhInit (i `ConsUnknown` sh@(_ `ConsUnknown` _)) = i `ConsUnknown` listhInit sh +listhInit (i `ConsUnknown` sh@(_ `ConsKnown` _)) = i `ConsUnknown` listhInit sh +listhInit (_ `ConsUnknown` ZH) = ZH +listhInit (i `ConsKnown` sh@(_ `ConsUnknown` _)) = i `ConsKnown` listhInit sh +listhInit (i `ConsKnown` sh@(_ `ConsKnown` _)) = i `ConsKnown` listhInit sh +listhInit (_ `ConsKnown` ZH) = ZH + +listhLast :: forall i n sh. ListH (n : sh) i -> SMayNat i (Last (n : sh)) +listhLast (_ `ConsUnknown` sh@(_ `ConsUnknown` _)) = listhLast sh +listhLast (_ `ConsUnknown` sh@(_ `ConsKnown` _)) = listhLast sh +listhLast (x `ConsUnknown` ZH) = SUnknown x +listhLast (_ `ConsKnown` sh@(_ `ConsUnknown` _)) = listhLast sh +listhLast (_ `ConsKnown` sh@(_ `ConsKnown` _)) = listhLast sh +listhLast (x `ConsKnown` ZH) = SKnown x + +-- * Mixed shapes + +-- | This is a newtype over 'ListH'. type role ShX nominal representational type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) - deriving (Eq, Ord, Generic) +newtype ShX sh i = ShX (ListH sh i) + deriving (Eq, Ord, NFData) pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i -pattern ZSX = ShX ZX +pattern ZSX = ShX ZH pattern (:$%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) - => SMayNat i SNat n -> ShX sh i -> ShX sh1 i -pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) - where i :$% ShX shl = ShX (i ::% shl) + => SMayNat i n -> ShX sh i -> ShX sh1 i +pattern i :$% shl <- ShX (listhUncons -> Just (UnconsListHRes (ShX -> shl) i)) + where i :$% ShX shl = case i of; SUnknown x -> ShX (x `ConsUnknown` shl); SKnown x -> ShX (x `ConsKnown` shl) infixr 3 :$% {-# COMPLETE ZSX, (:$%) #-} @@ -359,17 +507,12 @@ type IShX sh = ShX sh Int deriving instance Show i => Show (ShX sh i) #else instance Show i => Show (ShX sh i) where - showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + showsPrec _ (ShX l) = listhShow (fromSMayNat shows (shows . fromSNat)) l #endif instance Functor (ShX sh) where {-# INLINE fmap #-} - fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) - -instance NFData i => NFData (ShX sh i) where - rnf (ShX ZX) = () - rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) - rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) + fmap f (ShX l) = ShX (listhFmap (fromSMayNat (SUnknown . f) SKnown) l) -- | This checks only whether the types are equal; unknown dimensions might -- still differ. This corresponds to 'testEquality', except on the penultimate @@ -401,38 +544,40 @@ shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') shxEqual _ _ = Nothing shxLength :: ShX sh i -> Int -shxLength (ShX l) = listxLength l +shxLength (ShX l) = listhLength l shxRank :: ShX sh i -> SNat (Rank sh) -shxRank (ShX l) = listxRank l +shxRank (ShX l) = listhRank l -- | The number of elements in an array described by this shape. shxSize :: IShX sh -> Int shxSize ZSX = 1 shxSize (n :$% sh) = fromSMayNat' n * shxSize sh +-- We don't report the size of the list in case of errors in order not to retain the list. +{-# INLINEABLE shxFromList #-} shxFromList :: StaticShX sh -> [Int] -> IShX sh -shxFromList topssh topl = go topssh topl +shxFromList (StaticShX topssh) topl = ShX $ go topssh topl where - go :: StaticShX sh' -> [Int] -> IShX sh' - go ZKX [] = ZSX - go (SKnown sn :!% sh) (i : is) - | i == fromSNat' sn = SKnown sn :$% go sh is - | otherwise = error $ "shxFromList: Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is - go _ _ = error $ "shxFromList: Mismatched list length (type says " - ++ show (ssxLength topssh) ++ ", list has length " - ++ show (length topl) ++ ")" + go :: ListH sh' () -> [Int] -> ListH sh' Int + go ZH [] = ZH + go ZH _ = error $ "shxFromList: List too long (type says " ++ show (listhLength topssh) ++ ")" + go (ConsKnown sn sh) (i : is) + | i == fromSNat' sn = ConsKnown sn (go sh is) + | otherwise = error $ "shxFromList: Value does not match typing" + go (ConsUnknown () sh) (i : is) = ConsUnknown i (go sh is) + go _ _ = error $ "shxFromList: List too short (type says " ++ show (listhLength topssh) ++ ")" {-# INLINEABLE shxToList #-} shxToList :: IShX sh -> [Int] -shxToList list = build (\(cons :: i -> is -> is) (nil :: is) -> - let go :: IShX sh -> is - go ZSX = nil - go (smn :$% sh) = fromSMayNat' smn `cons` go sh - in go list) +shxToList (ShX l) = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ListH sh Int -> is + go ZH = nil + go (ConsUnknown i rest) = i `cons` go rest + go (ConsKnown sn rest) = fromSNat' sn `cons` go rest + in go l) +-- If it ever matters for performance, this is unsafeCoercible. shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i shxFromSSX ZKX = ZSX shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh)) @@ -447,35 +592,40 @@ shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh shxFromSSX2 (SUnknown _ :!% _) = Nothing shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) +shxAppend = coerce (listhAppend @_ @i) -shxHead :: ShX (n : sh) i -> SMayNat i SNat n -shxHead (ShX list) = listxHead list +shxHead :: ShX (n : sh) i -> SMayNat i n +shxHead (ShX list) = listhHead list shxTail :: ShX (n : sh) i -> ShX sh i -shxTail (ShX list) = ShX (listxTail list) +shxTail (ShX list) = ShX (listhTail list) + +shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSSX _ ZKX _ = ZSX +shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh + +shxTakeSh :: forall sh sh' i proxy. proxy sh' -> ShX sh i -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSh _ ZSX _ = ZSX +shxTakeSh p (_ :$% ssh1) (n :$% sh) = n :$% shxTakeSh p ssh1 sh shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) +shxDropSSX = coerce (listhDrop @i @()) shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i -shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) +shxDropIx ZIX long = long +shxDropIx (_ :.% short) long = case long of _ :$% long' -> shxDropIx short long' shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) +shxDropSh = coerce (listhDrop @i @i) shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i -shxInit = coerce (listxInit @(SMayNat i SNat)) +shxInit = coerce (listhInit @i) -shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) -shxLast = coerce (listxLast @(SMayNat i SNat)) - -shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i -shxTakeSSX _ ZKX _ = ZSX -shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh +shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh)) +shxLast = coerce (listhLast @i) {-# INLINE shxZipWith #-} -shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) +shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n) -> ShX sh i -> ShX sh j -> ShX sh k shxZipWith _ ZSX ZSX = ZSX shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js @@ -490,22 +640,6 @@ shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX shxSplitApp _ ZKX idx = (ZSX, idx) shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) -shxEnum :: IShX sh -> [IIxX sh] -shxEnum = shxEnum' - -{-# INLINABLE shxEnum' #-} -- ensure this can be specialised at use site -shxEnum' :: Num i => IShX sh -> [IxX sh i] -shxEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shxSize sh - 1]] - where - suffixes = drop 1 (scanr (*) 1 (shxToList sh)) - - fromLin :: Num i => IShX sh -> [Int] -> Int# -> IxX sh i - fromLin ZSX _ _ = ZIX - fromLin (_ :$% sh') (I# suff# : suffs) i# = - let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh' - in fromIntegral (I# q#) :.% fromLin sh' suffs r# - fromLin _ _ _ = error "impossible" - shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh') shxCast ZKX ZSX = Just ZSX shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh @@ -523,20 +657,24 @@ shxCast' ssh sh = case shxCast ssh sh of -- * Static mixed shapes --- | The part of a shape that is statically known. (A newtype over 'ListX'.) +-- | The part of a shape that is statically known. (A newtype over 'ListH'.) type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) - deriving (Eq, Ord) +newtype StaticShX sh = StaticShX (ListH sh ()) + deriving (NFData) + +instance Eq (StaticShX sh) where _ == _ = True +instance Ord (StaticShX sh) where compare _ _ = EQ pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh -pattern ZKX = StaticShX ZX +pattern ZKX = StaticShX ZH pattern (:!%) :: forall {sh1}. forall n sh. (n : sh ~ sh1) - => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 -pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) - where i :!% StaticShX shl = StaticShX (i ::% shl) + => SMayNat () n -> StaticShX sh -> StaticShX sh1 +pattern i :!% shl <- StaticShX (listhUncons -> Just (UnconsListHRes (StaticShX -> shl) i)) + where i :!% StaticShX shl = case i of; SUnknown () -> StaticShX (() `ConsUnknown` shl); SKnown x -> StaticShX (x `ConsKnown` shl) + infixr 3 :!% {-# COMPLETE ZKX, (:!%) #-} @@ -545,51 +683,50 @@ infixr 3 :!% deriving instance Show (StaticShX sh) #else instance Show (StaticShX sh) where - showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + showsPrec _ (StaticShX l) = listhShow (fromSMayNat shows (shows . fromSNat)) l #endif -instance NFData (StaticShX sh) where - rnf (StaticShX ZX) = () - rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l) - rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l) - instance TestEquality StaticShX where - testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2 + testEquality (StaticShX l1) (StaticShX l2) = listhEqType l1 l2 ssxLength :: StaticShX sh -> Int -ssxLength (StaticShX l) = listxLength l +ssxLength (StaticShX l) = listhLength l ssxRank :: StaticShX sh -> SNat (Rank sh) -ssxRank (StaticShX l) = listxRank l +ssxRank (StaticShX l) = listhRank l -- | @ssxEqType = 'testEquality'@. Provided for consistency. ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') ssxEqType = testEquality ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') -ssxAppend ZKX sh' = sh' -ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' +ssxAppend = coerce (listhAppend @_ @()) -ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n -ssxHead (StaticShX list) = listxHead list +ssxHead :: StaticShX (n : sh) -> SMayNat () n +ssxHead (StaticShX list) = listhHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh -ssxTail (_ :!% ssh) = ssh +ssxTail (StaticShX list) = StaticShX (listhTail list) -ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat)) +ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh +ssxTakeIx _ (IxX ZX) _ = ZKX +ssxTakeIx proxy (IxX (_ ::% long)) short = case short of i :!% short' -> i :!% ssxTakeIx proxy (IxX long) short' ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) +ssxDropIx (IxX ZX) long = long +ssxDropIx (IxX (_ ::% short)) long = case long of _ :!% long' -> ssxDropIx (IxX short) long' ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat)) +ssxDropSh = coerce (listhDrop @() @i) + +ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSSX = coerce (listhDrop @() @()) ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) -ssxInit = coerce (listxInit @(SMayNat () SNat)) +ssxInit = coerce (listhInit @()) -ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) -ssxLast = coerce (listxLast @(SMayNat () SNat)) +ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh)) +ssxLast = coerce (listhLast @()) ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX @@ -599,7 +736,7 @@ ssxReplicate (SS (n :: SNat n')) ssxIotaFrom :: StaticShX sh -> Int -> [Int] ssxIotaFrom ZKX _ = [] -ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i+1) +ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i + 1) ssxFromShX :: ShX sh i -> StaticShX sh ssxFromShX ZSX = ZKX @@ -632,18 +769,18 @@ type family Flatten' acc sh where Flatten' acc (Just n : sh) = Flatten' (acc * n) sh -- This function is currently unused -ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) +ssxFlatten :: StaticShX sh -> SMayNat () (Flatten sh) ssxFlatten = go (SNat @1) where - go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) + go :: SNat acc -> StaticShX sh -> SMayNat () (Flatten' acc sh) go acc ZKX = SKnown acc go _ (SUnknown () :!% _) = SUnknown () go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh -shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) +shxFlatten :: IShX sh -> SMayNat Int (Flatten sh) shxFlatten = go (SNat @1) where - go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go :: SNat acc -> IShX sh -> SMayNat Int (Flatten' acc sh) go acc ZSX = SKnown acc go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh @@ -655,8 +792,8 @@ shxFlatten = go (SNat @1) -- | Very untyped: only length is checked (at runtime). -instance KnownShX sh => IsList (ListX sh (Const i)) where - type Item (ListX sh (Const i)) = i +instance KnownShX sh => IsList (ListX sh i) where + type Item (ListX sh i) = i fromList = listxFromList (knownShX @sh) toList = listxToList @@ -667,12 +804,7 @@ instance KnownShX sh => IsList (IxX sh i) where toList = Foldable.toList -- | Untyped: length and known dimensions are checked (at runtime). -instance KnownShX sh => IsList (ShX sh Int) where - type Item (ShX sh Int) = Int +instance KnownShX sh => IsList (IShX sh) where + type Item (IShX sh) = Int fromList = shxFromList (knownShX @sh) toList = shxToList - --- This needs to be at the bottom of the file to not split the file into --- pieces; some of the shape/index stuff refers to StaticShX. -$(ixFromLinearStub "ixxFromLinear" [t| IShX |] [t| IxX |] [p| ZSX |] (\a b -> [p| (fromSMayNat' -> $a) :$% $b |]) [| ZIX |] [| (:.%) |] [| shxToList |]) -{-# INLINEABLE ixxFromLinear #-} diff --git a/src/Data/Array/Nested/Mixed/Shape/Internal.hs b/src/Data/Array/Nested/Mixed/Shape/Internal.hs deleted file mode 100644 index 2a86ac1..0000000 --- a/src/Data/Array/Nested/Mixed/Shape/Internal.hs +++ /dev/null @@ -1,59 +0,0 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -module Data.Array.Nested.Mixed.Shape.Internal where - -import Language.Haskell.TH - - --- | A TH stub function to avoid having to write the same code three times for --- the three kinds of shapes. -ixFromLinearStub :: String - -> TypeQ -> TypeQ - -> PatQ -> (PatQ -> PatQ -> PatQ) - -> ExpQ -> ExpQ - -> ExpQ - -> DecsQ -ixFromLinearStub fname' ishty ixty zshC consshC ixz ixcons shtolist = do - let fname = mkName fname' - typesig <- [t| forall i sh. Num i => $ishty sh -> Int -> $ixty sh i |] - - locals <- [d| - -- Unfold first iteration of fromLin to do the range check. - -- Don't inline this function at first to allow GHC to inline the outer - -- function and realise that 'suffixes' is shared. But then later inline it - -- anyway, to avoid the function call. Removing the pragma makes GHC - -- somehow unable to recognise that 'suffixes' can be shared in a loop. - {-# NOINLINE [0] fromLin0 #-} - fromLin0 :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i - fromLin0 sh suffixes i = - if i < 0 then outrange sh i else - case (sh, suffixes) of - ($zshC, _) | i > 0 -> outrange sh i - | otherwise -> $ixz - ($(consshC (varP (mkName "n")) (varP (mkName "sh'"))), suff : suffs) -> - let (q, r) = i `quotRem` suff - in if q >= n then outrange sh i else - $ixcons (fromIntegral q) (fromLin sh' suffs r) - _ -> error "impossible" - - fromLin :: Num i => $ishty sh -> [Int] -> Int -> $ixty sh i - fromLin $zshC _ !_ = $ixz - fromLin ($(consshC wildP (varP (mkName "sh'")))) (suff : suffs) i = - let (q, r) = i `quotRem` suff -- suff == shrSize sh' - in $ixcons (fromIntegral q) (fromLin sh' suffs r) - fromLin _ _ _ = error "impossible" - - {-# NOINLINE outrange #-} - outrange :: $ishty sh -> Int -> a - outrange sh i = error $ fname' ++ ": out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" |] - - body <- [| - \sh -> -- give this function arity 1 so that 'suffixes' is shared when - -- it's called many times - let suffixes = drop 1 (scanr (*) 1 ($shtolist sh)) - in fromLin0 sh suffixes |] - - return [SigD fname typesig - ,FunD fname [Clause [] (NormalB body) locals]] diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index 065c9fd..ecdb06d 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -18,7 +18,6 @@ module Data.Array.Nested.Permutation where import Data.Coerce (coerce) -import Data.Functor.Const import Data.List (sort) import Data.Maybe (fromMaybe) import Data.Proxy @@ -172,52 +171,95 @@ type family DropLen ref l where DropLen '[] l = l DropLen (_ : ref) (_ : xs) = DropLen ref xs -listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f -listxTakeLen PNil _ = ZX -listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh -listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape" - -listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f -listxDropLen PNil sh = sh -listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh -listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape" +listhTakeLen :: forall i is sh. Perm is -> ListH sh i -> ListH (TakeLen is sh) i +listhTakeLen PNil _ = ZH +listhTakeLen (_ `PCons` is) (n `ConsUnknown` sh) = n `ConsUnknown` listhTakeLen is sh +listhTakeLen (_ `PCons` is) (n `ConsKnown` sh) = n `ConsKnown` listhTakeLen is sh +listhTakeLen (_ `PCons` _) ZH = error "Permutation longer than shape" -listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f -listxPermute PNil _ = ZX -listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = - listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh +listhDropLen :: forall i is sh. Perm is -> ListH sh i -> ListH (DropLen is sh) i +listhDropLen PNil sh = sh +listhDropLen (_ `PCons` is) (_ `ConsUnknown` sh) = listhDropLen is sh +listhDropLen (_ `PCons` is) (_ `ConsKnown` sh) = listhDropLen is sh +listhDropLen (_ `PCons` _) ZH = error "Permutation longer than shape" -listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) -listxIndex _ _ SZ (n ::% _) = n -listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listxIndex p pT i sh -listxIndex _ _ _ ZX = error "Index into empty shape" +listhPermute :: forall i is sh. Perm is -> ListH sh i -> ListH (Permute is sh) i +listhPermute PNil _ = ZH +listhPermute (i `PCons` (is :: Perm is')) (sh :: ListH sh i) = + case listhIndex i sh of + SUnknown x -> x `ConsUnknown` listhPermute is sh + SKnown x -> x `ConsKnown` listhPermute is sh -listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f -listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) +listhIndex :: forall i k sh. SNat k -> ListH sh i -> SMayNat i (Index k sh) +listhIndex SZ (n `ConsUnknown` _) = SUnknown n +listhIndex SZ (n `ConsKnown` _) = SKnown n +listhIndex (SS (i :: SNat k')) ((_ :: i) `ConsUnknown` (sh :: ListH sh' i)) + | Refl <- lemIndexSucc (Proxy @k') (Proxy @Nothing) (Proxy @sh') + = listhIndex i sh +listhIndex (SS (i :: SNat k')) ((_ :: SNat n) `ConsKnown` (sh :: ListH sh' i)) + | Refl <- lemIndexSucc (Proxy @k') (Proxy @(Just n)) (Proxy @sh') + = listhIndex i sh +listhIndex _ ZH = error "Index into empty shape" -ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i -ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) +listhPermutePrefix :: forall i is sh. Perm is -> ListH sh i -> ListH (PermutePrefix is sh) i +listhPermutePrefix perm sh = listhAppend (listhPermute perm (listhTakeLen perm sh)) (listhDropLen perm sh) ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) +ssxTakeLen = coerce (listhTakeLen @()) ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) +ssxDropLen = coerce (listhDropLen @()) ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute = coerce (listxPermute @(SMayNat () SNat)) +ssxPermute = coerce (listhPermute @()) -ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) -ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i) +ssxIndex :: SNat k -> StaticShX sh -> SMayNat () (Index k sh) +ssxIndex k = coerce (listhIndex @() k) ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) -ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) +ssxPermutePrefix = coerce (listhPermutePrefix @()) + +shxTakeLen :: forall is sh. Perm is -> IShX sh -> IShX (TakeLen is sh) +shxTakeLen = coerce (listhTakeLen @Int) + +shxDropLen :: Perm is -> IShX sh -> IShX (DropLen is sh) +shxDropLen = coerce (listhDropLen @Int) + +shxPermute :: Perm is -> IShX sh -> IShX (Permute is sh) +shxPermute = coerce (listhPermute @Int) + +shxIndex :: forall k sh i. SNat k -> ShX sh i -> SMayNat i (Index k sh) +shxIndex k = coerce (listhIndex @i k) shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) -shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) +shxPermutePrefix = coerce (listhPermutePrefix @Int) + +listxTakeLen :: forall i is sh. Perm is -> ListX sh i -> ListX (TakeLen is sh) i +listxTakeLen PNil _ = ZX +listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh +listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape" + +listxDropLen :: forall i is sh. Perm is -> ListX sh i -> ListX (DropLen is sh) i +listxDropLen PNil sh = sh +listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh +listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape" + +listxPermute :: forall i is sh. Perm is -> ListX sh i -> ListX (Permute is sh) i +listxPermute PNil _ = ZX +listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = + listxIndex i sh ::% listxPermute is sh + +listxIndex :: forall j i sh. SNat i -> ListX sh j -> j +listxIndex SZ (n ::% _) = n +listxIndex (SS i) (_ ::% sh) = listxIndex i sh +listxIndex _ ZX = error "Index into empty shape" + +listxPermutePrefix :: forall i is sh. Perm is -> ListX sh i -> ListX (PermutePrefix is sh) i +listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) + +ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i +ixxPermutePrefix = coerce (listxPermutePrefix @i) -- * Operations on permutations diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index d687983..b448685 100644 --- a/src/Data/Array/Nested/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -79,6 +79,7 @@ rgeneratePrim sh f = in rfromVector sh $ VS.generate (shrSize sh) g -- | See the documentation of 'mlift'. +{-# INLINE rlift #-} rlift :: forall n1 n2 a. Elt a => SNat n2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) @@ -86,12 +87,14 @@ rlift :: forall n1 n2 a. Elt a rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) -- | See the documentation of 'mlift2'. +{-# INLINE rlift2 #-} rlift2 :: forall n1 n2 n3 a. Elt a => SNat n3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) +{-# INLINE rsumOuter1PrimP #-} rsumOuter1PrimP :: forall n a. (Storable a, NumElt a) => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) @@ -99,13 +102,16 @@ rsumOuter1PrimP (Ranked arr) | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) = Ranked (msumOuter1PrimP arr) +{-# INLINEABLE rsumOuter1Prim #-} rsumOuter1Prim :: forall n a. (NumElt a, PrimElt a) => Ranked (n + 1) a -> Ranked n a rsumOuter1Prim = rfromPrimitive . rsumOuter1PrimP . rtoPrimitive +{-# INLINE rsumAllPrimP #-} rsumAllPrimP :: (Storable a, NumElt a) => Ranked n (Primitive a) -> a rsumAllPrimP (Ranked arr) = msumAllPrimP arr +{-# INLINE rsumAllPrim #-} rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a rsumAllPrim (Ranked arr) = msumAllPrim arr @@ -137,17 +143,21 @@ rappend arr1 arr2 rscalar :: Elt a => a -> Ranked 0 a rscalar x = Ranked (mscalar x) +{-# INLINEABLE rfromVectorP #-} rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) rfromVectorP sh v | Dict <- lemKnownReplicate (shrRank sh) = Ranked (mfromVectorP (shxFromShR sh) v) +{-# INLINEABLE rfromVector #-} rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a rfromVector sh v = rfromPrimitive (rfromVectorP sh v) +{-# INLINEABLE rtoVectorP #-} rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a rtoVectorP = coerce mtoVectorP +{-# INLINEABLE rtoVector #-} rtoVector :: PrimElt a => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector @@ -220,7 +230,7 @@ rfromOrthotope sn arr rtoOrthotope :: forall a n. PrimElt a => Ranked n a -> S.Array n a rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) - | Refl <- lemRankReplicate (shrRank $ shrFromShX2 @n sh) + | Refl <- lemRankReplicate (shrRank $ shrFromShX @n sh) = arr runScalar :: Elt a => Ranked 0 a -> a @@ -333,6 +343,7 @@ rmaxIndexPrim rarr@(Ranked arr) | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) = ixrFromIxX (mmaxIndexPrim arr) +{-# INLINEABLE rdot1Inner #-} rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a rdot1Inner arr1 arr2 | SNat <- rrank arr1 @@ -341,14 +352,15 @@ rdot1Inner arr1 arr2 -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. -- Prefer 'rdot1Inner' if applicable. +{-# INLINE rdot #-} rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a rdot = coerce mdot rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrimP (Ranked arr) = first shrFromShX2 (mtoXArrayPrimP arr) +rtoXArrayPrimP (Ranked arr) = first shrFromShX (mtoXArrayPrimP arr) rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrim (Ranked arr) = first shrFromShX2 (mtoXArrayPrim arr) +rtoXArrayPrim (Ranked arr) = first shrFromShX (mtoXArrayPrim arr) rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr) diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs index 11a8ffb..beedbcf 100644 --- a/src/Data/Array/Nested/Ranked/Base.hs +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -26,16 +26,11 @@ import Data.Coerce (coerce) import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy -import Data.Type.Equality import Foreign.Storable (Storable) import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits -#ifndef OXAR_DEFAULT_SHOW_INSTANCES -import Data.Foldable (toList) -#endif - import Data.Array.Nested.Lemmas import Data.Array.Nested.Mixed import Data.Array.Nested.Mixed.Shape @@ -65,7 +60,7 @@ deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) #ifndef OXAR_DEFAULT_SHOW_INSTANCES instance (Show a, Elt a) => Show (Ranked n a) where showsPrec d arr@(Ranked marr) = - let sh = show (toList (rshape arr)) + let sh = show (shrToList (rshape arr)) in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr #endif @@ -87,9 +82,12 @@ newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed -- these instances allow them to also be used as elements of arrays, thus -- making them first-class in the API. instance Elt a => Elt (Ranked n a) where + {-# INLINE mshape #-} mshape (M_Ranked arr) = mshape arr + {-# INLINE mindex #-} mindex (M_Ranked arr) i = Ranked (mindex arr i) + {-# INLINE mindexPartial #-} mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) mindexPartial (M_Ranked arr) i = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ @@ -104,6 +102,7 @@ instance Elt a => Elt (Ranked n a) where mtoListOuter (M_Ranked arr) = coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) @@ -112,6 +111,7 @@ instance Elt a => Elt (Ranked n a) where coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ mlift ssh2 f arr + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) @@ -120,6 +120,7 @@ instance Elt a => Elt (Ranked n a) where coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ mlift2 ssh3 f arr1 arr2 + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) @@ -139,7 +140,7 @@ instance Elt a => Elt (Ranked n a) where type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) - mshapeTree (Ranked arr) = first shrFromShX2 (mshapeTree arr) + mshapeTree (Ranked arr) = first coerce (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -149,18 +150,19 @@ instance Elt a => Elt (Ranked n a) where marrayStrides (M_Ranked arr) = marrayStrides arr - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs = - mvecsWrite sh idx arr + mvecsWriteLinear :: forall sh s. Int -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWriteLinear idx (Ranked arr) vecs = + mvecsWriteLinear idx arr (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) - mvecsWritePartial :: forall sh sh' s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx + mvecsWritePartialLinear + :: forall sh sh' s. + Proxy sh -> Int -> Mixed sh' (Ranked n a) + -> MixedVecs s (sh ++ sh') (Ranked n a) + -> ST s () + mvecsWritePartialLinear proxy idx arr vecs = + mvecsWritePartialLinear proxy idx (coerce @(Mixed sh' (Ranked n a)) @(Mixed sh' (Mixed (Replicate n Nothing) a)) arr) @@ -176,6 +178,14 @@ instance Elt a => Elt (Ranked n a) where (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) vecs) + mvecsUnsafeFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsUnsafeFreeze sh vecs = + coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) + @(Mixed sh (Ranked n a)) + <$> mvecsUnsafeFreeze sh + (coerce @(MixedVecs s sh (Ranked n a)) + @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) @@ -188,6 +198,10 @@ instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where | Dict <- lemKnownReplicate (SNat @n) = MV_Ranked <$> mvecsUnsafeNew idx arr + mvecsReplicate idx (Ranked arr) + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsReplicate idx arr + mvecsNewEmpty _ | Dict <- lemKnownReplicate (SNat @n) = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) @@ -249,20 +263,9 @@ ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a ratan2Array = liftRanked2 matan2Array +{-# INLINE rshape #-} rshape :: Elt a => Ranked n a -> IShR n -rshape (Ranked arr) = shrFromShX2 (mshape arr) +rshape (Ranked arr) = coerce (mshape arr) rrank :: Elt a => Ranked n a -> SNat n rrank = shrRank . rshape - --- Needed already here, but re-exported in Data.Array.Nested.Convert. -shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) -shrFromShX ZSX = ZSR -shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx - --- Needed already here, but re-exported in Data.Array.Nested.Convert. --- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. -shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n -shrFromShX2 sh - | Refl <- lemRankReplicate (Proxy @n) - = shrFromShX sh diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs index 6d61bd5..6d47ade 100644 --- a/src/Data/Array/Nested/Ranked/Shape.hs +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -1,8 +1,5 @@ -{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} @@ -36,15 +33,16 @@ import Data.Foldable qualified as Foldable import Data.Kind (Type) import Data.Proxy import Data.Type.Equality -import GHC.Exts (Int(..), Int#, quotRemInt#, build) -import GHC.Generics (Generic) +import GHC.Exts (build) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import GHC.TypeNats qualified as TN +import Unsafe.Coerce (unsafeCoerce) import Data.Array.Nested.Lemmas -import Data.Array.Nested.Mixed.Shape.Internal +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation import Data.Array.Nested.Types @@ -183,7 +181,12 @@ listrZipWith f (i ::: irest) (j ::: jrest) = listrZipWith _ _ _ = error "listrZipWith: impossible pattern needlessly required" -listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i +listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) +listrSplitAt SZ sh = (ZR, sh) +listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) +listrSplitAt SS{} ZR = error "m' + 1 <= 0" + +listrPermutePrefix :: forall i n. PermR -> ListR n i -> ListR n i listrPermutePrefix = \perm sh -> TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> case listrRank sh of { shlen@SNat -> @@ -195,11 +198,6 @@ listrPermutePrefix = \perm sh -> ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" } where - listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) - listrSplitAt SZ sh = (ZR, sh) - listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) - listrSplitAt SS{} ZR = error "m' + 1 <= 0" - applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i applyPermRFull _ ZR _ = ZR applyPermRFull sm@SNat (i ::: perm) l = @@ -216,8 +214,7 @@ listrPermutePrefix = \perm sh -> type role IxR nominal representational type IxR :: Nat -> Type -> Type newtype IxR n i = IxR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) + deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIR :: forall n i. () => n ~ 0 => IxR n i pattern ZIR = IxR ZR @@ -243,8 +240,6 @@ instance Show i => Show (IxR n i) where showsPrec _ (IxR l) = listrShow shows l #endif -instance NFData i => NFData (IxR sh i) - ixrLength :: IxR sh i -> Int ixrLength (IxR l) = listrLength l @@ -255,12 +250,12 @@ ixrZero :: SNat n -> IIxR n ixrZero SZ = ZIR ixrZero (SS n) = 0 :.: ixrZero n +{-# INLINEABLE ixrFromList #-} ixrFromList :: forall n i. SNat n -> [i] -> IxR n i ixrFromList = coerce (listrFromList @_ @i) -{-# INLINEABLE ixrToList #-} -ixrToList :: forall n i. IxR n i -> [i] -ixrToList = coerce (listrToList @_ @i) +ixrToList :: IxR n i -> [i] +ixrToList = Foldable.toList ixrHead :: IxR (n + 1) i -> i ixrHead (IxR list) = listrHead list @@ -288,27 +283,69 @@ ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 -ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i +ixrPermutePrefix :: forall n i. PermR -> IxR n i -> IxR n i ixrPermutePrefix = coerce (listrPermutePrefix @i) +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixrToLinear #-} +ixrToLinear :: Num i => IShR m -> IxR m i -> i +ixrToLinear (ShR sh) ix = ixxToLinear sh (ixxFromIxR ix) + +ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i +ixxFromIxR = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled + +{-# INLINEABLE ixrFromLinear #-} +ixrFromLinear :: forall i m. Num i => IShR m -> Int -> IxR m i +ixrFromLinear (ShR sh) i + | Refl <- lemRankReplicate (Proxy @m) + = ixrFromIxX $ ixxFromLinear sh i + +ixrFromIxX :: IxX sh i -> IxR (Rank sh) i +ixrFromIxX = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled + +shrEnum :: IShR n -> [IIxR n] +shrEnum = shrEnum' + +{-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site +shrEnum' :: forall i n. Num i => IShR n -> [IxR n i] +shrEnum' (ShR sh) + | Refl <- lemRankReplicate (Proxy @n) + = (unsafeCoerce :: [IxX (Replicate n Nothing) i] -> [IxR n i]) $ shxEnum' sh + -- TODO: switch to coerce once newtypes overhauled -- * Ranked shapes type role ShR nominal representational type ShR :: Nat -> Type -> Type -newtype ShR n i = ShR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) +newtype ShR n i = ShR (ShX (Replicate n Nothing) i) + deriving (Eq, Ord, NFData, Functor) pattern ZSR :: forall n i. () => n ~ 0 => ShR n i -pattern ZSR = ShR ZR +pattern ZSR <- ShR (matchZSR @n -> Just Refl) + where ZSR = ShR ZSX + +matchZSR :: forall n i. ShX (Replicate n Nothing) i -> Maybe (n :~: 0) +matchZSR ZSX | Refl <- lemReplicateEmpty (Proxy @n) Refl = Just Refl +matchZSR _ = Nothing pattern (:$:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) - where i :$: ShR sh = ShR (i ::: sh) +pattern i :$: shl <- (shrUncons -> Just (UnconsShRRes shl i)) + where i :$: ShR shl | Refl <- lemReplicateSucc2 (Proxy @n1) Refl + = ShR (SUnknown i :$% shl) + +data UnconsShRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsShRRes (ShR n i) i +shrUncons :: forall n1 i. ShR n1 i -> Maybe (UnconsShRRes i n1) +shrUncons (ShR (SUnknown x :$% (sh' :: ShX sh' i))) + | Refl <- lemReplicateCons (Proxy @sh') (Proxy @n1) Refl + , Refl <- lemReplicateCons2 (Proxy @sh') (Proxy @n1) Refl + = Just (UnconsShRRes (ShR sh') x) +shrUncons (ShR _) = Nothing + infixr 3 :$: {-# COMPLETE ZSR, (:$:) #-} @@ -319,85 +356,140 @@ type IShR n = ShR n Int deriving instance Show i => Show (ShR n i) #else instance Show i => Show (ShR n i) where - showsPrec _ (ShR l) = listrShow shows l + showsPrec d (ShR l) = showsPrec d l #endif -instance NFData i => NFData (ShR sh i) - -- | This checks only whether the ranks are equal, not whether the actual -- values are. shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' +shrEqRank ZSR ZSR = Just Refl +shrEqRank (_ :$: sh) (_ :$: sh') + | Just Refl <- shrEqRank sh sh' + = Just Refl +shrEqRank _ _ = Nothing -- | This compares the shapes for value equality. shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' +shrEqual ZSR ZSR = Just Refl +shrEqual (i :$: sh) (i' :$: sh') + | Just Refl <- shrEqual sh sh' + , i == i' + = Just Refl +shrEqual _ _ = Nothing shrLength :: ShR sh i -> Int -shrLength (ShR l) = listrLength l +shrLength (ShR l) = shxLength l -- | This function can also be used to conjure up a 'KnownNat' dictionary; -- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern -- synonym yields 'KnownNat' evidence. -shrRank :: ShR n i -> SNat n -shrRank (ShR sh) = listrRank sh +shrRank :: forall n i. ShR n i -> SNat n +shrRank (ShR sh) | Refl <- lemRankReplicate (Proxy @n) = shxRank sh -- | The number of elements in an array described by this shape. shrSize :: IShR n -> Int -shrSize ZSR = 1 -shrSize (n :$: sh) = n * shrSize sh +shrSize (ShR sh) = shxSize sh -shrFromList :: forall n i. SNat n -> [i] -> ShR n i -shrFromList = coerce (listrFromList @_ @i) +-- This is equivalent to but faster than @coerce (shxFromList (ssxReplicate snat))@. +-- We don't report the size of the list in case of errors in order not to retain the list. +{-# INLINEABLE shrFromList #-} +shrFromList :: SNat n -> [Int] -> IShR n +shrFromList snat topl = ShR $ ShX $ go snat topl + where + go :: SNat n -> [Int] -> ListH (Replicate n Nothing) Int + go SZ [] = ZH + go SZ _ = error $ "shrFromList: List too long (type says " ++ show (fromSNat' snat) ++ ")" + go (SS sn :: SNat n1) (i : is) | Refl <- lemReplicateSucc2 (Proxy @n1) Refl = ConsUnknown i (go sn is) + go _ _ = error $ "shrFromList: List too short (type says " ++ show (fromSNat' snat) ++ ")" +-- This is equivalent to but faster than @coerce shxToList@. {-# INLINEABLE shrToList #-} -shrToList :: forall n i. ShR n i -> [i] -shrToList = coerce (listrToList @_ @i) +shrToList :: IShR n -> [Int] +shrToList (ShR (ShX l)) = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ListH sh Int -> is + go ZH = nil + go (ConsUnknown i rest) = i `cons` go rest + go ConsKnown{} = error "shrToList: impossible case" + in go l) -shrHead :: ShR (n + 1) i -> i -shrHead (ShR list) = listrHead list +shrHead :: forall n i. ShR (n + 1) i -> i +shrHead (ShR sh) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = case shxHead @Nothing @(Replicate n Nothing) sh of + SUnknown i -> i -shrTail :: ShR (n + 1) i -> ShR n i -shrTail (ShR list) = ShR (listrTail list) +shrTail :: forall n i. ShR (n + 1) i -> ShR n i +shrTail + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = coerce (shxTail @_ @_ @i) -shrInit :: ShR (n + 1) i -> ShR n i -shrInit (ShR list) = ShR (listrInit list) +shrInit :: forall n i. ShR (n + 1) i -> ShR n i +shrInit + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = -- TODO: change this and all other unsafeCoerceRefl to lemmas: + gcastWith (unsafeCoerceRefl + :: Init (Replicate (n + 1) (Nothing @Nat)) :~: Replicate n Nothing) $ + coerce (shxInit @_ @_ @i) -shrLast :: ShR (n + 1) i -> i -shrLast (ShR list) = listrLast list +shrLast :: forall n i. ShR (n + 1) i -> i +shrLast (ShR sh) + | Refl <- lemReplicateSucc @(Nothing @Nat) (Proxy @n) + = case shxLast sh of + SUnknown i -> i + SKnown{} -> error "shrLast: impossible SKnown" -- | Performs a runtime check that the lengths are identical. shrCast :: SNat n' -> ShR n i -> ShR n' i -shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh) +shrCast SZ ZSR = ZSR +shrCast (SS n) (i :$: sh) = i :$: shrCast n sh +shrCast _ _ = error "shrCast: ranks don't match" shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i -shrAppend = coerce (listrAppend @_ @i) - -shrZip :: ShR n i -> ShR n j -> ShR n (i, j) -shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 +shrAppend = + -- lemReplicatePlusApp requires an SNat + gcastWith (unsafeCoerceRefl + :: Replicate n (Nothing @Nat) ++ Replicate m Nothing :~: Replicate (n + m) Nothing) $ + coerce (shxAppend @_ @_ @i) {-# INLINE shrZipWith #-} shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k -shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 +shrZipWith _ ZSR ZSR = ZSR +shrZipWith f (i :$: irest) (j :$: jrest) = + f i j :$: shrZipWith f irest jrest +shrZipWith _ _ _ = + error "shrZipWith: impossible pattern needlessly required" -shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i -shrPermutePrefix = coerce (listrPermutePrefix @i) +shrSplitAt :: m <= n' => SNat m -> ShR n' i -> (ShR m i, ShR (n' - m) i) +shrSplitAt SZ sh = (ZSR, sh) +shrSplitAt (SS m) (n :$: sh) = (\(pre, post) -> (n :$: pre, post)) (shrSplitAt m sh) +shrSplitAt SS{} ZSR = error "m' + 1 <= 0" -shrEnum :: IShR sh -> [IIxR sh] -shrEnum = shrEnum' +shrIndex :: forall k sh i. SNat k -> ShR sh i -> i +shrIndex k (ShR sh) = case shxIndex @_ @_ @i k sh of + SUnknown i -> i + SKnown{} -> error "shrIndex: impossible SKnown" -{-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site -shrEnum' :: Num i => IShR sh -> [IxR sh i] -shrEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shrSize sh - 1]] +-- Copy-pasted from listrPermutePrefix, probably unavoidably. +shrPermutePrefix :: forall i n. PermR -> ShR n i -> ShR n i +shrPermutePrefix = \perm sh -> + TN.withSomeSNat (fromIntegral (length perm)) $ \permlen@SNat -> + case shrRank sh of { shlen@SNat -> + let sperm = shrFromList permlen perm in + case cmpNat permlen shlen of + LTI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = shrSplitAt permlen sh in shrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + } where - suffixes = drop 1 (scanr (*) 1 (shrToList sh)) - - fromLin :: Num i => IShR sh -> [Int] -> Int# -> IxR sh i - fromLin ZSR _ _ = ZIR - fromLin (_ :$: sh') (I# suff# : suffs) i# = - let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh' - in fromIntegral (I# q#) :.: fromLin sh' suffs r# - fromLin _ _ _ = error "impossible" + applyPermRFull :: SNat m -> ShR k Int -> ShR m i -> ShR k i + applyPermRFull _ ZSR _ = ZSR + applyPermRFull sm@SNat (i :$: perm) l = + TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> + case cmpNat (SNat @(idx + 1)) sm of + LTI -> shrIndex si l :$: applyPermRFull sm perm l + EQI -> shrIndex si l :$: applyPermRFull sm perm l + GTI -> error "shrPermutePrefix: Index in permutation out of range" -- | Untyped: length is checked at runtime. @@ -413,18 +505,15 @@ instance KnownNat n => IsList (IxR n i) where toList = Foldable.toList -- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ShR n i) where - type Item (ShR n i) = i - fromList = ShR . IsList.fromList - toList = Foldable.toList +instance KnownNat n => IsList (IShR n) where + type Item (IShR n) = Int + fromList = shrFromList (SNat @n) + toList = shrToList -- * Internal helper functions listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i listrCastWithName _ SZ ZR = ZR -listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx +listrCastWithName name (SS n) (i ::: l) = i ::: listrCastWithName name n l listrCastWithName name _ _ = error $ name ++ ": ranks don't match" - -$(ixFromLinearStub "ixrFromLinear" [t| IShR |] [t| IxR |] [p| ZSR |] (\a b -> [p| $a :$: $b |]) [| ZIR |] [| (:.:) |] [| shrToList |]) -{-# INLINEABLE ixrFromLinear #-} diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 99ad590..36ef24a 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -56,7 +56,7 @@ ssize = shsSize . sshape sindex :: Elt a => Shaped sh a -> IIxS sh -> a sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx) -shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh +shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IxS sh i -> ShS sh shsTakeIx _ _ ZIS = ZSS shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx @@ -70,7 +70,7 @@ sindexPartial sarr@(Shaped arr) idx = -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a -sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh)) +sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX)) -- | See 'mgeneratePrim'. {-# INLINE sgeneratePrim #-} @@ -81,6 +81,7 @@ sgeneratePrim sh f = in sfromVector sh $ VS.generate (shsSize sh) g -- | See the documentation of 'mlift'. +{-# INLINE slift #-} slift :: forall sh1 sh2 a. Elt a => ShS sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) @@ -88,23 +89,28 @@ slift :: forall sh1 sh2 a. Elt a slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr) -- | See the documentation of 'mlift'. +{-# INLINE slift2 #-} slift2 :: forall sh1 sh2 sh3 a. Elt a => ShS sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2) +{-# INLINE ssumOuter1PrimP #-} ssumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) ssumOuter1PrimP (Shaped arr) = Shaped (msumOuter1PrimP arr) +{-# INLINEABLE ssumOuter1Prim #-} ssumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a) => Shaped (n : sh) a -> Shaped sh a ssumOuter1Prim = sfromPrimitive . ssumOuter1PrimP . stoPrimitive +{-# INLINE ssumAllPrimP #-} ssumAllPrimP :: (PrimElt a, NumElt a) => Shaped n (Primitive a) -> a ssumAllPrimP (Shaped arr) = msumAllPrimP arr +{-# INLINE ssumAllPrim #-} ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a ssumAllPrim (Shaped arr) = msumAllPrim arr @@ -124,15 +130,19 @@ sappend = coerce mappend sscalar :: Elt a => a -> Shaped '[] a sscalar x = Shaped (mscalar x) +{-# INLINEABLE sfromVectorP #-} sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v) +{-# INLINEABLE sfromVector #-} sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a sfromVector sh v = sfromPrimitive (sfromVectorP sh v) +{-# INLINEABLE stoVectorP #-} stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a stoVectorP = coerce mtoVectorP +{-# INLINEABLE stoVector #-} stoVector :: PrimElt a => Shaped sh a -> VS.Vector a stoVector = coerce mtoVector @@ -246,21 +256,20 @@ sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shape sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr) sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a -sflatten arr = - case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff - n@SNat -> sreshape (n :$$ ZSS) arr +sflatten arr = sreshape (shsProduct (sshape arr) :$$ ZSS) arr siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a siota sn = Shaped (miota sn) -- | Throws if the array is empty. sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr) +sminIndexPrim (Shaped arr) = ixsFromIxX (mminIndexPrim arr) -- | Throws if the array is empty. smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) +smaxIndexPrim (Shaped arr) = ixsFromIxX (mmaxIndexPrim arr) +{-# INLINEABLE sdot1Inner #-} sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) @@ -272,6 +281,7 @@ sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) _ -> error "unreachable" +{-# INLINE sdot #-} -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. -- Prefer 'sdot1Inner' if applicable. sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs index 98f1241..4b119c4 100644 --- a/src/Data/Array/Nested/Shaped/Base.hs +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -26,7 +26,6 @@ import Data.Coerce (coerce) import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty) import Data.Proxy -import Data.Type.Equality import Foreign.Storable (Storable) import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) @@ -80,9 +79,12 @@ deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) instance Elt a => Elt (Shaped sh a) where + {-# INLINE mshape #-} mshape (M_Shaped arr) = mshape arr + {-# INLINE mindex #-} mindex (M_Shaped arr) i = Shaped (mindex arr i) + {-# INLINE mindexPartial #-} mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) mindexPartial (M_Shaped arr) i = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ @@ -97,6 +99,7 @@ instance Elt a => Elt (Shaped sh a) where mtoListOuter (M_Shaped arr) = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) + {-# INLINE mlift #-} mlift :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) @@ -105,6 +108,7 @@ instance Elt a => Elt (Shaped sh a) where coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ mlift ssh2 f arr + {-# INLINE mlift2 #-} mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) @@ -113,6 +117,7 @@ instance Elt a => Elt (Shaped sh a) where coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ mlift2 ssh3 f arr1 arr2 + {-# INLINE mliftL #-} mliftL :: forall sh1 sh2. StaticShX sh2 -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) @@ -132,7 +137,7 @@ instance Elt a => Elt (Shaped sh a) where type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr) + mshapeTree (Shaped arr) = first coerce (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -142,18 +147,19 @@ instance Elt a => Elt (Shaped sh a) where marrayStrides (M_Shaped arr) = marrayStrides arr - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs = - mvecsWrite sh idx arr + mvecsWriteLinear :: forall sh' s. Int -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWriteLinear idx (Shaped arr) vecs = + mvecsWriteLinear idx arr (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx + mvecsWritePartialLinear + :: forall sh1 sh2 s. + Proxy sh1 -> Int -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartialLinear proxy idx arr vecs = + mvecsWritePartialLinear proxy idx (coerce @(Mixed sh2 (Shaped sh a)) @(Mixed sh2 (Mixed (MapJust sh) a)) arr) @@ -169,6 +175,14 @@ instance Elt a => Elt (Shaped sh a) where (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) vecs) + mvecsUnsafeFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsUnsafeFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsUnsafeFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) @@ -181,6 +195,10 @@ instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsUnsafeNew idx arr + mvecsReplicate idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsReplicate idx arr + mvecsNewEmpty _ | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) @@ -242,14 +260,6 @@ satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped s satan2Array = liftShaped2 matan2Array +{-# INLINE sshape #-} sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh -sshape (Shaped arr) = shsFromShX (mshape arr) - --- Needed already here, but re-exported in Data.Array.Nested.Convert. -shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh -shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) = - castWith (subst1 (sym (lemMapJustCons Refl))) $ - n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) - idx) -shsFromShX (SUnknown _ :$% _) = error "impossible" +sshape (Shaped arr) = coerce (mshape arr) diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 0d90e91..c5e3202 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -1,10 +1,8 @@ -{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} @@ -32,173 +30,157 @@ import Control.DeepSeq (NFData(..)) import Data.Array.Shape qualified as O import Data.Coerce (coerce) import Data.Foldable qualified as Foldable -import Data.Functor.Const -import Data.Functor.Product qualified as Fun import Data.Kind (Constraint, Type) -import Data.Monoid (Sum(..)) -import Data.Proxy import Data.Type.Equality -import GHC.Exts (Int(..), Int#, quotRemInt#, withDict, build) -import GHC.Generics (Generic) +import GHC.Exts (build, withDict) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) import Data.Array.Nested.Mixed.Shape -import Data.Array.Nested.Mixed.Shape.Internal import Data.Array.Nested.Permutation import Data.Array.Nested.Types -- * Shaped lists --- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be --- removed in a future release. type role ListS nominal representational -type ListS :: [Nat] -> (Nat -> Type) -> Type -data ListS sh f where - ZS :: ListS '[] f - -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity - (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) +type ListS :: [Nat] -> Type -> Type +data ListS sh i where + ZS :: ListS '[] i + (::$) :: forall n sh {i}. i -> ListS sh i -> ListS (n : sh) i +deriving instance Eq i => Eq (ListS sh i) +deriving instance Ord i => Ord (ListS sh i) + infixr 3 ::$ #ifdef OXAR_DEFAULT_SHOW_INSTANCES -deriving instance (forall n. Show (f n)) => Show (ListS sh f) +deriving instance Show i => Show (ListS sh i) #else -instance (forall n. Show (f n)) => Show (ListS sh f) where +instance Show i => Show (ListS sh i) where showsPrec _ = listsShow shows #endif -instance (forall m. NFData (f m)) => NFData (ListS n f) where +instance NFData i => NFData (ListS n i) where rnf ZS = () rnf (x ::$ l) = rnf x `seq` rnf l -data UnconsListSRes f sh1 = - forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) -listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) +data UnconsListSRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListSRes (ListS sh i) i +listsUncons :: ListS sh1 i -> Maybe (UnconsListSRes i sh1) listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) listsUncons ZS = Nothing --- | This checks only whether the types are equal; if the elements of the list --- are not singletons, their values may still differ. This corresponds to --- 'testEquality', except on the penultimate type parameter. -listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') -listsEqType ZS ZS = Just Refl -listsEqType (n ::$ sh) (m ::$ sh') - | Just Refl <- testEquality n m - , Just Refl <- listsEqType sh sh' - = Just Refl -listsEqType _ _ = Nothing - --- | This checks whether the two lists actually contain equal values. This is --- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ --- in the @some@ package (except on the penultimate type parameter). -listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') -listsEqual ZS ZS = Just Refl -listsEqual (n ::$ sh) (m ::$ sh') - | Just Refl <- testEquality n m - , n == m - , Just Refl <- listsEqual sh sh' - = Just Refl -listsEqual _ _ = Nothing - -{-# INLINE listsFmap #-} -listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g -listsFmap _ ZS = ZS -listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs - -{-# INLINE listsFoldMap #-} -listsFoldMap :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m -listsFoldMap _ ZS = mempty -listsFoldMap f (x ::$ xs) = f x <> listsFoldMap f xs - -listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS +listsShow :: forall sh i. (i -> ShowS) -> ListS sh i -> ShowS listsShow f l = showString "[" . go "" l . showString "]" where - go :: String -> ListS sh' f -> ShowS + go :: String -> ListS sh' i -> ShowS go _ ZS = id go prefix (x ::$ xs) = showString prefix . f x . go "," xs -listsLength :: ListS sh f -> Int -listsLength = getSum . listsFoldMap (\_ -> Sum 1) +instance Functor (ListS l) where + {-# INLINE fmap #-} + fmap _ ZS = ZS + fmap f (x ::$ xs) = f x ::$ fmap f xs + +instance Foldable (ListS l) where + {-# INLINE foldMap #-} + foldMap _ ZS = mempty + foldMap f (x ::$ xs) = f x <> foldMap f xs + {-# INLINE foldr #-} + foldr _ z ZS = z + foldr f z (x ::$ xs) = f x (foldr f z xs) + toList = listsToList + null ZS = False + null _ = True + +listsLength :: ListS sh i -> Int +listsLength = length -listsRank :: ListS sh f -> SNat (Rank sh) +listsRank :: ListS sh i -> SNat (Rank sh) listsRank ZS = SNat listsRank (_ ::$ sh) = snatSucc (listsRank sh) -listsFromList :: ShS sh -> [i] -> ListS sh (Const i) +listsFromList :: ShS sh -> [i] -> ListS sh i listsFromList topsh topl = go topsh topl where - go :: ShS sh' -> [i] -> ListS sh' (Const i) + go :: ShS sh' -> [i] -> ListS sh' i go ZSS [] = ZS - go (_ :$$ sh) (i : is) = Const i ::$ go sh is + go (_ :$$ sh) (i : is) = i ::$ go sh is go _ _ = error $ "listsFromList: Mismatched list length (type says " ++ show (shsLength topsh) ++ ", list has length " ++ show (length topl) ++ ")" +{-# INLINEABLE listsFromListS #-} +listsFromListS :: ListS sh i0 -> [i] -> ListS sh i +listsFromListS topl0 topl = go topl0 topl + where + go :: ListS sh i0 -> [i] -> ListS sh i + go ZS [] = ZS + go (_ ::$ l0) (i : is) = i ::$ go l0 is + go _ _ = error $ "listsFromListS: Mismatched list length (the model says " + ++ show (listsLength topl0) ++ ", list has length " + ++ show (length topl) ++ ")" + {-# INLINEABLE listsToList #-} -listsToList :: ListS sh (Const i) -> [i] +listsToList :: ListS sh i -> [i] listsToList list = build (\(cons :: i -> is -> is) (nil :: is) -> - let go :: ListS sh (Const i) -> is + let go :: ListS sh i -> is go ZS = nil - go (Const i ::$ is) = i `cons` go is + go (i ::$ is) = i `cons` go is in go list) -listsHead :: ListS (n : sh) f -> f n +listsHead :: ListS (n : sh) i -> i listsHead (i ::$ _) = i -listsTail :: ListS (n : sh) f -> ListS sh f +listsTail :: ListS (n : sh) i -> ListS sh i listsTail (_ ::$ sh) = sh -listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f +listsInit :: ListS (n : sh) i -> ListS (Init (n : sh)) i listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh listsInit (_ ::$ ZS) = ZS -listsLast :: ListS (n : sh) f -> f (Last (n : sh)) +listsLast :: ListS (n : sh) i -> i listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh listsLast (n ::$ ZS) = n -listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f +listsAppend :: ListS sh i -> ListS sh' i -> ListS (sh ++ sh') i listsAppend ZS idx' = idx' listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' -listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) +listsZip :: ListS sh i -> ListS sh j -> ListS sh (i, j) listsZip ZS ZS = ZS -listsZip (i ::$ is) (j ::$ js) = Fun.Pair i j ::$ listsZip is js +listsZip (i ::$ is) (j ::$ js) = (i, j) ::$ listsZip is js {-# INLINE listsZipWith #-} -listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g - -> ListS sh h +listsZipWith :: (i -> j -> k) -> ListS sh i -> ListS sh j -> ListS sh k listsZipWith _ ZS ZS = ZS listsZipWith f (i ::$ is) (j ::$ js) = f i j ::$ listsZipWith f is js -listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f +listsTakeLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (TakeLen is sh) i listsTakeLenPerm PNil _ = ZS listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" -listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f +listsDropLenPerm :: forall i is sh. Perm is -> ListS sh i -> ListS (DropLen is sh) i listsDropLenPerm PNil sh = sh listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" -listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f +listsPermute :: forall i is sh. Perm is -> ListS sh i -> ListS (Permute is sh) i listsPermute PNil _ = ZS listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = - case listsIndex (Proxy @is') (Proxy @sh) i sh of - (item, SNat) -> item ::$ listsPermute is sh + case listsIndex i sh of + item -> item ::$ listsPermute is sh --- TODO: remove this SNat when the KnownNat constaint in ListS is removed -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) -listsIndex _ _ SZ (n ::$ _) = (n, SNat) -listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = listsIndex p pT i sh -listsIndex _ _ _ ZS = error "Index into empty shape" +-- TODO: try to remove this SNat now that the KnownNat constraint in ListS is removed +listsIndex :: forall j i sh. SNat i -> ListS sh j -> j +listsIndex SZ (n ::$ _) = n +listsIndex (SS i) (_ ::$ sh) = listsIndex i sh +listsIndex _ ZS = error "Index into empty shape" -listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f +listsPermutePrefix :: forall i is sh. Perm is -> ListS sh i -> ListS (PermutePrefix is sh) i listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) -- * Shaped indices @@ -206,8 +188,8 @@ listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm pe -- | An index into a shape-typed array. type role IxS nominal representational type IxS :: [Nat] -> Type -> Type -newtype IxS sh i = IxS (ListS sh (Const i)) - deriving (Eq, Ord, Generic) +newtype IxS sh i = IxS (ListS sh i) + deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS @@ -216,10 +198,10 @@ pattern ZIS = IxS ZS -- removed in a future release. pattern (:.$) :: forall {sh1} {i}. - forall n sh. (KnownNat n, n : sh ~ sh1) + forall n sh. (n : sh ~ sh1) => i -> IxS sh i -> IxS sh1 i -pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) - where i :.$ IxS shl = IxS (Const i ::$ shl) +pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) i)) + where i :.$ IxS shl = IxS (i ::$ shl) infixr 3 :.$ {-# COMPLETE ZIS, (:.$) #-} @@ -232,25 +214,9 @@ type IIxS sh = IxS sh Int deriving instance Show i => Show (IxS sh i) #else instance Show i => Show (IxS sh i) where - showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l + showsPrec _ (IxS l) = listsShow (\i -> shows i) l #endif -instance Functor (IxS sh) where - {-# INLINE fmap #-} - fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) - -instance Foldable (IxS sh) where - {-# INLINE foldMap #-} - foldMap f (IxS l) = listsFoldMap (f . getConst) l - {-# INLINE foldr #-} - foldr _ z ZIS = z - foldr f z (x :.$ xs) = f x (foldr f z xs) - toList = ixsToList - null ZIS = False - null _ = True - -instance NFData i => NFData (IxS sh i) - ixsLength :: IxS sh i -> Int ixsLength (IxS l) = listsLength l @@ -260,16 +226,19 @@ ixsRank (IxS l) = listsRank l ixsFromList :: forall sh i. ShS sh -> [i] -> IxS sh i ixsFromList = coerce (listsFromList @_ @i) -{-# INLINEABLE ixsToList #-} -ixsToList :: forall sh i. IxS sh i -> [i] -ixsToList = coerce (listsToList @_ @i) +{-# INLINEABLE ixsFromIxS #-} +ixsFromIxS :: forall sh i0 i. IxS sh i0 -> [i] -> IxS sh i +ixsFromIxS = coerce (listsFromListS @_ @i0 @i) + +ixsToList :: IxS sh i -> [i] +ixsToList = Foldable.toList ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh ixsHead :: IxS (n : sh) i -> i -ixsHead (IxS list) = getConst (listsHead list) +ixsHead (IxS list) = listsHead list ixsTail :: IxS (n : sh) i -> IxS sh i ixsTail (IxS list) = IxS (listsTail list) @@ -278,16 +247,14 @@ ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i -ixsLast (IxS list) = getConst (listsLast list) +ixsLast (IxS list) = listsLast list --- TODO: this takes a ShS because there are KnownNats inside IxS. -ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i -ixsCast ZSS ZIS = ZIS -ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx -ixsCast _ _ = error "ixsCast: ranks don't match" +ixsCast :: IxS sh i -> IxS sh i +ixsCast ZIS = ZIS +ixsCast (i :.$ idx) = i :.$ ixsCast idx ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i -ixsAppend = coerce (listsAppend @_ @(Const i)) +ixsAppend = coerce (listsAppend @_ @i) ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) ixsZip ZIS ZIS = ZIS @@ -299,8 +266,31 @@ ixsZipWith _ ZIS ZIS = ZIS ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i -ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) +ixsPermutePrefix = coerce (listsPermutePrefix @i) +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixsToLinear #-} +ixsToLinear :: Num i => ShS sh -> IxS sh i -> i +ixsToLinear (ShS sh) ix = ixxToLinear sh (ixxFromIxS ix) + +ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i +ixxFromIxS = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled + +{-# INLINEABLE ixsFromLinear #-} +ixsFromLinear :: Num i => ShS sh -> Int -> IxS sh i +ixsFromLinear (ShS sh) i = ixsFromIxX $ ixxFromLinear sh i + +ixsFromIxX :: IxX (MapJust sh) i -> IxS sh i +ixsFromIxX = unsafeCoerce -- TODO: switch to coerce once newtypes overhauled + +shsEnum :: ShS sh -> [IIxS sh] +shsEnum = shsEnum' + +{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site +shsEnum' :: Num i => ShS sh -> [IxS sh i] +shsEnum' (ShS sh) = (unsafeCoerce :: [IxX (MapJust sh) i] -> [IxS sh i]) $ shxEnum' sh + -- TODO: switch to coerce once newtypes overhauled -- * Shaped shapes @@ -310,21 +300,34 @@ ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) -- can also retrieve the array shape from a 'KnownShS' dictionary. type role ShS nominal type ShS :: [Nat] -> Type -newtype ShS sh = ShS (ListS sh SNat) - deriving (Generic) +newtype ShS sh = ShS (ShX (MapJust sh) Int) + deriving (NFData) instance Eq (ShS sh) where _ == _ = True instance Ord (ShS sh) where compare _ _ = EQ pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh -pattern ZSS = ShS ZS +pattern ZSS <- ShS (matchZSX -> Just Refl) + where ZSS = ShS ZSX + +matchZSX :: forall sh i. ShX (MapJust sh) i -> Maybe (sh :~: '[]) +matchZSX ZSX | Refl <- lemMapJustEmpty @sh Refl = Just Refl +matchZSX _ = Nothing pattern (:$$) :: forall {sh1}. - forall n sh. (KnownNat n, n : sh ~ sh1) + forall n sh. (n : sh ~ sh1) => SNat n -> ShS sh -> ShS sh1 -pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) - where i :$$ ShS shl = ShS (i ::$ shl) +pattern i :$$ shl <- (shsUncons -> Just (UnconsShSRes i shl)) + where i :$$ ShS shl = ShS (SKnown i :$% shl) + +data UnconsShSRes sh1 = + forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh) +shsUncons :: forall sh1. ShS sh1 -> Maybe (UnconsShSRes sh1) +shsUncons (ShS (SKnown x :$% sh')) + | Refl <- lemMapJustCons @sh1 Refl + = Just (UnconsShSRes x (ShS sh')) +shsUncons (ShS _) = Nothing infixr 3 :$$ @@ -334,15 +337,13 @@ infixr 3 :$$ deriving instance Show (ShS sh) #else instance Show (ShS sh) where - showsPrec _ (ShS l) = listsShow (shows . fromSNat) l + showsPrec d (ShS shx) = showsPrec d shx #endif -instance NFData (ShS sh) where - rnf (ShS ZS) = () - rnf (ShS (SNat ::$ l)) = rnf (ShS l) - instance TestEquality ShS where - testEquality (ShS l1) (ShS l2) = listsEqType l1 l2 + testEquality (ShS shx1) (ShS shx2) = case shxEqType shx1 shx2 of + Nothing -> Nothing + Just Refl -> Just unsafeCoerceRefl -- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are -- equal if and only if values are equal.) @@ -350,64 +351,106 @@ shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') shsEqual = testEquality shsLength :: ShS sh -> Int -shsLength (ShS l) = listsLength l +shsLength (ShS shx) = shxLength shx -shsRank :: ShS sh -> SNat (Rank sh) -shsRank (ShS l) = listsRank l +shsRank :: forall sh. ShS sh -> SNat (Rank sh) +shsRank (ShS shx) = + gcastWith (unsafeCoerceRefl + :: Rank (MapJust sh) :~: Rank sh) $ + shxRank shx shsSize :: ShS sh -> Int -shsSize ZSS = 1 -shsSize (n :$$ sh) = fromSNat' n * shsSize sh +shsSize (ShS sh) = shxSize sh -- | This is a partial @const@ that fails when the second argument --- doesn't match the first. +-- doesn't match the first. We don't report the size of the list +-- in case of errors in order not to retain the list. +{-# INLINEABLE shsFromList #-} shsFromList :: ShS sh -> [Int] -> ShS sh -shsFromList topsh topl = go topsh topl `seq` topsh +shsFromList sh0@(ShS (ShX topsh)) topl = go topsh topl `seq` sh0 where - go :: ShS sh' -> [Int] -> () - go ZSS [] = () - go (sn :$$ sh) (i : is) + go :: ListH sh' Int -> [Int] -> () + go ZH [] = () + go ZH _ = error $ "shsFromList: List too long (type says " ++ show (listhLength topsh) ++ ")" + go (ConsKnown sn sh) (i : is) | i == fromSNat' sn = go sh is - | otherwise = error $ "shsFromList: Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go _ _ = error $ "shsFromList: Mismatched list length (type says " - ++ show (shsLength topsh) ++ ", list has length " - ++ show (length topl) ++ ")" + | otherwise = error $ "shsFromList: Value does not match typing" + go ConsUnknown{} _ = error "shsFromList: impossible case" + go _ _ = error $ "shsFromList: List too short (type says " ++ show (listhLength topsh) ++ ")" +-- This is equivalent to but faster than @coerce shxToList@. {-# INLINEABLE shsToList #-} shsToList :: ShS sh -> [Int] -shsToList topsh = build (\(cons :: Int -> is -> is) (nil :: is) -> - let go :: ShS sh -> is - go ZSS = nil - go (sn :$$ sh) = fromSNat' sn `cons` go sh - in go topsh) +shsToList (ShS (ShX l)) = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ListH sh Int -> is + go ZH = nil + go ConsUnknown{} = error "shsToList: impossible case" + go (ConsKnown snat rest) = fromSNat' snat `cons` go rest + in go l) shsHead :: ShS (n : sh) -> SNat n -shsHead (ShS list) = listsHead list +shsHead (ShS shx) = case shxHead shx of + SKnown SNat -> SNat -shsTail :: ShS (n : sh) -> ShS sh -shsTail (ShS list) = ShS (listsTail list) +shsTail :: forall n sh. ShS (n : sh) -> ShS sh +shsTail = coerce (shxTail @_ @_ @Int) -shsInit :: ShS (n : sh) -> ShS (Init (n : sh)) -shsInit (ShS list) = ShS (listsInit list) +shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh)) +shsInit = + gcastWith (unsafeCoerceRefl + :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $ + coerce (shxInit @_ @_ @Int) -shsLast :: ShS (n : sh) -> SNat (Last (n : sh)) -shsLast (ShS list) = listsLast list +shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh)) +shsLast (ShS shx) = + gcastWith (unsafeCoerceRefl + :: Last (Just n : MapJust sh) :~: Just (Last (n : sh))) $ + case shxLast shx of + SKnown SNat -> SNat shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') -shsAppend = coerce (listsAppend @_ @SNat) +shsAppend = + gcastWith (unsafeCoerceRefl + :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $ + coerce (shxAppend @_ @_ @Int) + +shsTakeLen :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh) +shsTakeLen = + gcastWith (unsafeCoerceRefl + :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $ + coerce shxTakeLen -shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) -shsTakeLen = coerce (listsTakeLenPerm @SNat) +shsDropLen :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh) +shsDropLen = + gcastWith (unsafeCoerceRefl + :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $ + coerce shxDropLen -shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) -shsPermute = coerce (listsPermute @SNat) +shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh) +shsPermute = + gcastWith (unsafeCoerceRefl + :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $ + coerce shxPermute -shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) -shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) +shsIndex :: forall i sh. SNat i -> ShS sh -> SNat (Index i sh) +shsIndex i (ShS sh) = + gcastWith (unsafeCoerceRefl + :: Index i (MapJust sh) :~: Just (Index i sh)) $ + case shxIndex @_ @_ @Int i sh of + SKnown SNat -> SNat shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) -shsPermutePrefix = coerce (listsPermutePrefix @SNat) +shsPermutePrefix perm (ShS shx) + {- TODO: here and elsewhere, solve the module dependency cycle and add this: + | Refl <- lemTakeLenMapJust perm sh + , Refl <- lemDropLenMapJust perm sh + , Refl <- lemPermuteMapJust perm sh + , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm sh)) (shsDropLen perm sh) -} + = gcastWith (unsafeCoerceRefl + :: Permute is (TakeLen is (MapJust sh)) + ++ DropLen is (MapJust sh) + :~: MapJust (Permute is (TakeLen is sh) ++ DropLen is sh)) $ + ShS (shxPermutePrefix perm shx) type family Product sh where Product '[] = 1 @@ -435,37 +478,10 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh shsOrthotopeShape ZSS = Dict shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict --- | This function is a hack made possible by the 'KnownNat' inside 'ListS'. --- This function may be removed in a future release. -shsFromListS :: ListS sh f -> ShS sh -shsFromListS ZS = ZSS -shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l - --- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This --- function may be removed in a future release. -shsFromIxS :: IxS sh i -> ShS sh -shsFromIxS (IxS l) = shsFromListS l - -shsEnum :: ShS sh -> [IIxS sh] -shsEnum = shsEnum' - -{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site -shsEnum' :: Num i => ShS sh -> [IxS sh i] -shsEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shsSize sh - 1]] - where - suffixes = drop 1 (scanr (*) 1 (shsToList sh)) - - fromLin :: Num i => ShS sh -> [Int] -> Int# -> IxS sh i - fromLin ZSS _ _ = ZIS - fromLin (_ :$$ sh') (I# suff# : suffs) i# = - let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shsSize sh' - in fromIntegral (I# q#) :.$ fromLin sh' suffs r# - fromLin _ _ _ = error "impossible" - -- | Untyped: length is checked at runtime. -instance KnownShS sh => IsList (ListS sh (Const i)) where - type Item (ListS sh (Const i)) = i +instance KnownShS sh => IsList (ListS sh i) where + type Item (ListS sh i) = i fromList = listsFromList (knownShS @sh) toList = listsToList @@ -480,6 +496,3 @@ instance KnownShS sh => IsList (ShS sh) where type Item (ShS sh) = Int fromList = shsFromList (knownShS @sh) toList = shsToList - -$(ixFromLinearStub "ixsFromLinear" [t| ShS |] [t| IxS |] [p| ZSS |] (\a b -> [p| (fromSNat' -> $a) :$$ $b |]) [| ZIS |] [| (:.$) |] [| shsToList |]) -{-# INLINEABLE ixsFromLinear #-} diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs index a43ae0c..8bb5b85 100644 --- a/src/Data/Array/Nested/Types.hs +++ b/src/Data/Array/Nested/Types.hs @@ -46,7 +46,6 @@ import GHC.TypeLits import GHC.TypeNats qualified as TN import Unsafe.Coerce qualified - -- Reasoning helpers subst1 :: forall f a b. a :~: b -> f a :~: f b @@ -59,8 +58,9 @@ subst2 Refl = Refl data Dict c a where Dict :: c a => Dict c a +{-# INLINE fromSNat' #-} fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat +fromSNat' = fromEnum . TN.fromSNat sameNat' :: SNat n -> SNat m -> Maybe (n :~: m) sameNat' n@SNat m@SNat = sameNat n m @@ -110,7 +110,7 @@ type family Replicate n a where Replicate n a = a : Replicate (n - 1) a lemReplicateSucc :: forall a n proxy. - proxy n -> (a : Replicate n a) :~: Replicate (n + 1) a + proxy n -> a : Replicate n a :~: Replicate (n + 1) a lemReplicateSucc _ = unsafeCoerceRefl type family MapJust l = r | r -> l where diff --git a/src/Data/Array/Strided/Orthotope.hs b/src/Data/Array/Strided/Orthotope.hs index 5c38d14..e2cd17c 100644 --- a/src/Data/Array/Strided/Orthotope.hs +++ b/src/Data/Array/Strided/Orthotope.hs @@ -24,14 +24,19 @@ fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset ve toO :: AS.Array n a -> RS.Array n a toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec)) +{-# INLINE liftO1 #-} liftO1 :: (AS.Array n a -> AS.Array n' b) -> RS.Array n a -> RS.Array n' b liftO1 f = toO . f . fromO +{-# INLINE liftO2 #-} liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c liftO2 f x y = toO (f (fromO x) (fromO y)) +-- We don't inline this lifting function, because its code is not just +-- a wrapper, being relatively long and expensive. +{-# INLINEABLE liftVEltwise1 #-} liftVEltwise1 :: (Storable a, Storable b) => SNat n -> (VS.Vector a -> VS.Vector b) diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 1445ce6..4f5bb08 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -17,7 +17,7 @@ module Data.Array.XArray where import Control.DeepSeq (NFData) -import Control.Monad (foldM) +import Control.Monad (foldM_, foldM) import Control.Monad.ST import Data.Array.Internal qualified as OI import Data.Array.Internal.RankedG qualified as ORG @@ -26,7 +26,7 @@ import Data.Array.RankedS qualified as S import Data.Coerce import Data.Foldable (toList) import Data.Kind -import Data.List.NonEmpty (NonEmpty) +import Data.List.NonEmpty (NonEmpty(..)) import Data.Proxy import Data.Type.Equality import Data.Type.Ord @@ -62,6 +62,7 @@ shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l go _ _ = error "Invalid shapeL" +{-# INLINEABLE fromVector #-} fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownNatRank sh @@ -87,7 +88,7 @@ cast ssh1 sh2 ssh' (XArray arr) | Refl <- lemRankApp ssh1 ssh' , Refl <- lemRankApp (ssxFromShX sh2) ssh' = let arrsh :: IShX sh1 - (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) + arrsh = shxTakeSSX (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) in if shxToList arrsh == shxToList sh2 then XArray arr else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" @@ -184,7 +185,7 @@ rerank :: forall sh sh1 sh2 a b. -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f xarr@(XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) + = let sh = shxTakeSSX (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) in if 0 `elem` shxToList sh then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of @@ -211,7 +212,7 @@ rerank2 :: forall sh sh1 sh2 a b c. -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) - = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) + = let sh = shxTakeSSX (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) in if 0 `elem` shxToList sh then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) else case () of @@ -274,14 +275,14 @@ sumInner :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' arr | Refl <- lemAppNil @sh - = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + = let sh' = shxDropSSX @sh @sh' ssh (shape (ssxAppend ssh ssh') arr) sh'F = shxFlatten sh' :$% ZSX ssh'F = ssxFromShX sh'F go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a go (XArray arr') | Refl <- lemRankApp ssh ssh'F - , let sn = listxRank (let StaticShX l = ssh in l) + , let sn = ssxRank ssh = XArray (liftO1 (numEltSum1Inner sn) arr') in go $ @@ -294,7 +295,7 @@ sumOuter :: forall sh sh' a. (Storable a, NumElt a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' arr | Refl <- lemAppNil @sh - = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + = let sh = shxTakeSSX (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) shF = shxFlatten sh :$% ZSX in sumInner ssh' (ssxFromShX shF) $ transpose2 (ssxFromShX shF) ssh' $ @@ -305,50 +306,48 @@ sumOuter ssh ssh' arr -- the list's spine must be fully materialised to compute its length before -- constructing the array. The list can't be empty (not enough information -- in the given shape to guess the shape of the empty array, in general). -fromListOuter :: forall n sh a. Storable a - => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a -fromListOuter ssh l - | Dict <- lemKnownNatRankSSX (ssxTail ssh) - , let l' = coerce @[XArray sh a] @[S.Array (Rank sh) a] l - = case ssh of - _ :!% ZKX -> - fromList1 ssh (map S.unScalar l') - SKnown m :!% _ -> - let n = fromSNat' m - in XArray (ravelOuterN n l') - _ -> - let n = length l - in XArray (ravelOuterN n l') +{-# INLINE fromListOuterSN #-} +fromListOuterSN :: forall n sh a. Storable a + => SNat n -> IShX sh -> NonEmpty (XArray sh a) -> XArray (Just n : sh) a +fromListOuterSN m sh l + | Dict <- lemKnownNatRank sh + , let l' = coerce @(NonEmpty (XArray sh a)) @(NonEmpty (S.Array (Rank sh) a)) l + = case sh of + ZSX -> fromList1SN m (map S.unScalar (toList l')) + _ -> XArray (ravelOuterN (fromSNat' m) l') -- | This checks that the list has the given length and that all shapes in the -- list are equal. The list must be non-empty, and is streamed. +{-# INLINEABLE ravelOuterN #-} ravelOuterN :: (KnownNat k, Storable a) - => Int -> [S.Array k a] -> S.Array (1 + k) a + => Int -> NonEmpty (S.Array k a) -> S.Array (1 + k) a ravelOuterN 0 _ = error "ravelOuterN: N == 0" -ravelOuterN _ [] = error "ravelOuterN: empty list" -ravelOuterN k as@(a0 : _) = runST $ do +ravelOuterN k as@(a0 :| _) = runST $ do let sh0 = S.shapeL a0 len = product sh0 vecSize = k * len vec <- VSM.unsafeNew vecSize - let f !n a = + let f !n (ORS.A (ORG.A sht t)) = if | n >= k -> error $ "ravelOuterN: list too long " ++ show (n, k) -- if we do this check just once at the end, we may -- crash instead of producing an accurate error message - | S.shapeL a == sh0 -> do - VS.unsafeCopy (VSM.slice (n * len) len vec) (S.toVector a) - return $! n + 1 + | sht == sh0 -> do + let g off el = do + VS.unsafeCopy (VSM.slice off (VS.length el) vec) el + return $! off + VS.length el + foldM_ g (n * len) (OI.toVectorListT sht t) + return $! n + 1 | otherwise -> - error $ "ravelOuterN: unequal shapes " ++ show (S.shapeL a, sh0) + error $ "ravelOuterN: unequal shapes " ++ show (sht, sh0) nFinal <- foldM f 0 as if nFinal == k then S.fromVector (k : sh0) <$> VS.unsafeFreeze vec else error $ "ravelOuterN: list too short " ++ show (nFinal, k) toListOuter :: forall a n sh. Storable a => XArray (n : sh) a -> [XArray sh a] -toListOuter (XArray arr@(ORS.A (ORG.A _ t))) = - case S.shapeL arr of +toListOuter (XArray arr@(ORS.A (ORG.A shArr t))) = + case shArr of [] -> error "impossible" 0 : _ -> [] -- using orthotope's functions here would entail using rerank, which is slow, so we don't @@ -358,15 +357,20 @@ toListOuter (XArray arr@(ORS.A (ORG.A _ t))) = -- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown, -- the list's spine must be fully materialised to compute its length before -- constructing the array. -fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a -fromList1 ssh l = - case ssh of - SKnown m :!% _ -> - let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed - in XArray (S.fromVector [n] (VGC.fromListNChecked n l)) - _ -> - let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself - in XArray (S.fromVector [n] (VS.fromListN n l)) +{-# INLINE fromList1 #-} +fromList1 :: Storable a => [a] -> XArray '[Nothing] a +fromList1 l = + let n = length l -- avoid S.fromList because it takes a length _and_ does another length check itself + in XArray (S.fromVector [n] (VS.fromListN n l)) + +-- | If @n@ is an 'SKnown' dimension, the list is streamed. If @n@ is unknown, +-- the list's spine must be fully materialised to compute its length before +-- constructing the array. +{-# INLINE fromList1SN #-} +fromList1SN :: Storable a => SNat n -> [a] -> XArray '[Just n] a +fromList1SN m l = + let n = fromSNat' m -- do length check and vector construction simultaneously so that l can be streamed + in XArray (S.fromVector [n] (VGC.fromListNChecked n l)) toList1 :: Storable a => XArray '[n] a -> [a] toList1 (XArray arr) = S.toList arr |
