aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r--src/Data/Array/Nested/Internal.hs115
1 files changed, 47 insertions, 68 deletions
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 7d98975..a3c2b6d 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -1390,28 +1390,19 @@ sindexPartial sarr@(Shaped arr) idx =
sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh))
-{-
-- | See the documentation of 'mlift'.
-slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)
- => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
+slift :: forall sh1 sh2 a. Elt a
+ => ShS sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
-> Shaped sh1 a -> Shaped sh2 a
-slift f (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh2)
- = Shaped (mlift f arr)
+slift sh2 f (Shaped arr) = Shaped (mlift (X.staticShapeFrom (shCvtSX sh2)) f arr)
-ssumOuter1P :: forall sh n a.
- (Storable a, Num a, KnownNat n, KnownShape sh)
+ssumOuter1P :: forall sh n a. (Storable a, Num a)
=> Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
-ssumOuter1P (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped
- . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a))
- . X.sumOuter (natSing @n :!$@ ZKSX) (knownShapeX @(MapJust sh))
- . coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a)
- $ arr
-
-ssumOuter1 :: forall sh n a.
- (Storable a, Num a, PrimElt a, KnownNat n, KnownShape sh)
+ssumOuter1P (Shaped (M_Primitive (SKnown sn :$% sh) arr)) =
+ Shaped (M_Primitive sh (X.sumOuter (SKnown sn :!% ZKX) (X.staticShapeFrom sh) arr))
+
+ssumOuter1 :: forall sh n a. (Storable a, Num a, PrimElt a)
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive
@@ -1457,30 +1448,27 @@ shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest
= shIndex p pT i sh rest
shIndex _ _ _ ZSS _ = error "Index into empty shape"
-stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a
-stranspose perm (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- , Refl <- lemRankMapJust (Proxy @sh)
- , Refl <- lemCommMapJustTakeLen perm (knownShape @sh)
- , Refl <- lemCommMapJustDropLen perm (knownShape @sh)
- , Refl <- lemCommMapJustPermute perm (shTakeLen perm (knownShape @sh))
- , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (knownShape @sh))) (Proxy @(X.DropLen is sh))
+stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a)
+ => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a
+stranspose perm sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ , Refl <- lemCommMapJustTakeLen perm (sshape sarr)
+ , Refl <- lemCommMapJustDropLen perm (sshape sarr)
+ , Refl <- lemCommMapJustPermute perm (shTakeLen perm (sshape sarr))
+ , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh))
= Shaped (mtranspose perm arr)
-sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a)
- => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
-sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend
+sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
+sappend = coerce mappend
sscalar :: Elt a => a -> Shaped '[] a
sscalar x = Shaped (mscalar x)
-sfromVectorP :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a)
-sfromVectorP v
- | Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mfromVectorP (shCvtSX (knownShape @sh)) v)
+sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
+sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v)
-sfromVector :: forall sh a. (KnownShape sh, Storable a, PrimElt a) => VS.Vector a -> Shaped sh a
-sfromVector v = coerce fromPrimitive (sfromVectorP @sh @a v)
+sfromVector :: (Storable a, PrimElt a) => ShS sh -> VS.Vector a -> Shaped sh a
+sfromVector sh v = coerce fromPrimitive (sfromVectorP sh v)
stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
stoVectorP = coerce mtoVectorP
@@ -1488,14 +1476,11 @@ stoVectorP = coerce mtoVectorP
stoVector :: (Storable a, PrimElt a) => Shaped sh a -> VS.Vector a
stoVector = coerce mtoVector
-sfromList1 :: forall n sh a. (KnownNat n, KnownShape sh, Elt a)
- => NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromList1 l
- | Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mfromList1 (coerce l))
+sfromList1 :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
+sfromList1 sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromList1 (coerce l))
-sfromList :: (KnownNat n, Elt a) => NonEmpty a -> Shaped '[n] a
-sfromList = Shaped . mfromList1 . fmap mscalar
+sfromList :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
+sfromList sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList
stoList :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
stoList (Shaped arr) = coerce (mtoList1 arr)
@@ -1506,37 +1491,31 @@ stoList1 = map sunScalar . stoList
sunScalar :: Elt a => Shaped '[] a -> a
sunScalar arr = sindex arr ZIS
-sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a)
-sconstantP x
- | Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mconstantP (shCvtSX (knownShape @sh)) x)
+sconstantP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
+sconstantP sh x = Shaped (mconstantP (shCvtSX sh) x)
-sconstant :: forall sh a. (KnownShape sh, Storable a, PrimElt a)
- => a -> Shaped sh a
-sconstant x = coerce fromPrimitive (sconstantP @sh x)
+sconstant :: (Storable a, PrimElt a) => ShS sh -> a -> Shaped sh a
+sconstant sh x = coerce fromPrimitive (sconstantP sh x)
-sslice :: (KnownShape sh, Elt a) => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
-sslice i n@SNat = slift $ \_ -> X.slice i n
+sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
+sslice i n@SNat arr =
+ let _ :$$ sh = sshape arr
+ in slift (n :$$ sh) (\_ -> X.slice i n) arr
-srev1 :: (KnownNat n, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (n : sh) a
-srev1 = slift $ \_ -> X.rev1
+srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a
+srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr
-sreshape :: forall sh sh' a. (KnownShape sh, KnownShape sh', Elt a)
- => ShS sh' -> Shaped sh a -> Shaped sh' a
-sreshape sh' (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- , Dict <- lemKnownMapJust (Proxy @sh')
- = Shaped (mreshape (shCvtSX sh') arr)
+sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a
+sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr)
-sasXArrayPrimP :: Shaped sh (Primitive a) -> XArray (MapJust sh) a
-sasXArrayPrimP (Shaped arr) = masXArrayPrimP arr
+sasXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
+sasXArrayPrimP (Shaped arr) = first shCvtXS' (masXArrayPrimP arr)
-sasXArrayPrim :: PrimElt a => Shaped sh a -> XArray (MapJust sh) a
-sasXArrayPrim (Shaped arr) = masXArrayPrim arr
+sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a)
+sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr)
-sfromXArrayPrimP :: XArray (MapJust sh) a -> Shaped sh (Primitive a)
-sfromXArrayPrimP = Shaped . mfromXArrayPrimP
+sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a)
+sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (X.staticShapeFrom (shCvtSX sh)) arr)
-sfromXArrayPrim :: PrimElt a => XArray (MapJust sh) a -> Shaped sh a
-sfromXArrayPrim = Shaped . mfromXArrayPrim
--}
+sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
+sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (X.staticShapeFrom (shCvtSX sh)) arr)