aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Mixed.hs56
1 files changed, 28 insertions, 28 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 679e73b..373e62d 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -402,7 +402,7 @@ instance Storable a => Elt (Primitive a) where
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
mfromListOuter l@(arr1 :| _) =
let sh = SUnknown (length l) :$% mshape arr1
- in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l)))
mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
mlift :: forall sh1 sh2.
@@ -441,16 +441,16 @@ instance Storable a => Elt (Primitive a) where
mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) =
let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
sh2 = shxCast' sh1 ssh2
- in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)
+ in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr)
mtranspose perm (M_Primitive sh arr) =
M_Primitive (shxPermutePrefix perm sh)
- (X.transpose (ssxFromShape sh) perm arr)
+ (X.transpose (ssxFromShX sh) perm arr)
mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a)
mconcat l@(M_Primitive (_ :$% sh) _ :| _) =
- let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l)
- in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result
+ let result = X.concat (ssxFromShX sh) (fmap (\(M_Primitive _ arr) -> arr) l)
+ in M_Primitive (X.shape (SUnknown () :!% ssxFromShX sh) result) result
mrnf (M_Primitive sh a) = rnf sh `seq` rnf a
@@ -467,7 +467,7 @@ instance Storable a => Elt (Primitive a) where
:: forall sh' sh s.
IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
- let arrsh = X.shape (ssxFromShape sh') arr
+ let arrsh = X.shape (ssxFromShX sh') arr
offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
@@ -554,7 +554,7 @@ instance Elt a => Elt (Mixed sh' a) where
-- moverlongShape method, a prefix of which is mshape.
mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
mshape (M_Nest sh arr)
- = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr))
+ = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr))
mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
mindex (M_Nest _ arr) i = mindexPartial arr i
@@ -583,7 +583,7 @@ instance Elt a => Elt (Mixed sh' a) where
(sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
in M_Nest sh2 result
where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr)))
+ ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr)))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
f' sshT
@@ -600,7 +600,7 @@ instance Elt a => Elt (Mixed sh' a) where
(sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
in M_Nest sh3 result
where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
+ ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1)))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
f' sshT
@@ -618,7 +618,7 @@ instance Elt a => Elt (Mixed sh' a) where
(sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result))
in fmap (M_Nest sh2) result
where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
+ ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1)))
f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b)
f' sshT
@@ -640,7 +640,7 @@ instance Elt a => Elt (Mixed sh' a) where
-> Mixed (PermutePrefix is sh) (Mixed sh' a)
mtranspose perm (M_Nest sh arr)
| let sh' = shxDropSh @sh @sh' (mshape arr) sh
- , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh')
+ , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh')
, Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
, Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
, Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
@@ -651,14 +651,14 @@ instance Elt a => Elt (Mixed sh' a) where
mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
mconcat l@(M_Nest sh1 _ :| _) =
let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l)
- in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result
+ in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result
mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr
type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)
mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
- mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr)))))
+ mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr)))))
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
@@ -685,7 +685,7 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
mvecsUnsafeNew sh example
| shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
- | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh')))
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShX sh')))
where
sh' = mshape example
@@ -743,14 +743,14 @@ msumOuter1P :: forall sh n a. (Storable a, NumElt a)
=> Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
msumOuter1P (M_Primitive (n :$% sh) arr) =
let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
- in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr)
+ in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr)
msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
=> Mixed (n : sh) a -> Mixed sh a
msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
-msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShape sh) arr
+msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr
mappend :: forall n m sh a. Elt a
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
@@ -758,7 +758,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
where
sn :$% sh = mshape arr1
sm :$% _ = mshape arr2
- ssh = ssxFromShape sh
+ ssh = ssxFromShX sh
snm :: SMayNat () SNat (AddMaybe n m)
snm = case (sn, sm) of
(SUnknown{}, _) -> SUnknown ()
@@ -834,7 +834,7 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
mrerankP ssh sh2 f (M_Primitive sh arr) =
let sh1 = shxDropSSX sh ssh
in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2)
- (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2)
+ (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2)
(\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
arr)
@@ -849,8 +849,8 @@ mrerank ssh sh2 f (toPrimitive -> arr) =
mreplicate :: forall sh sh' a. Elt a
=> IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
mreplicate sh arr =
- let ssh' = ssxFromShape (mshape arr)
- in mlift (ssxAppend (ssxFromShape sh) ssh')
+ let ssh' = ssxFromShX (mshape arr)
+ in mlift (ssxAppend (ssxFromShX sh) ssh')
(\(sshT :: StaticShX shT) ->
case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
Refl -> X.replicate sh (ssxAppend ssh' sshT))
@@ -866,18 +866,18 @@ mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
mslice i n arr =
let _ :$% sh = mshape arr
- in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr
+ in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr
msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
-msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr
+msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr
mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
-mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr
+mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr
mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a
mreshape sh' arr =
- mlift (ssxFromShape sh')
- (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh')
+ mlift (ssxFromShX sh')
+ (\sshIn -> X.reshapePartial (ssxFromShX (mshape arr)) sshIn sh')
arr
mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a
@@ -889,12 +889,12 @@ miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
-- | Throws if the array is empty.
mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
- ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr))
+ ixxFromList (ssxFromShX sh) (numEltMinIndex (shxRank sh) (fromO arr))
-- | Throws if the array is empty.
mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
- ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr))
+ ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr))
mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
=> Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
@@ -904,7 +904,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi
= case sh1 of
_ :$% _
| sh1 == sh2
- , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) ->
+ , Refl <- lemRankApp (ssxInit (ssxFromShX sh1)) (ssxLast (ssxFromShX sh1) :!% ZKX) ->
fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b))
| otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")"
ZSX -> error "unreachable"