diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-17 22:53:52 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-17 22:53:52 +0200 |
commit | 4adbbd8e2e635cc4c647be40f0dd258668dd2452 (patch) | |
tree | 1f89ce0adc26ed98e80e759f2bf403b107d667e1 /src | |
parent | 06625c89089044b064bbc6cf36ea4e83199c19a4 (diff) |
More WIP singletonisation
Diffstat (limited to 'src')
-rw-r--r-- | src/Data/Array/Mixed.hs | 306 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 268 |
2 files changed, 309 insertions, 265 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 69c44ab..df506d6 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -8,6 +8,7 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -24,6 +25,7 @@ module Data.Array.Mixed where import qualified Data.Array.RankedS as S import qualified Data.Array.Ranked as ORB import Data.Coerce +import Data.Functor.Const import Data.Kind import Data.Proxy import Data.Type.Bool @@ -32,6 +34,7 @@ import qualified Data.Vector.Storable as VS import Foreign.Storable (Storable) import GHC.TypeError import GHC.TypeLits +import qualified GHC.TypeNats as TypeNats import Unsafe.Coerce (unsafeCoerce) @@ -87,45 +90,115 @@ type family Replicate n a where Replicate n a = a : Replicate (n - 1) a +type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type +data ListX sh f where + ZX :: ListX '[] f + (::%) :: f n -> ListX sh f -> ListX (n : sh) f +deriving instance (forall n. Show (f n)) => Show (ListX sh f) +deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) +infixr 3 ::% + +data UnconsListXRes f sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) +unconsListX :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) +unconsListX (i ::% shl') = Just (UnconsListXRes shl' i) +unconsListX ZX = Nothing + +fmapListX :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g +fmapListX _ ZX = ZX +fmapListX f (x ::% xs) = f x ::% fmapListX f xs + +foldListX :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m +foldListX _ ZX = mempty +foldListX f (x ::% xs) = f x <> foldListX f xs + + type IxX :: [Maybe Nat] -> Type -> Type -data IxX sh i where - ZIX :: IxX '[] i - (:.@) :: forall n sh i. i -> IxX sh i -> IxX (Just n : sh) i - (:.?) :: forall sh i. i -> IxX sh i -> IxX (Nothing : sh) i -deriving instance Show i => Show (IxX sh i) -deriving instance Eq i => Eq (IxX sh i) -deriving instance Ord i => Ord (IxX sh i) -deriving instance Functor (IxX sh) -deriving instance Foldable (IxX sh) -infixr 3 :.@ -infixr 3 :.? +newtype IxX sh i = IxX (ListX sh (Const i)) + deriving (Show, Eq, Ord) + +pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i +pattern ZIX = IxX ZX + +pattern (:.%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> IxX sh i -> IxX sh1 i +pattern i :.% shl <- IxX (unconsListX -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) + where i :.% IxX shl = IxX (Const i ::% shl) +infixr 3 :.% + +{-# COMPLETE ZIX, (:.%) #-} type IIxX sh = IxX sh Int +instance Functor (IxX sh) where + fmap f (IxX l) = IxX (fmapListX (Const . f . getConst) l) + +instance Foldable (IxX sh) where + foldMap f (IxX l) = foldListX (f . getConst) l + + +data SMayNat i f n where + SUnknown :: i -> SMayNat i f Nothing + SKnown :: f n -> SMayNat i f (Just n) +deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) +deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) +deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) + +fromSMayNat :: (n ~ Nothing => i -> r) -> (forall m. n ~ Just m => f m -> r) -> SMayNat i f n -> r +fromSMayNat f _ (SUnknown i) = f i +fromSMayNat _ g (SKnown s) = g s + +fromSMayNat' :: SMayNat Int SNat n -> Int +fromSMayNat' = fromSMayNat id fromSNat' + type ShX :: [Maybe Nat] -> Type -> Type -data ShX sh i where - ZSX :: ShX '[] i - (:$@) :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i - (:$?) :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i -deriving instance Show i => Show (ShX sh i) -deriving instance Eq i => Eq (ShX sh i) -deriving instance Ord i => Ord (ShX sh i) -deriving instance Functor (ShX sh) -deriving instance Foldable (ShX sh) -infixr 3 :$@ -infixr 3 :$? +newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) + deriving (Show, Eq, Ord) + +pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i +pattern ZSX = ShX ZX + +pattern (:$%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => SMayNat i SNat n -> ShX sh i -> ShX sh1 i +pattern i :$% shl <- ShX (unconsListX -> Just (UnconsListXRes (ShX -> shl) i)) + where i :$% ShX shl = ShX (i ::% shl) +infixr 3 :$% + +{-# COMPLETE ZSX, (:$%) #-} type IShX sh = ShX sh Int +instance Functor (ShX sh) where + fmap f (ShX l) = ShX (fmapListX (fromSMayNat (SUnknown . f) SKnown) l) + +lengthShX :: ShX sh i -> Int +lengthShX ZSX = 0 +lengthShX (_ :$% sh) = 1 + lengthShX sh + + -- | The part of a shape that is statically known. type StaticShX :: [Maybe Nat] -> Type -data StaticShX sh where - ZKSX :: StaticShX '[] - (:!$@) :: SNat n -> StaticShX sh -> StaticShX (Just n : sh) - (:!$?) :: () -> StaticShX sh -> StaticShX (Nothing : sh) -deriving instance Show (StaticShX sh) -infixr 3 :!$@ -infixr 3 :!$? +newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) + deriving (Show, Eq, Ord) + +pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh +pattern ZKX = StaticShX ZX + +pattern (:!%) + :: forall {sh1}. + forall n sh. (n : sh ~ sh1) + => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 +pattern i :!% shl <- StaticShX (unconsListX -> Just (UnconsListXRes (StaticShX -> shl) i)) + where i :!% StaticShX shl = StaticShX (i ::% shl) +infixr 3 :!% + +{-# COMPLETE ZKX, (:!%) #-} + type family Rank sh where Rank '[] = 0 @@ -136,70 +209,68 @@ newtype XArray sh a = XArray (S.Array (Rank sh) a) deriving (Show) zeroIxX :: StaticShX sh -> IIxX sh -zeroIxX ZKSX = ZIX -zeroIxX (_ :!$@ ssh) = 0 :.@ zeroIxX ssh -zeroIxX (_ :!$? ssh) = 0 :.? zeroIxX ssh +zeroIxX ZKX = ZIX +zeroIxX (_ :!% ssh) = 0 :.% zeroIxX ssh zeroIxX' :: IShX sh -> IIxX sh zeroIxX' ZSX = ZIX -zeroIxX' (_ :$@ sh) = 0 :.@ zeroIxX' sh -zeroIxX' (_ :$? sh) = 0 :.? zeroIxX' sh +zeroIxX' (_ :$% sh) = 0 :.% zeroIxX' sh -- This is a weird operation, so it has a long name completeShXzeros :: StaticShX sh -> IShX sh -completeShXzeros ZKSX = ZSX -completeShXzeros (n :!$@ ssh) = n :$@ completeShXzeros ssh -completeShXzeros (_ :!$? ssh) = 0 :$? completeShXzeros ssh +completeShXzeros ZKX = ZSX +completeShXzeros (SUnknown () :!% ssh) = SUnknown 0 :$% completeShXzeros ssh +completeShXzeros (SKnown n :!% ssh) = SKnown n :$% completeShXzeros ssh -- TODO: generalise all these things to arbitrary @i@ ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh') ixAppend ZIX idx' = idx' -ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx' -ixAppend (i :.? idx) idx' = i :.? ixAppend idx idx' +ixAppend (i :.% idx) idx' = i :.% ixAppend idx idx' shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh') shAppend ZSX sh' = sh' -shAppend (n :$@ sh) sh' = n :$@ shAppend sh sh' -shAppend (n :$? sh) sh' = n :$? shAppend sh sh' +shAppend (n :$% sh) sh' = n :$% shAppend sh sh' ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh' -ixDrop sh ZIX = sh -ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx -ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx +ixDrop long ZIX = long +ixDrop long (_ :.% short) = case long of _ :.% long' -> ixDrop long' short shDropIx :: IShX (sh ++ sh') -> IIxX sh -> IShX sh' shDropIx sh ZIX = sh -shDropIx (_ :$@ sh) (_ :.@ idx) = shDropIx sh idx -shDropIx (_ :$? sh) (_ :.? idx) = shDropIx sh idx +shDropIx sh (_ :.% idx) = case sh of _ :$% sh' -> shDropIx sh' idx + +ssxDropIx :: StaticShX (sh ++ sh') -> IIxX sh -> StaticShX sh' +ssxDropIx ssh ZIX = ssh +ssxDropIx ssh (_ :.% idx) = case ssh of _ :!% ssh' -> ssxDropIx ssh' idx shTail :: IShX (n : sh) -> IShX sh -shTail (_ :$@ sh) = sh -shTail (_ :$? sh) = sh +shTail (_ :$% sh) = sh + +ssxTail :: StaticShX (n : sh) -> StaticShX sh +ssxTail (_ :!% ssh) = ssh ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') -ssxAppend ZKSX sh' = sh' -ssxAppend (n :!$@ sh) sh' = n :!$@ ssxAppend sh sh' -ssxAppend (() :!$? sh) sh' = () :!$? ssxAppend sh sh' +ssxAppend ZKX sh' = sh' +ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' shapeSize :: IShX sh -> Int shapeSize ZSX = 1 -shapeSize (n :$@ sh) = fromSNat' n * shapeSize sh -shapeSize (n :$? sh) = n * shapeSize sh +shapeSize (n :$% sh) = fromSMayNat' n * shapeSize sh -- | This may fail if @sh@ has @Nothing@s in it. ssxToShape' :: StaticShX sh -> Maybe (IShX sh) -ssxToShape' ZKSX = Just ZSX -ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh -ssxToShape' (_ :!$? _) = Nothing +ssxToShape' ZKX = Just ZSX +ssxToShape' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShape' sh +ssxToShape' (SUnknown _ :!% _) = Nothing lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a lemReplicateSucc = unsafeCoerce Refl ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) -ssxReplicate SZ = ZKSX +ssxReplicate SZ = ZKX ssxReplicate (SS (n :: SNat n')) | Refl <- lemReplicateSucc @(Nothing @Nat) @n' - = () :!$? ssxReplicate n + = SUnknown () :!% ssxReplicate n fromLinearIdx :: IShX sh -> Int -> IIxX sh fromLinearIdx = \sh i -> case go sh i of @@ -210,14 +281,10 @@ fromLinearIdx = \sh i -> case go sh i of -- returns (index in subarray, remaining index in enclosing array) go :: IShX sh -> Int -> (IIxX sh, Int) go ZSX i = (ZIX, i) - go (n :$@ sh) i = + go (n :$% sh) i = let (idx, i') = go sh i - (upi, locali) = i' `quotRem` fromSNat' n - in (locali :.@ idx, upi) - go (n :$? sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` n - in (locali :.? idx, upi) + (upi, locali) = i' `quotRem` fromSMayNat' n + in (locali :.% idx, upi) toLinearIdx :: IShX sh -> IIxX sh -> Int toLinearIdx = \sh i -> fst (go sh i) @@ -225,40 +292,32 @@ toLinearIdx = \sh i -> fst (go sh i) -- returns (index in subarray, size of subarray) go :: IShX sh -> IIxX sh -> (Int, Int) go ZSX ZIX = (0, 1) - go (n :$@ sh) (i :.@ ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, fromSNat' n * sz) - go (n :$? sh) (i :.? ix) = + go (n :$% sh) (i :.% ix) = let (lidx, sz) = go sh ix - in (sz * i + lidx, n * sz) + in (sz * i + lidx, fromSMayNat' n * sz) enumShape :: IShX sh -> [IIxX sh] enumShape = \sh -> go sh id [] where go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] go ZSX f = (f ZIX :) - go (n :$@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. fromSNat' n - 1]] - go (n :$? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]] + go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] shapeLshape :: IShX sh -> S.ShapeL shapeLshape ZSX = [] -shapeLshape (n :$@ sh) = fromSNat' n : shapeLshape sh -shapeLshape (n :$? sh) = n : shapeLshape sh +shapeLshape (n :$% sh) = fromSMayNat' n : shapeLshape sh ssxLength :: StaticShX sh -> Int -ssxLength ZKSX = 0 -ssxLength (_ :!$@ ssh) = 1 + ssxLength ssh -ssxLength (_ :!$? ssh) = 1 + ssxLength ssh +ssxLength ZKX = 0 +ssxLength (_ :!% ssh) = 1 + ssxLength ssh ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKSX = [] -ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh -ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom _ ZKX = [] +ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh staticShapeFrom :: IShX sh -> StaticShX sh -staticShapeFrom ZSX = ZKSX -staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh -staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh +staticShapeFrom ZSX = ZKX +staticShapeFrom (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% staticShapeFrom sh lemRankApp :: StaticShX sh1 -> StaticShX sh2 -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 @@ -270,35 +329,18 @@ lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) lemKnownNatRank ZSX = Dict -lemKnownNatRank (_ :$@ sh) | Dict <- lemKnownNatRank sh = Dict -lemKnownNatRank (_ :$? sh) | Dict <- lemKnownNatRank sh = Dict +lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) -lemKnownNatRankSSX ZKSX = Dict -lemKnownNatRankSSX (_ :!$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict -lemKnownNatRankSSX (_ :!$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict - --- lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh --- lemKnownShapeX ZKSX = Dict --- lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict --- lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict - --- lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2) --- lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh' --- lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh' --- | Dict <- lemAppKnownShapeX ssh ssh' --- = Dict --- lemAppKnownShapeX (() :!$? ssh) ssh' --- | Dict <- lemAppKnownShapeX ssh ssh' --- = Dict +lemKnownNatRankSSX ZKX = Dict +lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) where go :: StaticShX sh' -> [Int] -> IShX sh' - go ZKSX [] = ZSX - go (n :!$@ ssh) (_ : l) = n :$@ go ssh l - go (() :!$? ssh) (n : l) = n :$? go ssh l + go ZKX [] = ZSX + go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l go _ _ = error "Invalid shapeL" fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a @@ -330,8 +372,7 @@ generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a indexPartial (XArray arr) ZIX = XArray arr -indexPartial (XArray arr) (i :.@ idx) = indexPartial (XArray (S.index arr i)) idx -indexPartial (XArray arr) (i :.? idx) = indexPartial (XArray (S.index arr i)) idx +indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a index xarr i @@ -344,6 +385,15 @@ type family AddMaybe n m where AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) +-- This should be a function in base +plusSNat :: SNat n -> SNat m -> SNat (n + m) +plusSNat n m = TypeNats.withSomeSNat (TypeNats.fromSNat n + TypeNats.fromSNat m) unsafeCoerce + +smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) +smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) +smnAddMaybe (SKnown n) (SKnown m) = SKnown (plusSNat n m) + append :: forall n m sh a. Storable a => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a append ssh (XArray a) (XArray b) @@ -442,42 +492,34 @@ lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl lemRankDropLen :: forall is sh. (Rank is <= Rank sh) => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is -lemRankDropLen ZKSX HNil = Refl -lemRankDropLen (_ :!$@ sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl -lemRankDropLen (_ :!$? sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl -lemRankDropLen (_ :!$@ _) HNil = Refl -lemRankDropLen (_ :!$? _) HNil = Refl -lemRankDropLen ZKSX (_ `HCons` _) = error "1 <= 0" +lemRankDropLen ZKX HNil = Refl +lemRankDropLen (_ :!% sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!% _) HNil = Refl +lemRankDropLen ZKX (_ `HCons` _) = error "1 <= 0" lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l lemIndexSucc _ _ _ = unsafeCoerce Refl ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen HNil _ = ZKSX -ssxTakeLen (_ `HCons` is) (n :!$@ sh) = n :!$@ ssxTakeLen is sh -ssxTakeLen (_ `HCons` is) (n :!$? sh) = n :!$? ssxTakeLen is sh -ssxTakeLen (_ `HCons` _) ZKSX = error "Permutation longer than shape" +ssxTakeLen HNil _ = ZKX +ssxTakeLen (_ `HCons` is) (n :!% sh) = n :!% ssxTakeLen is sh +ssxTakeLen (_ `HCons` _) ZKX = error "Permutation longer than shape" ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh) ssxDropLen HNil sh = sh -ssxDropLen (_ `HCons` is) (_ :!$@ sh) = ssxDropLen is sh -ssxDropLen (_ `HCons` is) (_ :!$? sh) = ssxDropLen is sh -ssxDropLen (_ `HCons` _) ZKSX = error "Permutation longer than shape" +ssxDropLen (_ `HCons` is) (_ :!% sh) = ssxDropLen is sh +ssxDropLen (_ `HCons` _) ZKX = error "Permutation longer than shape" ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute HNil _ = ZKSX +ssxPermute HNil _ = ZKX ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh) ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) -ssxIndex _ _ SZ (n :!$@ _) rest = n :!$@ rest -ssxIndex _ _ SZ (n :!$? _) rest = n :!$? rest -ssxIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :!$@ (sh :: StaticShX sh')) rest - | Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @sh') - = ssxIndex p pT i sh rest -ssxIndex p pT (SS (i :: SNat i')) (() :!$? (sh :: StaticShX sh')) rest - | Refl <- lemIndexSucc (Proxy @i') (Proxy @Nothing) (Proxy @sh') +ssxIndex _ _ SZ (n :!% _) rest = n :!% rest +ssxIndex p pT (SS (i :: SNat i')) ((_ :: SMayNat () SNat n) :!% (sh :: StaticShX sh')) rest + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') = ssxIndex p pT i sh rest -ssxIndex _ _ _ ZKSX _ = error "Index into empty shape" +ssxIndex _ _ _ ZKX _ = error "Index into empty shape" -- | The list argument gives indices into the original dimension list. transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh) @@ -526,7 +568,7 @@ sumInner :: forall sh sh' a. (Storable a, Num a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' | Refl <- lemAppNil @sh - = rerank ssh ssh' ZKSX (scalar . sumFull) + = rerank ssh ssh' ZKX (scalar . sumFull) sumOuter :: forall sh sh' a. (Storable a, Num a) => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a @@ -539,7 +581,7 @@ fromList1 :: forall n sh a. Storable a fromList1 ssh l | Dict <- lemKnownNatRankSSX ssh = case ssh of - m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) -> + SKnown m@GHC_SNat :!% _ | natVal m /= fromIntegral (length l) -> error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ "does not match the type (" ++ show (natVal m) ++ ")" _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) @@ -575,4 +617,4 @@ reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX reshapePartial ssh1 ssh' sh2 (XArray arr) | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') - = XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr) + = XArray (S.reshape (shapeLshape sh2 ++ drop (lengthShX sh2) (S.shapeL arr)) arr) diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index d2883a7..e7e2fd6 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -23,11 +23,10 @@ {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +{-# OPTIONS -Wno-unused-imports #-} + {-| TODO: -* We should be more consistent in whether functions take a 'StaticShX' - argument or a 'KnownShapeX' constraint. - * Allow downtyping certain dimensions, and write conversions between Mixed, Ranked and Shaped @@ -89,7 +88,7 @@ import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Storable (Storable) import GHC.TypeLits -import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..), pattern SZ, pattern SS, Replicate) +import Data.Array.Mixed import qualified Data.Array.Mixed as X @@ -179,9 +178,8 @@ lemReplicatePlusApp _ _ _ = go (natSing @n) = sym (X.lemReplicateSucc @a @(n'm1 + m)) shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') -shAppSplit _ ZKSX idx = (ZSX, idx) -shAppSplit p (_ :!$@ ssh) (i :$@ idx) = first (i :$@) (shAppSplit p ssh idx) -shAppSplit p (_ :!$? ssh) (i :$? idx) = first (i :$?) (shAppSplit p ssh idx) +shAppSplit _ ZKX idx = (ZSX, idx) +shAppSplit p (_ :!% ssh) (i :$% idx) = first (i :$%) (shAppSplit p ssh idx) -- | Wrapper type used as a tag to attach instances on. The instances on arrays @@ -197,11 +195,11 @@ class PrimElt a where fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) - default fromPrimitive :: Coercible (Mixed' sh a) (Mixed' sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a - fromPrimitive (Mixed sh m) = Mixed sh (coerce m) + default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a + fromPrimitive = coerce - default toPrimitive :: Coercible (Mixed' sh (Primitive a)) (Mixed' sh a) => Mixed sh a -> Mixed sh (Primitive a) - toPrimitive (Mixed sh m) = Mixed sh (coerce m) + default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) + toPrimitive = coerce -- [PRIMITIVE ELEMENT TYPES LIST] instance PrimElt Int @@ -218,37 +216,31 @@ instance PrimElt () -- -- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type -- class. -data Mixed sh a = Mixed (IShX sh) (Mixed' sh a) -deriving instance Show (Mixed' sh a) => Show (Mixed sh a) - -unMixed :: Mixed sh a -> Mixed' sh a -unMixed (Mixed _ arr) = arr - -type Mixed' :: [Maybe Nat] -> Type -> Type -data family Mixed' sh a +type Mixed :: [Maybe Nat] -> Type -> Type +data family Mixed sh a -- NOTE: When opening up the Mixed abstraction, you might see dimension sizes -- that you're not supposed to see. In particular, you might see (nonempty) -- sizes of the elements of an empty array, which is information that should -- ostensibly not exist; the full array is still empty. -newtype instance Mixed' sh (Primitive a) = M_Primitive (XArray sh a) +data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) deriving (Show) -- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance Mixed' sh Int = M_Int (XArray sh Int) +newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show) -newtype instance Mixed' sh Double = M_Double (XArray sh Double) +newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show) -newtype instance Mixed' sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector) +newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) -- no content, orthotope optimises this (via Vector) deriving (Show) -- etc. -data instance Mixed' sh (a, b) = M_Tup2 !(Mixed' sh a) !(Mixed' sh b) -deriving instance (Show (Mixed' sh a), Show (Mixed' sh b)) => Show (Mixed' sh (a, b)) +data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) +deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) -- etc. -newtype instance Mixed' sh1 (Mixed sh2 a) = M_Nest (Mixed' (sh1 ++ sh2) a) -deriving instance Show (Mixed' (sh1 ++ sh2) a) => Show (Mixed' sh1 (Mixed sh2 a)) +data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(StaticShX sh1) !(Mixed (sh1 ++ sh2) a) +deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) -- | Internal helper data family mirroring 'Mixed' that consists of mutable @@ -279,7 +271,7 @@ type family ShapeTree a where ShapeTree () = () ShapeTree (a, b) = (ShapeTree a, ShapeTree b) - ShapeTree (Mixed' sh a) = (IShX sh, ShapeTree a) + ShapeTree (Mixed sh a) = (IShX sh, ShapeTree a) ShapeTree (Ranked n a) = (IShR n, ShapeTree a) ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) @@ -357,37 +349,37 @@ class Elt a where -- Arrays of scalars are basically just arrays of scalars. instance Storable a => Elt (Primitive a) where - mshape (Mixed sh _) = sh - mindex (Mixed _ (M_Primitive a)) i = Primitive (X.index a i) - mindexPartial (Mixed sh (M_Primitive a)) i = Mixed (X.shDropIx sh i) (M_Primitive (X.indexPartial a i)) - mscalar (Primitive x) = Mixed ZSX (M_Primitive (X.scalar x)) + mshape (M_Primitive sh _) = sh + mindex (M_Primitive _ a) i = Primitive (X.index a i) + mindexPartial (M_Primitive sh a) i = M_Primitive (X.shDropIx sh i) (X.indexPartial a i) + mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) mfromList1 sn l@(arr1 :| _) = - let sh = sn :$@ mshape arr1 - in Mixed sh (M_Primitive (X.fromList1 (X.staticShapeFrom sh) (map (coerce . unMixed) (toList l)))) - mtoList1 (Mixed sh (M_Primitive arr)) = map (Mixed (X.shTail sh) . coerce) (X.toList1 arr) + let sh = SKnown sn :$% mshape arr1 + in M_Primitive sh (X.fromList1 (X.staticShapeFrom sh) (map (\(M_Primitive _ a) -> a) (toList l))) + mtoList1 (M_Primitive sh arr) = map (M_Primitive (X.shTail sh)) (X.toList1 arr) mlift :: forall sh1 sh2. StaticShX sh2 -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) - mlift ssh2 f (Mixed _ (M_Primitive a)) + mlift ssh2 f (M_Primitive _ a) | Refl <- X.lemAppNil @sh1 , Refl <- X.lemAppNil @sh2 - , let result = f ZKSX a - = Mixed (X.shape ssh2 result) (M_Primitive result) + , let result = f ZKX a + = M_Primitive (X.shape ssh2 result) result mlift2 :: forall sh1 sh2 sh3. StaticShX sh3 -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) - mlift2 ssh3 f (Mixed _ (M_Primitive a)) (Mixed _ (M_Primitive b)) + mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) | Refl <- X.lemAppNil @sh1 , Refl <- X.lemAppNil @sh2 , Refl <- X.lemAppNil @sh3 - , let result = f ZKSX a b - = Mixed (X.shape ssh3 result) (M_Primitive result) + , let result = f ZKX a b + = M_Primitive (X.shape ssh3 result) result - memptyArray sh = Mixed sh (M_Primitive (X.empty sh)) + memptyArray sh = M_Primitive sh (X.empty sh) mshapeTree _ = () mshapeTreeEq _ () () = True mshapeTreeEmpty _ () = False @@ -400,18 +392,14 @@ instance Storable a => Elt (Primitive a) where mvecsWritePartial :: forall sh' sh s. IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (Mixed sh' (M_Primitive arr)) (MV_Primitive v) = do + mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do let arrsh = X.shape (X.staticShapeFrom sh') arr offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' arrsh)) VS.copy (VSM.slice offset (X.shapeSize arrsh) v) (X.toVector arr) - mvecsFreeze sh (MV_Primitive v) = Mixed sh . M_Primitive . X.fromVector sh <$> VS.freeze v + mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v -- [PRIMITIVE ELEMENT TYPES LIST] - - - -TODO -- should rewrite methods of Elt class to take ' in their name, and work on Mixed' instead of Mixed (and take explicit StaticShapeX). Then wrap all of the public functions to work on Mixed. Then don't export the contents of Elt from Nested.hs, and export the wrappers instead. This also makes the haddocks more consistent. deriving via Primitive Int instance Elt Int deriving via Primitive Double instance Elt Double deriving via Primitive () instance Elt () @@ -422,11 +410,12 @@ instance (Elt a, Elt b) => Elt (a, b) where mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) - mfromList1 l = M_Tup2 (mfromList1 ((\(M_Tup2 x _) -> x) <$> l)) - (mfromList1 ((\(M_Tup2 _ y) -> y) <$> l)) + mfromList1 n l = + M_Tup2 (mfromList1 n ((\(M_Tup2 x _) -> x) <$> l)) + (mfromList1 n ((\(M_Tup2 _ y) -> y) <$> l)) mtoList1 (M_Tup2 a b) = zipWith M_Tup2 (mtoList1 a) (mtoList1 b) - mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b) - mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y) + mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) + mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) mshapeTree (x, y) = (mshapeTree x, mshapeTree y) @@ -443,66 +432,74 @@ instance (Elt a, Elt b) => Elt (a, b) where mvecsWritePartial sh i y b mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShX :: [Maybe Nat] -> Constraint +class KnownShX sh where knownShX :: StaticShX sh +instance KnownShX '[] where knownShX = ZKX +instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX +instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX + -- Arrays of arrays are just arrays, but with more dimensions. -instance Elt a => Elt (Mixed sh' a) where +instance (Elt a, KnownShX sh') => Elt (Mixed sh' a) where -- TODO: this is quadratic in the nesting depth because it repeatedly -- truncates the shape vector to one a little shorter. Fix with a -- moverlongShape method, a prefix of which is mshape. mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh - mshape (M_Nest arr) - | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') - = fst (shAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr)) + mshape (M_Nest ssh arr) + = fst (shAppSplit (Proxy @sh') ssh (mshape arr)) - mindex (M_Nest arr) i = mindexPartial arr i + mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a + mindex (M_Nest _ arr) i = mindexPartial arr i mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - mindexPartial (M_Nest arr) i + mindexPartial (M_Nest ssh arr) i | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + = M_Nest (X.ssxDropIx ssh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) - mscalar = M_Nest + mscalar = M_Nest ZKX mfromList1 :: forall n sh. SNat n -> NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Just n : sh) (Mixed sh' a) - mfromList1 l - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @(n : sh)) (knownShapeX @sh')) - = M_Nest (mfromList1 (coerce l)) + mfromList1 sn l@(arr :| _) = + M_Nest (SKnown sn :!% X.staticShapeFrom (mshape arr)) + (mfromList1 sn ((\(M_Nest _ a) -> a) <$> l)) - mtoList1 (M_Nest arr) = coerce (mtoList1 arr) + mtoList1 (M_Nest ssh arr) = map (M_Nest (X.ssxTail ssh)) (mtoList1 arr) mlift :: forall sh1 sh2. - (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) + StaticShX sh2 + -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) - mlift f (M_Nest arr) - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) - = M_Nest (mlift f' arr) + mlift ssh2 f (M_Nest ssh1 arr) = M_Nest ssh2 (mlift (X.ssxAppend ssh2 ssh') f' arr) where + ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') ssh1 (mshape arr))) + f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b - f' _ + f' sshT | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) - = f (Proxy @(sh' ++ shT)) + = f (X.ssxAppend ssh' sshT) mlift2 :: forall sh1 sh2 sh3. - (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) + StaticShX sh3 + -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) - mlift2 f (M_Nest arr1) (M_Nest arr2) - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) - , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh3) (knownShapeX @sh')) - = M_Nest (mlift2 f' arr1 arr2) + mlift2 ssh3 f (M_Nest ssh1 arr1) (M_Nest _ arr2) = M_Nest ssh3 (mlift2 (X.ssxAppend ssh3 ssh') f' arr1 arr2) where + ssh' = X.staticShapeFrom (snd (shAppSplit (Proxy @sh') ssh1 (mshape arr1))) + f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b - f' _ + f' sshT | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) , Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) - , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) - = f (Proxy @(sh' ++ shT)) + = f (X.ssxAppend ssh' sshT) - memptyArray sh = M_Nest (memptyArray (X.shAppend sh (X.completeShXzeros (knownShapeX @sh')))) + memptyArray sh = M_Nest (X.staticShapeFrom sh) (memptyArray (X.shAppend sh (X.completeShXzeros (knownShX @sh')))) - mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (knownShapeX @sh')))) + mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) + mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (X.staticShapeFrom (mshape arr))))) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -512,12 +509,11 @@ instance Elt a => Elt (Mixed sh' a) where mvecsUnsafeNew sh example | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh (mshape example)) - (mindex example (X.zeroIxX (knownShapeX @sh'))) + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh sh') (mindex example (X.zeroIxX (X.staticShapeFrom sh'))) where sh' = mshape example - mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a) + mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs @@ -525,12 +521,11 @@ instance Elt a => Elt (Mixed sh' a) where IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -> ST s () - mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) - , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.shAppend sh12 sh') idx arr vecs + mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) + | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = mvecsWritePartial (X.shAppend sh12 sh') idx arr vecs - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.shAppend sh sh') vecs + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest (X.staticShapeFrom sh) <$> mvecsFreeze (X.shAppend sh sh') vecs -- | Create an array given a size and a function that computes the element at a @@ -572,27 +567,36 @@ mgenerate sh f = case X.enumShape sh of mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a -mtranspose perm - | Dict <- X.lemKnownShapeX (X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm (knownShapeX @sh))) (X.ssxDropLen perm (knownShapeX @sh))) - = mlift $ \(Proxy @sh') -> - X.rerankTop (knownShapeX @sh) (knownShapeX @(X.PermutePrefix is sh)) (knownShapeX @sh') - (X.transpose perm) +mtranspose perm arr = + let ssh = X.staticShapeFrom (mshape arr) + sshPP = X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm ssh)) (X.ssxDropLen perm ssh) + in mlift sshPP (\ssh' -> X.rerankTop ssh sshPP ssh' (X.transpose ssh perm)) arr mappend :: forall n m sh a. Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a -mappend = mlift2 go - where go :: forall sh' b. Storable b - => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b - go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append +mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 + where + sn :$% sh = mshape arr1 + sm :$% _ = mshape arr2 + ssh = X.staticShapeFrom sh + snm :: SMayNat () SNat (X.AddMaybe n m) + snm = case (sn, sm) of + (SUnknown{}, _) -> SUnknown () + (SKnown{}, SUnknown{}) -> SUnknown () + (SKnown n, SKnown m) -> SKnown (X.plusSNat n m) + + f :: forall sh' b. Storable b + => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b + f ssh' = X.append (X.ssxAppend ssh ssh') mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) -mfromVectorP sh v = M_Primitive (X.fromVector sh v) +mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) mfromVector :: forall sh a. (Storable a, PrimElt a) => IShX sh -> VS.Vector a -> Mixed sh a mfromVector sh v = fromPrimitive (mfromVectorP sh v) mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a -mtoVectorP (M_Primitive v) = X.toVector v +mtoVectorP (M_Primitive _ v) = X.toVector v mtoVector :: (Storable a, PrimElt a) => Mixed sh a -> VS.Vector a mtoVector arr = mtoVectorP (coerce toPrimitive arr) @@ -607,64 +611,60 @@ munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr ZIX mconstantP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) -mconstantP sh x = M_Primitive (X.constant sh x) +mconstantP sh x = M_Primitive sh (X.constant sh x) mconstant :: forall sh a. (Storable a, PrimElt a) => IShX sh -> a -> Mixed sh a mconstant sh x = fromPrimitive (mconstantP sh x) mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a -mslice i n = withKnownNat n $ mlift $ \_ -> X.slice i n +mslice i n arr = + let _ :$% sh = mshape arr + in withKnownNat n $ mlift (SKnown n :!% X.staticShapeFrom sh) (\_ -> X.slice i n) arr msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n = mlift $ \_ -> X.sliceU i n +msliceU i n arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.sliceU i n) arr mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 = mlift $ \_ -> X.rev1 +mrev1 arr = mlift (X.staticShapeFrom (mshape arr)) (\_ -> X.rev1) arr mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a -mreshape sh' = mlift $ \(_ :: Proxy shIn) -> - X.reshapePartial (knownShapeX @sh) (knownShapeX @shIn) sh' +mreshape sh' arr = + mlift (X.staticShapeFrom sh') + (\sshIn -> X.reshapePartial (X.staticShapeFrom (mshape arr)) sshIn sh') + arr -masXArrayPrimP :: Mixed sh (Primitive a) -> XArray sh a -masXArrayPrimP (M_Primitive arr) = arr +masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) +masXArrayPrimP (M_Primitive sh arr) = (sh, arr) -masXArrayPrim :: PrimElt a => Mixed sh a -> XArray sh a +masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a) masXArrayPrim = masXArrayPrimP . toPrimitive -mfromXArrayPrimP :: XArray sh a -> Mixed sh (Primitive a) +mfromXArrayPrimP :: IShX sh -> XArray sh a -> Mixed sh (Primitive a) mfromXArrayPrimP = M_Primitive -mfromXArrayPrim :: PrimElt a => XArray sh a -> Mixed sh a -mfromXArrayPrim = fromPrimitive . mfromXArrayPrimP +mfromXArrayPrim :: PrimElt a => IShX sh -> XArray sh a -> Mixed sh a +mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP -mliftPrim :: Storable a +mliftPrim :: (Storable a, PrimElt a) => (a -> a) - -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -mliftPrim f (M_Primitive (X.XArray arr)) = M_Primitive (X.XArray (S.mapA f arr)) + -> Mixed sh a -> Mixed sh a +mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) -mliftPrim2 :: Storable a +mliftPrim2 :: (Storable a, PrimElt a) => (a -> a -> a) - -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -mliftPrim2 f (M_Primitive (X.XArray arr1)) (M_Primitive (X.XArray arr2)) = - M_Primitive (X.XArray (S.zipWithA f arr1 arr2)) + -> Mixed sh a -> Mixed sh a -> Mixed sh a +mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = + fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) -instance (Storable a, Num a) => Num (Mixed sh (Primitive a)) where +instance (Storable a, Num a, PrimElt a) => Num (Mixed sh a) where (+) = mliftPrim2 (+) (-) = mliftPrim2 (-) (*) = mliftPrim2 (*) negate = mliftPrim negate abs = mliftPrim abs signum = mliftPrim signum - fromInteger n = - case X.ssxToShape' (knownShapeX @sh) of - Just sh -> M_Primitive (X.constant sh (fromInteger n)) - Nothing -> error "Data.Array.Nested.fromIntegral: \ - \Unknown components in shape, use explicit mconstant" - --- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types) -deriving via Mixed sh (Primitive Int) instance Num (Mixed sh Int) -deriving via Mixed sh (Primitive Double) instance Num (Mixed sh Double) + fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mconstant" -- | A rank-typed array: the number of dimensions of the array (its /rank/) is @@ -694,10 +694,10 @@ newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) -- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed' sh (Ranked n a) = M_Ranked (Mixed' sh (Mixed (Replicate n Nothing) a)) -deriving instance Show (Mixed' sh (Mixed (Replicate n Nothing) a)) => Show (Mixed' sh (Ranked n a)) -newtype instance Mixed' sh (Shaped sh' a) = M_Shaped (Mixed' sh (Mixed (MapJust sh' ) a)) -deriving instance Show (Mixed' sh (Mixed (MapJust sh' ) a)) => Show (Mixed' sh (Shaped sh' a)) +newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) +deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a)) +newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a)) +deriving instance Show (Mixed sh (Mixed (MapJust sh' ) a)) => Show (Mixed sh (Shaped sh' a)) newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a)) @@ -804,6 +804,8 @@ instance (Elt a, KnownNat n) => Elt (Ranked n a) where -} -- | The shape of a shape-typed array given as a list of 'SNat' values. +TODO -- write ListS and implement IxS and ShS in terms of it. +TODO -- for ListR and ListS, write an uncons function like for ListX and implement the cons pattern synonyms in terms of it directly, instead of using a separate uncons function for both types. data ShS sh where ZSS :: ShS '[] (:$$) :: forall n sh. SNat n -> ShS sh -> ShS (n : sh) |