aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-30 11:58:40 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-30 11:58:40 +0200
commita65306ba5d80891b20ac86fa3a3242f9497751e6 (patch)
tree834af370556a46bbeca807a92c31bef098b47a89
parentd8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (diff)
Refactor Mixed (modules, regular function names)
-rw-r--r--ox-arrays.cabal13
-rw-r--r--src/Data/Array/Mixed.hs757
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs (renamed from src/Data/Array/Nested/Internal/Arith.hs)6
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Foreign.hs (renamed from src/Data/Array/Nested/Internal/Arith/Foreign.hs)4
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Lists.hs (renamed from src/Data/Array/Nested/Internal/Arith/Lists.hs)4
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs (renamed from src/Data/Array/Nested/Internal/Arith/Lists/TH.hs)2
-rw-r--r--src/Data/Array/Mixed/Lemmas.hs47
-rw-r--r--src/Data/Array/Mixed/Permutation.hs252
-rw-r--r--src/Data/Array/Mixed/Shape.hs455
-rw-r--r--src/Data/Array/Mixed/Types.hs110
-rw-r--r--src/Data/Array/Nested.hs10
-rw-r--r--src/Data/Array/Nested/Internal.hs326
-rw-r--r--test/Gen.hs3
-rw-r--r--test/Util.hs2
14 files changed, 1099 insertions, 892 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index e53815e..2356e72 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -10,12 +10,16 @@ extra-source-files: cbits/arith_lists.h
library
exposed-modules:
Data.Array.Mixed
+ Data.Array.Mixed.Internal.Arith
+ Data.Array.Mixed.Internal.Arith.Foreign
+ Data.Array.Mixed.Internal.Arith.Lists
+ Data.Array.Mixed.Internal.Arith.Lists.TH
+ Data.Array.Mixed.Lemmas
+ Data.Array.Mixed.Permutation
+ Data.Array.Mixed.Shape
+ Data.Array.Mixed.Types
Data.Array.Nested
Data.Array.Nested.Internal
- Data.Array.Nested.Internal.Arith
- Data.Array.Nested.Internal.Arith.Foreign
- Data.Array.Nested.Internal.Arith.Lists
- Data.Array.Nested.Internal.Arith.Lists.TH
build-depends:
base >=4.18 && <4.20,
deepseq,
@@ -38,6 +42,7 @@ test-suite test
other-modules:
Gen
Tests.C
+ Tests.Mixed
Util
build-depends:
ox-arrays,
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 4ae89a1..0100ec8 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -30,300 +30,23 @@ module Data.Array.Mixed where
import Control.DeepSeq (NFData(..))
import qualified Data.Array.RankedS as S
import qualified Data.Array.Ranked as ORB
-import Data.Bifunctor (first)
import Data.Coerce
-import qualified Data.Foldable as Foldable
-import Data.Functor.Const
import Data.Kind
-import Data.List (sort)
-import Data.Monoid (Sum(..))
import Data.Proxy
-import Data.Type.Bool
import Data.Type.Equality
import Data.Type.Ord
import qualified Data.Vector.Storable as VS
import Foreign.Storable (Storable)
import GHC.Generics (Generic)
-import GHC.IsList (IsList)
-import qualified GHC.IsList as IsList
-import GHC.TypeError
import GHC.TypeLits
-import qualified GHC.TypeNats as TypeNats
-import Unsafe.Coerce (unsafeCoerce)
-import Data.Array.Nested.Internal.Arith
+import Data.Array.Mixed.Internal.Arith
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
--- | Evidence for the constraint @c a@.
-data Dict c a where
- Dict :: c a => Dict c a
-
-fromSNat' :: SNat n -> Int
-fromSNat' = fromIntegral . fromSNat
-
-pattern SZ :: () => (n ~ 0) => SNat n
-pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
- where SZ = SNat
-
-pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
-pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
- where SS = snatSucc
-
-{-# COMPLETE SZ, SS #-}
-
-snatSucc :: SNat n -> SNat (n + 1)
-snatSucc SNat = SNat
-
-data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
-snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
-snatPred snp1 =
- withKnownNat snp1 $
- case cmpNat (Proxy @1) (Proxy @np1) of
- LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
- EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
- GTI -> Nothing
-
-
--- | Type-level list append.
-type family l1 ++ l2 where
- '[] ++ l2 = l2
- (x : xs) ++ l2 = x : xs ++ l2
-
-lemAppNil :: l ++ '[] :~: l
-lemAppNil = unsafeCoerce Refl
-
-lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
-lemAppAssoc _ _ _ = unsafeCoerce Refl
-
-type family Replicate n a where
- Replicate 0 a = '[]
- Replicate n a = a : Replicate (n - 1) a
-
-
-type role ListX nominal representational
-type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
-data ListX sh f where
- ZX :: ListX '[] f
- (::%) :: f n -> ListX sh f -> ListX (n : sh) f
-deriving instance (forall n. Eq (f n)) => Eq (ListX sh f)
-deriving instance (forall n. Ord (f n)) => Ord (ListX sh f)
-infixr 3 ::%
-
-instance (forall n. Show (f n)) => Show (ListX sh f) where
- showsPrec _ = showListX shows
-
-instance (forall n. NFData (f n)) => NFData (ListX sh f) where
- rnf ZX = ()
- rnf (x ::% l) = rnf x `seq` rnf l
-
-data UnconsListXRes f sh1 =
- forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n)
-unconsListX :: ListX sh1 f -> Maybe (UnconsListXRes f sh1)
-unconsListX (i ::% shl') = Just (UnconsListXRes shl' i)
-unconsListX ZX = Nothing
-
-fmapListX :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g
-fmapListX _ ZX = ZX
-fmapListX f (x ::% xs) = f x ::% fmapListX f xs
-
-foldListX :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m
-foldListX _ ZX = mempty
-foldListX f (x ::% xs) = f x <> foldListX f xs
-
-lengthListX :: ListX sh f -> Int
-lengthListX = getSum . foldListX (\_ -> Sum 1)
-
-snatLengthListX :: ListX sh f -> SNat (Rank sh)
-snatLengthListX ZX = SNat
-snatLengthListX (_ ::% l) | SNat <- snatLengthListX l = SNat
-
-showListX :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
-showListX f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListX sh' f -> ShowS
- go _ ZX = id
- go prefix (x ::% xs) = showString prefix . f x . go "," xs
-
-listXToList :: ListX sh' (Const i) -> [i]
-listXToList ZX = []
-listXToList (Const i ::% is) = i : listXToList is
-
-
-type role IxX nominal representational
-type IxX :: [Maybe Nat] -> Type -> Type
-newtype IxX sh i = IxX (ListX sh (Const i))
- deriving (Eq, Ord, Generic)
-
-pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i
-pattern ZIX = IxX ZX
-
-pattern (:.%)
- :: forall {sh1} {i}.
- forall n sh. (n : sh ~ sh1)
- => i -> IxX sh i -> IxX sh1 i
-pattern i :.% shl <- IxX (unconsListX -> Just (UnconsListXRes (IxX -> shl) (getConst -> i)))
- where i :.% IxX shl = IxX (Const i ::% shl)
-infixr 3 :.%
-
-{-# COMPLETE ZIX, (:.%) #-}
-
-type IIxX sh = IxX sh Int
-
-instance Show i => Show (IxX sh i) where
- showsPrec _ (IxX l) = showListX (\(Const i) -> shows i) l
-
-instance Functor (IxX sh) where
- fmap f (IxX l) = IxX (fmapListX (Const . f . getConst) l)
-
-instance Foldable (IxX sh) where
- foldMap f (IxX l) = foldListX (f . getConst) l
-
-instance NFData i => NFData (IxX sh i)
-
-
-data SMayNat i f n where
- SUnknown :: i -> SMayNat i f Nothing
- SKnown :: f n -> SMayNat i f (Just n)
-deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n)
-deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n)
-deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n)
-
-instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where
- rnf (SUnknown i) = rnf i
- rnf (SKnown x) = rnf x
-
-fromSMayNat :: (n ~ Nothing => i -> r) -> (forall m. n ~ Just m => f m -> r) -> SMayNat i f n -> r
-fromSMayNat f _ (SUnknown i) = f i
-fromSMayNat _ g (SKnown s) = g s
-
-fromSMayNat' :: SMayNat Int SNat n -> Int
-fromSMayNat' = fromSMayNat id fromSNat'
-
-type role ShX nominal representational
-type ShX :: [Maybe Nat] -> Type -> Type
-newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
- deriving (Eq, Ord, Generic)
-
-pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
-pattern ZSX = ShX ZX
-
-pattern (:$%)
- :: forall {sh1} {i}.
- forall n sh. (n : sh ~ sh1)
- => SMayNat i SNat n -> ShX sh i -> ShX sh1 i
-pattern i :$% shl <- ShX (unconsListX -> Just (UnconsListXRes (ShX -> shl) i))
- where i :$% ShX shl = ShX (i ::% shl)
-infixr 3 :$%
-
-{-# COMPLETE ZSX, (:$%) #-}
-
-type IShX sh = ShX sh Int
-
-instance Show i => Show (ShX sh i) where
- showsPrec _ (ShX l) = showListX (fromSMayNat shows (shows . fromSNat)) l
-
-instance Functor (ShX sh) where
- fmap f (ShX l) = ShX (fmapListX (fromSMayNat (SUnknown . f) SKnown) l)
-
-instance NFData i => NFData (ShX sh i) where
- rnf (ShX ZX) = ()
- rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l)
- rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l)
-
-lengthShX :: ShX sh i -> Int
-lengthShX (ShX l) = lengthListX l
-
-shXToList :: IShX sh -> [Int]
-shXToList ZSX = []
-shXToList (smn :$% sh) = fromSMayNat' smn : shXToList sh
-
-
--- | The part of a shape that is statically known.
-type StaticShX :: [Maybe Nat] -> Type
-newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat))
- deriving (Eq, Ord)
-
-pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
-pattern ZKX = StaticShX ZX
-
-pattern (:!%)
- :: forall {sh1}.
- forall n sh. (n : sh ~ sh1)
- => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1
-pattern i :!% shl <- StaticShX (unconsListX -> Just (UnconsListXRes (StaticShX -> shl) i))
- where i :!% StaticShX shl = StaticShX (i ::% shl)
-infixr 3 :!%
-
-{-# COMPLETE ZKX, (:!%) #-}
-
-instance Show (StaticShX sh) where
- showsPrec _ (StaticShX l) = showListX (fromSMayNat shows (shows . fromSNat)) l
-
-lengthStaticShX :: StaticShX sh -> Int
-lengthStaticShX (StaticShX l) = lengthListX l
-
-geqStaticShX :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
-geqStaticShX ZKX ZKX = Just Refl
-geqStaticShX (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh')
- | Just Refl <- sameNat n m
- , Just Refl <- geqStaticShX sh sh'
- = Just Refl
-geqStaticShX (SUnknown () :!% sh) (SUnknown () :!% sh')
- | Just Refl <- geqStaticShX sh sh'
- = Just Refl
-geqStaticShX _ _ = Nothing
-
-
--- | Evidence for the static part of a shape. This pops up only when you are
--- polymorphic in the element type of an array.
-type KnownShX :: [Maybe Nat] -> Constraint
-class KnownShX sh where knownShX :: StaticShX sh
-instance KnownShX '[] where knownShX = ZKX
-instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX
-instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
-
-
--- | Very untyped: only length is checked (at runtime).
-instance KnownShX sh => IsList (ListX sh (Const i)) where
- type Item (ListX sh (Const i)) = i
- fromList topl = go (knownShX @sh) topl
- where
- go :: StaticShX sh' -> [i] -> ListX sh' (Const i)
- go ZKX [] = ZX
- go (_ :!% sh) (i : is) = Const i ::% go sh is
- go _ _ = error $ "IsList(ListX): Mismatched list length (type says "
- ++ show (lengthStaticShX (knownShX @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
- toList = listXToList
-
--- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
-instance KnownShX sh => IsList (IxX sh i) where
- type Item (IxX sh i) = i
- fromList = IxX . IsList.fromList
- toList = Foldable.toList
-
--- | Untyped: length and known dimensions are checked (at runtime).
-instance KnownShX sh => IsList (ShX sh Int) where
- type Item (ShX sh Int) = Int
- fromList topl = ShX (go (knownShX @sh) topl)
- where
- go :: StaticShX sh' -> [Int] -> ListX sh' (SMayNat Int SNat)
- go ZKX [] = ZX
- go (SKnown sn :!% sh) (i : is)
- | i == fromSNat' sn = SKnown sn ::% go sh is
- | otherwise = error $ "IsList(ShX): Value does not match typing (type says "
- ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
- go (SUnknown () :!% sh) (i : is) = SUnknown i ::% go sh is
- go _ _ = error $ "IsList(ShX): Mismatched list length (type says "
- ++ show (lengthStaticShX (knownShX @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
- toList = shXToList
-
-
-type family Rank sh where
- Rank '[] = 0
- Rank (_ : sh) = Rank sh + 1
-
type XArray :: [Maybe Nat] -> Type -> Type
newtype XArray sh a = XArray (S.Array (Rank sh) a)
deriving (Show, Eq, Generic)
@@ -333,180 +56,6 @@ deriving instance (Ord a, Storable a) => Ord (XArray '[] a)
instance NFData a => NFData (XArray sh a)
-zeroIxX :: StaticShX sh -> IIxX sh
-zeroIxX ZKX = ZIX
-zeroIxX (_ :!% ssh) = 0 :.% zeroIxX ssh
-
-zeroIxX' :: IShX sh -> IIxX sh
-zeroIxX' ZSX = ZIX
-zeroIxX' (_ :$% sh) = 0 :.% zeroIxX' sh
-
--- This is a weird operation, so it has a long name
-completeShXzeros :: StaticShX sh -> IShX sh
-completeShXzeros ZKX = ZSX
-completeShXzeros (SUnknown () :!% ssh) = SUnknown 0 :$% completeShXzeros ssh
-completeShXzeros (SKnown n :!% ssh) = SKnown n :$% completeShXzeros ssh
-
-listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
-listxAppend ZX idx' = idx'
-listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'
-
-ixAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
-ixAppend = coerce (listxAppend @_ @(Const i))
-
-shAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
-shAppend = coerce (listxAppend @_ @(SMayNat i SNat))
-
-listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f
-listxDrop long ZX = long
-listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short
-
-ixDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i
-ixDrop = coerce (listxDrop @(Const i) @(Const i))
-
-shDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i
-shDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
-
-shDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i
-shDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
-
-shDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i
-shDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
-
-shTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i
-shTakeSSX _ = flip go
- where
- go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i
- go ZKX _ = ZSX
- go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh
-
-ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
-ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
-
--- TODO: generalise all these things to arbitrary @i@
-shTail :: IShX (n : sh) -> IShX sh
-shTail (_ :$% sh) = sh
-
-ssxTail :: StaticShX (n : sh) -> StaticShX sh
-ssxTail (_ :!% ssh) = ssh
-
-shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh')
-shAppSplit _ ZKX idx = (ZSX, idx)
-shAppSplit p (_ :!% ssh) (i :$% idx) = first (i :$%) (shAppSplit p ssh idx)
-
-ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
-ssxAppend ZKX sh' = sh'
-ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
-
-shapeSize :: IShX sh -> Int
-shapeSize ZSX = 1
-shapeSize (n :$% sh) = fromSMayNat' n * shapeSize sh
-
--- | This may fail if @sh@ has @Nothing@s in it.
-ssxToShape' :: StaticShX sh -> Maybe (IShX sh)
-ssxToShape' ZKX = Just ZSX
-ssxToShape' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShape' sh
-ssxToShape' (SUnknown _ :!% _) = Nothing
-
-lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
-lemReplicateSucc = unsafeCoerce Refl
-
-ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
-ssxReplicate SZ = ZKX
-ssxReplicate (SS (n :: SNat n'))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n'
- = SUnknown () :!% ssxReplicate n
-
-fromLinearIdx :: IShX sh -> Int -> IIxX sh
-fromLinearIdx = \sh i -> case go sh i of
- (idx, 0) -> idx
- _ -> error $ "fromLinearIdx: out of range (" ++ show i ++
- " in array of shape " ++ show sh ++ ")"
- where
- -- returns (index in subarray, remaining index in enclosing array)
- go :: IShX sh -> Int -> (IIxX sh, Int)
- go ZSX i = (ZIX, i)
- go (n :$% sh) i =
- let (idx, i') = go sh i
- (upi, locali) = i' `quotRem` fromSMayNat' n
- in (locali :.% idx, upi)
-
-toLinearIdx :: IShX sh -> IIxX sh -> Int
-toLinearIdx = \sh i -> fst (go sh i)
- where
- -- returns (index in subarray, size of subarray)
- go :: IShX sh -> IIxX sh -> (Int, Int)
- go ZSX ZIX = (0, 1)
- go (n :$% sh) (i :.% ix) =
- let (lidx, sz) = go sh ix
- in (sz * i + lidx, fromSMayNat' n * sz)
-
-enumShape :: IShX sh -> [IIxX sh]
-enumShape = \sh -> go sh id []
- where
- go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a]
- go ZSX f = (f ZIX :)
- go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]
-
-shapeLshape :: IShX sh -> S.ShapeL
-shapeLshape ZSX = []
-shapeLshape (n :$% sh) = fromSMayNat' n : shapeLshape sh
-
-ssxLength :: StaticShX sh -> Int
-ssxLength ZKX = 0
-ssxLength (_ :!% ssh) = 1 + ssxLength ssh
-
-ssxIotaFrom :: Int -> StaticShX sh -> [Int]
-ssxIotaFrom _ ZKX = []
-ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh
-
-type Flatten sh = Flatten' 1 sh
-
-type family Flatten' acc sh where
- Flatten' acc '[] = Just acc
- Flatten' acc (Nothing : sh) = Nothing
- Flatten' acc (Just n : sh) = Flatten' (acc * n) sh
-
-flattenSSX :: StaticShX sh -> SMayNat () SNat (Flatten sh)
-flattenSSX = go (SNat @1)
- where
- go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)
- go acc ZKX = SKnown acc
- go _ (SUnknown () :!% _) = SUnknown ()
- go acc (SKnown sn :!% sh) = go (mulSNat acc sn) sh
-
-flattenSh :: IShX sh -> SMayNat Int SNat (Flatten sh)
-flattenSh = go (SNat @1)
- where
- go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
- go acc ZSX = SKnown acc
- go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh)
- go acc (SKnown sn :$% sh) = go (mulSNat acc sn) sh
-
- goUnknown :: Int -> IShX sh -> Int
- goUnknown acc ZSX = acc
- goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh
- goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh
-
-staticShapeFrom :: IShX sh -> StaticShX sh
-staticShapeFrom ZSX = ZKX
-staticShapeFrom (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% staticShapeFrom sh
-
-lemRankApp :: StaticShX sh1 -> StaticShX sh2
- -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
-lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this
-
-lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
- -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
-lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
-
-lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh)
-lemKnownNatRank ZSX = Dict
-lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict
-
-lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
-lemKnownNatRankSSX ZKX = Dict
-lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh
shape = \ssh (XArray arr) -> go ssh (S.shapeL arr)
@@ -519,7 +68,7 @@ shape = \ssh (XArray arr) -> go ssh (S.shapeL arr)
fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
fromVector sh v
| Dict <- lemKnownNatRank sh
- = XArray (S.fromVector (shapeLshape sh) v)
+ = XArray (S.fromVector (shxToList sh) v)
toVector :: Storable a => XArray sh a -> VS.Vector a
toVector (XArray arr) = S.toVector arr
@@ -527,23 +76,18 @@ toVector (XArray arr) = S.toVector arr
scalar :: Storable a => a -> XArray '[] a
scalar = XArray . S.scalar
-eqShX :: IShX sh1 -> IShX sh2 -> Bool
-eqShX ZSX ZSX = True
-eqShX (n :$% sh1) (m :$% sh2) = fromSMayNat' n == fromSMayNat' m && eqShX sh1 sh2
-eqShX _ _ = False
-
-- | Will throw if the array does not have the casted-to shape.
cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2
=> StaticShX sh1 -> IShX sh2 -> StaticShX sh'
-> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
cast ssh1 sh2 ssh' (XArray arr)
| Refl <- lemRankApp ssh1 ssh'
- , Refl <- lemRankApp (staticShapeFrom sh2) ssh'
+ , Refl <- lemRankApp (ssxFromShape sh2) ssh'
= let arrsh :: IShX sh1
- (arrsh, _) = shAppSplit (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
- in if eqShX arrsh sh2
- then XArray arr
- else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")"
+ (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr))
+ in case shxEqual arrsh sh2 of
+ Just _ -> XArray arr
+ Nothing -> error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")"
unScalar :: Storable a => XArray '[] a -> a
unScalar (XArray a) = S.unScalar a
@@ -551,24 +95,24 @@ unScalar (XArray a) = S.unScalar a
replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a
replicate sh ssh' (XArray arr)
| Dict <- lemKnownNatRankSSX ssh'
- , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh) ssh')
- , Refl <- lemRankApp (staticShapeFrom sh) ssh'
- = XArray (S.stretch (shapeLshape sh ++ S.shapeL arr) $
- S.reshape (map (const 1) (shapeLshape sh) ++ S.shapeL arr) $
+ , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh) ssh')
+ , Refl <- lemRankApp (ssxFromShape sh) ssh'
+ = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $
+ S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) $
arr)
replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
replicateScal sh x
| Dict <- lemKnownNatRank sh
- = XArray (S.constant (shapeLshape sh) x)
+ = XArray (S.constant (shxToList sh) x)
generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a
-generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)
+generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh)
-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a)
-- generateM sh f | Dict <- lemKnownNatRank sh =
--- XArray . S.fromVector (shapeLshape sh)
--- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh)
+-- XArray . S.fromVector (shxShapeL sh)
+-- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh)
indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a
indexPartial (XArray arr) ZIX = XArray arr
@@ -580,24 +124,6 @@ index xarr i
= let XArray arr' = indexPartial xarr i :: XArray '[] a
in S.unScalar arr'
-type family AddMaybe n m where
- AddMaybe Nothing _ = Nothing
- AddMaybe (Just _) Nothing = Nothing
- AddMaybe (Just n) (Just m) = Just (n + m)
-
--- This should be a function in base
-plusSNat :: SNat n -> SNat m -> SNat (n + m)
-plusSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n + TypeNats.fromSNat m) unsafeCoerce
-
--- This should be a function in base
-mulSNat :: SNat n -> SNat m -> SNat (n * m)
-mulSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n * TypeNats.fromSNat m) unsafeCoerce
-
-smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
-smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
-smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
-smnAddMaybe (SKnown n) (SKnown m) = SKnown (plusSNat n m)
-
append :: forall n m sh a. Storable a
=> StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
append ssh (XArray a) (XArray b)
@@ -639,9 +165,9 @@ rerank :: forall sh sh1 sh2 a b.
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
rerank ssh ssh1 ssh2 f xarr@(XArray arr)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
- = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr)
- in if any (== 0) (shapeLshape sh)
- then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) [])
+ = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr)
+ in if any (== 0) (shxToList sh)
+ then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
else case () of
() | Dict <- lemKnownNatRankSSX ssh
, Dict <- lemKnownNatRankSSX ssh2
@@ -666,9 +192,9 @@ rerank2 :: forall sh sh1 sh2 a b c.
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2)
- = let (sh, _) = shAppSplit (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1)
- in if any (== 0) (shapeLshape sh)
- then XArray (S.fromList (shapeLshape (shAppend sh (completeShXzeros ssh2))) [])
+ = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1)
+ in if any (== 0) (shxToList sh)
+ then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) [])
else case () of
() | Dict <- lemKnownNatRankSSX ssh
, Dict <- lemKnownNatRankSSX ssh2
@@ -678,211 +204,14 @@ rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2)
(\a b -> let XArray r = f (XArray a) (XArray b) in r)
arr1 arr2)
-type family Elem x l where
- Elem x '[] = 'False
- Elem x (x : _) = 'True
- Elem x (_ : ys) = Elem x ys
-
-type family AllElem' as bs where
- AllElem' '[] bs = 'True
- AllElem' (a : as) bs = Elem a bs && AllElem' as bs
-
-type AllElem as bs = Assert (AllElem' as bs)
- (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs))
-
-type family Count i n where
- Count n n = '[]
- Count i n = i : Count (i + 1) n
-
-type Permutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as)
-
-type family Index i sh where
- Index 0 (n : sh) = n
- Index i (_ : sh) = Index (i - 1) sh
-
-type family Permute is sh where
- Permute '[] sh = '[]
- Permute (i : is) sh = Index i sh : Permute is sh
-
-type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh
-
-data HList f list where
- HNil :: HList f '[]
- HCons :: f a -> HList f l -> HList f (a : l)
-infixr 5 `HCons`
-deriving instance (forall a. Show (f a)) => Show (HList f list)
-deriving instance (forall a. Eq (f a)) => Eq (HList f list)
-
-foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m
-foldHList _ HNil = mempty
-foldHList f (x `HCons` l) = f x <> foldHList f l
-
-snatLengthHList :: HList f list -> SNat (Rank list)
-snatLengthHList HNil = SNat
-snatLengthHList (_ `HCons` l) | SNat <- snatLengthHList l = SNat
-
-permFromList :: [Int] -> (forall list. HList SNat list -> r) -> r
-permFromList [] k = k HNil
-permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case
- Just sn -> permFromList xs $ \list -> k (sn `HCons` list)
- Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x
-
-permToList :: HList SNat list -> [Int]
-permToList = foldHList (pure . fromSNat')
-
-type family TakeLen ref l where
- TakeLen '[] l = '[]
- TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs
-
-type family DropLen ref l where
- DropLen '[] l = l
- DropLen (_ : ref) (_ : xs) = DropLen ref xs
-
-lemRankPermute :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is
-lemRankPermute _ HNil = Refl
-lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl
-
-lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
- => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is
-lemRankDropLen ZKX HNil = Refl
-lemRankDropLen (_ :!% sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl
-lemRankDropLen (_ :!% _) HNil = Refl
-lemRankDropLen ZKX (_ `HCons` _) = error "1 <= 0"
-
-lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l
-lemIndexSucc _ _ _ = unsafeCoerce Refl
-
-listxTakeLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (TakeLen is sh) f
-listxTakeLen HNil _ = ZX
-listxTakeLen (_ `HCons` is) (n ::% sh) = n ::% listxTakeLen is sh
-listxTakeLen (_ `HCons` _) ZX = error "Permutation longer than shape"
-
-listxDropLen :: forall f is sh. HList SNat is -> ListX sh f -> ListX (DropLen is sh) f
-listxDropLen HNil sh = sh
-listxDropLen (_ `HCons` is) (_ ::% sh) = listxDropLen is sh
-listxDropLen (_ `HCons` _) ZX = error "Permutation longer than shape"
-
-listxPermute :: forall f is sh. HList SNat is -> ListX sh f -> ListX (Permute is sh) f
-listxPermute HNil _ = ZX
-listxPermute (i `HCons` (is :: HList SNat is')) (sh :: ListX sh f) = listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh)
-
-listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f
-listxIndex _ _ SZ (n ::% _) rest = n ::% rest
-listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listxIndex p pT i sh rest
-listxIndex _ _ _ ZX _ = error "Index into empty shape"
-
-listxPermutePrefix :: forall f is sh. HList SNat is -> ListX sh f -> ListX (PermutePrefix is sh) f
-listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
-
-ssxTakeLen :: forall is sh. HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh)
-ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
-
-ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh)
-ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
-
-ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh)
-ssxPermute = coerce (listxPermute @(SMayNat () SNat))
-
-ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT)
-ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
-
-ssxPermutePrefix :: HList SNat is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
-ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
-
-shPermutePrefix :: HList SNat is -> IShX sh -> IShX (PermutePrefix is sh)
-shPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
-
--- TODO: test this thing more properly
-invertPermutation :: HList SNat is
- -> (forall is'.
- Permutation is'
- => HList SNat is'
- -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh)
- -> r)
- -> r
-invertPermutation = \perm k ->
- genPerm perm $ \(invperm :: HList SNat is') ->
- let sn = snatLengthHList invperm
- in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of
- (Just Refl, Just Refl) ->
- k invperm
- (\ssh -> case provePermInverse perm invperm ssh of
- Just eq -> eq
- Nothing -> error $ "invertPermutation: did not generate inverse? perm = " ++ show perm
- ++ " ; invperm = " ++ show invperm)
- _ -> error $ "invertPermutation: did not generate permutation? perm = " ++ show perm
- ++ " ; invperm = " ++ show invperm
- where
- genPerm :: HList SNat is -> (forall is'. HList SNat is' -> r) -> r
- genPerm perm =
- let permList = foldHList (pure . fromSNat) perm
- in toHList $ map snd (sort (zip permList [0..]))
- where
- toHList :: [Natural] -> (forall is'. HList SNat is' -> r) -> r
- toHList [] k = k HNil
- toHList (n : ns) k = toHList ns $ \l -> TypeNats.withSomeSNat n $ \sn -> k (HCons sn l)
-
- lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True
- lemElemCount _ _ = unsafeCoerce Refl
-
- lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n
- lemCount _ _ = unsafeCoerce Refl
-
- lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True
- lemElem _ _ = unsafeCoerce Refl
-
- provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> HList SNat is'
- -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True)
- provePerm1 _ _ HNil = Just (Refl)
- provePerm1 p rtop@SNat (HCons sn@SNat perm)
- | Just Refl <- provePerm1 p rtop perm
- = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of
- (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
- (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
- _ -> Nothing
- | otherwise
- = Nothing
-
- provePerm2 :: SNat i -> SNat n -> HList SNat is' -> Maybe (AllElem' (Count i n) is' :~: True)
- provePerm2 = \i@(SNat :: SNat i) n@SNat perm ->
- case cmpNat i n of
- EQI -> Just Refl
- LTI | Refl <- lemCount i n
- , Just Refl <- provePerm2 (SNat @(i + 1)) n perm
- -> checkElem i perm
- | otherwise -> Nothing
- GTI -> error "unreachable"
- where
- checkElem :: SNat i -> HList SNat is' -> Maybe (Elem i is' :~: True)
- checkElem _ HNil = Nothing
- checkElem i@SNat (HCons k@SNat perm :: HList SNat is') =
- case sameNat i k of
- Just Refl -> Just Refl
- Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl
- | otherwise -> Nothing
-
- provePermInverse :: HList SNat is -> HList SNat is' -> StaticShX sh -> Maybe (Permute is' (Permute is sh) :~: sh)
- provePermInverse perm perminv ssh = geqStaticShX (ssxPermute perminv (ssxPermute perm ssh)) ssh
-
-applyPermX :: forall f is sh. HList SNat is -> ListX sh f -> ListX (PermutePrefix is sh) f
-applyPermX perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
-
-applyPermIxX :: forall i is sh. HList SNat is -> IxX sh i -> IxX (PermutePrefix is sh) i
-applyPermIxX = coerce (applyPermX @(Const i))
-
-applyPermShX :: forall i is sh. HList SNat is -> ShX sh i -> ShX (PermutePrefix is sh) i
-applyPermShX = coerce (applyPermX @(SMayNat i SNat))
-
-class KnownNatList l where makeNatList :: HList SNat l
-instance KnownNatList '[] where makeNatList = HNil
-instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList
+class KnownNatList l where makeNatList :: Perm l
+instance KnownNatList '[] where makeNatList = PNil
+instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `PCons` makeNatList
-- | The list argument gives indices into the original dimension list.
-transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh)
+transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh)
=> StaticShX sh
- -> HList SNat is
+ -> Perm is
-> XArray sh a
-> XArray (PermutePrefix is sh) a
transpose ssh perm (XArray arr)
@@ -890,7 +219,7 @@ transpose ssh perm (XArray arr)
, Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh)
, Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm
, Refl <- lemRankDropLen ssh perm
- = XArray (S.transpose (permToList perm) arr)
+ = XArray (S.transpose (permToList' perm) arr)
-- | The list argument gives indices into the original dimension list.
--
@@ -929,14 +258,14 @@ sumInner :: forall sh sh' a. (Storable a, NumElt a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a
sumInner ssh ssh' arr
| Refl <- lemAppNil @sh
- = let (_, sh') = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
- sh'F = flattenSh sh' :$% ZSX
- ssh'F = staticShapeFrom sh'F
+ = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
+ sh'F = shxFlatten sh' :$% ZSX
+ ssh'F = ssxFromShape sh'F
go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a
go (XArray arr')
| Refl <- lemRankApp ssh ssh'F
- , let sn = snatLengthListX (let StaticShX l = ssh in l)
+ , let sn = listxLengthSNat (let StaticShX l = ssh in l)
= XArray (numEltSum1Inner sn arr')
in go $
@@ -949,10 +278,10 @@ sumOuter :: forall sh sh' a. (Storable a, NumElt a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a
sumOuter ssh ssh' arr
| Refl <- lemAppNil @sh
- = let (sh, _) = shAppSplit (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
- shF = flattenSh sh :$% ZSX
- in sumInner ssh' (staticShapeFrom shF) $
- transpose2 (staticShapeFrom shF) ssh' $
+ = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr)
+ shF = shxFlatten sh :$% ZSX
+ in sumInner ssh' (ssxFromShape shF) $
+ transpose2 (ssxFromShape shF) ssh' $
reshapePartial ssh ssh' shF $
arr
@@ -988,7 +317,7 @@ toList1 (XArray arr) = S.toList arr
empty :: forall sh a. Storable a => IShX sh -> XArray sh a
empty sh
| Dict <- lemKnownNatRank sh
- = XArray (S.constant (shapeLshape sh)
+ = XArray (S.constant (shxToList sh)
(error "Data.Array.Mixed.empty: shape was not empty"))
slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a
@@ -1005,14 +334,14 @@ reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray s
reshape ssh1 sh2 (XArray arr)
| Dict <- lemKnownNatRankSSX ssh1
, Dict <- lemKnownNatRank sh2
- = XArray (S.reshape (shapeLshape sh2) arr)
+ = XArray (S.reshape (shxToList sh2) arr)
-- | Throws if the given array and the target shape do not have the same number of elements.
reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
reshapePartial ssh1 ssh' sh2 (XArray arr)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh')
- , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh')
- = XArray (S.reshape (shapeLshape sh2 ++ drop (lengthStaticShX ssh1) (S.shapeL arr)) arr)
+ , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShape sh2) ssh')
+ = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr)
-- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo).
iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a
diff --git a/src/Data/Array/Nested/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 95fcfcf..cf6820b 100644
--- a/src/Data/Array/Nested/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -6,7 +6,7 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-module Data.Array.Nested.Internal.Arith where
+module Data.Array.Mixed.Internal.Arith where
import Control.Monad (forM, guard)
import qualified Data.Array.Internal as OI
@@ -24,8 +24,8 @@ import GHC.TypeLits
import Language.Haskell.TH
import System.IO.Unsafe
-import Data.Array.Nested.Internal.Arith.Foreign
-import Data.Array.Nested.Internal.Arith.Lists
+import Data.Array.Mixed.Internal.Arith.Foreign
+import Data.Array.Mixed.Internal.Arith.Lists
liftVEltwise1 :: Storable a
diff --git a/src/Data/Array/Nested/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index ac83188..6fc7229 100644
--- a/src/Data/Array/Nested/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -1,6 +1,6 @@
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE TemplateHaskell #-}
-module Data.Array.Nested.Internal.Arith.Foreign where
+module Data.Array.Mixed.Internal.Arith.Foreign where
import Control.Monad
import Data.Int
@@ -9,7 +9,7 @@ import Foreign.C.Types
import Foreign.Ptr
import Language.Haskell.TH
-import Data.Array.Nested.Internal.Arith.Lists
+import Data.Array.Mixed.Internal.Arith.Lists
$(fmap concat . forM typesList $ \arithtype -> do
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists.hs b/src/Data/Array/Mixed/Internal/Arith/Lists.hs
index ce2836d..a284bc1 100644
--- a/src/Data/Array/Nested/Internal/Arith/Lists.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Lists.hs
@@ -1,12 +1,12 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TemplateHaskell #-}
-module Data.Array.Nested.Internal.Arith.Lists where
+module Data.Array.Mixed.Internal.Arith.Lists where
import Data.Char
import Data.Int
import Language.Haskell.TH
-import Data.Array.Nested.Internal.Arith.Lists.TH
+import Data.Array.Mixed.Internal.Arith.Lists.TH
data ArithType = ArithType
diff --git a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs
index 7142dfa..8b7d05f 100644
--- a/src/Data/Array/Nested/Internal/Arith/Lists/TH.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Lists/TH.hs
@@ -1,5 +1,5 @@
{-# LANGUAGE TemplateHaskellQuotes #-}
-module Data.Array.Nested.Internal.Arith.Lists.TH where
+module Data.Array.Mixed.Internal.Arith.Lists.TH where
import Control.Monad
import Control.Monad.IO.Class
diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs
new file mode 100644
index 0000000..30ec9c0
--- /dev/null
+++ b/src/Data/Array/Mixed/Lemmas.hs
@@ -0,0 +1,47 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE DataKinds #-}
+module Data.Array.Mixed.Lemmas where
+
+import Data.Proxy
+import Data.Type.Equality
+import GHC.TypeLits
+
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
+
+
+lemRankApp :: forall sh1 sh2.
+ StaticShX sh1 -> StaticShX sh2
+ -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
+lemRankApp ZKX _ = Refl
+lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2
+ = lem2 (Proxy @(Rank sh1T)) Proxy Proxy $
+ lem (Proxy @(Rank sh2)) (Proxy @(Rank sh1T)) (Proxy @(Rank (sh1T ++ sh2))) $
+ lemRankApp ssh1 ssh2
+ where
+ lem :: proxy a -> proxy b -> proxy c
+ -> c :~: b + a
+ -> b + a :~: c
+ lem _ _ _ Refl = Refl
+
+ lem2 :: proxy a -> proxy b -> proxy c
+ -> (a + b :~: c)
+ -> c + 1 :~: (a + 1 + b)
+ lem2 _ _ _ Refl = Refl
+
+lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
+ -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
+lemRankAppComm _ _ = unsafeCoerceRefl -- TODO improve this
+
+lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRank ZSX = Dict
+lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict
+
+lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRankSSX ZKX = Dict
+lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs
new file mode 100644
index 0000000..2710018
--- /dev/null
+++ b/src/Data/Array/Mixed/Permutation.hs
@@ -0,0 +1,252 @@
+{-# LANGUAGE ConstraintKinds #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Mixed.Permutation where
+
+import Data.Coerce (coerce)
+import Data.Functor.Const
+import Data.List (sort)
+import Data.Proxy
+import Data.Type.Bool
+import Data.Type.Equality
+import Data.Type.Ord
+import GHC.TypeError
+import GHC.TypeLits
+import qualified GHC.TypeNats as TN
+
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
+
+
+-- * Permutations
+
+-- | A "backward" permutation of a dimension list. The operation on the
+-- dimension list is most similar to 'Data.Vector.backpermute'; see 'Permute'
+-- for code that implements this.
+data Perm list where
+ PNil :: Perm '[]
+ PCons :: SNat a -> Perm l -> Perm (a : l)
+infixr 5 `PCons`
+deriving instance Show (Perm list)
+deriving instance Eq (Perm list)
+
+permLengthSNat :: Perm list -> SNat (Rank list)
+permLengthSNat PNil = SNat
+permLengthSNat (_ `PCons` l) | SNat <- permLengthSNat l = SNat
+
+permFromList :: [Int] -> (forall list. Perm list -> r) -> r
+permFromList [] k = k PNil
+permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case
+ Just sn -> permFromList xs $ \list -> k (sn `PCons` list)
+ Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x
+
+permToList :: Perm list -> [Natural]
+permToList PNil = mempty
+permToList (x `PCons` l) = TN.fromSNat x : permToList l
+
+permToList' :: Perm list -> [Int]
+permToList' = map fromIntegral . permToList
+
+
+-- ** Applying permutations
+
+type family Elem x l where
+ Elem x '[] = 'False
+ Elem x (x : _) = 'True
+ Elem x (_ : ys) = Elem x ys
+
+type family AllElem' as bs where
+ AllElem' '[] bs = 'True
+ AllElem' (a : as) bs = Elem a bs && AllElem' as bs
+
+type AllElem as bs = Assert (AllElem' as bs)
+ (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs))
+
+type family Count i n where
+ Count n n = '[]
+ Count i n = i : Count (i + 1) n
+
+type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as)
+
+type family Index i sh where
+ Index 0 (n : sh) = n
+ Index i (_ : sh) = Index (i - 1) sh
+
+type family Permute is sh where
+ Permute '[] sh = '[]
+ Permute (i : is) sh = Index i sh : Permute is sh
+
+type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh
+
+type family TakeLen ref l where
+ TakeLen '[] l = '[]
+ TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs
+
+type family DropLen ref l where
+ DropLen '[] l = l
+ DropLen (_ : ref) (_ : xs) = DropLen ref xs
+
+listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f
+listxTakeLen PNil _ = ZX
+listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh
+listxTakeLen (_ `PCons` _) ZX = error "IsPermutation longer than shape"
+
+listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f
+listxDropLen PNil sh = sh
+listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh
+listxDropLen (_ `PCons` _) ZX = error "IsPermutation longer than shape"
+
+listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f
+listxPermute PNil _ = ZX
+listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
+ listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh)
+
+listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f
+listxIndex _ _ SZ (n ::% _) rest = n ::% rest
+listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = listxIndex p pT i sh rest
+listxIndex _ _ _ ZX _ = error "Index into empty shape"
+
+listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
+listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
+
+ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i
+ixxPermutePrefix = coerce (listxPermutePrefix @(Const i))
+
+ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
+ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
+
+ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
+ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
+
+ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
+ssxPermute = coerce (listxPermute @(SMayNat () SNat))
+
+ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT)
+ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
+
+ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
+ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
+
+shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
+shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
+
+
+-- * Operations on permutations
+
+-- TODO: test this thing more properly
+permInverse :: Perm is
+ -> (forall is'.
+ IsPermutation is'
+ => Perm is'
+ -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh)
+ -> r)
+ -> r
+permInverse = \perm k ->
+ genPerm perm $ \(invperm :: Perm is') ->
+ let sn = permLengthSNat invperm
+ in case (provePerm1 (Proxy @is') sn invperm, provePerm2 (SNat @0) sn invperm) of
+ (Just Refl, Just Refl) ->
+ k invperm
+ (\ssh -> case provePermInverse perm invperm ssh of
+ Just eq -> eq
+ Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm
+ ++ " ; invperm = " ++ show invperm)
+ _ -> error $ "permInverse: did not generate permutation? perm = " ++ show perm
+ ++ " ; invperm = " ++ show invperm
+ where
+ genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r
+ genPerm perm =
+ let permList = permToList' perm
+ in toHList $ map snd (sort (zip permList [0..]))
+ where
+ toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r
+ toHList [] k = k PNil
+ toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l)
+
+ lemElemCount :: (0 <= n, Compare n m ~ LT) => proxy n -> proxy m -> Elem n (Count 0 m) :~: True
+ lemElemCount _ _ = unsafeCoerceRefl
+
+ lemCount :: (OrdCond (Compare i n) True False True ~ True) => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n
+ lemCount _ _ = unsafeCoerceRefl
+
+ lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True
+ lemElem _ _ = unsafeCoerceRefl
+
+ provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is'
+ -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True)
+ provePerm1 _ _ PNil = Just (Refl)
+ provePerm1 p rtop@SNat (PCons sn@SNat perm)
+ | Just Refl <- provePerm1 p rtop perm
+ = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of
+ (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
+ (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl
+ _ -> Nothing
+ | otherwise
+ = Nothing
+
+ provePerm2 :: SNat i -> SNat n -> Perm is'
+ -> Maybe (AllElem' (Count i n) is' :~: True)
+ provePerm2 = \i@(SNat :: SNat i) n@SNat perm ->
+ case cmpNat i n of
+ EQI -> Just Refl
+ LTI | Refl <- lemCount i n
+ , Just Refl <- provePerm2 (SNat @(i + 1)) n perm
+ -> checkElem i perm
+ | otherwise -> Nothing
+ GTI -> error "unreachable"
+ where
+ checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True)
+ checkElem _ PNil = Nothing
+ checkElem i@SNat (PCons k@SNat perm :: Perm is') =
+ case sameNat i k of
+ Just Refl -> Just Refl
+ Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl
+ | otherwise -> Nothing
+
+ provePermInverse :: Perm is -> Perm is' -> StaticShX sh
+ -> Maybe (Permute is' (Permute is sh) :~: sh)
+ provePermInverse perm perminv ssh =
+ ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh
+
+type family MapSucc is where
+ MapSucc '[] = '[]
+ MapSucc (i : is) = i + 1 : MapSucc is
+
+permShift1 :: Perm l -> Perm (0 : MapSucc l)
+permShift1 = (SNat @0 `PCons`) . permMapSucc
+ where
+ permMapSucc :: Perm l -> Perm (MapSucc l)
+ permMapSucc PNil = PNil
+ permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns
+
+
+-- * Lemmas
+
+lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is
+lemRankPermute _ PNil = Refl
+lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl
+
+lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
+ => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is
+lemRankDropLen ZKX PNil = Refl
+lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!% _) PNil = Refl
+lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0"
+
+lemIndexSucc :: Proxy i -> Proxy a -> Proxy l
+ -> Index (i + 1) (a : l) :~: Index i l
+lemIndexSucc _ _ _ = unsafeCoerceRefl
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs
new file mode 100644
index 0000000..a16da76
--- /dev/null
+++ b/src/Data/Array/Mixed/Shape.hs
@@ -0,0 +1,455 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Mixed.Shape where
+
+import Control.DeepSeq (NFData(..))
+import qualified Data.Foldable as Foldable
+import Data.Functor.Const
+import Data.Kind (Type, Constraint)
+import Data.Monoid (Sum(..))
+import Data.Proxy
+import Data.Type.Equality
+import GHC.Generics (Generic)
+import GHC.IsList (IsList)
+import qualified GHC.IsList as IsList
+import GHC.TypeLits
+
+import Data.Array.Mixed.Types
+import Data.Coerce
+import Data.Bifunctor (first)
+
+
+-- | The length of a type-level list. If the argument is a shape, then the
+-- result is the rank of that shape.
+type family Rank sh where
+ Rank '[] = 0
+ Rank (_ : sh) = Rank sh + 1
+
+
+-- * Mixed lists
+
+type role ListX nominal representational
+type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
+data ListX sh f where
+ ZX :: ListX '[] f
+ (::%) :: f n -> ListX sh f -> ListX (n : sh) f
+deriving instance (forall n. Eq (f n)) => Eq (ListX sh f)
+deriving instance (forall n. Ord (f n)) => Ord (ListX sh f)
+infixr 3 ::%
+
+instance (forall n. Show (f n)) => Show (ListX sh f) where
+ showsPrec _ = listxShow shows
+
+instance (forall n. NFData (f n)) => NFData (ListX sh f) where
+ rnf ZX = ()
+ rnf (x ::% l) = rnf x `seq` rnf l
+
+data UnconsListXRes f sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n)
+listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1)
+listxUncons (i ::% shl') = Just (UnconsListXRes shl' i)
+listxUncons ZX = Nothing
+
+listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g
+listxFmap _ ZX = ZX
+listxFmap f (x ::% xs) = f x ::% listxFmap f xs
+
+listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m
+listxFold _ ZX = mempty
+listxFold f (x ::% xs) = f x <> listxFold f xs
+
+listxLength :: ListX sh f -> Int
+listxLength = getSum . listxFold (\_ -> Sum 1)
+
+listxLengthSNat :: ListX sh f -> SNat (Rank sh)
+listxLengthSNat ZX = SNat
+listxLengthSNat (_ ::% l) | SNat <- listxLengthSNat l = SNat
+
+listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS
+listxShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListX sh' f -> ShowS
+ go _ ZX = id
+ go prefix (x ::% xs) = showString prefix . f x . go "," xs
+
+listxToList :: ListX sh' (Const i) -> [i]
+listxToList ZX = []
+listxToList (Const i ::% is) = i : listxToList is
+
+listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f
+listxAppend ZX idx' = idx'
+listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx'
+
+listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f
+listxDrop long ZX = long
+listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short
+
+
+-- * Mixed indices
+
+-- | This is a newtype over 'ListX'.
+type role IxX nominal representational
+type IxX :: [Maybe Nat] -> Type -> Type
+newtype IxX sh i = IxX (ListX sh (Const i))
+ deriving (Eq, Ord, Generic)
+
+pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i
+pattern ZIX = IxX ZX
+
+pattern (:.%)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => i -> IxX sh i -> IxX sh1 i
+pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i)))
+ where i :.% IxX shl = IxX (Const i ::% shl)
+infixr 3 :.%
+
+{-# COMPLETE ZIX, (:.%) #-}
+
+type IIxX sh = IxX sh Int
+
+instance Show i => Show (IxX sh i) where
+ showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l
+
+instance Functor (IxX sh) where
+ fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l)
+
+instance Foldable (IxX sh) where
+ foldMap f (IxX l) = listxFold (f . getConst) l
+
+instance NFData i => NFData (IxX sh i)
+
+ixxZero :: StaticShX sh -> IIxX sh
+ixxZero ZKX = ZIX
+ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh
+
+ixxZero' :: IShX sh -> IIxX sh
+ixxZero' ZSX = ZIX
+ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh
+
+ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i
+ixxAppend = coerce (listxAppend @_ @(Const i))
+
+ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i
+ixxDrop = coerce (listxDrop @(Const i) @(Const i))
+
+ixxFromLinear :: IShX sh -> Int -> IIxX sh
+ixxFromLinear = \sh i -> case go sh i of
+ (idx, 0) -> idx
+ _ -> error $ "ixxFromLinear: out of range (" ++ show i ++
+ " in array of shape " ++ show sh ++ ")"
+ where
+ -- returns (index in subarray, remaining index in enclosing array)
+ go :: IShX sh -> Int -> (IIxX sh, Int)
+ go ZSX i = (ZIX, i)
+ go (n :$% sh) i =
+ let (idx, i') = go sh i
+ (upi, locali) = i' `quotRem` fromSMayNat' n
+ in (locali :.% idx, upi)
+
+ixxToLinear :: IShX sh -> IIxX sh -> Int
+ixxToLinear = \sh i -> fst (go sh i)
+ where
+ -- returns (index in subarray, size of subarray)
+ go :: IShX sh -> IIxX sh -> (Int, Int)
+ go ZSX ZIX = (0, 1)
+ go (n :$% sh) (i :.% ix) =
+ let (lidx, sz) = go sh ix
+ in (sz * i + lidx, fromSMayNat' n * sz)
+
+
+-- * Mixed shapes
+
+data SMayNat i f n where
+ SUnknown :: i -> SMayNat i f Nothing
+ SKnown :: f n -> SMayNat i f (Just n)
+deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n)
+deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n)
+deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n)
+
+instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where
+ rnf (SUnknown i) = rnf i
+ rnf (SKnown x) = rnf x
+
+fromSMayNat :: (n ~ Nothing => i -> r)
+ -> (forall m. n ~ Just m => f m -> r)
+ -> SMayNat i f n -> r
+fromSMayNat f _ (SUnknown i) = f i
+fromSMayNat _ g (SKnown s) = g s
+
+fromSMayNat' :: SMayNat Int SNat n -> Int
+fromSMayNat' = fromSMayNat id fromSNat'
+
+type family AddMaybe n m where
+ AddMaybe Nothing _ = Nothing
+ AddMaybe (Just _) Nothing = Nothing
+ AddMaybe (Just n) (Just m) = Just (n + m)
+
+smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
+smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
+smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
+smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
+
+
+-- | This is a newtype over 'ListX'.
+type role ShX nominal representational
+type ShX :: [Maybe Nat] -> Type -> Type
+newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
+ deriving (Eq, Ord, Generic)
+
+pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
+pattern ZSX = ShX ZX
+
+pattern (:$%)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => SMayNat i SNat n -> ShX sh i -> ShX sh1 i
+pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i))
+ where i :$% ShX shl = ShX (i ::% shl)
+infixr 3 :$%
+
+{-# COMPLETE ZSX, (:$%) #-}
+
+type IShX sh = ShX sh Int
+
+instance Show i => Show (ShX sh i) where
+ showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+
+instance Functor (ShX sh) where
+ fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l)
+
+instance NFData i => NFData (ShX sh i) where
+ rnf (ShX ZX) = ()
+ rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l)
+ rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l)
+
+shxLength :: ShX sh i -> Int
+shxLength (ShX l) = listxLength l
+
+-- | This is more than @geq@: it also checks that the integers (the unknown
+-- dimensions) are the same.
+shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
+shxEqual ZSX ZSX = Just Refl
+shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
+ | Just Refl <- sameNat n m
+ , Just Refl <- shxEqual sh sh'
+ = Just Refl
+shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh')
+ | i == j
+ , Just Refl <- shxEqual sh sh'
+ = Just Refl
+shxEqual _ _ = Nothing
+
+-- | The number of elements in an array described by this shape.
+shxSize :: IShX sh -> Int
+shxSize ZSX = 1
+shxSize (n :$% sh) = fromSMayNat' n * shxSize sh
+
+shxToList :: IShX sh -> [Int]
+shxToList ZSX = []
+shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh
+
+shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
+shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
+
+shxTail :: ShX (n : sh) i -> ShX sh i
+shxTail (_ :$% sh) = sh
+
+shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i
+shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
+
+shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i
+shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
+
+shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i
+shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
+
+shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i
+shxTakeSSX _ = flip go
+ where
+ go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i
+ go ZKX _ = ZSX
+ go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh
+
+-- This is a weird operation, so it has a long name
+shxCompleteZeros :: StaticShX sh -> IShX sh
+shxCompleteZeros ZKX = ZSX
+shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh
+shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh
+
+shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i)
+shxSplitApp _ ZKX idx = (ZSX, idx)
+shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx)
+
+shxEnum :: IShX sh -> [IIxX sh]
+shxEnum = \sh -> go sh id []
+ where
+ go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a]
+ go ZSX f = (f ZIX :)
+ go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]
+
+
+-- * Static mixed shapes
+
+-- | The part of a shape that is statically known. (A newtype over 'ListX'.)
+type StaticShX :: [Maybe Nat] -> Type
+newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat))
+ deriving (Eq, Ord)
+
+pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
+pattern ZKX = StaticShX ZX
+
+pattern (:!%)
+ :: forall {sh1}.
+ forall n sh. (n : sh ~ sh1)
+ => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1
+pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i))
+ where i :!% StaticShX shl = StaticShX (i ::% shl)
+infixr 3 :!%
+
+{-# COMPLETE ZKX, (:!%) #-}
+
+instance Show (StaticShX sh) where
+ showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
+
+ssxLength :: StaticShX sh -> Int
+ssxLength (StaticShX l) = listxLength l
+
+-- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@
+-- class of the @some@ package.
+ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
+ssxGeq ZKX ZKX = Just Refl
+ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh')
+ | Just Refl <- sameNat n m
+ , Just Refl <- ssxGeq sh sh'
+ = Just Refl
+ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh')
+ | Just Refl <- ssxGeq sh sh'
+ = Just Refl
+ssxGeq _ _ = Nothing
+
+ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
+ssxAppend ZKX sh' = sh'
+ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
+
+ssxTail :: StaticShX (n : sh) -> StaticShX sh
+ssxTail (_ :!% ssh) = ssh
+
+ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh'
+ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
+
+-- | This may fail if @sh@ has @Nothing@s in it.
+ssxToShX' :: StaticShX sh -> Maybe (IShX sh)
+ssxToShX' ZKX = Just ZSX
+ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh
+ssxToShX' (SUnknown _ :!% _) = Nothing
+
+ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
+ssxReplicate SZ = ZKX
+ssxReplicate (SS (n :: SNat n'))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n'
+ = SUnknown () :!% ssxReplicate n
+
+ssxIotaFrom :: Int -> StaticShX sh -> [Int]
+ssxIotaFrom _ ZKX = []
+ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh
+
+ssxFromShape :: IShX sh -> StaticShX sh
+ssxFromShape ZSX = ZKX
+ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh
+
+
+-- | Evidence for the static part of a shape. This pops up only when you are
+-- polymorphic in the element type of an array.
+type KnownShX :: [Maybe Nat] -> Constraint
+class KnownShX sh where knownShX :: StaticShX sh
+instance KnownShX '[] where knownShX = ZKX
+instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX
+instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX
+
+
+-- * Flattening
+
+type Flatten sh = Flatten' 1 sh
+
+type family Flatten' acc sh where
+ Flatten' acc '[] = Just acc
+ Flatten' acc (Nothing : sh) = Nothing
+ Flatten' acc (Just n : sh) = Flatten' (acc * n) sh
+
+-- This function is currently unused
+ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh)
+ssxFlatten = go (SNat @1)
+ where
+ go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)
+ go acc ZKX = SKnown acc
+ go _ (SUnknown () :!% _) = SUnknown ()
+ go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh
+
+shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh)
+shxFlatten = go (SNat @1)
+ where
+ go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
+ go acc ZSX = SKnown acc
+ go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh)
+ go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh
+
+ goUnknown :: Int -> IShX sh -> Int
+ goUnknown acc ZSX = acc
+ goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh
+ goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh
+
+
+-- | Very untyped: only length is checked (at runtime).
+instance KnownShX sh => IsList (ListX sh (Const i)) where
+ type Item (ListX sh (Const i)) = i
+ fromList topl = go (knownShX @sh) topl
+ where
+ go :: StaticShX sh' -> [i] -> ListX sh' (Const i)
+ go ZKX [] = ZX
+ go (_ :!% sh) (i : is) = Const i ::% go sh is
+ go _ _ = error $ "IsList(ListX): Mismatched list length (type says "
+ ++ show (ssxLength (knownShX @sh)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = listxToList
+
+-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
+instance KnownShX sh => IsList (IxX sh i) where
+ type Item (IxX sh i) = i
+ fromList = IxX . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length and known dimensions are checked (at runtime).
+instance KnownShX sh => IsList (ShX sh Int) where
+ type Item (ShX sh Int) = Int
+ fromList topl = ShX (go (knownShX @sh) topl)
+ where
+ go :: StaticShX sh' -> [Int] -> ListX sh' (SMayNat Int SNat)
+ go ZKX [] = ZX
+ go (SKnown sn :!% sh) (i : is)
+ | i == fromSNat' sn = SKnown sn ::% go sh is
+ | otherwise = error $ "IsList(ShX): Value does not match typing (type says "
+ ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
+ go (SUnknown () :!% sh) (i : is) = SUnknown i ::% go sh is
+ go _ _ = error $ "IsList(ShX): Mismatched list length (type says "
+ ++ show (ssxLength (knownShX @sh)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = shxToList
diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs
new file mode 100644
index 0000000..d77513f
--- /dev/null
+++ b/src/Data/Array/Mixed/Types.hs
@@ -0,0 +1,110 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE NoStarIsType #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Mixed.Types (
+ -- * Reified evidence of a type class
+ Dict(..),
+
+ -- * Type-level naturals
+ pattern SZ, pattern SS,
+ fromSNat',
+ snatPlus, snatMul,
+
+ -- * Type-level lists
+ type (++),
+ lemAppNil,
+ lemAppAssoc,
+ Replicate,
+ lemReplicateSucc,
+
+ -- * Unsafe
+ unsafeCoerceRefl,
+) where
+
+import Data.Type.Equality
+import Data.Proxy
+import GHC.TypeLits
+import qualified GHC.TypeNats as TN
+import qualified Unsafe.Coerce
+
+
+-- | Evidence for the constraint @c a@.
+data Dict c a where
+ Dict :: c a => Dict c a
+
+fromSNat' :: SNat n -> Int
+fromSNat' = fromIntegral . fromSNat
+
+pattern SZ :: () => (n ~ 0) => SNat n
+pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
+ where SZ = SNat
+
+pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
+pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
+ where SS = snatSucc
+
+{-# COMPLETE SZ, SS #-}
+
+snatSucc :: SNat n -> SNat (n + 1)
+snatSucc SNat = SNat
+
+data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
+snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
+snatPred snp1 =
+ withKnownNat snp1 $
+ case cmpNat (Proxy @1) (Proxy @np1) of
+ LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ GTI -> Nothing
+
+-- This should be a function in base
+snatPlus :: SNat n -> SNat m -> SNat (n + m)
+snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce
+
+-- This should be a function in base
+snatMul :: SNat n -> SNat m -> SNat (n * m)
+snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce
+
+
+-- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to
+-- only typecheck for actual type equalities. One cannot, e.g. accidentally
+-- write this:
+--
+-- @
+-- foo :: Proxy a -> Proxy b -> a :~: b
+-- foo = unsafeCoerceRefl
+-- @
+--
+-- which would have been permitted with normal 'Unsafe.Coerce.unsafeCoerce',
+-- but would have resulted in interesting memory errors at runtime.
+unsafeCoerceRefl :: a :~: b
+unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl
+
+
+-- | Type-level list append.
+type family l1 ++ l2 where
+ '[] ++ l2 = l2
+ (x : xs) ++ l2 = x : xs ++ l2
+
+lemAppNil :: l ++ '[] :~: l
+lemAppNil = unsafeCoerceRefl
+
+lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
+lemAppAssoc _ _ _ = unsafeCoerceRefl
+
+type family Replicate n a where
+ Replicate 0 a = '[]
+ Replicate n a = a : Replicate (n - 1) a
+
+lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
+lemReplicateSucc = unsafeCoerceRefl
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index e3af0ee..1a4e094 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -58,8 +58,9 @@ module Data.Array.Nested (
type (++),
Storable,
SNat, pattern SNat,
- HList,
- Permutation,
+ pattern SZ, pattern SS,
+ Perm(..),
+ IsPermutation,
KnownNatList(..),
listSToList,
shSToList,
@@ -69,7 +70,10 @@ module Data.Array.Nested (
import Prelude hiding (mappend)
import Data.Array.Mixed
+import Data.Array.Mixed.Internal.Arith
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
import Data.Array.Nested.Internal
-import Data.Array.Nested.Internal.Arith
import Foreign.Storable
import GHC.TypeLits
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 712c5f1..0870789 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -60,7 +60,11 @@ import Unsafe.Coerce
import Data.Array.Mixed
import qualified Data.Array.Mixed as X
-import Data.Array.Nested.Internal.Arith
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Internal.Arith
+import Data.Array.Mixed.Types
-- Invariant in the API
@@ -123,19 +127,19 @@ lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict
ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)
ssxFromSNat SZ = ZKX
-ssxFromSNat (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n
+ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n
lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn)
-lemRankReplicate :: SNat n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n
lemRankReplicate SZ = Refl
lemRankReplicate (SS (n :: SNat nm1))
- | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
, Refl <- lemRankReplicate n
= Refl
-lemRankMapJust :: forall sh. ShS sh -> X.Rank (MapJust sh) :~: X.Rank sh
+lemRankMapJust :: forall sh. ShS sh -> Rank (MapJust sh) :~: Rank sh
lemRankMapJust ZSS = Refl
lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl
@@ -146,9 +150,9 @@ lemReplicatePlusApp sn _ _ = go sn
go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
go SZ = Refl
go (SS (n :: SNat n'm1))
- | Refl <- X.lemReplicateSucc @a @n'm1
+ | Refl <- lemReplicateSucc @a @n'm1
, Refl <- go n
- = sym (X.lemReplicateSucc @a @(n'm1 + m))
+ = sym (lemReplicateSucc @a @(n'm1 + m))
lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
lemLeqPlus _ _ _ = Refl
@@ -156,17 +160,17 @@ lemLeqPlus _ _ _ = Refl
lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True
lemLeqSuccSucc _ _ = unsafeCoerce Refl
-lemDropLenApp :: X.Rank l1 <= X.Rank l2
+lemDropLenApp :: Rank l1 <= Rank l2
=> Proxy l1 -> Proxy l2 -> Proxy rest
- -> X.DropLen l1 l2 ++ rest :~: X.DropLen l1 (l2 ++ rest)
+ -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
lemDropLenApp _ _ _ = unsafeCoerce Refl
-lemTakeLenApp :: X.Rank l1 <= X.Rank l2
+lemTakeLenApp :: Rank l1 <= Rank l2
=> Proxy l1 -> Proxy l2 -> Proxy rest
- -> X.TakeLen l1 l2 :~: X.TakeLen l1 (l2 ++ rest)
+ -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest)
lemTakeLenApp _ _ _ = unsafeCoerce Refl
-srankSh :: ShX sh f -> SNat (X.Rank sh)
+srankSh :: ShX sh f -> SNat (Rank sh)
srankSh ZSX = SNat
srankSh (_ :$% sh) | SNat <- srankSh sh = SNat
@@ -585,11 +589,11 @@ class Elt a where
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
-> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
- mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2
+ mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
=> StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
- mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh)
- => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a
+ mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
+ => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a
-- ====== PRIVATE METHODS ====== --
@@ -635,20 +639,20 @@ class Elt a => KnownElt a where
instance Storable a => Elt (Primitive a) where
mshape (M_Primitive sh _) = sh
mindex (M_Primitive _ a) i = Primitive (X.index a i)
- mindexPartial (M_Primitive sh a) i = M_Primitive (X.shDropIx sh i) (X.indexPartial a i)
+ mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
mfromListOuter l@(arr1 :| _) =
let sh = SUnknown (length l) :$% mshape arr1
- in M_Primitive sh (X.fromListOuter (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l)))
- mtoListOuter (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toListOuter arr)
+ in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
mlift :: forall sh1 sh2.
StaticShX sh2
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
-> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
mlift ssh2 f (M_Primitive _ a)
- | Refl <- X.lemAppNil @sh1
- , Refl <- X.lemAppNil @sh2
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
, let result = f ZKX a
= M_Primitive (X.shape ssh2 result) result
@@ -657,36 +661,36 @@ instance Storable a => Elt (Primitive a) where
-> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
-> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a)
mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b)
- | Refl <- X.lemAppNil @sh1
- , Refl <- X.lemAppNil @sh2
- , Refl <- X.lemAppNil @sh3
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ , Refl <- lemAppNil @sh3
, let result = f ZKX a b
= M_Primitive (X.shape ssh3 result) result
- mcast :: forall sh1 sh2 sh'. X.Rank sh1 ~ X.Rank sh2
+ mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
=> StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
mcast ssh1 sh2 _ (M_Primitive sh1' arr) =
- let (_, sh') = shAppSplit (Proxy @sh') ssh1 sh1'
- in M_Primitive (shAppend sh2 sh') (X.cast ssh1 sh2 (X.staticShapeFrom sh') arr)
+ let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
+ in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)
mtranspose perm (M_Primitive sh arr) =
- M_Primitive (X.shPermutePrefix perm sh)
- (X.transpose (X.staticShapeFrom sh) perm arr)
+ M_Primitive (shxPermutePrefix perm sh)
+ (X.transpose (ssxFromShape sh) perm arr)
mshapeTree _ = ()
mshapeTreeEq _ () () = True
mshapeTreeEmpty _ () = False
mshowShapeTree _ () = "()"
- mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x
+ mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
-- TODO: this use of toVector is suboptimal
mvecsWritePartial
:: forall sh' sh s.
IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
- let arrsh = X.shape (X.staticShapeFrom sh') arr
- offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' arrsh))
- VS.copy (VSM.slice offset (X.shapeSize arrsh) v) (X.toVector arr)
+ let arrsh = X.shape (ssxFromShape sh') arr
+ offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
+ VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
@@ -701,7 +705,7 @@ deriving via Primitive () instance Elt ()
instance Storable a => KnownElt (Primitive a) where
memptyArray sh = M_Primitive sh (X.empty sh)
- mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh)
+ mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
-- [PRIMITIVE ELEMENT TYPES LIST]
@@ -755,7 +759,7 @@ instance Elt a => Elt (Mixed sh' a) where
-- moverlongShape method, a prefix of which is mshape.
mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
mshape (M_Nest sh arr)
- = fst (shAppSplit (Proxy @sh') (X.staticShapeFrom sh) (mshape arr))
+ = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr))
mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
mindex (M_Nest _ arr) i = mindexPartial arr i
@@ -763,8 +767,8 @@ instance Elt a => Elt (Mixed sh' a) where
mindexPartial :: forall sh1 sh2.
Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
mindexPartial (M_Nest sh arr) i
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = M_Nest (X.shDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
mscalar = M_Nest ZSX
@@ -773,95 +777,95 @@ instance Elt a => Elt (Mixed sh' a) where
M_Nest (SUnknown (length l) :$% mshape arr)
(mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
- mtoListOuter (M_Nest sh arr) = map (M_Nest (X.shTail sh)) (mtoListOuter arr)
+ mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
mlift :: forall sh1 sh2.
StaticShX sh2
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
mlift ssh2 f (M_Nest sh1 arr) =
- let result = mlift (X.ssxAppend ssh2 ssh') f' arr
- (sh2, _) = shAppSplit (Proxy @sh') ssh2 (mshape result)
+ let result = mlift (ssxAppend ssh2 ssh') f' arr
+ (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
in M_Nest sh2 result
where
- ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr)))
+ ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr)))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
f' sshT
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
- , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
- = f (X.ssxAppend ssh' sshT)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
mlift2 :: forall sh1 sh2 sh3.
StaticShX sh3
-> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
-> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) =
- let result = mlift2 (X.ssxAppend ssh3 ssh') f' arr1 arr2
- (sh3, _) = shAppSplit (Proxy @sh') ssh3 (mshape result)
+ let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2
+ (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
in M_Nest sh3 result
where
- ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') (X.staticShapeFrom sh1) (mshape arr1)))
+ ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
f' sshT
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
- , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
- , Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
- = f (X.ssxAppend ssh' sshT)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
- mcast :: forall sh1 sh2 shT. X.Rank sh1 ~ X.Rank sh2
+ mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2
=> StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
mcast ssh1 sh2 _ (M_Nest sh1T arr)
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
- , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
- = let (_, shT) = shAppSplit (Proxy @shT) ssh1 sh1T
- in M_Nest (shAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr)
-
- mtranspose :: forall is sh. (X.Permutation is, X.Rank is <= X.Rank sh)
- => HList SNat is -> Mixed sh (Mixed sh' a)
- -> Mixed (X.PermutePrefix is sh) (Mixed sh' a)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
+ = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
+ in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr)
+
+ mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
+ => Perm is -> Mixed sh (Mixed sh' a)
+ -> Mixed (PermutePrefix is sh) (Mixed sh' a)
mtranspose perm (M_Nest sh arr)
- | let sh' = X.shDropSh @sh @sh' (mshape arr) sh
- , Refl <- X.lemRankApp (X.staticShapeFrom sh) (X.staticShapeFrom sh')
- , Refl <- lemLeqPlus (Proxy @(X.Rank is)) (Proxy @(X.Rank sh)) (Proxy @(X.Rank sh'))
- , Refl <- X.lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
+ | let sh' = shxDropSh @sh @sh' (mshape arr) sh
+ , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh')
+ , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
+ , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
, Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
, Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
- = M_Nest (X.shPermutePrefix perm sh)
+ = M_Nest (shxPermutePrefix perm sh)
(mtranspose perm arr)
mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
- mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (X.staticShapeFrom (mshape arr)))))
+ mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr)))))
mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
- mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+ mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
- mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs
+ mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
mvecsWritePartial :: forall sh1 sh2 s.
IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
-> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
-> ST s ()
mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = mvecsWritePartial (X.shAppend sh12 sh') idx arr vecs
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
- mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (X.shAppend sh sh') vecs
+ mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
- memptyArray sh = M_Nest sh (memptyArray (X.shAppend sh (X.completeShXzeros (knownShX @sh'))))
+ memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
mvecsUnsafeNew sh example
- | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
- | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh sh') (mindex example (X.zeroIxX (X.staticShapeFrom sh')))
+ | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh')))
where
sh' = mshape example
- mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
+ mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
-- | Create an array given a size and a function that computes the element at a
@@ -882,10 +886,10 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
-- array. The type of 'mgenerate' allows this requirement to be broken very
-- easily, hence the runtime check.
mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
-mgenerate sh f = case X.enumShape sh of
+mgenerate sh f = case shxEnum sh of
[] -> memptyArray sh
firstidx : restidxs ->
- let firstelem = f (X.zeroIxX' sh)
+ let firstelem = f (ixxZero' sh)
shapetree = mshapeTree firstelem
in if mshapeTreeEmpty (Proxy @a) shapetree
then memptyArray sh
@@ -905,28 +909,28 @@ msumOuter1P :: forall sh n a. (Storable a, NumElt a)
=> Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
msumOuter1P (M_Primitive (n :$% sh) arr) =
let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
- in M_Primitive sh (X.sumOuter nssh (X.staticShapeFrom sh) arr)
+ in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr)
msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
=> Mixed (n : sh) a -> Mixed sh a
msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
mappend :: forall n m sh a. Elt a
- => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a
+ => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
where
sn :$% sh = mshape arr1
sm :$% _ = mshape arr2
- ssh = X.staticShapeFrom sh
- snm :: SMayNat () SNat (X.AddMaybe n m)
+ ssh = ssxFromShape sh
+ snm :: SMayNat () SNat (AddMaybe n m)
snm = case (sn, sm) of
(SUnknown{}, _) -> SUnknown ()
(SKnown{}, SUnknown{}) -> SUnknown ()
- (SKnown n, SKnown m) -> SKnown (X.plusSNat n m)
+ (SKnown n, SKnown m) -> SKnown (snatPlus n m)
f :: forall sh' b. Storable b
- => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b
- f ssh' = X.append (X.ssxAppend ssh ssh')
+ => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b
+ f ssh' = X.append (ssxAppend ssh ssh')
mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
@@ -971,9 +975,9 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
-> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
-> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
mrerankP ssh sh2 f (M_Primitive sh arr) =
- let sh1 = shDropSSX sh ssh
- in M_Primitive (X.shAppend (shTakeSSX (Proxy @sh1) sh ssh) sh2)
- (X.rerank ssh (X.staticShapeFrom sh1) (X.staticShapeFrom sh2)
+ let sh1 = shxDropSSX sh ssh
+ in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2)
+ (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2)
(\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
arr)
@@ -988,10 +992,10 @@ mrerank ssh sh2 f (toPrimitive -> arr) =
mreplicate :: forall sh sh' a. Elt a
=> IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
mreplicate sh arr =
- let ssh' = X.staticShapeFrom (mshape arr)
- in mlift (X.ssxAppend (X.staticShapeFrom sh) ssh')
+ let ssh' = ssxFromShape (mshape arr)
+ in mlift (ssxAppend (ssxFromShape sh) ssh')
(\(sshT :: StaticShX shT) ->
- case X.lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
+ case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
Refl -> X.replicate sh (ssxAppend ssh' sshT))
arr
@@ -1005,18 +1009,18 @@ mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
mslice i n arr =
let _ :$% sh = mshape arr
- in mlift (SKnown n :!% X.staticShapeFrom sh) (\_ -> X.slice i n) arr
+ in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr
msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
-msliceU i n arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.sliceU i n) arr
+msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr
mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
-mrev1 arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.rev1) arr
+mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr
mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a
mreshape sh' arr =
- mlift (X.staticShapeFrom sh')
- (\sshIn -> X.reshapePartial (X.staticShapeFrom (mshape arr)) sshIn sh')
+ mlift (ssxFromShape sh')
+ (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh')
arr
miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a
@@ -1095,26 +1099,26 @@ instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where
log1pexp = mliftNumElt1 floatEltLog1pexp
log1mexp = mliftNumElt1 floatEltLog1mexp
-mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (X.Rank sh) a
+mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
mtoRanked arr
- | Refl <- X.lemAppNil @sh
- , Refl <- X.lemAppNil @(Replicate (X.Rank sh) (Nothing @Nat))
+ | Refl <- lemAppNil @sh
+ , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat))
, Refl <- lemRankReplicate (srankSh (mshape arr))
- = Ranked (mcast (X.staticShapeFrom (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr)
+ = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr)
where
- convSh :: IShX sh' -> IShX (Replicate (X.Rank sh') Nothing)
+ convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing)
convSh ZSX = ZSX
convSh (smn :$% (sh :: IShX sh'T))
- | Refl <- X.lemReplicateSucc @(Nothing @Nat) @(X.Rank sh'T)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T)
= SUnknown (fromSMayNat' smn) :$% convSh sh
-mcastToShaped :: forall sh sh' a. (Elt a, X.Rank sh ~ X.Rank sh')
+mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
=> Mixed sh a -> ShS sh' -> Shaped sh' a
mcastToShaped arr targetsh
- | Refl <- X.lemAppNil @sh
- , Refl <- X.lemAppNil @(MapJust sh')
+ | Refl <- lemAppNil @sh
+ , Refl <- lemAppNil @(MapJust sh')
, Refl <- lemRankMapJust targetsh
- = Shaped (mcast (X.staticShapeFrom (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr)
+ = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr)
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
@@ -1418,7 +1422,7 @@ zeroIxR :: SNat n -> IIxR n
zeroIxR SZ = ZIR
zeroIxR (SS n) = 0 :.: zeroIxR n
-ixCvtXR :: IIxX sh -> IIxR (X.Rank sh)
+ixCvtXR :: IIxX sh -> IIxR (Rank sh)
ixCvtXR ZIX = ZIR
ixCvtXR (n :.% idx) = n :.: ixCvtXR idx
@@ -1429,7 +1433,7 @@ shCvtXR' ZSX =
shCvtXR' (n :$% (idx :: IShX sh))
| Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
castWith (subst2 (lem1 @sh Refl))
- (X.fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
+ (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
where
lem1 :: forall sh' n' k.
k : sh' :~: Replicate n' Nothing
@@ -1443,13 +1447,13 @@ shCvtXR' (n :$% (idx :: IShX sh))
ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
ixCvtRX ZIR = ZIX
ixCvtRX (n :.: (idx :: IxR m Int)) =
- castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m))
+ castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
(n :.% ixCvtRX idx)
shCvtRX :: IShR n -> IShX (Replicate n Nothing)
shCvtRX ZSR = ZSX
shCvtRX (n :$: (idx :: ShR m Int)) =
- castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m))
+ castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
(SUnknown n :$% shCvtRX idx)
shapeSizeR :: IShR n -> Int
@@ -1506,7 +1510,7 @@ rsumOuter1P :: forall n a.
(Storable a, NumElt a)
=> Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
rsumOuter1P (Ranked arr)
- | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
= Ranked (msumOuter1P arr)
rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
@@ -1559,7 +1563,7 @@ rappend :: forall n a. Elt a
rappend arr1 arr2
| sn@SNat <- snatFromShR (rshape arr1)
, Dict <- lemKnownReplicate sn
- , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ , Refl <- lemReplicateSucc @(Nothing @Nat) @n
= coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
arr1 arr2
@@ -1582,7 +1586,7 @@ rtoVector = coerce mtoVector
rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
rfromListOuter l
- | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
= Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
@@ -1593,7 +1597,7 @@ rfromList1Prim l = Ranked (mfromList1Prim l)
rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
rtoListOuter (Ranked arr)
- | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
= coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
rtoList1 :: Elt a => Ranked 1 a -> [a]
@@ -1677,7 +1681,7 @@ rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
rslice i n arr
- | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
= rlift (snatFromShR (rshape arr))
(\_ -> X.sliceU i n)
arr
@@ -1686,7 +1690,7 @@ rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
rrev1 arr =
rlift (snatFromShR (rshape arr))
(\(_ :: StaticShX sh') ->
- case X.lemReplicateSucc @(Nothing @Nat) @n of
+ case lemReplicateSucc @(Nothing @Nat) @n of
Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
arr
@@ -1707,12 +1711,12 @@ rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing
rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr)
rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
-rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr)
+rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a
-rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (X.staticShapeFrom (X.shape (ssxFromSNat sn) arr)) arr)
+rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
-rcastToShaped :: Elt a => Ranked (X.Rank sh) a -> ShS sh -> Shaped sh a
+rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
rcastToShaped (Ranked arr) targetsh
| Refl <- lemRankReplicate (srankSh (shCvtSX targetsh))
, Refl <- lemRankMapJust targetsh
@@ -1809,7 +1813,7 @@ shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh
shapeSizeS :: ShS sh -> Int
shapeSizeS ZSS = 1
-shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh
+shapeSizeS (n :$$ sh) = fromSNat' n * shapeSizeS sh
sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
@@ -1838,14 +1842,14 @@ slift :: forall sh1 sh2 a. Elt a
=> ShS sh2
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
-> Shaped sh1 a -> Shaped sh2 a
-slift sh2 f (Shaped arr) = Shaped (mlift (X.staticShapeFrom (shCvtSX sh2)) f arr)
+slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr)
-- | See the documentation of 'mlift'.
slift2 :: forall sh1 sh2 sh3 a. Elt a
=> ShS sh3
-> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
-> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
-slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (X.staticShapeFrom (shCvtSX sh3)) f arr1 arr2)
+slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2)
ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
=> Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
@@ -1855,28 +1859,28 @@ ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
-lemCommMapJustTakeLen :: HList SNat is -> ShS sh -> X.TakeLen is (MapJust sh) :~: MapJust (X.TakeLen is sh)
-lemCommMapJustTakeLen HNil _ = Refl
-lemCommMapJustTakeLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl
-lemCommMapJustTakeLen (_ `HCons` _) ZSS = error "TakeLen of empty"
+lemCommMapJustTakeLen :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)
+lemCommMapJustTakeLen PNil _ = Refl
+lemCommMapJustTakeLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl
+lemCommMapJustTakeLen (_ `PCons` _) ZSS = error "TakeLen of empty"
-lemCommMapJustDropLen :: HList SNat is -> ShS sh -> X.DropLen is (MapJust sh) :~: MapJust (X.DropLen is sh)
-lemCommMapJustDropLen HNil _ = Refl
-lemCommMapJustDropLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl
-lemCommMapJustDropLen (_ `HCons` _) ZSS = error "DropLen of empty"
+lemCommMapJustDropLen :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh)
+lemCommMapJustDropLen PNil _ = Refl
+lemCommMapJustDropLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl
+lemCommMapJustDropLen (_ `PCons` _) ZSS = error "DropLen of empty"
-lemCommMapJustIndex :: SNat i -> ShS sh -> X.Index i (MapJust sh) :~: Just (X.Index i sh)
+lemCommMapJustIndex :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)
lemCommMapJustIndex SZ (_ :$$ _) = Refl
lemCommMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
| Refl <- lemCommMapJustIndex i sh
- , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
- , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
= Refl
lemCommMapJustIndex _ ZSS = error "Index of empty"
-lemCommMapJustPermute :: HList SNat is -> ShS sh -> X.Permute is (MapJust sh) :~: MapJust (X.Permute is sh)
-lemCommMapJustPermute HNil _ = Refl
-lemCommMapJustPermute (i `HCons` is) sh
+lemCommMapJustPermute :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh)
+lemCommMapJustPermute PNil _ = Refl
+lemCommMapJustPermute (i `PCons` is) sh
| Refl <- lemCommMapJustPermute is sh
, Refl <- lemCommMapJustIndex i sh
= Refl
@@ -1885,53 +1889,53 @@ listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
listsAppend ZS idx' = idx'
listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
-listsTakeLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.TakeLen is sh) f
-listsTakeLen HNil _ = ZS
-listsTakeLen (_ `HCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh
-listsTakeLen (_ `HCons` _) ZS = error "Permutation longer than shape"
+listsTakeLen :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
+listsTakeLen PNil _ = ZS
+listsTakeLen (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh
+listsTakeLen (_ `PCons` _) ZS = error "Permutation longer than shape"
-listsDropLen :: forall f is sh. HList SNat is -> ListS sh f -> ListS (DropLen is sh) f
-listsDropLen HNil sh = sh
-listsDropLen (_ `HCons` is) (_ ::$ sh) = listsDropLen is sh
-listsDropLen (_ `HCons` _) ZS = error "Permutation longer than shape"
+listsDropLen :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
+listsDropLen PNil sh = sh
+listsDropLen (_ `PCons` is) (_ ::$ sh) = listsDropLen is sh
+listsDropLen (_ `PCons` _) ZS = error "Permutation longer than shape"
-listsPermute :: forall f is sh. HList SNat is -> ListS sh f -> ListS (X.Permute is sh) f
-listsPermute HNil _ = ZS
-listsPermute (i `HCons` (is :: HList SNat is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh)
+listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
+listsPermute PNil _ = ZS
+listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh)
-listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (X.Permute is shT) f -> ListS (X.Index i sh : X.Permute is shT) f
+listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (Permute is shT) f -> ListS (Index i sh : Permute is shT) f
listsIndex _ _ SZ (n ::$ _) rest = n ::$ rest
listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) rest
- | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
= listsIndex p pT i sh rest
listsIndex _ _ _ ZS _ = error "Index into empty shape"
-shsTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh)
+shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
shsTakeLen = coerce (listsTakeLen @SNat)
-shsPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh)
+shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
shsPermute = coerce (listsPermute @SNat)
-shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT)
+shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (Permute is shT) -> ShS (Index i sh : Permute is shT)
shsIndex pis pshT = coerce (listsIndex @SNat pis pshT)
-applyPermS :: forall f is sh. HList SNat is -> ListS sh f -> ListS (PermutePrefix is sh) f
+applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLen perm sh)) (listsDropLen perm sh)
-applyPermIxS :: forall i is sh. HList SNat is -> IxS sh i -> IxS (PermutePrefix is sh) i
+applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
applyPermIxS = coerce (applyPermS @(Const i))
-applyPermShS :: forall is sh. HList SNat is -> ShS sh -> ShS (PermutePrefix is sh)
+applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
applyPermShS = coerce (applyPermS @SNat)
-stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a)
- => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a
+stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a)
+ => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a
stranspose perm sarr@(Shaped arr)
| Refl <- lemRankMapJust (sshape sarr)
, Refl <- lemCommMapJustTakeLen perm (sshape sarr)
, Refl <- lemCommMapJustDropLen perm (sshape sarr)
, Refl <- lemCommMapJustPermute perm (shsTakeLen perm (sshape sarr))
- , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(X.DropLen is sh))
+ , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh))
= Shaped (mtranspose perm arr)
sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
@@ -1969,7 +1973,7 @@ stoList1 = map sunScalar . stoListOuter
sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
sfromListPrim sn l
- | Refl <- X.lemAppNil @'[Just n]
+ | Refl <- lemAppNil @'[Just n]
= let ssh = SUnknown () :!% ZKX
xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
@@ -1989,7 +1993,7 @@ srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
srerankP sh sh2 f sarr@(Shaped arr)
| Refl <- lemCommMapJustApp sh (Proxy @sh1)
, Refl <- lemCommMapJustApp sh (Proxy @sh2)
- = Shaped (mrerankP (X.staticShapeFrom (shTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (X.staticShapeFrom (shCvtSX sh))))
+ = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh))))
(shCvtSX sh2)
(\a -> let Shaped r = f (Shaped a) in r)
arr)
@@ -2033,12 +2037,12 @@ sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a)
sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr)
sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a)
-sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (X.staticShapeFrom (shCvtSX sh)) arr)
+sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr)
sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
-sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (X.staticShapeFrom (shCvtSX sh)) arr)
+sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr)
-stoRanked :: Elt a => Shaped sh a -> Ranked (X.Rank sh) a
+stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
stoRanked sarr@(Shaped arr)
| Refl <- lemRankMapJust (sshape sarr)
= mtoRanked arr
diff --git a/test/Gen.hs b/test/Gen.hs
index 2d2a30b..9652963 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -18,7 +18,8 @@ import Foreign
import GHC.TypeLits
import qualified GHC.TypeNats as TN
-import Data.Array.Mixed (fromSNat', pattern SZ, pattern SS)
+import Data.Array.Mixed
+import Data.Array.Mixed.Types
import Data.Array.Nested
import Hedgehog
diff --git a/test/Util.hs b/test/Util.hs
index 1249bf9..9afa922 100644
--- a/test/Util.hs
+++ b/test/Util.hs
@@ -11,7 +11,7 @@ module Util where
import qualified Data.Array.RankedS as OR
import GHC.TypeLits
-import Data.Array.Mixed (fromSNat')
+import Data.Array.Mixed.Types (fromSNat')
import Data.Array.Nested