diff options
Diffstat (limited to 'src/Data/Array/Nested/Shaped.hs')
| -rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 24 |
1 files changed, 17 insertions, 7 deletions
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 99ad590..36ef24a 100644 --- a/src/Data/Array/Nested/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -56,7 +56,7 @@ ssize = shsSize . sshape sindex :: Elt a => Shaped sh a -> IIxS sh -> a sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx) -shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh +shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IxS sh i -> ShS sh shsTakeIx _ _ ZIS = ZSS shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx @@ -70,7 +70,7 @@ sindexPartial sarr@(Shaped arr) idx = -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a -sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh)) +sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX)) -- | See 'mgeneratePrim'. {-# INLINE sgeneratePrim #-} @@ -81,6 +81,7 @@ sgeneratePrim sh f = in sfromVector sh $ VS.generate (shsSize sh) g -- | See the documentation of 'mlift'. +{-# INLINE slift #-} slift :: forall sh1 sh2 a. Elt a => ShS sh2 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) @@ -88,23 +89,28 @@ slift :: forall sh1 sh2 a. Elt a slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr) -- | See the documentation of 'mlift'. +{-# INLINE slift2 #-} slift2 :: forall sh1 sh2 sh3 a. Elt a => ShS sh3 -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2) +{-# INLINE ssumOuter1PrimP #-} ssumOuter1PrimP :: forall sh n a. (Storable a, NumElt a) => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) ssumOuter1PrimP (Shaped arr) = Shaped (msumOuter1PrimP arr) +{-# INLINEABLE ssumOuter1Prim #-} ssumOuter1Prim :: forall sh n a. (NumElt a, PrimElt a) => Shaped (n : sh) a -> Shaped sh a ssumOuter1Prim = sfromPrimitive . ssumOuter1PrimP . stoPrimitive +{-# INLINE ssumAllPrimP #-} ssumAllPrimP :: (PrimElt a, NumElt a) => Shaped n (Primitive a) -> a ssumAllPrimP (Shaped arr) = msumAllPrimP arr +{-# INLINE ssumAllPrim #-} ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a ssumAllPrim (Shaped arr) = msumAllPrim arr @@ -124,15 +130,19 @@ sappend = coerce mappend sscalar :: Elt a => a -> Shaped '[] a sscalar x = Shaped (mscalar x) +{-# INLINEABLE sfromVectorP #-} sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v) +{-# INLINEABLE sfromVector #-} sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a sfromVector sh v = sfromPrimitive (sfromVectorP sh v) +{-# INLINEABLE stoVectorP #-} stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a stoVectorP = coerce mtoVectorP +{-# INLINEABLE stoVector #-} stoVector :: PrimElt a => Shaped sh a -> VS.Vector a stoVector = coerce mtoVector @@ -246,21 +256,20 @@ sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shape sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr) sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a -sflatten arr = - case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff - n@SNat -> sreshape (n :$$ ZSS) arr +sflatten arr = sreshape (shsProduct (sshape arr) :$$ ZSS) arr siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a siota sn = Shaped (miota sn) -- | Throws if the array is empty. sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr) +sminIndexPrim (Shaped arr) = ixsFromIxX (mminIndexPrim arr) -- | Throws if the array is empty. smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) +smaxIndexPrim (Shaped arr) = ixsFromIxX (mmaxIndexPrim arr) +{-# INLINEABLE sdot1Inner #-} sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) @@ -272,6 +281,7 @@ sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) _ -> error "unreachable" +{-# INLINE sdot #-} -- | This has a temporary, suboptimal implementation in terms of 'mflatten'. -- Prefer 'sdot1Inner' if applicable. sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a |
