diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-05-20 22:55:59 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-20 22:55:59 +0200 | 
| commit | beec5f9ab3f37da78bb95e2631871b4e47b46c57 (patch) | |
| tree | 8efeae83646526ec7bdeaeb1674fa62f3cea2676 /src | |
| parent | 8d495b7e6c21fc843f0538711c2203dfb213b7e1 (diff) | |
Better {from,to}List set
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Mixed.hs | 28 | ||||
| -rw-r--r-- | src/Data/Array/Nested.hs | 8 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 123 | 
3 files changed, 102 insertions, 57 deletions
| diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 2f23903..1dc6b58 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -728,18 +728,30 @@ sumOuter ssh ssh'    | Refl <- lemAppNil @sh    = sumInner ssh' ssh . transpose2 ssh ssh' -fromList1 :: forall n sh a. Storable a -          => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a -fromList1 ssh l +fromListOuter :: forall n sh a. Storable a +              => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a +fromListOuter ssh l    | Dict <- lemKnownNatRankSSX ssh    = case ssh of -      SKnown m@SNat :!% _ | natVal m /= fromIntegral (length l) -> -        error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ -                "does not match the type (" ++ show (natVal m) ++ ")" +      SKnown m :!% _ | fromSNat' m /= length l -> +        error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ +                "does not match the type (" ++ show (fromSNat' m) ++ ")"        _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) -toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a] -toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr)) +toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] +toListOuter (XArray arr) = coerce (ORB.toList (S.unravel arr)) + +fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a +fromList1 ssh l = +  let n = length l +  in case ssh of +       SKnown m :!% _ | fromSNat' m /= n -> +         error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ +                 "does not match the type (" ++ show (fromSNat' m) ++ ")" +       _ -> XArray (S.fromVector [n] (VS.fromListN n l)) + +toList1 :: Storable a => XArray '[n] a -> [a] +toList1 (XArray arr) = S.toList arr  -- | Throws if the given shape is not, in fact, empty.  empty :: forall sh a. Storable a => IShX sh -> XArray sh a diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 7298918..e804748 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -9,7 +9,7 @@ module Data.Array.Nested (    rshape, rindex, rindexPartial, rgenerate, rsumOuter1,    rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar,    rrerank, -  rreplicate, rfromList, rfromList1, rtoList, rtoList1, +  rreplicate, rfromListOuter, rfromList1, rtoListOuter, rtoList1,    rslice, rrev1, rreshape,    -- ** Lifting orthotope operations to 'Ranked' arrays    rlift, @@ -26,7 +26,7 @@ module Data.Array.Nested (    sshape, sindex, sindexPartial, sgenerate, ssumOuter1,    stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,    srerank, -  sreplicate, sfromList, sfromList1, stoList, stoList1, +  sreplicate, sfromListOuter, sfromList1, stoListOuter, stoList1,    sslice, srev1, sreshape,    -- ** Lifting orthotope operations to 'Shaped' arrays    slift, @@ -40,13 +40,13 @@ module Data.Array.Nested (    KnownShX(..), StaticShX(..),    mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar,    mrerank, -  mreplicate, mfromList, mtoList, mslice, mrev1, mreshape, +  mreplicate, mfromList1, mtoList1, mslice, mrev1, mreshape,    -- ** Conversions    masXArrayPrim, mfromXArrayPrim,    mtoRanked, mcastToShaped,    -- * Array elements -  Elt(mshape, mindex, mindexPartial, mscalar, mfromList1, mtoList1, mlift, mlift2), +  Elt(mshape, mindex, mindexPartial, mscalar, mfromListOuter, mtoListOuter, mlift, mlift2),    PrimElt,    Primitive(..), diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index feb0662..ab67dcc 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -529,11 +529,10 @@ class Elt a where    -- @Just m@ and @m@ does not equal the length of the list, a runtime error is    -- thrown.    -- -  -- If you want a single-dimensional array from your list, map 'mscalar' -  -- first. -  mfromList1 :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a +  -- Consider also 'mfromListPrim', which can avoid intermediate arrays. +  mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a -  mtoList1 :: Mixed (n : sh) a -> [Mixed sh a] +  mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]    -- | Note: this library makes no particular guarantees about the shapes of    -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the @@ -603,10 +602,10 @@ instance Storable a => Elt (Primitive a) where    mindex (M_Primitive _ a) i = Primitive (X.index a i)    mindexPartial (M_Primitive sh a) i = M_Primitive (X.shDropIx sh i) (X.indexPartial a i)    mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) -  mfromList1 l@(arr1 :| _) = +  mfromListOuter l@(arr1 :| _) =      let sh = SUnknown (length l) :$% mshape arr1 -    in M_Primitive sh (X.fromList1 (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l))) -  mtoList1 (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toList1 arr) +    in M_Primitive sh (X.fromListOuter (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l))) +  mtoListOuter (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toListOuter arr)    mlift :: forall sh1 sh2.             StaticShX sh2 @@ -679,10 +678,10 @@ instance (Elt a, Elt b) => Elt (a, b) where    mindex (M_Tup2 a b) i = (mindex a i, mindex b i)    mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)    mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) -  mfromList1 l = -    M_Tup2 (mfromList1 ((\(M_Tup2 x _) -> x) <$> l)) -           (mfromList1 ((\(M_Tup2 _ y) -> y) <$> l)) -  mtoList1 (M_Tup2 a b) = zipWith M_Tup2 (mtoList1 a) (mtoList1 b) +  mfromListOuter l = +    M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) +           (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) +  mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)    mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)    mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) @@ -728,12 +727,12 @@ instance Elt a => Elt (Mixed sh' a) where    mscalar = M_Nest ZSX -  mfromList1 :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) -  mfromList1 l@(arr :| _) = +  mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) +  mfromListOuter l@(arr :| _) =      M_Nest (SUnknown (length l) :$% mshape arr) -           (mfromList1 ((\(M_Nest _ a) -> a) <$> l)) +           (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) -  mtoList1 (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoList1 arr) +  mtoListOuter (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoListOuter arr)    mlift :: forall sh1 sh2.             StaticShX sh2 @@ -890,11 +889,22 @@ mtoVectorP (M_Primitive _ v) = X.toVector v  mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a  mtoVector arr = mtoVectorP (coerce toPrimitive arr) -mfromList :: Elt a => NonEmpty a -> Mixed '[Nothing] a -mfromList = mfromList1 . fmap mscalar +mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a +mfromList1 = mfromListOuter . fmap mscalar  -- TODO: optimise? -mtoList :: Elt a => Mixed '[n] a -> [a] -mtoList = map munScalar . mtoList1 +mtoList1 :: Elt a => Mixed '[n] a -> [a] +mtoList1 = map munScalar . mtoListOuter + +mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a +mfromListPrim l = +  let ssh = SUnknown () :!% ZKX +      xarr = X.fromList1 ssh l +  in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a +mfromListPrimLinear sh l = +  let M_Primitive _ xarr = toPrimitive (mfromListPrim l) +  in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)  munScalar :: Elt a => Mixed '[] a -> a  munScalar arr = mindex arr ZIX @@ -1045,12 +1055,12 @@ instance Elt a => Elt (Ranked n a) where    mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) -  mfromList1 :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) -  mfromList1 l = M_Ranked (mfromList1 (coerce l)) +  mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) +  mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) -  mtoList1 :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] -  mtoList1 (M_Ranked arr) = -    coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoList1 arr) +  mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] +  mtoListOuter (M_Ranked arr) = +    coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)    mlift :: forall sh1 sh2.             StaticShX sh2 @@ -1143,12 +1153,12 @@ instance Elt a => Elt (Shaped sh a) where    mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) -  mfromList1 :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) -  mfromList1 l = M_Shaped (mfromList1 (coerce l)) +  mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) +  mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) -  mtoList1 :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] -  mtoList1 (M_Shaped arr) -    = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoList1 arr) +  mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] +  mtoListOuter (M_Shaped arr) +    = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)    mlift :: forall sh1 sh2.             StaticShX sh2 @@ -1387,21 +1397,32 @@ rtoVectorP = coerce mtoVectorP  rtoVector :: PrimElt a => Ranked n a -> VS.Vector a  rtoVector = coerce mtoVector -rfromList1 :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a -rfromList1 l +rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a +rfromListOuter l    | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n -  = Ranked (mfromList1 (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) +  = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) -rfromList :: Elt a => NonEmpty a -> Ranked 1 a -rfromList l = Ranked (mfromList l) +rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a +rfromList1 l = Ranked (mfromList1 l) -rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] -rtoList (Ranked arr) +rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoListOuter (Ranked arr)    | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n -  = coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr) +  = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)  rtoList1 :: Elt a => Ranked 1 a -> [a] -rtoList1 = map runScalar . rtoList +rtoList1 = map runScalar . rtoListOuter + +rfromListPrim :: PrimElt a => [a] -> Ranked 1 a +rfromListPrim l = +  let ssh = SUnknown () :!% ZKX +      xarr = X.fromList1 ssh l +  in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a +rfromListPrimLinear sh l = +  let M_Primitive _ xarr = toPrimitive (mfromListPrim l) +  in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr)  rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a  rfromOrthotope sn arr @@ -1651,17 +1672,29 @@ stoVectorP = coerce mtoVectorP  stoVector :: PrimElt a => Shaped sh a -> VS.Vector a  stoVector = coerce mtoVector -sfromList1 :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromList1 sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromList1 (coerce l)) +sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l)) -sfromList :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a -sfromList sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList +sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a +sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1 -stoList :: Elt a => Shaped (n : sh) a -> [Shaped sh a] -stoList (Shaped arr) = coerce (mtoList1 arr) +stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] +stoListOuter (Shaped arr) = coerce (mtoListOuter arr)  stoList1 :: Elt a => Shaped '[n] a -> [a] -stoList1 = map sunScalar . stoList +stoList1 = map sunScalar . stoListOuter + +sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a +sfromListPrim sn l +  | Refl <- X.lemAppNil @'[Just n] +  = let ssh = SUnknown () :!% ZKX +        xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) +    in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr + +sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a +sfromListPrimLinear sh l = +  let M_Primitive _ xarr = toPrimitive (mfromListPrim l) +  in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr)  sunScalar :: Elt a => Shaped '[] a -> a  sunScalar arr = sindex arr ZIS | 
