diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-17 22:16:11 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-17 22:21:31 +0100 |
| commit | ba798129655503d7e69de271d956cceaef4cef56 (patch) | |
| tree | a906aaf03581060307a18d8347b8e4e5cef82a54 /src/Data/Array/Nested/Mixed.hs | |
| parent | 0766e22df98179ce7debb179e544716bccfbca24 (diff) | |
Provide explicit-length versions of fromList functions
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 105 |
1 files changed, 76 insertions, 29 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 6b152f7..a2787b8 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -7,6 +7,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -307,15 +308,9 @@ class Elt a where mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a mscalar :: a -> Mixed '[] a - -- | All arrays in the list, even subarrays inside @a@, must have the same - -- shape; if they do not, a runtime error will be thrown. See the - -- documentation of 'mgenerate' for more information about this restriction. - -- Furthermore, the length of the list must correspond with @n@: if @n@ is - -- @Just m@ and @m@ does not equal the length of the list, a runtime error is - -- thrown. - -- - -- Consider also 'mfromListPrim', which can avoid intermediate arrays. - mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a + -- | See 'mfromListOuter'. If the list does not have the given length, a + -- runtime error is thrown. 'mfromListPrimSN' is faster if applicable. + mfromListOuterSN :: forall sh n. SNat n -> NonEmpty (Mixed sh a) -> Mixed (Just n : sh) a mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] @@ -407,8 +402,8 @@ instance Storable a => Elt (Primitive a) where mindex (M_Primitive _ a) i = Primitive (X.index a i) mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) - mfromListOuter l@(arr1 :| _) = - let sh = SUnknown (length l) :$% mshape arr1 + mfromListOuterSN sn l@(arr1 :| _) = + let sh = SKnown sn :$% mshape arr1 in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l))) mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) @@ -515,9 +510,9 @@ instance (Elt a, Elt b) => Elt (a, b) where mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) - mfromListOuter l = - M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) - (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) + mfromListOuterSN sn l = + M_Tup2 (mfromListOuterSN sn ((\(M_Tup2 x _) -> x) <$> l)) + (mfromListOuterSN sn ((\(M_Tup2 _ y) -> y) <$> l)) mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) @@ -578,10 +573,9 @@ instance Elt a => Elt (Mixed sh' a) where mscalar = M_Nest ZSX - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) - mfromListOuter l@(arr :| _) = - M_Nest (SUnknown (length l) :$% mshape arr) - (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) + mfromListOuterSN sn l@(arr :| _) = + M_Nest (SKnown sn :$% mshape arr) + (mfromListOuterSN sn ((\(M_Nest _ a) -> a) <$> l)) mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) @@ -793,23 +787,76 @@ mtoVectorP (M_Primitive _ v) = X.toVector v mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a mtoVector arr = mtoVectorP (toPrimitive arr) +-- | All arrays in the list, even subarrays inside @a@, must have the same +-- shape; if they do not, a runtime error will be thrown. See the +-- documentation of 'mgenerate' for more information about this restriction. +-- +-- Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'mfromListOuterN' or 'mfromListOuterSN' to be able to +-- stream the list. +-- +-- If your array is 1-dimensional and contains scalars, use 'mfromList1Prim'. +mfromListOuter :: Elt a => NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a +mfromListOuter l = mfromListOuterN (length l) l + +-- | See 'mfromListOuter'. If the list does not have the given length, a +-- runtime error is thrown. 'mfromList1PrimN' is faster if applicable. +mfromListOuterN :: Elt a => Int -> NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a +mfromListOuterN n l = + withSomeSNat (fromIntegral n) $ \case + Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromListOuterSN sn l) + Nothing -> error $ "mfromListOuterN: length negative (" ++ show n ++ ")" + +-- | Because the length of the 'NonEmpty' list is unknown, its spine must be +-- materialised in memory in order to compute its length. If its length is +-- already known, use 'mfromList1N' or 'mfromList1SN' to be able to stream the +-- list. +-- +-- If the elements are scalars, 'mfromList1Prim' is faster. mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a -mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? +mfromList1 = mfromListOuter . fmap mscalar + +-- | If the elements are scalars, 'mfromList1PrimN' is faster. A runtime error +-- is thrown if the list length does not match the given length. +mfromList1N :: Elt a => Int -> NonEmpty a -> Mixed '[Nothing] a +mfromList1N n = mfromListOuterN n . fmap mscalar + +-- | If the elements are scalars, 'mfromList1PrimSN' is faster. A runtime error +-- is thrown if the list length does not match the given length. +mfromList1SN :: Elt a => SNat n -> NonEmpty a -> Mixed '[Just n] a +mfromList1SN sn = mfromListOuterSN sn . fmap mscalar -- This forall is there so that a simple type application can constrain the -- shape, in case the user wants to use OverloadedLists for the shape. +-- | If the elements are scalars, 'mfromListPrimLinear' is faster. mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a -mfromListLinear sh l = mreshape sh (mfromList1 l) +mfromListLinear sh l = mreshape sh (mfromList1N (shxSize sh) l) -mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromListPrim l = +-- | Because the length of the list is unknown, its spine must be materialised +-- in memory in order to compute its length. If its length is already known, +-- use 'mfromList1PrimN' or 'mfromList1PrimSN' to be able to stream the list. +mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a +mfromList1Prim l = let ssh = SUnknown () :!% ZKX xarr = X.fromList1 ssh l in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr -mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a +mfromList1PrimN :: PrimElt a => Int -> [a] -> Mixed '[Nothing] a +mfromList1PrimN n l = + withSomeSNat (fromIntegral n) $ \case + Just sn -> mcastPartial (SKnown sn :!% ZKX) (SUnknown () :!% ZKX) Proxy (mfromList1PrimSN sn l) + Nothing -> error $ "mfromList1PrimN: length negative (" ++ show n ++ ")" + +mfromList1PrimSN :: PrimElt a => SNat n -> [a] -> Mixed '[Just n] a +mfromList1PrimSN sn l = + let ssh = SKnown sn :!% ZKX + xarr = X.fromList1 ssh l + in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +mfromListPrimLinear :: forall sh a. PrimElt a => IShX sh -> [a] -> Mixed sh a mfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + let M_Primitive _ xarr = toPrimitive (mfromList1PrimN (shxSize sh) l) in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) mtoList :: Elt a => Mixed '[n] a -> [a] @@ -872,14 +919,14 @@ mreplicateScal :: forall sh a. PrimElt a => IShX sh -> a -> Mixed sh a mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) -mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a -mslice i n arr = +msliceN :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a +msliceN i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr + +msliceSN :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a +msliceSN i n arr = let _ :$% sh = mshape arr in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr -msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr - mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr |
