aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Shaped.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-30 22:47:52 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-30 22:47:52 +0200
commit8b59d8ef4ff97936f2a753d1ce345e0404c26b2b (patch)
tree947f75cb43982fbdb551dc329f036b0591f3c2b2 /src/Data/Array/Nested/Shaped.hs
parentf0752d67cd188f438280e1f0c692dc1f5f14a190 (diff)
Clearer module purposes
Thanks Mikolaj for discussion
Diffstat (limited to 'src/Data/Array/Nested/Shaped.hs')
-rw-r--r--src/Data/Array/Nested/Shaped.hs379
1 files changed, 0 insertions, 379 deletions
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
deleted file mode 100644
index 934433e..0000000
--- a/src/Data/Array/Nested/Shaped.hs
+++ /dev/null
@@ -1,379 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-{-# LANGUAGE ImportQualifiedPost #-}
-{-# LANGUAGE InstanceSigs #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Nested.Shaped where
-
-import Prelude hiding (mappend)
-
-import Control.DeepSeq (NFData)
-import Control.Monad.ST
-import Data.Bifunctor (first)
-import Data.Coerce (coerce)
-import Data.Kind (Type)
-import Data.List.NonEmpty (NonEmpty)
-import Data.Proxy
-import Data.Type.Equality
-import Data.Vector.Storable qualified as VS
-import Foreign.Storable (Storable)
-import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp)
-import GHC.TypeLits
-
-import Data.Array.Mixed (XArray)
-import Data.Array.Mixed qualified as X
-import Data.Array.Mixed.Internal.Arith
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Types
-import Data.Array.Nested.Lemmas
-import Data.Array.Nested.Mixed
-import Data.Array.Nested.Shape
-
-
--- | A shape-typed array: the full shape of the array (the sizes of its
--- dimensions) is represented on the type level as a list of 'Nat's. Note that
--- these are "GHC.TypeLits" naturals, because we do not need induction over
--- them and we want very large arrays to be possible.
---
--- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
--- and 'Shaped' itself is again an instance of 'Elt' as well.
---
--- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
-type Shaped :: [Nat] -> Type -> Type
-newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
-deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a)
-deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
-deriving instance Ord (Mixed '[] a) => Ord (Shaped '[] a)
-deriving instance NFData (Mixed (MapJust sh) a) => NFData (Shaped sh a)
-
--- just unwrap the newtype and defer to the general instance for nested arrays
-newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))
-deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a))
-
-newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a))
-
-instance Elt a => Elt (Shaped sh a) where
- mshape (M_Shaped arr) = mshape arr
- mindex (M_Shaped arr) i = Shaped (mindex arr i)
-
- mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- mindexPartial (M_Shaped arr) i =
- coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mindexPartial arr i
-
- mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
-
- mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
- mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
-
- mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
- mtoListOuter (M_Shaped arr)
- = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
- mlift ssh2 f (M_Shaped arr) =
- coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mlift ssh2 f arr
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
- mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) =
- coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
- mlift2 ssh3 f arr1 arr2
-
- mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr)
-
- mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
-
- type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
-
- mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr)
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
- mvecsWrite sh idx (Shaped arr) vecs =
- mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
-
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
- (coerce @(Mixed sh2 (Shaped sh a))
- @(Mixed sh2 (Mixed (MapJust sh) a))
- arr)
- (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
- @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
- vecs)
-
- mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
- mvecsFreeze sh vecs =
- coerce @(Mixed sh' (Mixed (MapJust sh) a))
- @(Mixed sh' (Shaped sh a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh' (Shaped sh a))
- @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
-
-instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
- memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
- memptyArray i
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
- memptyArray i
-
- mvecsUnsafeNew idx (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = MV_Shaped <$> mvecsUnsafeNew idx arr
-
- mvecsNewEmpty _
- | Dict <- lemKnownMapJust (Proxy @sh)
- = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
-
-
-arithPromoteShaped :: forall sh a. PrimElt a
- => (forall shx. Mixed shx a -> Mixed shx a)
- -> Shaped sh a -> Shaped sh a
-arithPromoteShaped = coerce
-
-arithPromoteShaped2 :: forall sh a. PrimElt a
- => (forall shx. Mixed shx a -> Mixed shx a -> Mixed shx a)
- -> Shaped sh a -> Shaped sh a -> Shaped sh a
-arithPromoteShaped2 = coerce
-
-instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
- (+) = arithPromoteShaped2 (+)
- (-) = arithPromoteShaped2 (-)
- (*) = arithPromoteShaped2 (*)
- negate = arithPromoteShaped negate
- abs = arithPromoteShaped abs
- signum = arithPromoteShaped signum
- fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal"
-
-instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Shaped sh a) where
- fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
- recip = arithPromoteShaped recip
- (/) = arithPromoteShaped2 (/)
-
-instance (FloatElt a, NumElt a, PrimElt a) => Floating (Shaped sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
- exp = arithPromoteShaped exp
- log = arithPromoteShaped log
- sqrt = arithPromoteShaped sqrt
- (**) = arithPromoteShaped2 (**)
- logBase = arithPromoteShaped2 logBase
- sin = arithPromoteShaped sin
- cos = arithPromoteShaped cos
- tan = arithPromoteShaped tan
- asin = arithPromoteShaped asin
- acos = arithPromoteShaped acos
- atan = arithPromoteShaped atan
- sinh = arithPromoteShaped sinh
- cosh = arithPromoteShaped cosh
- tanh = arithPromoteShaped tanh
- asinh = arithPromoteShaped asinh
- acosh = arithPromoteShaped acosh
- atanh = arithPromoteShaped atanh
- log1p = arithPromoteShaped GHC.Float.log1p
- expm1 = arithPromoteShaped GHC.Float.expm1
- log1pexp = arithPromoteShaped GHC.Float.log1pexp
- log1mexp = arithPromoteShaped GHC.Float.log1mexp
-
-
-sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
-sshape (Shaped arr) = shCvtXS' (mshape arr)
-
-sindex :: Elt a => Shaped sh a -> IIxS sh -> a
-sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
-
-shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh
-shsTakeIx _ _ ZIS = ZSS
-shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx
-
-sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
-sindexPartial sarr@(Shaped arr) idx =
- Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
- (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr)
- (ixCvtSX idx))
-
--- | __WARNING__: All values returned from the function must have equal shape.
--- See the documentation of 'mgenerate' for more details.
-sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
-sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh))
-
--- | See the documentation of 'mlift'.
-slift :: forall sh1 sh2 a. Elt a
- => ShS sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
- -> Shaped sh1 a -> Shaped sh2 a
-slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr)
-
--- | See the documentation of 'mlift'.
-slift2 :: forall sh1 sh2 sh3 a. Elt a
- => ShS sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
- -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
-slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2)
-
-ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
-ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)
-
-ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Shaped (n : sh) a -> Shaped sh a
-ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
-
-stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a)
- => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a
-stranspose perm sarr@(Shaped arr)
- | Refl <- lemRankMapJust (sshape sarr)
- , Refl <- lemMapJustTakeLen perm (sshape sarr)
- , Refl <- lemMapJustDropLen perm (sshape sarr)
- , Refl <- lemMapJustPermute perm (shsTakeLen perm (sshape sarr))
- , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh))
- = Shaped (mtranspose perm arr)
-
-sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
-sappend = coerce mappend
-
-sscalar :: Elt a => a -> Shaped '[] a
-sscalar x = Shaped (mscalar x)
-
-sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
-sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v)
-
-sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a
-sfromVector sh v = sfromPrimitive (sfromVectorP sh v)
-
-stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
-stoVectorP = coerce mtoVectorP
-
-stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
-stoVector = coerce mtoVector
-
-sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l))
-
-sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
-sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1
-
-sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a
-sfromList1Prim sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1Prim
-
-stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
-stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
-
-stoList1 :: Elt a => Shaped '[n] a -> [a]
-stoList1 = map sunScalar . stoListOuter
-
-sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
-sfromListPrim sn l
- | Refl <- lemAppNil @'[Just n]
- = let ssh = SUnknown () :!% ZKX
- xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
- in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
-
-sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a
-sfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr)
-
-sunScalar :: Elt a => Shaped '[] a -> a
-sunScalar arr = sindex arr ZIS
-
-srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
- -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b)
-srerankP sh sh2 f sarr@(Shaped arr)
- | Refl <- lemMapJustApp sh (Proxy @sh1)
- , Refl <- lemMapJustApp sh (Proxy @sh2)
- = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh))))
- (shCvtSX sh2)
- (\a -> let Shaped r = f (Shaped a) in r)
- arr)
-
-srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 a -> Shaped sh2 b)
- -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b
-srerank sh sh2 f (stoPrimitive -> arr) =
- sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
-
-sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
-sreplicate sh (Shaped arr)
- | Refl <- lemMapJustApp sh (Proxy @sh')
- = Shaped (mreplicate (shCvtSX sh) arr)
-
-sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
-sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x)
-
-sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
-sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)
-
-sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
-sslice i n@SNat arr =
- let _ :$$ sh = sshape arr
- in slift (n :$$ sh) (\_ -> X.slice i n) arr
-
-srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a
-srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr
-
-sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a
-sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr)
-
-siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
-siota sn = Shaped (miota sn)
-
-sasXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
-sasXArrayPrimP (Shaped arr) = first shCvtXS' (masXArrayPrimP arr)
-
-sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a)
-sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr)
-
-sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a)
-sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr)
-
-sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
-sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr)
-
-sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a
-sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr)
-
-stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a)
-stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)
-
-mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
- => Mixed sh a -> ShS sh' -> Shaped sh' a
-mcastToShaped arr targetsh
- | Refl <- lemAppNil @sh
- , Refl <- lemAppNil @(MapJust sh')
- , Refl <- lemRankMapJust targetsh
- = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr)