diff options
Diffstat (limited to 'src/Data/Array/Nested/Mixed.hs')
-rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 56 |
1 files changed, 28 insertions, 28 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 679e73b..373e62d 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -402,7 +402,7 @@ instance Storable a => Elt (Primitive a) where mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) mfromListOuter l@(arr1 :| _) = let sh = SUnknown (length l) :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) + in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l))) mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) mlift :: forall sh1 sh2. @@ -441,16 +441,16 @@ instance Storable a => Elt (Primitive a) where mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' sh2 = shxCast' sh1 ssh2 - in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) + in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr) mtranspose perm (M_Primitive sh arr) = M_Primitive (shxPermutePrefix perm sh) - (X.transpose (ssxFromShape sh) perm arr) + (X.transpose (ssxFromShX sh) perm arr) mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a) mconcat l@(M_Primitive (_ :$% sh) _ :| _) = - let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l) - in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result + let result = X.concat (ssxFromShX sh) (fmap (\(M_Primitive _ arr) -> arr) l) + in M_Primitive (X.shape (SUnknown () :!% ssxFromShX sh) result) result mrnf (M_Primitive sh a) = rnf sh `seq` rnf a @@ -467,7 +467,7 @@ instance Storable a => Elt (Primitive a) where :: forall sh' sh s. IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do - let arrsh = X.shape (ssxFromShape sh') arr + let arrsh = X.shape (ssxFromShX sh') arr offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) @@ -554,7 +554,7 @@ instance Elt a => Elt (Mixed sh' a) where -- moverlongShape method, a prefix of which is mshape. mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh mshape (M_Nest sh arr) - = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) + = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr)) mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a mindex (M_Nest _ arr) i = mindexPartial arr i @@ -583,7 +583,7 @@ instance Elt a => Elt (Mixed sh' a) where (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) in M_Nest sh2 result where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) + ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr))) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b f' sshT @@ -600,7 +600,7 @@ instance Elt a => Elt (Mixed sh' a) where (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) in M_Nest sh3 result where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) + ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b f' sshT @@ -618,7 +618,7 @@ instance Elt a => Elt (Mixed sh' a) where (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result)) in fmap (M_Nest sh2) result where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) + ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b) f' sshT @@ -640,7 +640,7 @@ instance Elt a => Elt (Mixed sh' a) where -> Mixed (PermutePrefix is sh) (Mixed sh' a) mtranspose perm (M_Nest sh arr) | let sh' = shxDropSh @sh @sh' (mshape arr) sh - , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') + , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh') , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') @@ -651,14 +651,14 @@ instance Elt a => Elt (Mixed sh' a) where mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) mconcat l@(M_Nest sh1 _ :| _) = let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) - in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result + in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) - mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) + mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr))))) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -685,7 +685,7 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where mvecsUnsafeNew sh example | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShX sh'))) where sh' = mshape example @@ -743,14 +743,14 @@ msumOuter1P :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1P (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX - in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) + in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr) msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) => Mixed (n : sh) a -> Mixed sh a msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a -msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShape sh) arr +msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr mappend :: forall n m sh a. Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a @@ -758,7 +758,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 where sn :$% sh = mshape arr1 sm :$% _ = mshape arr2 - ssh = ssxFromShape sh + ssh = ssxFromShX sh snm :: SMayNat () SNat (AddMaybe n m) snm = case (sn, sm) of (SUnknown{}, _) -> SUnknown () @@ -834,7 +834,7 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) mrerankP ssh sh2 f (M_Primitive sh arr) = let sh1 = shxDropSSX sh ssh in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) - (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) + (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2) (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) arr) @@ -849,8 +849,8 @@ mrerank ssh sh2 f (toPrimitive -> arr) = mreplicate :: forall sh sh' a. Elt a => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a mreplicate sh arr = - let ssh' = ssxFromShape (mshape arr) - in mlift (ssxAppend (ssxFromShape sh) ssh') + let ssh' = ssxFromShX (mshape arr) + in mlift (ssxAppend (ssxFromShX sh) ssh') (\(sshT :: StaticShX shT) -> case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of Refl -> X.replicate sh (ssxAppend ssh' sshT)) @@ -866,18 +866,18 @@ mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a mslice i n arr = let _ :$% sh = mshape arr - in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr + in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr +msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr +mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a mreshape sh' arr = - mlift (ssxFromShape sh') - (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') + mlift (ssxFromShX sh') + (\sshIn -> X.reshapePartial (ssxFromShX (mshape arr)) sshIn sh') arr mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a @@ -889,12 +889,12 @@ miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) -- | Throws if the array is empty. mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr)) + ixxFromList (ssxFromShX sh) (numEltMinIndex (shxRank sh) (fromO arr)) -- | Throws if the array is empty. mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr)) + ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr)) mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a @@ -904,7 +904,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi = case sh1 of _ :$% _ | sh1 == sh2 - , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) -> + , Refl <- lemRankApp (ssxInit (ssxFromShX sh1)) (ssxLast (ssxFromShX sh1) :!% ZKX) -> fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b)) | otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")" ZSX -> error "unreachable" |