aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-17 22:53:52 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-17 22:53:52 +0200
commit4adbbd8e2e635cc4c647be40f0dd258668dd2452 (patch)
tree1f89ce0adc26ed98e80e759f2bf403b107d667e1 /src/Data/Array/Mixed.hs
parent06625c89089044b064bbc6cf36ea4e83199c19a4 (diff)
More WIP singletonisation
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs306
1 files changed, 174 insertions, 132 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 69c44ab..df506d6 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -8,6 +8,7 @@
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -24,6 +25,7 @@ module Data.Array.Mixed where
import qualified Data.Array.RankedS as S
import qualified Data.Array.Ranked as ORB
import Data.Coerce
+import Data.Functor.Const
import Data.Kind
import Data.Proxy
import Data.Type.Bool
@@ -32,6 +34,7 @@ import qualified Data.Vector.Storable as VS
import Foreign.Storable (Storable)
import GHC.TypeError
import GHC.TypeLits
+import qualified GHC.TypeNats as TypeNats
import Unsafe.Coerce (unsafeCoerce)
@@ -87,45 +90,115 @@ type family Replicate n a where
Replicate n a = a : Replicate (n - 1) a
+type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
+data ListX sh f where
+ ZX :: ListX '[] f
+ (::%) :: f n -> ListX sh f -> ListX (n : sh) f
+deriving instance (forall n. Show (f n)) => Show (ListX sh f)
+deriving instance (forall n. Eq (f n)) => Eq (ListX sh f)
+deriving instance (forall n. Ord (f n)) => Ord (ListX sh f)
+infixr 3 ::%
+
+data UnconsListXRes f sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n)
+unconsListX :: ListX sh1 f -> Maybe (UnconsListXRes f sh1)
+unconsListX (i ::% shl') = Just (UnconsListXRes shl' i)
+unconsListX ZX = Nothing
+
+fmapListX :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g
+fmapListX _ ZX = ZX
+fmapListX f (x ::% xs) = f x ::% fmapListX f xs
+
+foldListX :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m
+foldListX _ ZX = mempty
+foldListX f (x ::% xs) = f x <> foldListX f xs
+
+
type IxX :: [Maybe Nat] -> Type -> Type
-data IxX sh i where
- ZIX :: IxX '[] i
- (:.@) :: forall n sh i. i -> IxX sh i -> IxX (Just n : sh) i
- (:.?) :: forall sh i. i -> IxX sh i -> IxX (Nothing : sh) i
-deriving instance Show i => Show (IxX sh i)
-deriving instance Eq i => Eq (IxX sh i)
-deriving instance Ord i => Ord (IxX sh i)
-deriving instance Functor (IxX sh)
-deriving instance Foldable (IxX sh)
-infixr 3 :.@
-infixr 3 :.?
+newtype IxX sh i = IxX (ListX sh (Const i))
+ deriving (Show, Eq, Ord)
+
+pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i
+pattern ZIX = IxX ZX
+
+pattern (:.%)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => i -> IxX sh i -> IxX sh1 i
+pattern i :.% shl <- IxX (unconsListX -> Just (UnconsListXRes (IxX -> shl) (getConst -> i)))
+ where i :.% IxX shl = IxX (Const i ::% shl)
+infixr 3 :.%
+
+{-# COMPLETE ZIX, (:.%) #-}
type IIxX sh = IxX sh Int
+instance Functor (IxX sh) where
+ fmap f (IxX l) = IxX (fmapListX (Const . f . getConst) l)
+
+instance Foldable (IxX sh) where
+ foldMap f (IxX l) = foldListX (f . getConst) l
+
+
+data SMayNat i f n where
+ SUnknown :: i -> SMayNat i f Nothing
+ SKnown :: f n -> SMayNat i f (Just n)
+deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n)
+deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n)
+deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n)
+
+fromSMayNat :: (n ~ Nothing => i -> r) -> (forall m. n ~ Just m => f m -> r) -> SMayNat i f n -> r
+fromSMayNat f _ (SUnknown i) = f i
+fromSMayNat _ g (SKnown s) = g s
+
+fromSMayNat' :: SMayNat Int SNat n -> Int
+fromSMayNat' = fromSMayNat id fromSNat'
+
type ShX :: [Maybe Nat] -> Type -> Type
-data ShX sh i where
- ZSX :: ShX '[] i
- (:$@) :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i
- (:$?) :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i
-deriving instance Show i => Show (ShX sh i)
-deriving instance Eq i => Eq (ShX sh i)
-deriving instance Ord i => Ord (ShX sh i)
-deriving instance Functor (ShX sh)
-deriving instance Foldable (ShX sh)
-infixr 3 :$@
-infixr 3 :$?
+newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
+ deriving (Show, Eq, Ord)
+
+pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
+pattern ZSX = ShX ZX
+
+pattern (:$%)
+ :: forall {sh1} {i}.
+ forall n sh. (n : sh ~ sh1)
+ => SMayNat i SNat n -> ShX sh i -> ShX sh1 i
+pattern i :$% shl <- ShX (unconsListX -> Just (UnconsListXRes (ShX -> shl) i))
+ where i :$% ShX shl = ShX (i ::% shl)
+infixr 3 :$%
+
+{-# COMPLETE ZSX, (:$%) #-}
type IShX sh = ShX sh Int
+instance Functor (ShX sh) where
+ fmap f (ShX l) = ShX (fmapListX (fromSMayNat (SUnknown . f) SKnown) l)
+
+lengthShX :: ShX sh i -> Int
+lengthShX ZSX = 0
+lengthShX (_ :$% sh) = 1 + lengthShX sh
+
+
-- | The part of a shape that is statically known.
type StaticShX :: [Maybe Nat] -> Type
-data StaticShX sh where
- ZKSX :: StaticShX '[]
- (:!$@) :: SNat n -> StaticShX sh -> StaticShX (Just n : sh)
- (:!$?) :: () -> StaticShX sh -> StaticShX (Nothing : sh)
-deriving instance Show (StaticShX sh)
-infixr 3 :!$@
-infixr 3 :!$?
+newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat))
+ deriving (Show, Eq, Ord)
+
+pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
+pattern ZKX = StaticShX ZX
+
+pattern (:!%)
+ :: forall {sh1}.
+ forall n sh. (n : sh ~ sh1)
+ => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1
+pattern i :!% shl <- StaticShX (unconsListX -> Just (UnconsListXRes (StaticShX -> shl) i))
+ where i :!% StaticShX shl = StaticShX (i ::% shl)
+infixr 3 :!%
+
+{-# COMPLETE ZKX, (:!%) #-}
+
type family Rank sh where
Rank '[] = 0
@@ -136,70 +209,68 @@ newtype XArray sh a = XArray (S.Array (Rank sh) a)
deriving (Show)
zeroIxX :: StaticShX sh -> IIxX sh
-zeroIxX ZKSX = ZIX
-zeroIxX (_ :!$@ ssh) = 0 :.@ zeroIxX ssh
-zeroIxX (_ :!$? ssh) = 0 :.? zeroIxX ssh
+zeroIxX ZKX = ZIX
+zeroIxX (_ :!% ssh) = 0 :.% zeroIxX ssh
zeroIxX' :: IShX sh -> IIxX sh
zeroIxX' ZSX = ZIX
-zeroIxX' (_ :$@ sh) = 0 :.@ zeroIxX' sh
-zeroIxX' (_ :$? sh) = 0 :.? zeroIxX' sh
+zeroIxX' (_ :$% sh) = 0 :.% zeroIxX' sh
-- This is a weird operation, so it has a long name
completeShXzeros :: StaticShX sh -> IShX sh
-completeShXzeros ZKSX = ZSX
-completeShXzeros (n :!$@ ssh) = n :$@ completeShXzeros ssh
-completeShXzeros (_ :!$? ssh) = 0 :$? completeShXzeros ssh
+completeShXzeros ZKX = ZSX
+completeShXzeros (SUnknown () :!% ssh) = SUnknown 0 :$% completeShXzeros ssh
+completeShXzeros (SKnown n :!% ssh) = SKnown n :$% completeShXzeros ssh
-- TODO: generalise all these things to arbitrary @i@
ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh')
ixAppend ZIX idx' = idx'
-ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx'
-ixAppend (i :.? idx) idx' = i :.? ixAppend idx idx'
+ixAppend (i :.% idx) idx' = i :.% ixAppend idx idx'
shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh')
shAppend ZSX sh' = sh'
-shAppend (n :$@ sh) sh' = n :$@ shAppend sh sh'
-shAppend (n :$? sh) sh' = n :$? shAppend sh sh'
+shAppend (n :$% sh) sh' = n :$% shAppend sh sh'
ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh'
-ixDrop sh ZIX = sh
-ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx
-ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx
+ixDrop long ZIX = long
+ixDrop long (_ :.% short) = case long of _ :.% long' -> ixDrop long' short
shDropIx :: IShX (sh ++ sh') -> IIxX sh -> IShX sh'
shDropIx sh ZIX = sh
-shDropIx (_ :$@ sh) (_ :.@ idx) = shDropIx sh idx
-shDropIx (_ :$? sh) (_ :.? idx) = shDropIx sh idx
+shDropIx sh (_ :.% idx) = case sh of _ :$% sh' -> shDropIx sh' idx
+
+ssxDropIx :: StaticShX (sh ++ sh') -> IIxX sh -> StaticShX sh'
+ssxDropIx ssh ZIX = ssh
+ssxDropIx ssh (_ :.% idx) = case ssh of _ :!% ssh' -> ssxDropIx ssh' idx
shTail :: IShX (n : sh) -> IShX sh
-shTail (_ :$@ sh) = sh
-shTail (_ :$? sh) = sh
+shTail (_ :$% sh) = sh
+
+ssxTail :: StaticShX (n : sh) -> StaticShX sh
+ssxTail (_ :!% ssh) = ssh
ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
-ssxAppend ZKSX sh' = sh'
-ssxAppend (n :!$@ sh) sh' = n :!$@ ssxAppend sh sh'
-ssxAppend (() :!$? sh) sh' = () :!$? ssxAppend sh sh'
+ssxAppend ZKX sh' = sh'
+ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
shapeSize :: IShX sh -> Int
shapeSize ZSX = 1
-shapeSize (n :$@ sh) = fromSNat' n * shapeSize sh
-shapeSize (n :$? sh) = n * shapeSize sh
+shapeSize (n :$% sh) = fromSMayNat' n * shapeSize sh
-- | This may fail if @sh@ has @Nothing@s in it.
ssxToShape' :: StaticShX sh -> Maybe (IShX sh)
-ssxToShape' ZKSX = Just ZSX
-ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh
-ssxToShape' (_ :!$? _) = Nothing
+ssxToShape' ZKX = Just ZSX
+ssxToShape' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShape' sh
+ssxToShape' (SUnknown _ :!% _) = Nothing
lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
lemReplicateSucc = unsafeCoerce Refl
ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
-ssxReplicate SZ = ZKSX
+ssxReplicate SZ = ZKX
ssxReplicate (SS (n :: SNat n'))
| Refl <- lemReplicateSucc @(Nothing @Nat) @n'
- = () :!$? ssxReplicate n
+ = SUnknown () :!% ssxReplicate n
fromLinearIdx :: IShX sh -> Int -> IIxX sh
fromLinearIdx = \sh i -> case go sh i of
@@ -210,14 +281,10 @@ fromLinearIdx = \sh i -> case go sh i of
-- returns (index in subarray, remaining index in enclosing array)
go :: IShX sh -> Int -> (IIxX sh, Int)
go ZSX i = (ZIX, i)
- go (n :$@ sh) i =
+ go (n :$% sh) i =
let (idx, i') = go sh i
- (upi, locali) = i' `quotRem` fromSNat' n
- in (locali :.@ idx, upi)
- go (n :$? sh) i =
- let (idx, i') = go sh i
- (upi, locali) = i' `quotRem` n
- in (locali :.? idx, upi)
+ (upi, locali) = i' `quotRem` fromSMayNat' n
+ in (locali :.% idx, upi)
toLinearIdx :: IShX sh -> IIxX sh -> Int
toLinearIdx = \sh i -> fst (go sh i)
@@ -225,40 +292,32 @@ toLinearIdx = \sh i -> fst (go sh i)
-- returns (index in subarray, size of subarray)
go :: IShX sh -> IIxX sh -> (Int, Int)
go ZSX ZIX = (0, 1)
- go (n :$@ sh) (i :.@ ix) =
- let (lidx, sz) = go sh ix
- in (sz * i + lidx, fromSNat' n * sz)
- go (n :$? sh) (i :.? ix) =
+ go (n :$% sh) (i :.% ix) =
let (lidx, sz) = go sh ix
- in (sz * i + lidx, n * sz)
+ in (sz * i + lidx, fromSMayNat' n * sz)
enumShape :: IShX sh -> [IIxX sh]
enumShape = \sh -> go sh id []
where
go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a]
go ZSX f = (f ZIX :)
- go (n :$@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. fromSNat' n - 1]]
- go (n :$? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]]
+ go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]
shapeLshape :: IShX sh -> S.ShapeL
shapeLshape ZSX = []
-shapeLshape (n :$@ sh) = fromSNat' n : shapeLshape sh
-shapeLshape (n :$? sh) = n : shapeLshape sh
+shapeLshape (n :$% sh) = fromSMayNat' n : shapeLshape sh
ssxLength :: StaticShX sh -> Int
-ssxLength ZKSX = 0
-ssxLength (_ :!$@ ssh) = 1 + ssxLength ssh
-ssxLength (_ :!$? ssh) = 1 + ssxLength ssh
+ssxLength ZKX = 0
+ssxLength (_ :!% ssh) = 1 + ssxLength ssh
ssxIotaFrom :: Int -> StaticShX sh -> [Int]
-ssxIotaFrom _ ZKSX = []
-ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh
-ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh
+ssxIotaFrom _ ZKX = []
+ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh
staticShapeFrom :: IShX sh -> StaticShX sh
-staticShapeFrom ZSX = ZKSX
-staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh
-staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh
+staticShapeFrom ZSX = ZKX
+staticShapeFrom (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% staticShapeFrom sh
lemRankApp :: StaticShX sh1 -> StaticShX sh2
-> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
@@ -270,35 +329,18 @@ lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh)
lemKnownNatRank ZSX = Dict
-lemKnownNatRank (_ :$@ sh) | Dict <- lemKnownNatRank sh = Dict
-lemKnownNatRank (_ :$? sh) | Dict <- lemKnownNatRank sh = Dict
+lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict
lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
-lemKnownNatRankSSX ZKSX = Dict
-lemKnownNatRankSSX (_ :!$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
-lemKnownNatRankSSX (_ :!$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
-
--- lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh
--- lemKnownShapeX ZKSX = Dict
--- lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict
--- lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict
-
--- lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
--- lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh'
--- lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh'
--- | Dict <- lemAppKnownShapeX ssh ssh'
--- = Dict
--- lemAppKnownShapeX (() :!$? ssh) ssh'
--- | Dict <- lemAppKnownShapeX ssh ssh'
--- = Dict
+lemKnownNatRankSSX ZKX = Dict
+lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh
shape = \ssh (XArray arr) -> go ssh (S.shapeL arr)
where
go :: StaticShX sh' -> [Int] -> IShX sh'
- go ZKSX [] = ZSX
- go (n :!$@ ssh) (_ : l) = n :$@ go ssh l
- go (() :!$? ssh) (n : l) = n :$? go ssh l
+ go ZKX [] = ZSX
+ go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l
go _ _ = error "Invalid shapeL"
fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
@@ -330,8 +372,7 @@ generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh
indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a
indexPartial (XArray arr) ZIX = XArray arr
-indexPartial (XArray arr) (i :.@ idx) = indexPartial (XArray (S.index arr i)) idx
-indexPartial (XArray arr) (i :.? idx) = indexPartial (XArray (S.index arr i)) idx
+indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx
index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a
index xarr i
@@ -344,6 +385,15 @@ type family AddMaybe n m where
AddMaybe (Just _) Nothing = Nothing
AddMaybe (Just n) (Just m) = Just (n + m)
+-- This should be a function in base
+plusSNat :: SNat n -> SNat m -> SNat (n + m)
+plusSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n + TypeNats.fromSNat m) unsafeCoerce
+
+smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
+smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
+smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
+smnAddMaybe (SKnown n) (SKnown m) = SKnown (plusSNat n m)
+
append :: forall n m sh a. Storable a
=> StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
append ssh (XArray a) (XArray b)
@@ -442,42 +492,34 @@ lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl
lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
=> StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is
-lemRankDropLen ZKSX HNil = Refl
-lemRankDropLen (_ :!$@ sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl
-lemRankDropLen (_ :!$? sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl
-lemRankDropLen (_ :!$@ _) HNil = Refl
-lemRankDropLen (_ :!$? _) HNil = Refl
-lemRankDropLen ZKSX (_ `HCons` _) = error "1 <= 0"
+lemRankDropLen ZKX HNil = Refl
+lemRankDropLen (_ :!% sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!% _) HNil = Refl
+lemRankDropLen ZKX (_ `HCons` _) = error "1 <= 0"
lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l
lemIndexSucc _ _ _ = unsafeCoerce Refl
ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh)
-ssxTakeLen HNil _ = ZKSX
-ssxTakeLen (_ `HCons` is) (n :!$@ sh) = n :!$@ ssxTakeLen is sh
-ssxTakeLen (_ `HCons` is) (n :!$? sh) = n :!$? ssxTakeLen is sh
-ssxTakeLen (_ `HCons` _) ZKSX = error "Permutation longer than shape"
+ssxTakeLen HNil _ = ZKX
+ssxTakeLen (_ `HCons` is) (n :!% sh) = n :!% ssxTakeLen is sh
+ssxTakeLen (_ `HCons` _) ZKX = error "Permutation longer than shape"
ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh)
ssxDropLen HNil sh = sh
-ssxDropLen (_ `HCons` is) (_ :!$@ sh) = ssxDropLen is sh
-ssxDropLen (_ `HCons` is) (_ :!$? sh) = ssxDropLen is sh
-ssxDropLen (_ `HCons` _) ZKSX = error "Permutation longer than shape"
+ssxDropLen (_ `HCons` is) (_ :!% sh) = ssxDropLen is sh
+ssxDropLen (_ `HCons` _) ZKX = error "Permutation longer than shape"
ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh)
-ssxPermute HNil _ = ZKSX
+ssxPermute HNil _ = ZKX
ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh)
ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT)
-ssxIndex _ _ SZ (n :!$@ _) rest = n :!$@ rest
-ssxIndex _ _ SZ (n :!$? _) rest = n :!$? rest
-ssxIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :!$@ (sh :: StaticShX sh')) rest
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @sh')
- = ssxIndex p pT i sh rest
-ssxIndex p pT (SS (i :: SNat i')) (() :!$? (sh :: StaticShX sh')) rest
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @Nothing) (Proxy @sh')
+ssxIndex _ _ SZ (n :!% _) rest = n :!% rest
+ssxIndex p pT (SS (i :: SNat i')) ((_ :: SMayNat () SNat n) :!% (sh :: StaticShX sh')) rest
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
= ssxIndex p pT i sh rest
-ssxIndex _ _ _ ZKSX _ = error "Index into empty shape"
+ssxIndex _ _ _ ZKX _ = error "Index into empty shape"
-- | The list argument gives indices into the original dimension list.
transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh)
@@ -526,7 +568,7 @@ sumInner :: forall sh sh' a. (Storable a, Num a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a
sumInner ssh ssh'
| Refl <- lemAppNil @sh
- = rerank ssh ssh' ZKSX (scalar . sumFull)
+ = rerank ssh ssh' ZKX (scalar . sumFull)
sumOuter :: forall sh sh' a. (Storable a, Num a)
=> StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a
@@ -539,7 +581,7 @@ fromList1 :: forall n sh a. Storable a
fromList1 ssh l
| Dict <- lemKnownNatRankSSX ssh
= case ssh of
- m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) ->
+ SKnown m@GHC_SNat :!% _ | natVal m /= fromIntegral (length l) ->
error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++
"does not match the type (" ++ show (natVal m) ++ ")"
_ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
@@ -575,4 +617,4 @@ reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX
reshapePartial ssh1 ssh' sh2 (XArray arr)
| Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh')
, Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh')
- = XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr)
+ = XArray (S.reshape (shapeLshape sh2 ++ drop (lengthShX sh2) (S.shapeL arr)) arr)