aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Mixed.hs260
1 files changed, 183 insertions, 77 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 54f8fe6..182943d 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -7,6 +7,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -91,6 +92,9 @@ import Data.Bag
-- Unfortunately, the setup of the library requires us to list these primitive
-- element types multiple times; to aid in extending the list, all these lists
-- have been marked with [PRIMITIVE ELEMENT TYPES LIST].
+--
+-- NOTE: if you add PRIMITIVE types, be sure to also add NumElt and IntElt
+-- instances for them!
-- | Wrapper type used as a tag to attach instances on. The instances on arrays
@@ -118,6 +122,8 @@ instance PrimElt Bool
instance PrimElt Int
instance PrimElt Int64
instance PrimElt Int32
+instance PrimElt Int16
+instance PrimElt Int8
instance PrimElt CInt
instance PrimElt Float
instance PrimElt Double
@@ -154,6 +160,8 @@ newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq
newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic ANDSHOW)
newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic ANDSHOW)
newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Int16 = M_Int16 (Mixed sh (Primitive Int16)) deriving (Eq, Ord, Generic ANDSHOW)
+newtype instance Mixed sh Int8 = M_Int8 (Mixed sh (Primitive Int8)) deriving (Eq, Ord, Generic ANDSHOW)
newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic ANDSHOW)
newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic ANDSHOW)
newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic ANDSHOW)
@@ -190,6 +198,8 @@ newtype instance MixedVecs s sh Bool = MV_Bool (VS.MVector s Bool)
newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int)
newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64)
newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32)
+newtype instance MixedVecs s sh Int16 = MV_Int16 (VS.MVector s Int16)
+newtype instance MixedVecs s sh Int8 = MV_Int8 (VS.MVector s Int8)
newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt)
newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double)
newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float)
@@ -247,15 +257,15 @@ instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
abs = mliftNumElt1 (liftO1 . numEltAbs)
signum = mliftNumElt1 (liftO1 . numEltSignum)
-- TODO: THIS IS BAD, WE NEED TO REMOVE THIS
- fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal"
+ fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicatePrim"
instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where
- fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicatePrim"
recip = mliftNumElt1 (liftO1 . floatEltRecip)
(/) = mliftNumElt2 (liftO2 . floatEltDiv)
instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicatePrim"
exp = mliftNumElt1 (liftO1 . floatEltExp)
log = mliftNumElt1 (liftO1 . floatEltLog)
sqrt = mliftNumElt1 (liftO1 . floatEltSqrt)
@@ -298,15 +308,9 @@ class Elt a where
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
mscalar :: a -> Mixed '[] a
- -- | All arrays in the list, even subarrays inside @a@, must have the same
- -- shape; if they do not, a runtime error will be thrown. See the
- -- documentation of 'mgenerate' for more information about this restriction.
- -- Furthermore, the length of the list must correspond with @n@: if @n@ is
- -- @Just m@ and @m@ does not equal the length of the list, a runtime error is
- -- thrown.
- --
- -- Consider also 'mfromListPrim', which can avoid intermediate arrays.
- mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+ -- | See 'mfromListOuter'. If the list does not have the given length, a
+ -- runtime error is thrown. 'mfromListPrimSN' is faster if applicable.
+ mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a
mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
@@ -355,7 +359,7 @@ class Elt a where
mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
- mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool
+ mshapeTreeIsEmpty :: Proxy a -> ShapeTree a -> Bool
mshowShapeTree :: Proxy a -> ShapeTree a -> String
@@ -380,9 +384,7 @@ class Elt a where
-- of this class with those of 'Elt': some instances have an additional
-- "known-shape" constraint.
--
--- This class is (currently) only required for 'mgenerate',
--- 'Data.Array.Nested.Ranked.rgenerate' and
--- 'Data.Array.Nested.Shaped.sgenerate'.
+-- This class is (currently) only required for `memptyArray` and 'mgenerate'.
class Elt a => KnownElt a where
-- | Create an empty array. The given shape must have size zero; this may or may not be checked.
memptyArrayUnsafe :: IShX sh -> Mixed sh a
@@ -397,11 +399,13 @@ class Elt a => KnownElt a where
-- Arrays of scalars are basically just arrays of scalars.
instance Storable a => Elt (Primitive a) where
mshape (M_Primitive sh _) = sh
+ {-# INLINEABLE mindex #-}
mindex (M_Primitive _ a) i = Primitive (X.index a i)
- mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
+ {-# INLINEABLE mindexPartial #-}
+ mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
- mfromListOuter l@(arr1 :| _) =
- let sh = SUnknown (length l) :$% mshape arr1
+ mfromListOuterSN sn l@(arr1 :| _) =
+ let sh = SKnown sn :$% mshape arr1
in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l)))
mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
@@ -440,7 +444,7 @@ instance Storable a => Elt (Primitive a) where
=> StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) =
let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
- sh2 = shxCast' sh1 ssh2
+ sh2 = shxCast' ssh2 sh1
in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr)
mtranspose perm (M_Primitive sh arr) =
@@ -457,7 +461,7 @@ instance Storable a => Elt (Primitive a) where
type ShapeTree (Primitive a) = ()
mshapeTree _ = ()
mshapeTreeEq _ () () = True
- mshapeTreeEmpty _ () = False
+ mshapeTreeIsEmpty _ () = False
mshowShapeTree _ () = "()"
marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr)
mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
@@ -478,6 +482,8 @@ deriving via Primitive Bool instance Elt Bool
deriving via Primitive Int instance Elt Int
deriving via Primitive Int64 instance Elt Int64
deriving via Primitive Int32 instance Elt Int32
+deriving via Primitive Int16 instance Elt Int16
+deriving via Primitive Int8 instance Elt Int8
deriving via Primitive CInt instance Elt CInt
deriving via Primitive Double instance Elt Double
deriving via Primitive Float instance Elt Float
@@ -493,6 +499,8 @@ deriving via Primitive Bool instance KnownElt Bool
deriving via Primitive Int instance KnownElt Int
deriving via Primitive Int64 instance KnownElt Int64
deriving via Primitive Int32 instance KnownElt Int32
+deriving via Primitive Int16 instance KnownElt Int16
+deriving via Primitive Int8 instance KnownElt Int8
deriving via Primitive CInt instance KnownElt CInt
deriving via Primitive Double instance KnownElt Double
deriving via Primitive Float instance KnownElt Float
@@ -504,9 +512,9 @@ instance (Elt a, Elt b) => Elt (a, b) where
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
- mfromListOuter l =
- M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l))
- (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l))
+ mfromListOuterSN sn l =
+ M_Tup2 (mfromListOuterSN sn ((\(M_Tup2 x _) -> x) <$> l))
+ (mfromListOuterSN sn ((\(M_Tup2 _ y) -> y) <$> l))
mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
@@ -531,7 +539,7 @@ instance (Elt a, Elt b) => Elt (a, b) where
type ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
- mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
+ mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2
mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b
mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
@@ -557,20 +565,19 @@ instance Elt a => Elt (Mixed sh' a) where
= fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr))
mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
- mindex (M_Nest _ arr) i = mindexPartial arr i
+ mindex (M_Nest _ arr) = mindexPartial arr
mindexPartial :: forall sh1 sh2.
Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
mindexPartial (M_Nest sh arr) i
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+ = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
mscalar = M_Nest ZSX
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
- mfromListOuter l@(arr :| _) =
- M_Nest (SUnknown (length l) :$% mshape arr)
- (mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
+ mfromListOuterSN sn l@(arr :| _) =
+ M_Nest (SKnown sn :$% mshape arr)
+ (mfromListOuterSN sn ((\(M_Nest _ a) -> a) <$> l))
mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
@@ -632,14 +639,14 @@ instance Elt a => Elt (Mixed sh' a) where
| Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
, Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
= let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
- sh2 = shxCast' sh1 ssh2
+ sh2 = shxCast' ssh2 sh1
in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr)
mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
=> Perm is -> Mixed sh (Mixed sh' a)
-> Mixed (PermutePrefix is sh) (Mixed sh' a)
mtranspose perm (M_Nest sh arr)
- | let sh' = shxDropSh @sh @sh' (mshape arr) sh
+ | let sh' = shxDropSh @sh @sh' sh (mshape arr)
, Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh')
, Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
, Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
@@ -662,7 +669,8 @@ instance Elt a => Elt (Mixed sh' a) where
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
- mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+ -- the array is empty if either there are no subarrays, or the subarrays themselves are empty
+ mshapeTreeIsEmpty _ (sh, t) = shxSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
@@ -692,7 +700,8 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
-memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
+-- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere
+memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a
memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh)
mrank :: Elt a => Mixed sh a -> SNat (Rank sh)
@@ -719,19 +728,19 @@ msize = shxSize . mshape
-- the entire hierarchy (after distributing out tuples) must be a rectangular
-- array. The type of 'mgenerate' allows this requirement to be broken very
-- easily, hence the runtime check.
+--
+-- If your element type @a@ is a scalar, use the faster 'mgeneratePrim'.
mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
mgenerate sh f = case shxEnum sh of
[] -> memptyArrayUnsafe sh
firstidx : restidxs ->
let firstelem = f (ixxZero' sh)
shapetree = mshapeTree firstelem
- in if mshapeTreeEmpty (Proxy @a) shapetree
+ in if mshapeTreeIsEmpty (Proxy @a) shapetree
then memptyArrayUnsafe sh
else runST $ do
vecs <- mvecsUnsafeNew sh firstelem
mvecsWrite sh firstidx firstelem vecs
- -- TODO: This is likely fine if @a@ is big, but if @a@ is a
- -- scalar this array copying inefficient. Should improve this.
forM_ restidxs $ \idx -> do
let val = f idx
when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
@@ -739,18 +748,32 @@ mgenerate sh f = case shxEnum sh of
mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
-msumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
-msumOuter1P (M_Primitive (n :$% sh) arr) =
+-- | An optimized special case of 'mgenerate', where the function results
+-- are of a primitive type and so there's not need to check that all shapes
+-- are equal. This is also generalized to an arbitrary @Num@ index type
+-- compared to @mgenerate@.
+{-# INLINE mgeneratePrim #-}
+mgeneratePrim :: forall sh a i. (PrimElt a, Num i)
+ => IShX sh -> (IxX sh i -> a) -> Mixed sh a
+mgeneratePrim sh f =
+ let g i = f (ixxFromLinear sh i)
+ in mfromVector sh $ VS.generate (shxSize sh) g
+
+msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a)
+ => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
+msumOuter1PrimP (M_Primitive (n :$% sh) arr) =
let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr)
-msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Mixed (n : sh) a -> Mixed sh a
-msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
+msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a)
+ => Mixed (n : sh) a -> Mixed sh a
+msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive
+
+msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a
+msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
-msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
+msumAllPrim arr = msumAllPrimP (toPrimitive arr)
mappend :: forall n m sh a. Elt a
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
@@ -781,23 +804,76 @@ mtoVectorP (M_Primitive _ v) = X.toVector v
mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
mtoVector arr = mtoVectorP (toPrimitive arr)
+-- | All arrays in the list, even subarrays inside @a@, must have the same
+-- shape; if they do not, a runtime error will be thrown. See the
+-- documentation of 'mgenerate' for more information about this restriction.
+--
+-- Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'mfromListOuterN' or 'mfromListOuterSN' to be able to
+-- stream the list.
+--
+-- If your array is 1-dimensional and contains scalars, use 'mfromList1Prim'.
+mfromListOuter :: Elt a => NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+mfromListOuter l = mfromListOuterN (length l) l
+
+-- | See 'mfromListOuter'. If the list does not have the given length, a
+-- runtime error is thrown. 'mfromList1PrimN' is faster if applicable.
+mfromListOuterN :: Elt a => Int -> NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+mfromListOuterN n l =
+ withSomeSNat (fromIntegral n) $ \case
+ Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromListOuterSN sn l)
+ Nothing -> error $ "mfromListOuterN: length negative (" ++ show n ++ ")"
+
+-- | Because the length of the 'NonEmpty' list is unknown, its spine must be
+-- materialised in memory in order to compute its length. If its length is
+-- already known, use 'mfromList1N' or 'mfromList1SN' to be able to stream the
+-- list.
+--
+-- If the elements are scalars, 'mfromList1Prim' is faster.
mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
-mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
+mfromList1 = mfromListOuter . fmap mscalar
+
+-- | If the elements are scalars, 'mfromList1PrimN' is faster. A runtime error
+-- is thrown if the list length does not match the given length.
+mfromList1N :: Elt a => Int -> NonEmpty a -> Mixed '[Nothing] a
+mfromList1N n = mfromListOuterN n . fmap mscalar
+
+-- | If the elements are scalars, 'mfromList1PrimSN' is faster. A runtime error
+-- is thrown if the list length does not match the given length.
+mfromList1SN :: Elt a => SNat n -> NonEmpty a -> Mixed '[Just n] a
+mfromList1SN sn = mfromListOuterSN sn . fmap mscalar
-- This forall is there so that a simple type application can constrain the
-- shape, in case the user wants to use OverloadedLists for the shape.
+-- | If the elements are scalars, 'mfromListPrimLinear' is faster.
mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a
-mfromListLinear sh l = mreshape sh (mfromList1 l)
+mfromListLinear sh l = mreshape sh (mfromList1N (shxSize sh) l)
-mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
-mfromListPrim l =
+-- | Because the length of the list is unknown, its spine must be materialised
+-- in memory in order to compute its length. If its length is already known,
+-- use 'mfromList1PrimN' or 'mfromList1PrimSN' to be able to stream the list.
+mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
+mfromList1Prim l =
let ssh = SUnknown () :!% ZKX
xarr = X.fromList1 ssh l
in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a
+mfromList1PrimN :: PrimElt a => Int -> [a] -> Mixed '[Nothing] a
+mfromList1PrimN n l =
+ withSomeSNat (fromIntegral n) $ \case
+ Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromList1PrimSN sn l)
+ Nothing -> error $ "mfromList1PrimN: length negative (" ++ show n ++ ")"
+
+mfromList1PrimSN :: PrimElt a => SNat n -> [a] -> Mixed '[Just n] a
+mfromList1PrimSN sn l =
+ let ssh = SKnown sn :!% ZKX
+ xarr = X.fromList1 ssh l
+ in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+mfromListPrimLinear :: forall sh a. PrimElt a => IShX sh -> [a] -> Mixed sh a
mfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ let M_Primitive _ xarr = toPrimitive (mfromList1PrimN (shxSize sh) l)
in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
mtoList :: Elt a => Mixed '[n] a -> [a]
@@ -824,24 +900,54 @@ mzip a b
munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b)
munzip (M_Tup2 a b) = (a, b)
-mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
- -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
-mrerankP ssh sh2 f (M_Primitive sh arr) =
- let sh1 = shxDropSSX sh ssh
- in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2)
- (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2)
- (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
- arr)
+mrerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => IShX sh2
+ -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
+ -> Mixed sh (Mixed sh1 (Primitive a)) -> Mixed sh (Mixed sh2 (Primitive b))
+mrerankPrimP sh2 f (M_Nest sh (M_Primitive shsh1 arr)) =
+ let sh1 = shxDropSh sh shsh1
+ in M_Nest sh $
+ M_Primitive (shxAppend sh sh2)
+ (X.rerank (ssxFromShX sh) (ssxFromShX sh1) (ssxFromShX sh2)
+ (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
+ arr)
--- | See the caveats at @X.rerank@.
-mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 a -> Mixed sh2 b)
- -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b
-mrerank ssh sh2 f (toPrimitive -> arr) =
- fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
+-- | If the shape of the outer array (@sh@) is empty (i.e. contains a zero),
+-- then there is no way to deduce the full shape of the output array (more
+-- precisely, the @sh2@ part): that could only come from calling @f@, and there
+-- are no subarrays to call @f@ on. @orthotope@ errors out in this case; we
+-- choose to fill the shape with zeros wherever we cannot deduce what it should
+-- be.
+--
+-- For example, if:
+--
+-- @
+-- -- arr has shape [3, 0, 4] and the inner arrays have shape [2, 21]
+-- arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 2, Nothing] Int)
+-- f :: Mixed '[Just 2, Nothing] Int -> Mixed '[Just 5, Nothing, Just 17] Float
+-- @
+--
+-- then:
+--
+-- @
+-- mrerankPrim _ f arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 5, Nothing, Just 17] Float)
+-- @
+--
+-- and the inner arrays of the result will have shape @[5, 0, 17]@. Note the
+-- @0@ in this shape: we don't know if @f@ intended to return an array with
+-- shape 0 here (it probably didn't), but there is no better number to put here
+-- absent a subarray of the input to pass to @f@.
+--
+-- In this particular case the fact that @sh@ is empty was evident from the
+-- type-level information, but the same situation occurs when @sh@ consists of
+-- @Nothing@s, and some of those happen to be zero at runtime.
+mrerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => IShX sh2
+ -> (Mixed sh1 a -> Mixed sh2 b)
+ -> Mixed sh (Mixed sh1 a) -> Mixed sh (Mixed sh2 b)
+mrerankPrim sh2 f (M_Nest sh arr) =
+ let M_Nest sh' arr' = mrerankPrimP sh2 (toPrimitive . f . fromPrimitive) (M_Nest sh (toPrimitive arr))
+ in M_Nest sh' (fromPrimitive arr')
mreplicate :: forall sh sh' a. Elt a
=> IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
@@ -853,21 +959,21 @@ mreplicate sh arr =
Refl -> X.replicate sh (ssxAppend ssh' sshT))
arr
-mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
-mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
+mreplicatePrimP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
+mreplicatePrimP sh x = M_Primitive sh (X.replicateScal sh x)
-mreplicateScal :: forall sh a. PrimElt a
+mreplicatePrim :: forall sh a. PrimElt a
=> IShX sh -> a -> Mixed sh a
-mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
+mreplicatePrim sh x = fromPrimitive (mreplicatePrimP sh x)
-mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
-mslice i n arr =
+msliceN :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
+msliceN i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr
+
+msliceSN :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
+msliceSN i n arr =
let _ :$% sh = mshape arr
in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr
-msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
-msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr
-
mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr