diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-05-14 11:38:22 +0200 | 
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-05-14 11:38:22 +0200 | 
| commit | ffa8dacb1d7ea53438f784bf5f8b425b8cd48f46 (patch) | |
| tree | 154026b503f334bf14851a3826d1562c5cfb9d08 /src/Data/Array/Mixed | |
| parent | 978919b5be0fe74f5da1e071e97557d2e30f0ad2 (diff) | |
Split and uniformly rename Shape modules
Diffstat (limited to 'src/Data/Array/Mixed')
| -rw-r--r-- | src/Data/Array/Mixed/Lemmas.hs | 2 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 2 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 611 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/XArray.hs | 2 | 
4 files changed, 3 insertions, 614 deletions
| diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs index 560f762..ca82573 100644 --- a/src/Data/Array/Mixed/Lemmas.hs +++ b/src/Data/Array/Mixed/Lemmas.hs @@ -13,7 +13,7 @@ import Data.Type.Equality  import GHC.TypeLits  import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape +import Data.Array.Nested.Mixed.Shape  import Data.Array.Mixed.Types diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs index cedfa22..22672cb 100644 --- a/src/Data/Array/Mixed/Permutation.hs +++ b/src/Data/Array/Mixed/Permutation.hs @@ -29,7 +29,7 @@ import GHC.TypeError  import GHC.TypeLits  import GHC.TypeNats qualified as TN -import Data.Array.Mixed.Shape +import Data.Array.Nested.Mixed.Shape  import Data.Array.Mixed.Types diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs deleted file mode 100644 index eb8434f..0000000 --- a/src/Data/Array/Mixed/Shape.hs +++ /dev/null @@ -1,611 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE NoStarIsType #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed.Shape where - -import Control.DeepSeq (NFData(..)) -import Data.Bifunctor (first) -import Data.Coerce -import Data.Foldable qualified as Foldable -import Data.Functor.Const -import Data.Functor.Product -import Data.Kind (Constraint, Type) -import Data.Monoid (Sum(..)) -import Data.Proxy -import Data.Type.Equality -import GHC.Exts (withDict) -import GHC.Generics (Generic) -import GHC.IsList (IsList) -import GHC.IsList qualified as IsList -import GHC.TypeLits - -import Data.Array.Mixed.Types - - --- | The length of a type-level list. If the argument is a shape, then the --- result is the rank of that shape. -type family Rank sh where -  Rank '[] = 0 -  Rank (_ : sh) = Rank sh + 1 - - --- * Mixed lists - -type role ListX nominal representational -type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type -data ListX sh f where -  ZX :: ListX '[] f -  (::%) :: f n -> ListX sh f -> ListX (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) -infixr 3 ::% - -instance (forall n. Show (f n)) => Show (ListX sh f) where -  showsPrec _ = listxShow shows - -instance (forall n. NFData (f n)) => NFData (ListX sh f) where -  rnf ZX = () -  rnf (x ::% l) = rnf x `seq` rnf l - -data UnconsListXRes f sh1 = -  forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) -listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) -listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) -listxUncons ZX = Nothing - --- | This checks only whether the types are equal; if the elements of the list --- are not singletons, their values may still differ. This corresponds to --- 'testEquality', except on the penultimate type parameter. -listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') -listxEqType ZX ZX = Just Refl -listxEqType (n ::% sh) (m ::% sh') -  | Just Refl <- testEquality n m -  , Just Refl <- listxEqType sh sh' -  = Just Refl -listxEqType _ _ = Nothing - --- | This checks whether the two lists actually contain equal values. This is --- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ --- in the @some@ package (except on the penultimate type parameter). -listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') -listxEqual ZX ZX = Just Refl -listxEqual (n ::% sh) (m ::% sh') -  | Just Refl <- testEquality n m -  , n == m -  , Just Refl <- listxEqual sh sh' -  = Just Refl -listxEqual _ _ = Nothing - -listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g -listxFmap _ ZX = ZX -listxFmap f (x ::% xs) = f x ::% listxFmap f xs - -listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m -listxFold _ ZX = mempty -listxFold f (x ::% xs) = f x <> listxFold f xs - -listxLength :: ListX sh f -> Int -listxLength = getSum . listxFold (\_ -> Sum 1) - -listxRank :: ListX sh f -> SNat (Rank sh) -listxRank ZX = SNat -listxRank (_ ::% l) | SNat <- listxRank l = SNat - -listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS -listxShow f l = showString "[" . go "" l . showString "]" -  where -    go :: String -> ListX sh' f -> ShowS -    go _ ZX = id -    go prefix (x ::% xs) = showString prefix . f x . go "," xs - -listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i) -listxFromList topssh topl = go topssh topl -  where -    go :: StaticShX sh' -> [i] -> ListX sh' (Const i) -    go ZKX [] = ZX -    go (_ :!% sh) (i : is) = Const i ::% go sh is -    go _ _ = error $ "listxFromList: Mismatched list length (type says " -                       ++ show (ssxLength topssh) ++ ", list has length " -                       ++ show (length topl) ++ ")" - -listxToList :: ListX sh' (Const i) -> [i] -listxToList ZX = [] -listxToList (Const i ::% is) = i : listxToList is - -listxHead :: ListX (mn ': sh) f -> f mn -listxHead (i ::% _) = i - -listxTail :: ListX (n : sh) i -> ListX sh i -listxTail (_ ::% sh) = sh - -listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f -listxAppend ZX idx' = idx' -listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' - -listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f -listxDrop long ZX = long -listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short - -listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f -listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh -listxInit (_ ::% ZX) = ZX - -listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh)) -listxLast (_ ::% sh@(_ ::% _)) = listxLast sh -listxLast (x ::% ZX) = x - -listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g) -listxZip ZX ZX = ZX -listxZip (i ::% irest) (j ::% jrest) = -  Pair i j ::% listxZip irest jrest - -listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g -             -> ListX sh h -listxZipWith _ ZX ZX = ZX -listxZipWith f (i ::% is) (j ::% js) = -  f i j ::% listxZipWith f is js - - --- * Mixed indices - --- | This is a newtype over 'ListX'. -type role IxX nominal representational -type IxX :: [Maybe Nat] -> Type -> Type -newtype IxX sh i = IxX (ListX sh (Const i)) -  deriving (Eq, Ord, Generic) - -pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i -pattern ZIX = IxX ZX - -pattern (:.%) -  :: forall {sh1} {i}. -     forall n sh. (n : sh ~ sh1) -  => i -> IxX sh i -> IxX sh1 i -pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) -  where i :.% IxX shl = IxX (Const i ::% shl) -infixr 3 :.% - -{-# COMPLETE ZIX, (:.%) #-} - -type IIxX sh = IxX sh Int - -instance Show i => Show (IxX sh i) where -  showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l - -instance Functor (IxX sh) where -  fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) - -instance Foldable (IxX sh) where -  foldMap f (IxX l) = listxFold (f . getConst) l - -instance NFData i => NFData (IxX sh i) - -ixxLength :: IxX sh i -> Int -ixxLength (IxX l) = listxLength l - -ixxRank :: IxX sh i -> SNat (Rank sh) -ixxRank (IxX l) = listxRank l - -ixxZero :: StaticShX sh -> IIxX sh -ixxZero ZKX = ZIX -ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh - -ixxZero' :: IShX sh -> IIxX sh -ixxZero' ZSX = ZIX -ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh - -ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i -ixxFromList = coerce (listxFromList @_ @i) - -ixxHead :: IxX (n : sh) i -> i -ixxHead (IxX list) = getConst (listxHead list) - -ixxTail :: IxX (n : sh) i -> IxX sh i -ixxTail (IxX list) = IxX (listxTail list) - -ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i -ixxAppend = coerce (listxAppend @_ @(Const i)) - -ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i -ixxDrop = coerce (listxDrop @(Const i) @(Const i)) - -ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i -ixxInit = coerce (listxInit @(Const i)) - -ixxLast :: forall n sh i. IxX (n : sh) i -> i -ixxLast = coerce (listxLast @(Const i)) - -ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j) -ixxZip ZIX ZIX = ZIX -ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js - -ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k -ixxZipWith _ ZIX ZIX = ZIX -ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js - -ixxFromLinear :: IShX sh -> Int -> IIxX sh -ixxFromLinear = \sh i -> case go sh i of -  (idx, 0) -> idx -  _ -> error $ "ixxFromLinear: out of range (" ++ show i ++ -               " in array of shape " ++ show sh ++ ")" -  where -    -- returns (index in subarray, remaining index in enclosing array) -    go :: IShX sh -> Int -> (IIxX sh, Int) -    go ZSX i = (ZIX, i) -    go (n :$% sh) i = -      let (idx, i') = go sh i -          (upi, locali) = i' `quotRem` fromSMayNat' n -      in (locali :.% idx, upi) - -ixxToLinear :: IShX sh -> IIxX sh -> Int -ixxToLinear = \sh i -> fst (go sh i) -  where -    -- returns (index in subarray, size of subarray) -    go :: IShX sh -> IIxX sh -> (Int, Int) -    go ZSX ZIX = (0, 1) -    go (n :$% sh) (i :.% ix) = -      let (lidx, sz) = go sh ix -      in (sz * i + lidx, fromSMayNat' n * sz) - - --- * Mixed shapes - -data SMayNat i f n where -  SUnknown :: i -> SMayNat i f Nothing -  SKnown :: f n -> SMayNat i f (Just n) -deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) -deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) -deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) - -instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where -  rnf (SUnknown i) = rnf i -  rnf (SKnown x) = rnf x - -instance TestEquality f => TestEquality (SMayNat i f) where -  testEquality SUnknown{} SUnknown{} = Just Refl -  testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl -  testEquality _ _ = Nothing - -fromSMayNat :: (n ~ Nothing => i -> r) -            -> (forall m. n ~ Just m => f m -> r) -            -> SMayNat i f n -> r -fromSMayNat f _ (SUnknown i) = f i -fromSMayNat _ g (SKnown s) = g s - -fromSMayNat' :: SMayNat Int SNat n -> Int -fromSMayNat' = fromSMayNat id fromSNat' - -type family AddMaybe n m where -  AddMaybe Nothing _ = Nothing -  AddMaybe (Just _) Nothing = Nothing -  AddMaybe (Just n) (Just m) = Just (n + m) - -smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) -smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) -smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) -smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) - - --- | This is a newtype over 'ListX'. -type role ShX nominal representational -type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) -  deriving (Eq, Ord, Generic) - -pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i -pattern ZSX = ShX ZX - -pattern (:$%) -  :: forall {sh1} {i}. -     forall n sh. (n : sh ~ sh1) -  => SMayNat i SNat n -> ShX sh i -> ShX sh1 i -pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) -  where i :$% ShX shl = ShX (i ::% shl) -infixr 3 :$% - -{-# COMPLETE ZSX, (:$%) #-} - -type IShX sh = ShX sh Int - -instance Show i => Show (ShX sh i) where -  showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l - -instance Functor (ShX sh) where -  fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) - -instance NFData i => NFData (ShX sh i) where -  rnf (ShX ZX) = () -  rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) -  rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) - --- | This checks only whether the types are equal; unknown dimensions might --- still differ. This corresponds to 'testEquality', except on the penultimate --- type parameter. -shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') -shxEqType ZSX ZSX = Just Refl -shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') -  | Just Refl <- sameNat n m -  , Just Refl <- shxEqType sh sh' -  = Just Refl -shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh') -  | Just Refl <- shxEqType sh sh' -  = Just Refl -shxEqType _ _ = Nothing - --- | This checks whether all dimensions have the same value. This is more than --- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the --- @some@ package (except on the penultimate type parameter). -shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') -shxEqual ZSX ZSX = Just Refl -shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') -  | Just Refl <- sameNat n m -  , Just Refl <- shxEqual sh sh' -  = Just Refl -shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') -  | i == j -  , Just Refl <- shxEqual sh sh' -  = Just Refl -shxEqual _ _ = Nothing - -shxLength :: ShX sh i -> Int -shxLength (ShX l) = listxLength l - -shxRank :: ShX sh i -> SNat (Rank sh) -shxRank (ShX l) = listxRank l - --- | The number of elements in an array described by this shape. -shxSize :: IShX sh -> Int -shxSize ZSX = 1 -shxSize (n :$% sh) = fromSMayNat' n * shxSize sh - -shxFromList :: StaticShX sh -> [Int] -> ShX sh Int -shxFromList topssh topl = go topssh topl -  where -    go :: StaticShX sh' -> [Int] -> ShX sh' Int -    go ZKX [] = ZSX -    go (SKnown sn :!% sh) (i : is) -      | i == fromSNat' sn = SKnown sn :$% go sh is -      | otherwise = error $ "shxFromList: Value does not match typing (type says " -                              ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" -    go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is -    go _ _ = error $ "shxFromList: Mismatched list length (type says " -                       ++ show (ssxLength topssh) ++ ", list has length " -                       ++ show (length topl) ++ ")" - -shxToList :: IShX sh -> [Int] -shxToList ZSX = [] -shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh - -shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) - -shxHead :: ShX (n : sh) i -> SMayNat i SNat n -shxHead (ShX list) = listxHead list - -shxTail :: ShX (n : sh) i -> ShX sh i -shxTail (ShX list) = ShX (listxTail list) - -shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i -shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) - -shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i -shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) - -shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i -shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) - -shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i -shxInit = coerce (listxInit @(SMayNat i SNat)) - -shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) -shxLast = coerce (listxLast @(SMayNat i SNat)) - -shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i -shxTakeSSX _ = flip go -  where -    go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i -    go ZKX _ = ZSX -    go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh - -shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) -           -> ShX sh i -> ShX sh j -> ShX sh k -shxZipWith _ ZSX ZSX = ZSX -shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js - --- This is a weird operation, so it has a long name -shxCompleteZeros :: StaticShX sh -> IShX sh -shxCompleteZeros ZKX = ZSX -shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh -shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh - -shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) -shxSplitApp _ ZKX idx = (ZSX, idx) -shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) - -shxEnum :: IShX sh -> [IIxX sh] -shxEnum = \sh -> go sh id [] -  where -    go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] -    go ZSX f = (f ZIX :) -    go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] - -shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh') -shxCast ZSX ZKX = Just ZSX -shxCast (SKnown n   :$% sh) (SKnown m    :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SKnown m    :!% ssh) | n == fromSNat' m              = (SKnown m :$%) <$> shxCast sh ssh -shxCast (SKnown n   :$% sh) (SUnknown () :!% ssh)                                 = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh)                                 = (SUnknown n :$%) <$> shxCast sh ssh -shxCast _ _ = Nothing - --- | Partial version of 'shxCast'. -shxCast' :: IShX sh -> StaticShX sh' -> IShX sh' -shxCast' sh ssh = case shxCast sh ssh of -  Just sh' -> sh' -  Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")" - - --- * Static mixed shapes - --- | The part of a shape that is statically known. (A newtype over 'ListX'.) -type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) -  deriving (Eq, Ord) - -pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh -pattern ZKX = StaticShX ZX - -pattern (:!%) -  :: forall {sh1}. -     forall n sh. (n : sh ~ sh1) -  => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 -pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) -  where i :!% StaticShX shl = StaticShX (i ::% shl) -infixr 3 :!% - -{-# COMPLETE ZKX, (:!%) #-} - -instance Show (StaticShX sh) where -  showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l - -instance NFData (StaticShX sh) where -  rnf (StaticShX ZX) = () -  rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l) -  rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l) - -instance TestEquality StaticShX where -  testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2 - -ssxLength :: StaticShX sh -> Int -ssxLength (StaticShX l) = listxLength l - -ssxRank :: StaticShX sh -> SNat (Rank sh) -ssxRank (StaticShX l) = listxRank l - --- | @ssxEqType = 'testEquality'@. Provided for consistency. -ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') -ssxEqType = testEquality - -ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') -ssxAppend ZKX sh' = sh' -ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' - -ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n -ssxHead (StaticShX list) = listxHead list - -ssxTail :: StaticShX (n : sh) -> StaticShX sh -ssxTail (_ :!% ssh) = ssh - -ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' -ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) - -ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) -ssxInit = coerce (listxInit @(SMayNat () SNat)) - -ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) -ssxLast = coerce (listxLast @(SMayNat () SNat)) - --- | This may fail if @sh@ has @Nothing@s in it. -ssxToShX' :: StaticShX sh -> Maybe (IShX sh) -ssxToShX' ZKX = Just ZSX -ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh -ssxToShX' (SUnknown _ :!% _) = Nothing - -ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) -ssxReplicate SZ = ZKX -ssxReplicate (SS (n :: SNat n')) -  | Refl <- lemReplicateSucc @(Nothing @Nat) @n' -  = SUnknown () :!% ssxReplicate n - -ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKX = [] -ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh - -ssxFromShape :: IShX sh -> StaticShX sh -ssxFromShape ZSX = ZKX -ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh - -ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) -ssxFromSNat SZ = ZKX -ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n - - --- | Evidence for the static part of a shape. This pops up only when you are --- polymorphic in the element type of an array. -type KnownShX :: [Maybe Nat] -> Constraint -class KnownShX sh where knownShX :: StaticShX sh -instance KnownShX '[] where knownShX = ZKX -instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX -instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX - -withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r -withKnownShX k = withDict @(KnownShX sh) k - - --- * Flattening - -type Flatten sh = Flatten' 1 sh - -type family Flatten' acc sh where -  Flatten' acc '[] = Just acc -  Flatten' acc (Nothing : sh) = Nothing -  Flatten' acc (Just n : sh) = Flatten' (acc * n) sh - --- This function is currently unused -ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) -ssxFlatten = go (SNat @1) -  where -    go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) -    go acc ZKX = SKnown acc -    go _ (SUnknown () :!% _) = SUnknown () -    go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh - -shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) -shxFlatten = go (SNat @1) -  where -    go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) -    go acc ZSX = SKnown acc -    go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) -    go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh - -    goUnknown :: Int -> IShX sh -> Int -    goUnknown acc ZSX = acc -    goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh -    goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh - - --- | Very untyped: only length is checked (at runtime). -instance KnownShX sh => IsList (ListX sh (Const i)) where -  type Item (ListX sh (Const i)) = i -  fromList = listxFromList (knownShX @sh) -  toList = listxToList - --- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. -instance KnownShX sh => IsList (IxX sh i) where -  type Item (IxX sh i) = i -  fromList = IxX . IsList.fromList -  toList = Foldable.toList - --- | Untyped: length and known dimensions are checked (at runtime). -instance KnownShX sh => IsList (ShX sh Int) where -  type Item (ShX sh Int) = Int -  fromList = shxFromList (knownShX @sh) -  toList = shxToList diff --git a/src/Data/Array/Mixed/XArray.hs b/src/Data/Array/Mixed/XArray.hs index cb790e1..3e7a498 100644 --- a/src/Data/Array/Mixed/XArray.hs +++ b/src/Data/Array/Mixed/XArray.hs @@ -34,7 +34,7 @@ import GHC.TypeLits  import Data.Array.Mixed.Internal.Arith  import Data.Array.Mixed.Lemmas  import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape +import Data.Array.Nested.Mixed.Shape  import Data.Array.Mixed.Types | 
