diff options
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 260 |
1 files changed, 183 insertions, 77 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 54f8fe6..182943d 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -7,6 +7,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -91,6 +92,9 @@ import Data.Bag -- Unfortunately, the setup of the library requires us to list these primitive -- element types multiple times; to aid in extending the list, all these lists -- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. +-- +-- NOTE: if you add PRIMITIVE types, be sure to also add NumElt and IntElt +-- instances for them! -- | Wrapper type used as a tag to attach instances on. The instances on arrays @@ -118,6 +122,8 @@ instance PrimElt Bool instance PrimElt Int instance PrimElt Int64 instance PrimElt Int32 +instance PrimElt Int16 +instance PrimElt Int8 instance PrimElt CInt instance PrimElt Float instance PrimElt Double @@ -154,6 +160,8 @@ newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic ANDSHOW) newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic ANDSHOW) newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Int16 = M_Int16 (Mixed sh (Primitive Int16)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Int8 = M_Int8 (Mixed sh (Primitive Int8)) deriving (Eq, Ord, Generic ANDSHOW) newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic ANDSHOW) newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic ANDSHOW) newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic ANDSHOW) @@ -190,6 +198,8 @@ newtype instance MixedVecs s sh Bool = MV_Bool (VS.MVector s Bool) newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) +newtype instance MixedVecs s sh Int16 = MV_Int16 (VS.MVector s Int16) +newtype instance MixedVecs s sh Int8 = MV_Int8 (VS.MVector s Int8) newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) @@ -247,15 +257,15 @@ instance (NumElt a, PrimElt a) => Num (Mixed sh a) where abs = mliftNumElt1 (liftO1 . numEltAbs) signum = mliftNumElt1 (liftO1 . numEltSignum) -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS - fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal" + fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicatePrim" instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where - fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicatePrim" recip = mliftNumElt1 (liftO1 . floatEltRecip) (/) = mliftNumElt2 (liftO2 . floatEltDiv) instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" + pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicatePrim" exp = mliftNumElt1 (liftO1 . floatEltExp) log = mliftNumElt1 (liftO1 . floatEltLog) sqrt = mliftNumElt1 (liftO1 . floatEltSqrt) @@ -298,15 +308,9 @@ class Elt a where mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a mscalar :: a -> Mixed '[] a - -- | All arrays in the list, even subarrays inside @a@, must have the same - -- shape; if they do not, a runtime error will be thrown. See the - -- documentation of 'mgenerate' for more information about this restriction. - -- Furthermore, the length of the list must correspond with @n@: if @n@ is - -- @Just m@ and @m@ does not equal the length of the list, a runtime error is - -- thrown. - -- - -- Consider also 'mfromListPrim', which can avoid intermediate arrays. - mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a + -- | See 'mfromListOuter'. If the list does not have the given length, a + -- runtime error is thrown. 'mfromListPrimSN' is faster if applicable. + mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] @@ -355,7 +359,7 @@ class Elt a where mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool - mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool + mshapeTreeIsEmpty :: Proxy a -> ShapeTree a -> Bool mshowShapeTree :: Proxy a -> ShapeTree a -> String @@ -380,9 +384,7 @@ class Elt a where -- of this class with those of 'Elt': some instances have an additional -- "known-shape" constraint. -- --- This class is (currently) only required for 'mgenerate', --- 'Data.Array.Nested.Ranked.rgenerate' and --- 'Data.Array.Nested.Shaped.sgenerate'. +-- This class is (currently) only required for `memptyArray` and 'mgenerate'. class Elt a => KnownElt a where -- | Create an empty array. The given shape must have size zero; this may or may not be checked. memptyArrayUnsafe :: IShX sh -> Mixed sh a @@ -397,11 +399,13 @@ class Elt a => KnownElt a where -- Arrays of scalars are basically just arrays of scalars. instance Storable a => Elt (Primitive a) where mshape (M_Primitive sh _) = sh + {-# INLINEABLE mindex #-} mindex (M_Primitive _ a) i = Primitive (X.index a i) - mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) + {-# INLINEABLE mindexPartial #-} + mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) - mfromListOuter l@(arr1 :| _) = - let sh = SUnknown (length l) :$% mshape arr1 + mfromListOuterSN sn l@(arr1 :| _) = + let sh = SKnown sn :$% mshape arr1 in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l))) mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) @@ -440,7 +444,7 @@ instance Storable a => Elt (Primitive a) where => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' - sh2 = shxCast' sh1 ssh2 + sh2 = shxCast' ssh2 sh1 in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr) mtranspose perm (M_Primitive sh arr) = @@ -457,7 +461,7 @@ instance Storable a => Elt (Primitive a) where type ShapeTree (Primitive a) = () mshapeTree _ = () mshapeTreeEq _ () () = True - mshapeTreeEmpty _ () = False + mshapeTreeIsEmpty _ () = False mshowShapeTree _ () = "()" marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x @@ -478,6 +482,8 @@ deriving via Primitive Bool instance Elt Bool deriving via Primitive Int instance Elt Int deriving via Primitive Int64 instance Elt Int64 deriving via Primitive Int32 instance Elt Int32 +deriving via Primitive Int16 instance Elt Int16 +deriving via Primitive Int8 instance Elt Int8 deriving via Primitive CInt instance Elt CInt deriving via Primitive Double instance Elt Double deriving via Primitive Float instance Elt Float @@ -493,6 +499,8 @@ deriving via Primitive Bool instance KnownElt Bool deriving via Primitive Int instance KnownElt Int deriving via Primitive Int64 instance KnownElt Int64 deriving via Primitive Int32 instance KnownElt Int32 +deriving via Primitive Int16 instance KnownElt Int16 +deriving via Primitive Int8 instance KnownElt Int8 deriving via Primitive CInt instance KnownElt CInt deriving via Primitive Double instance KnownElt Double deriving via Primitive Float instance KnownElt Float @@ -504,9 +512,9 @@ instance (Elt a, Elt b) => Elt (a, b) where mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) - mfromListOuter l = - M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) - (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) + mfromListOuterSN sn l = + M_Tup2 (mfromListOuterSN sn ((\(M_Tup2 x _) -> x) <$> l)) + (mfromListOuterSN sn ((\(M_Tup2 _ y) -> y) <$> l)) mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) @@ -531,7 +539,7 @@ instance (Elt a, Elt b) => Elt (a, b) where type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) mshapeTree (x, y) = (mshapeTree x, mshapeTree y) mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' - mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 + mshapeTreeIsEmpty _ (t1, t2) = mshapeTreeIsEmpty (Proxy @a) t1 && mshapeTreeIsEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b mvecsWrite sh i (x, y) (MV_Tup2 a b) = do @@ -557,20 +565,19 @@ instance Elt a => Elt (Mixed sh' a) where = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr)) mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a - mindex (M_Nest _ arr) i = mindexPartial arr i + mindex (M_Nest _ arr) = mindexPartial arr mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest sh arr) i | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) mscalar = M_Nest ZSX - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) - mfromListOuter l@(arr :| _) = - M_Nest (SUnknown (length l) :$% mshape arr) - (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) + mfromListOuterSN sn l@(arr :| _) = + M_Nest (SKnown sn :$% mshape arr) + (mfromListOuterSN sn ((\(M_Nest _ a) -> a) <$> l)) mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) @@ -632,14 +639,14 @@ instance Elt a => Elt (Mixed sh' a) where | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T - sh2 = shxCast' sh1 ssh2 + sh2 = shxCast' ssh2 sh1 in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr) mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) => Perm is -> Mixed sh (Mixed sh' a) -> Mixed (PermutePrefix is sh) (Mixed sh' a) mtranspose perm (M_Nest sh arr) - | let sh' = shxDropSh @sh @sh' (mshape arr) sh + | let sh' = shxDropSh @sh @sh' sh (mshape arr) , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh') , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') @@ -662,7 +669,8 @@ instance Elt a => Elt (Mixed sh' a) where mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + -- the array is empty if either there are no subarrays, or the subarrays themselves are empty + mshapeTreeIsEmpty _ (sh, t) = shxSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" @@ -692,7 +700,8 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) -memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a +-- TODO: should we provide a function that's just memptyArrayUnsafe but with a size==0 check? That may save someone a transpose somewhere +memptyArray :: forall sh a. KnownElt a => IShX sh -> Mixed (Just 0 : sh) a memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) mrank :: Elt a => Mixed sh a -> SNat (Rank sh) @@ -719,19 +728,19 @@ msize = shxSize . mshape -- the entire hierarchy (after distributing out tuples) must be a rectangular -- array. The type of 'mgenerate' allows this requirement to be broken very -- easily, hence the runtime check. +-- +-- If your element type @a@ is a scalar, use the faster 'mgeneratePrim'. mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a mgenerate sh f = case shxEnum sh of [] -> memptyArrayUnsafe sh firstidx : restidxs -> let firstelem = f (ixxZero' sh) shapetree = mshapeTree firstelem - in if mshapeTreeEmpty (Proxy @a) shapetree + in if mshapeTreeIsEmpty (Proxy @a) shapetree then memptyArrayUnsafe sh else runST $ do vecs <- mvecsUnsafeNew sh firstelem mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. forM_ restidxs $ \idx -> do let val = f idx when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ @@ -739,18 +748,32 @@ mgenerate sh f = case shxEnum sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs -msumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) -msumOuter1P (M_Primitive (n :$% sh) arr) = +-- | An optimized special case of 'mgenerate', where the function results +-- are of a primitive type and so there's not need to check that all shapes +-- are equal. This is also generalized to an arbitrary @Num@ index type +-- compared to @mgenerate@. +{-# INLINE mgeneratePrim #-} +mgeneratePrim :: forall sh a i. (PrimElt a, Num i) + => IShX sh -> (IxX sh i -> a) -> Mixed sh a +mgeneratePrim sh f = + let g i = f (ixxFromLinear sh i) + in mfromVector sh $ VS.generate (shxSize sh) g + +msumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) + => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) +msumOuter1PrimP (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr) -msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Mixed (n : sh) a -> Mixed sh a -msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive +msumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a) + => Mixed (n : sh) a -> Mixed sh a +msumOuter1Prim = fromPrimitive . msumOuter1PrimP @sh @n @a . toPrimitive + +msumAllPrimP :: (Storable a, NumElt a) => Mixed sh (Primitive a) -> a +msumAllPrimP (M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a -msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr +msumAllPrim arr = msumAllPrimP (toPrimitive arr) mappend :: forall n m sh a. Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a @@ -781,23 +804,76 @@ mtoVectorP (M_Primitive _ v) = X.toVector v mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a mtoVector arr = mtoVectorP (toPrimitive arr) +-- | All arrays in the list, even subarrays inside @a@, must have the same +-- shape; if they do not, a runtime error will be thrown. See the +-- documentation of 'mgenerate' for more information about this restriction. +-- +-- Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'mfromListOuterN' or 'mfromListOuterSN' to be able to +-- stream the list. +-- +-- If your array is 1-dimensional and contains scalars, use 'mfromList1Prim'. +mfromListOuter :: Elt a => NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a +mfromListOuter l = mfromListOuterN (length l) l + +-- | See 'mfromListOuter'. If the list does not have the given length, a +-- runtime error is thrown. 'mfromList1PrimN' is faster if applicable. +mfromListOuterN :: Elt a => Int -> NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a +mfromListOuterN n l = + withSomeSNat (fromIntegral n) $ \case + Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromListOuterSN sn l) + Nothing -> error $ "mfromListOuterN: length negative (" ++ show n ++ ")" + +-- | Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'mfromList1N' or 'mfromList1SN' to be able to stream the +-- list. +-- +-- If the elements are scalars, 'mfromList1Prim' is faster. mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a -mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? +mfromList1 = mfromListOuter . fmap mscalar + +-- | If the elements are scalars, 'mfromList1PrimN' is faster. A runtime error +-- is thrown if the list length does not match the given length. +mfromList1N :: Elt a => Int -> NonEmpty a -> Mixed '[Nothing] a +mfromList1N n = mfromListOuterN n . fmap mscalar + +-- | If the elements are scalars, 'mfromList1PrimSN' is faster. A runtime error +-- is thrown if the list length does not match the given length. +mfromList1SN :: Elt a => SNat n -> NonEmpty a -> Mixed '[Just n] a +mfromList1SN sn = mfromListOuterSN sn . fmap mscalar -- This forall is there so that a simple type application can constrain the -- shape, in case the user wants to use OverloadedLists for the shape. +-- | If the elements are scalars, 'mfromListPrimLinear' is faster. mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a -mfromListLinear sh l = mreshape sh (mfromList1 l) +mfromListLinear sh l = mreshape sh (mfromList1N (shxSize sh) l) -mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromListPrim l = +-- | Because the length of the list is unknown, its spine must be materialised +-- in memory in order to compute its length. If its length is already known, +-- use 'mfromList1PrimN' or 'mfromList1PrimSN' to be able to stream the list. +mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a +mfromList1Prim l = let ssh = SUnknown () :!% ZKX xarr = X.fromList1 ssh l in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr -mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a +mfromList1PrimN :: PrimElt a => Int -> [a] -> Mixed '[Nothing] a +mfromList1PrimN n l = + withSomeSNat (fromIntegral n) $ \case + Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromList1PrimSN sn l) + Nothing -> error $ "mfromList1PrimN: length negative (" ++ show n ++ ")" + +mfromList1PrimSN :: PrimElt a => SNat n -> [a] -> Mixed '[Just n] a +mfromList1PrimSN sn l = + let ssh = SKnown sn :!% ZKX + xarr = X.fromList1 ssh l + in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +mfromListPrimLinear :: forall sh a. PrimElt a => IShX sh -> [a] -> Mixed sh a mfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + let M_Primitive _ xarr = toPrimitive (mfromList1PrimN (shxSize sh) l) in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) mtoList :: Elt a => Mixed '[n] a -> [a] @@ -824,24 +900,54 @@ mzip a b munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b) munzip (M_Tup2 a b) = (a, b) -mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) - -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) -mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shxDropSSX sh ssh - in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) - (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2) - (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) - arr) +mrerankPrimP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => IShX sh2 + -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) + -> Mixed sh (Mixed sh1 (Primitive a)) -> Mixed sh (Mixed sh2 (Primitive b)) +mrerankPrimP sh2 f (M_Nest sh (M_Primitive shsh1 arr)) = + let sh1 = shxDropSh sh shsh1 + in M_Nest sh $ + M_Primitive (shxAppend sh sh2) + (X.rerank (ssxFromShX sh) (ssxFromShX sh1) (ssxFromShX sh2) + (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) + arr) --- | See the caveats at @X.rerank@. -mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => StaticShX sh -> IShX sh2 - -> (Mixed sh1 a -> Mixed sh2 b) - -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b -mrerank ssh sh2 f (toPrimitive -> arr) = - fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr +-- | If the shape of the outer array (@sh@) is empty (i.e. contains a zero), +-- then there is no way to deduce the full shape of the output array (more +-- precisely, the @sh2@ part): that could only come from calling @f@, and there +-- are no subarrays to call @f@ on. @orthotope@ errors out in this case; we +-- choose to fill the shape with zeros wherever we cannot deduce what it should +-- be. +-- +-- For example, if: +-- +-- @ +-- -- arr has shape [3, 0, 4] and the inner arrays have shape [2, 21] +-- arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 2, Nothing] Int) +-- f :: Mixed '[Just 2, Nothing] Int -> Mixed '[Just 5, Nothing, Just 17] Float +-- @ +-- +-- then: +-- +-- @ +-- mrerankPrim _ f arr :: Mixed '[Just 3, Just 0, Just 4] (Mixed '[Just 5, Nothing, Just 17] Float) +-- @ +-- +-- and the inner arrays of the result will have shape @[5, 0, 17]@. Note the +-- @0@ in this shape: we don't know if @f@ intended to return an array with +-- shape 0 here (it probably didn't), but there is no better number to put here +-- absent a subarray of the input to pass to @f@. +-- +-- In this particular case the fact that @sh@ is empty was evident from the +-- type-level information, but the same situation occurs when @sh@ consists of +-- @Nothing@s, and some of those happen to be zero at runtime. +mrerankPrim :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => IShX sh2 + -> (Mixed sh1 a -> Mixed sh2 b) + -> Mixed sh (Mixed sh1 a) -> Mixed sh (Mixed sh2 b) +mrerankPrim sh2 f (M_Nest sh arr) = + let M_Nest sh' arr' = mrerankPrimP sh2 (toPrimitive . f . fromPrimitive) (M_Nest sh (toPrimitive arr)) + in M_Nest sh' (fromPrimitive arr') mreplicate :: forall sh sh' a. Elt a => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a @@ -853,21 +959,21 @@ mreplicate sh arr = Refl -> X.replicate sh (ssxAppend ssh' sshT)) arr -mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) -mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) +mreplicatePrimP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) +mreplicatePrimP sh x = M_Primitive sh (X.replicateScal sh x) -mreplicateScal :: forall sh a. PrimElt a +mreplicatePrim :: forall sh a. PrimElt a => IShX sh -> a -> Mixed sh a -mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) +mreplicatePrim sh x = fromPrimitive (mreplicatePrimP sh x) -mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a -mslice i n arr = +msliceN :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a +msliceN i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr + +msliceSN :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a +msliceSN i n arr = let _ :$% sh = mshape arr in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr -msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr - mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr |
