aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Mixed.hs105
1 files changed, 76 insertions, 29 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 6b152f7..a2787b8 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -7,6 +7,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -307,15 +308,9 @@ class Elt a where
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
mscalar :: a -> Mixed '[] a
- -- | All arrays in the list, even subarrays inside @a@, must have the same
- -- shape; if they do not, a runtime error will be thrown. See the
- -- documentation of 'mgenerate' for more information about this restriction.
- -- Furthermore, the length of the list must correspond with @n@: if @n@ is
- -- @Just m@ and @m@ does not equal the length of the list, a runtime error is
- -- thrown.
- --
- -- Consider also 'mfromListPrim', which can avoid intermediate arrays.
- mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+ -- | See 'mfromListOuter'. If the list does not have the given length, a
+ -- runtime error is thrown. 'mfromListPrimSN' is faster if applicable.
+ mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a
mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
@@ -407,8 +402,8 @@ instance Storable a => Elt (Primitive a) where
mindex (M_Primitive _ a) i = Primitive (X.index a i)
mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
- mfromListOuter l@(arr1 :| _) =
- let sh = SUnknown (length l) :$% mshape arr1
+ mfromListOuterSN sn l@(arr1 :| _) =
+ let sh = SKnown sn :$% mshape arr1
in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l)))
mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
@@ -515,9 +510,9 @@ instance (Elt a, Elt b) => Elt (a, b) where
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
- mfromListOuter l =
- M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l))
- (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l))
+ mfromListOuterSN sn l =
+ M_Tup2 (mfromListOuterSN sn ((\(M_Tup2 x _) -> x) <$> l))
+ (mfromListOuterSN sn ((\(M_Tup2 _ y) -> y) <$> l))
mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
@@ -578,10 +573,9 @@ instance Elt a => Elt (Mixed sh' a) where
mscalar = M_Nest ZSX
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
- mfromListOuter l@(arr :| _) =
- M_Nest (SUnknown (length l) :$% mshape arr)
- (mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
+ mfromListOuterSN sn l@(arr :| _) =
+ M_Nest (SKnown sn :$% mshape arr)
+ (mfromListOuterSN sn ((\(M_Nest _ a) -> a) <$> l))
mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
@@ -793,23 +787,76 @@ mtoVectorP (M_Primitive _ v) = X.toVector v
mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
mtoVector arr = mtoVectorP (toPrimitive arr)
+-- | All arrays in the list, even subarrays inside @a@, must have the same
+-- shape; if they do not, a runtime error will be thrown. See the
+-- documentation of 'mgenerate' for more information about this restriction.
+--
+-- Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'mfromListOuterN' or 'mfromListOuterSN' to be able to
+-- stream the list.
+--
+-- If your array is 1-dimensional and contains scalars, use 'mfromList1Prim'.
+mfromListOuter :: Elt a => NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+mfromListOuter l = mfromListOuterN (length l) l
+
+-- | See 'mfromListOuter'. If the list does not have the given length, a
+-- runtime error is thrown. 'mfromList1PrimN' is faster if applicable.
+mfromListOuterN :: Elt a => Int -> NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+mfromListOuterN n l =
+ withSomeSNat (fromIntegral n) $ \case
+ Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromListOuterSN sn l)
+ Nothing -> error $ "mfromListOuterN: length negative (" ++ show n ++ ")"
+
+-- | Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'mfromList1N' or 'mfromList1SN' to be able to stream the
+-- list.
+--
+-- If the elements are scalars, 'mfromList1Prim' is faster.
mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
-mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
+mfromList1 = mfromListOuter . fmap mscalar
+
+-- | If the elements are scalars, 'mfromList1PrimN' is faster. A runtime error
+-- is thrown if the list length does not match the given length.
+mfromList1N :: Elt a => Int -> NonEmpty a -> Mixed '[Nothing] a
+mfromList1N n = mfromListOuterN n . fmap mscalar
+
+-- | If the elements are scalars, 'mfromList1PrimSN' is faster. A runtime error
+-- is thrown if the list length does not match the given length.
+mfromList1SN :: Elt a => SNat n -> NonEmpty a -> Mixed '[Just n] a
+mfromList1SN sn = mfromListOuterSN sn . fmap mscalar
-- This forall is there so that a simple type application can constrain the
-- shape, in case the user wants to use OverloadedLists for the shape.
+-- | If the elements are scalars, 'mfromListPrimLinear' is faster.
mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
-mfromListLinear sh l = mreshape sh (mfromList1 l)
+mfromListLinear sh l = mreshape sh (mfromList1N (shxSize sh) l)
-mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
-mfromListPrim l =
+-- | Because the length of the list is unknown, its spine must be materialised
+-- in memory in order to compute its length. If its length is already known,
+-- use 'mfromList1PrimN' or 'mfromList1PrimSN' to be able to stream the list.
+mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
+mfromList1Prim l =
let ssh = SUnknown () :!% ZKX
xarr = X.fromList1 ssh l
in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a
+mfromList1PrimN :: PrimElt a => Int -> [a] -> Mixed '[Nothing] a
+mfromList1PrimN n l =
+ withSomeSNat (fromIntegral n) $ \case
+ Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromList1PrimSN sn l)
+ Nothing -> error $ "mfromList1PrimN: length negative (" ++ show n ++ ")"
+
+mfromList1PrimSN :: PrimElt a => SNat n -> [a] -> Mixed '[Just n] a
+mfromList1PrimSN sn l =
+ let ssh = SKnown sn :!% ZKX
+ xarr = X.fromList1 ssh l
+ in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+mfromListPrimLinear :: forall sh a. PrimElt a => IShX sh -> [a] -> Mixed sh a
mfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ let M_Primitive _ xarr = toPrimitive (mfromList1PrimN (shxSize sh) l)
in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
mtoList :: Elt a => Mixed '[n] a -> [a]
@@ -872,14 +919,14 @@ mreplicateScal :: forall sh a. PrimElt a
=> IShX sh -> a -> Mixed sh a
mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
-mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
-mslice i n arr =
+msliceN :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
+msliceN i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr
+
+msliceSN :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
+msliceSN i n arr =
let _ :$% sh = mshape arr
in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr
-msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
-msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr
-
mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr