From 4adbbd8e2e635cc4c647be40f0dd258668dd2452 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Fri, 17 May 2024 22:53:52 +0200 Subject: More WIP singletonisation --- src/Data/Array/Mixed.hs | 306 +++++++++++++++++++++++++++--------------------- 1 file changed, 174 insertions(+), 132 deletions(-) (limited to 'src/Data/Array/Mixed.hs') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 69c44ab..df506d6 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -8,6 +8,7 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -24,6 +25,7 @@ module Data.Array.Mixed where import qualified Data.Array.RankedS as S import qualified Data.Array.Ranked as ORB import Data.Coerce +import Data.Functor.Const import Data.Kind import Data.Proxy import Data.Type.Bool @@ -32,6 +34,7 @@ import qualified Data.Vector.Storable as VS import Foreign.Storable (Storable) import GHC.TypeError import GHC.TypeLits +import qualified GHC.TypeNats as TypeNats import Unsafe.Coerce (unsafeCoerce) @@ -87,45 +90,115 @@ type family Replicate n a where Replicate n a = a : Replicate (n - 1) a +type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type +data ListX sh f where + ZX :: ListX '[] f + (::%) :: f n -> ListX sh f -> ListX (n : sh) f +deriving instance (forall n. Show (f n)) => Show (ListX sh f) +deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) +infixr 3 ::% + +data UnconsListXRes f sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) +unconsListX :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) +unconsListX (i ::% shl') = Just (UnconsListXRes shl' i) +unconsListX ZX = Nothing + +fmapListX :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g +fmapListX _ ZX = ZX +fmapListX f (x ::% xs) = f x ::% fmapListX f xs + +foldListX :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m +foldListX _ ZX = mempty +foldListX f (x ::% xs) = f x <> foldListX f xs + + type IxX :: [Maybe Nat] -> Type -> Type -data IxX sh i where - ZIX :: IxX '[] i - (:.@) :: forall n sh i. i -> IxX sh i -> IxX (Just n : sh) i - (:.?) :: forall sh i. i -> IxX sh i -> IxX (Nothing : sh) i -deriving instance Show i => Show (IxX sh i) -deriving instance Eq i => Eq (IxX sh i) -deriving instance Ord i => Ord (IxX sh i) -deriving instance Functor (IxX sh) -deriving instance Foldable (IxX sh) -infixr 3 :.@ -infixr 3 :.? +newtype IxX sh i = IxX (ListX sh (Const i)) + deriving (Show, Eq, Ord) + +pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i +pattern ZIX = IxX ZX + +pattern (:.%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> IxX sh i -> IxX sh1 i +pattern i :.% shl <- IxX (unconsListX -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) + where i :.% IxX shl = IxX (Const i ::% shl) +infixr 3 :.% + +{-# COMPLETE ZIX, (:.%) #-} type IIxX sh = IxX sh Int +instance Functor (IxX sh) where + fmap f (IxX l) = IxX (fmapListX (Const . f . getConst) l) + +instance Foldable (IxX sh) where + foldMap f (IxX l) = foldListX (f . getConst) l + + +data SMayNat i f n where + SUnknown :: i -> SMayNat i f Nothing + SKnown :: f n -> SMayNat i f (Just n) +deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) +deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) +deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) + +fromSMayNat :: (n ~ Nothing => i -> r) -> (forall m. n ~ Just m => f m -> r) -> SMayNat i f n -> r +fromSMayNat f _ (SUnknown i) = f i +fromSMayNat _ g (SKnown s) = g s + +fromSMayNat' :: SMayNat Int SNat n -> Int +fromSMayNat' = fromSMayNat id fromSNat' + type ShX :: [Maybe Nat] -> Type -> Type -data ShX sh i where - ZSX :: ShX '[] i - (:$@) :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i - (:$?) :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i -deriving instance Show i => Show (ShX sh i) -deriving instance Eq i => Eq (ShX sh i) -deriving instance Ord i => Ord (ShX sh i) -deriving instance Functor (ShX sh) -deriving instance Foldable (ShX sh) -infixr 3 :$@ -infixr 3 :$? +newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) + deriving (Show, Eq, Ord) + +pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i +pattern ZSX = ShX ZX + +pattern (:$%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => SMayNat i SNat n -> ShX sh i -> ShX sh1 i +pattern i :$% shl <- ShX (unconsListX -> Just (UnconsListXRes (ShX -> shl) i)) + where i :$% ShX shl = ShX (i ::% shl) +infixr 3 :$% + +{-# COMPLETE ZSX, (:$%) #-} type IShX sh = ShX sh Int +instance Functor (ShX sh) where + fmap f (ShX l) = ShX (fmapListX (fromSMayNat (SUnknown . f) SKnown) l) + +lengthShX :: ShX sh i -> Int +lengthShX ZSX = 0 +lengthShX (_ :$% sh) = 1 + lengthShX sh + + -- | The part of a shape that is statically known. type StaticShX :: [Maybe Nat] -> Type -data StaticShX sh where - ZKSX :: StaticShX '[] - (:!$@) :: SNat n -> StaticShX sh -> StaticShX (Just n : sh) - (:!$?) :: () -> StaticShX sh -> StaticShX (Nothing : sh) -deriving instance Show (StaticShX sh) -infixr 3 :!$@ -infixr 3 :!$? +newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) + deriving (Show, Eq, Ord) + +pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh +pattern ZKX = StaticShX ZX + +pattern (:!%) + :: forall {sh1}. + forall n sh. (n : sh ~ sh1) + => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 +pattern i :!% shl <- StaticShX (unconsListX -> Just (UnconsListXRes (StaticShX -> shl) i)) + where i :!% StaticShX shl = StaticShX (i ::% shl) +infixr 3 :!% + +{-# COMPLETE ZKX, (:!%) #-} + type family Rank sh where Rank '[] = 0 @@ -136,70 +209,68 @@ newtype XArray sh a = XArray (S.Array (Rank sh) a) deriving (Show) zeroIxX :: StaticShX sh -> IIxX sh -zeroIxX ZKSX = ZIX -zeroIxX (_ :!$@ ssh) = 0 :.@ zeroIxX ssh -zeroIxX (_ :!$? ssh) = 0 :.? zeroIxX ssh +zeroIxX ZKX = ZIX +zeroIxX (_ :!% ssh) = 0 :.% zeroIxX ssh zeroIxX' :: IShX sh -> IIxX sh zeroIxX' ZSX = ZIX -zeroIxX' (_ :$@ sh) = 0 :.@ zeroIxX' sh -zeroIxX' (_ :$? sh) = 0 :.? zeroIxX' sh +zeroIxX' (_ :$% sh) = 0 :.% zeroIxX' sh -- This is a weird operation, so it has a long name completeShXzeros :: StaticShX sh -> IShX sh -completeShXzeros ZKSX = ZSX -completeShXzeros (n :!$@ ssh) = n :$@ completeShXzeros ssh -completeShXzeros (_ :!$? ssh) = 0 :$? completeShXzeros ssh +completeShXzeros ZKX = ZSX +completeShXzeros (SUnknown () :!% ssh) = SUnknown 0 :$% completeShXzeros ssh +completeShXzeros (SKnown n :!% ssh) = SKnown n :$% completeShXzeros ssh -- TODO: generalise all these things to arbitrary @i@ ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh') ixAppend ZIX idx' = idx' -ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx' -ixAppend (i :.? idx) idx' = i :.? ixAppend idx idx' +ixAppend (i :.% idx) idx' = i :.% ixAppend idx idx' shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh') shAppend ZSX sh' = sh' -shAppend (n :$@ sh) sh' = n :$@ shAppend sh sh' -shAppend (n :$? sh) sh' = n :$? shAppend sh sh' +shAppend (n :$% sh) sh' = n :$% shAppend sh sh' ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh' -ixDrop sh ZIX = sh -ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx -ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx +ixDrop long ZIX = long +ixDrop long (_ :.% short) = case long of _ :.% long' -> ixDrop long' short shDropIx :: IShX (sh ++ sh') -> IIxX sh -> IShX sh' shDropIx sh ZIX = sh -shDropIx (_ :$@ sh) (_ :.@ idx) = shDropIx sh idx -shDropIx (_ :$? sh) (_ :.? idx) = shDropIx sh idx +shDropIx sh (_ :.% idx) = case sh of _ :$% sh' -> shDropIx sh' idx + +ssxDropIx :: StaticShX (sh ++ sh') -> IIxX sh -> StaticShX sh' +ssxDropIx ssh ZIX = ssh +ssxDropIx ssh (_ :.% idx) = case ssh of _ :!% ssh' -> ssxDropIx ssh' idx shTail :: IShX (n : sh) -> IShX sh -shTail (_ :$@ sh) = sh -shTail (_ :$? sh) = sh +shTail (_ :$% sh) = sh + +ssxTail :: StaticShX (n : sh) -> StaticShX sh +ssxTail (_ :!% ssh) = ssh ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') -ssxAppend ZKSX sh' = sh' -ssxAppend (n :!$@ sh) sh' = n :!$@ ssxAppend sh sh' -ssxAppend (() :!$? sh) sh' = () :!$? ssxAppend sh sh' +ssxAppend ZKX sh' = sh' +ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' shapeSize :: IShX sh -> Int shapeSize ZSX = 1 -shapeSize (n :$@ sh) = fromSNat' n * shapeSize sh -shapeSize (n :$? sh) = n * shapeSize sh +shapeSize (n :$% sh) = fromSMayNat' n * shapeSize sh -- | This may fail if @sh@ has @Nothing@s in it. ssxToShape' :: StaticShX sh -> Maybe (IShX sh) -ssxToShape' ZKSX = Just ZSX -ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh -ssxToShape' (_ :!$? _) = Nothing +ssxToShape' ZKX = Just ZSX +ssxToShape' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShape' sh +ssxToShape' (SUnknown _ :!% _) = Nothing lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a lemReplicateSucc = unsafeCoerce Refl ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) -ssxReplicate SZ = ZKSX +ssxReplicate SZ = ZKX ssxReplicate (SS (n :: SNat n')) | Refl <- lemReplicateSucc @(Nothing @Nat) @n' - = () :!$? ssxReplicate n + = SUnknown () :!% ssxReplicate n fromLinearIdx :: IShX sh -> Int -> IIxX sh fromLinearIdx = \sh i -> case go sh i of @@ -210,14 +281,10 @@ fromLinearIdx = \sh i -> case go sh i of -- returns (index in subarray, remaining index in enclosing array) go :: IShX sh -> Int -> (IIxX sh, Int) go ZSX i = (ZIX, i) - go (n :$@ sh) i = + go (n :$% sh) i = let (idx, i') = go sh i - (upi, locali) = i' `quotRem` fromSNat' n - in (locali :.@ idx, upi) - go (n :$? sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` n - in (locali :.? idx, upi) + (upi, locali) = i' `quotRem` fromSMayNat' n + in (locali :.% idx, upi) toLinearIdx :: IShX sh -> IIxX sh -> Int toLinearIdx = \sh i -> fst (go sh i) @@ -225,40 +292,32 @@ toLinearIdx = \sh i -> fst (go sh i) -- returns (index in subarray, size of subarray) go :: IShX sh -> IIxX sh -> (Int, Int) go ZSX ZIX = (0, 1) - go (n :$@ sh) (i :.@ ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, fromSNat' n * sz) - go (n :$? sh) (i :.? ix) = + go (n :$% sh) (i :.% ix) = let (lidx, sz) = go sh ix - in (sz * i + lidx, n * sz) + in (sz * i + lidx, fromSMayNat' n * sz) enumShape :: IShX sh -> [IIxX sh] enumShape = \sh -> go sh id [] where go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] go ZSX f = (f ZIX :) - go (n :$@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. fromSNat' n - 1]] - go (n :$? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]] + go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] shapeLshape :: IShX sh -> S.ShapeL shapeLshape ZSX = [] -shapeLshape (n :$@ sh) = fromSNat' n : shapeLshape sh -shapeLshape (n :$? sh) = n : shapeLshape sh +shapeLshape (n :$% sh) = fromSMayNat' n : shapeLshape sh ssxLength :: StaticShX sh -> Int -ssxLength ZKSX = 0 -ssxLength (_ :!$@ ssh) = 1 + ssxLength ssh -ssxLength (_ :!$? ssh) = 1 + ssxLength ssh +ssxLength ZKX = 0 +ssxLength (_ :!% ssh) = 1 + ssxLength ssh ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKSX = [] -ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh -ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom _ ZKX = [] +ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh staticShapeFrom :: IShX sh -> StaticShX sh -staticShapeFrom ZSX = ZKSX -staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh -staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh +staticShapeFrom ZSX = ZKX +staticShapeFrom (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% staticShapeFrom sh lemRankApp :: StaticShX sh1 -> StaticShX sh2 -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 @@ -270,35 +329,18 @@ lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) lemKnownNatRank ZSX = Dict -lemKnownNatRank (_ :$@ sh) | Dict <- lemKnownNatRank sh = Dict -lemKnownNatRank (_ :$? sh) | Dict <- lemKnownNatRank sh = Dict +lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) -lemKnownNatRankSSX ZKSX = Dict -lemKnownNatRankSSX (_ :!$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict -lemKnownNatRankSSX (_ :!$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict - --- lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh --- lemKnownShapeX ZKSX = Dict --- lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict --- lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict - --- lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2) --- lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh' --- lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh' --- | Dict <- lemAppKnownShapeX ssh ssh' --- = Dict --- lemAppKnownShapeX (() :!$? ssh) ssh' --- | Dict <- lemAppKnownShapeX ssh ssh' --- = Dict +lemKnownNatRankSSX ZKX = Dict +lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) where go :: StaticShX sh' -> [Int] -> IShX sh' - go ZKSX [] = ZSX - go (n :!$@ ssh) (_ : l) = n :$@ go ssh l - go (() :!$? ssh) (n : l) = n :$? go ssh l + go ZKX [] = ZSX + go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l go _ _ = error "Invalid shapeL" fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a @@ -330,8 +372,7 @@ generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a indexPartial (XArray arr) ZIX = XArray arr -indexPartial (XArray arr) (i :.@ idx) = indexPartial (XArray (S.index arr i)) idx -indexPartial (XArray arr) (i :.? idx) = indexPartial (XArray (S.index arr i)) idx +indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a index xarr i @@ -344,6 +385,15 @@ type family AddMaybe n m where AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) +-- This should be a function in base +plusSNat :: SNat n -> SNat m -> SNat (n + m) +plusSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n + TypeNats.fromSNat m) unsafeCoerce + +smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) +smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) +smnAddMaybe (SKnown n) (SKnown m) = SKnown (plusSNat n m) + append :: forall n m sh a. Storable a => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a append ssh (XArray a) (XArray b) @@ -442,42 +492,34 @@ lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl lemRankDropLen :: forall is sh. (Rank is <= Rank sh) => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is -lemRankDropLen ZKSX HNil = Refl -lemRankDropLen (_ :!$@ sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl -lemRankDropLen (_ :!$? sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl -lemRankDropLen (_ :!$@ _) HNil = Refl -lemRankDropLen (_ :!$? _) HNil = Refl -lemRankDropLen ZKSX (_ `HCons` _) = error "1 <= 0" +lemRankDropLen ZKX HNil = Refl +lemRankDropLen (_ :!% sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!% _) HNil = Refl +lemRankDropLen ZKX (_ `HCons` _) = error "1 <= 0" lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l lemIndexSucc _ _ _ = unsafeCoerce Refl ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen HNil _ = ZKSX -ssxTakeLen (_ `HCons` is) (n :!$@ sh) = n :!$@ ssxTakeLen is sh -ssxTakeLen (_ `HCons` is) (n :!$? sh) = n :!$? ssxTakeLen is sh -ssxTakeLen (_ `HCons` _) ZKSX = error "Permutation longer than shape" +ssxTakeLen HNil _ = ZKX +ssxTakeLen (_ `HCons` is) (n :!% sh) = n :!% ssxTakeLen is sh +ssxTakeLen (_ `HCons` _) ZKX = error "Permutation longer than shape" ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh) ssxDropLen HNil sh = sh -ssxDropLen (_ `HCons` is) (_ :!$@ sh) = ssxDropLen is sh -ssxDropLen (_ `HCons` is) (_ :!$? sh) = ssxDropLen is sh -ssxDropLen (_ `HCons` _) ZKSX = error "Permutation longer than shape" +ssxDropLen (_ `HCons` is) (_ :!% sh) = ssxDropLen is sh +ssxDropLen (_ `HCons` _) ZKX = error "Permutation longer than shape" ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute HNil _ = ZKSX +ssxPermute HNil _ = ZKX ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh) ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) -ssxIndex _ _ SZ (n :!$@ _) rest = n :!$@ rest -ssxIndex _ _ SZ (n :!$? _) rest = n :!$? rest -ssxIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :!$@ (sh :: StaticShX sh')) rest - | Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @sh') - = ssxIndex p pT i sh rest -ssxIndex p pT (SS (i :: SNat i')) (() :!$? (sh :: StaticShX sh')) rest - | Refl <- lemIndexSucc (Proxy @i') (Proxy @Nothing) (Proxy @sh') +ssxIndex _ _ SZ (n :!% _) rest = n :!% rest +ssxIndex p pT (SS (i :: SNat i')) ((_ :: SMayNat () SNat n) :!% (sh :: StaticShX sh')) rest + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = ssxIndex p pT i sh rest -ssxIndex _ _ _ ZKSX _ = error "Index into empty shape" +ssxIndex _ _ _ ZKX _ = error "Index into empty shape" -- | The list argument gives indices into the original dimension list. transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh) @@ -526,7 +568,7 @@ sumInner :: forall sh sh' a. (Storable a, Num a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' | Refl <- lemAppNil @sh - = rerank ssh ssh' ZKSX (scalar . sumFull) + = rerank ssh ssh' ZKX (scalar . sumFull) sumOuter :: forall sh sh' a. (Storable a, Num a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a @@ -539,7 +581,7 @@ fromList1 :: forall n sh a. Storable a fromList1 ssh l | Dict <- lemKnownNatRankSSX ssh = case ssh of - m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) -> + SKnown m@GHC_SNat :!% _ | natVal m /= fromIntegral (length l) -> error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ "does not match the type (" ++ show (natVal m) ++ ")" _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) @@ -575,4 +617,4 @@ reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX reshapePartial ssh1 ssh' sh2 (XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') - = XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr) + = XArray (S.reshape (shapeLshape sh2 ++ drop (lengthShX sh2) (S.shapeL arr)) arr) -- cgit v1.2.3-70-g09d2