diff options
Diffstat (limited to 'src/Data/Array/Nested/Mixed')
| -rw-r--r-- | src/Data/Array/Nested/Mixed/ListX.hs | 139 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 707 |
2 files changed, 501 insertions, 345 deletions
diff --git a/src/Data/Array/Nested/Mixed/ListX.hs b/src/Data/Array/Nested/Mixed/ListX.hs new file mode 100644 index 0000000..2c8a9cc --- /dev/null +++ b/src/Data/Array/Nested/Mixed/ListX.hs @@ -0,0 +1,139 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Mixed.ListX (ListX, pattern ZX, pattern (::%), listxShow, lazily, lazilyConcat, lazilyForce, Rank, coerceEqualRankListX) where + +import Control.DeepSeq (NFData(..)) +import Data.Foldable qualified as Foldable +import Data.Kind (Type) +import Data.Type.Equality +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits +#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) +import GHC.TypeLits.Orphans () +#endif + +import Data.Array.Nested.Types + + +-- | The length of a type-level list. If the argument is a shape, then the +-- result is the rank of that shape. +type family Rank sh where + Rank '[] = 0 + Rank (_ : sh) = Rank sh + 1 + + +-- * Mixed lists implementation + +-- | Data invariant: each element on the list is in WHNF (the spine may be lazy) +-- and the length of the list is the same as of the type-level shape. +type role ListX nominal representational +type ListX :: [Maybe Nat] -> Type -> Type +newtype ListX sh i = ListX [i] + deriving (Eq, Ord, NFData, Foldable) + +{-# INLINE ZX #-} +pattern ZX :: forall sh i. () => sh ~ '[] => ListX sh i +pattern ZX <- (listxNull -> Just Refl) + where ZX = ListX [] + +{-# INLINE listxNull #-} +listxNull :: ListX sh i -> Maybe (sh :~: '[]) +listxNull (ListX []) = Just unsafeCoerceRefl +listxNull (ListX (_ : _)) = Nothing + +{-# INLINE (::%) #-} +pattern (::%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> ListX sh i -> ListX sh1 i +pattern i ::% l <- (listxUncons -> Just (UnconsListXRes i l)) + where !i ::% ListX !l = ListX (i : l) +infixr 3 ::% + +data UnconsListXRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListXRes i (ListX sh i) +{-# INLINE listxUncons #-} +listxUncons :: forall sh1 i. ListX sh1 i -> Maybe (UnconsListXRes i sh1) +listxUncons (ListX (i : l)) = gcastWith (unsafeCoerceRefl :: Head sh1 ': Tail sh1 :~: sh1) $ + Just (UnconsListXRes i (ListX @(Tail sh1) l)) +listxUncons (ListX []) = Nothing + +{-# COMPLETE ZX, (::%) #-} + +-- | This function makes no attempt to stop you from breaking the data invariant for 'ListX'. If you do so, you must later ensure that the invariant is reinstated, for example using @'lazilyForce' 'id'@. +{-# INLINE lazily #-} +lazily :: ([a] -> [b]) -> ListX sh a -> ListX sh b +lazily f (ListX l) = ListX $ f l + +-- | This function makes no attempt to stop you from breaking the data invariant for 'ListX'. If you do so, you must later ensure that the invariant is reinstated, for example using @'lazilyForce' 'id'@. +{-# INLINE lazilyConcat #-} +lazilyConcat :: ([a] -> [b] -> [c]) -> ListX sh a -> ListX sh' b -> ListX (sh ++ sh') c +lazilyConcat f (ListX l) (ListX k) = ListX $ f l k + +-- | This operation forces all elements of the @[b]@ list to restore the strictness part of the data invariant for 'ListX'. Note that ensuring the list has the right length is still the user's responsibility. +{-# INLINE lazilyForce #-} +lazilyForce :: ([a] -> [b]) -> ListX sh a -> ListX sh b +lazilyForce f (ListX l) = let res = f l + in foldr seq () res `seq` ListX res + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ListX sh i) +#else +instance Show i => Show (ListX sh i) where + showsPrec _ = listxShow shows +#endif + +{-# INLINE listxShow #-} +listxShow :: forall sh i. (i -> ShowS) -> ListX sh i -> ShowS +listxShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListX sh' i -> ShowS + go _ ZX = id + go prefix (x ::% xs) = showString prefix . f x . go "," xs + +-- This can't be derived, becauses the list needs to be fully evaluated, +-- per data invariant. This version is faster than versions defined using +-- (::%) or lazilyForce. +instance Functor (ListX l) where + {-# INLINE fmap #-} + fmap f (ListX l) = + let fmap' [] = [] + fmap' (x : xs) = let y = f x + rest = fmap' xs + in y `seq` rest `seq` (y : rest) + in ListX $ fmap' l + +-- | Very untyped: not even length is checked (at runtime). +instance IsList (ListX sh i) where + type Item (ListX sh i) = i + {-# INLINE fromList #-} + fromList l = foldr seq () l `seq` ListX l + {-# INLINE toList #-} + toList = Foldable.toList + +coerceEqualRankListX :: Rank sh ~ Rank sh' => ListX sh i -> ListX sh' i +coerceEqualRankListX (ListX l) = ListX l diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index 5f4775c..a01e0f3 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -1,12 +1,15 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -16,162 +19,44 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Mixed.Shape where +module Data.Array.Nested.Mixed.Shape ( + module Data.Array.Nested.Mixed.Shape, + Rank, +) where import Control.DeepSeq (NFData(..)) +import Control.Exception (assert) import Data.Bifunctor (first) import Data.Coerce import Data.Foldable qualified as Foldable -import Data.Functor.Const -import Data.Functor.Product import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality -import GHC.Exts (withDict) -import GHC.Generics (Generic) +import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits +#if !MIN_VERSION_GLASGOW_HASKELL(9,8,0,0) +import GHC.TypeLits.Orphans () +#endif -import Data.Array.Mixed.Types - - --- | The length of a type-level list. If the argument is a shape, then the --- result is the rank of that shape. -type family Rank sh where - Rank '[] = 0 - Rank (_ : sh) = Rank sh + 1 - - --- * Mixed lists - -type role ListX nominal representational -type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type -data ListX sh f where - ZX :: ListX '[] f - (::%) :: f n -> ListX sh f -> ListX (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) -infixr 3 ::% - -instance (forall n. Show (f n)) => Show (ListX sh f) where - showsPrec _ = listxShow shows - -instance (forall n. NFData (f n)) => NFData (ListX sh f) where - rnf ZX = () - rnf (x ::% l) = rnf x `seq` rnf l - -data UnconsListXRes f sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) -listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) -listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) -listxUncons ZX = Nothing - --- | This checks only whether the types are equal; if the elements of the list --- are not singletons, their values may still differ. This corresponds to --- 'testEquality', except on the penultimate type parameter. -listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') -listxEqType ZX ZX = Just Refl -listxEqType (n ::% sh) (m ::% sh') - | Just Refl <- testEquality n m - , Just Refl <- listxEqType sh sh' - = Just Refl -listxEqType _ _ = Nothing - --- | This checks whether the two lists actually contain equal values. This is --- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ --- in the @some@ package (except on the penultimate type parameter). -listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') -listxEqual ZX ZX = Just Refl -listxEqual (n ::% sh) (m ::% sh') - | Just Refl <- testEquality n m - , n == m - , Just Refl <- listxEqual sh sh' - = Just Refl -listxEqual _ _ = Nothing - -listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g -listxFmap _ ZX = ZX -listxFmap f (x ::% xs) = f x ::% listxFmap f xs - -listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m -listxFold _ ZX = mempty -listxFold f (x ::% xs) = f x <> listxFold f xs - -listxLength :: ListX sh f -> Int -listxLength = getSum . listxFold (\_ -> Sum 1) - -listxRank :: ListX sh f -> SNat (Rank sh) -listxRank ZX = SNat -listxRank (_ ::% l) | SNat <- listxRank l = SNat - -listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS -listxShow f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListX sh' f -> ShowS - go _ ZX = id - go prefix (x ::% xs) = showString prefix . f x . go "," xs - -listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i) -listxFromList topssh topl = go topssh topl - where - go :: StaticShX sh' -> [i] -> ListX sh' (Const i) - go ZKX [] = ZX - go (_ :!% sh) (i : is) = Const i ::% go sh is - go _ _ = error $ "listxFromList: Mismatched list length (type says " - ++ show (ssxLength topssh) ++ ", list has length " - ++ show (length topl) ++ ")" - -listxToList :: ListX sh' (Const i) -> [i] -listxToList ZX = [] -listxToList (Const i ::% is) = i : listxToList is - -listxHead :: ListX (mn ': sh) f -> f mn -listxHead (i ::% _) = i - -listxTail :: ListX (n : sh) i -> ListX sh i -listxTail (_ ::% sh) = sh - -listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f -listxAppend ZX idx' = idx' -listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' - -listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f -listxDrop long ZX = long -listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short - -listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f -listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh -listxInit (_ ::% ZX) = ZX - -listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh)) -listxLast (_ ::% sh@(_ ::% _)) = listxLast sh -listxLast (x ::% ZX) = x - -listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g) -listxZip ZX ZX = ZX -listxZip (i ::% irest) (j ::% jrest) = - Pair i j ::% listxZip irest jrest - -listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g - -> ListX sh h -listxZipWith _ ZX ZX = ZX -listxZipWith f (i ::% is) (j ::% js) = - f i j ::% listxZipWith f is js +import Data.Array.Nested.Mixed.ListX +import Data.Array.Nested.Types -- * Mixed indices --- | This is a newtype over 'ListX'. +-- | An index into a mixed-typed array. type role IxX nominal representational type IxX :: [Maybe Nat] -> Type -> Type -newtype IxX sh i = IxX (ListX sh (Const i)) - deriving (Eq, Ord, Generic) +newtype IxX sh i = IxX (ListX sh i) + deriving (Eq, Ord, NFData, Functor, Foldable) pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i pattern ZIX = IxX ZX @@ -180,30 +65,30 @@ pattern (:.%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => i -> IxX sh i -> IxX sh1 i -pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) - where i :.% IxX shl = IxX (Const i ::% shl) +pattern i :.% l <- IxX (i ::% (IxX -> l)) + where i :.% IxX l = IxX (i ::% l) infixr 3 :.% {-# COMPLETE ZIX, (:.%) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxX sh = IxX sh Int +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (IxX sh i) +#else instance Show i => Show (IxX sh i) where - showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l - -instance Functor (IxX sh) where - fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) - -instance Foldable (IxX sh) where - foldMap f (IxX l) = listxFold (f . getConst) l + showsPrec _ (IxX l) = listxShow shows l +#endif -instance NFData i => NFData (IxX sh i) - -ixxLength :: IxX sh i -> Int -ixxLength (IxX l) = listxLength l +{-# INLINE ixxFromList #-} +ixxFromList :: StaticShX sh -> [i] -> IxX sh i +ixxFromList sh l = assert (ssxLength sh == length l) $ IsList.fromList l ixxRank :: IxX sh i -> SNat (Rank sh) -ixxRank (IxX l) = listxRank l +ixxRank ZIX = SNat +ixxRank (_ :.% l) | SNat <- ixxRank l = SNat ixxZero :: StaticShX sh -> IIxX sh ixxZero ZKX = ZIX @@ -213,85 +98,131 @@ ixxZero' :: IShX sh -> IIxX sh ixxZero' ZSX = ZIX ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh -ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i -ixxFromList = coerce (listxFromList @_ @i) - -ixxHead :: IxX (n : sh) i -> i -ixxHead (IxX list) = getConst (listxHead list) +ixxHead :: IxX (mn ': sh) i -> i +ixxHead (i :.% _) = i ixxTail :: IxX (n : sh) i -> IxX sh i -ixxTail (IxX list) = IxX (listxTail list) +ixxTail (_ :.% sh) = sh ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i -ixxAppend = coerce (listxAppend @_ @(Const i)) +ixxAppend (IxX l1) (IxX l2) = IxX $ lazilyConcat (++) l1 l2 -ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i -ixxDrop = coerce (listxDrop @(Const i) @(Const i)) +ixxDrop :: forall i j sh sh'. IxX sh j -> IxX (sh ++ sh') i -> IxX sh' i +ixxDrop ZIX long = long +ixxDrop (_ :.% short) long = case long of _ :.% long' -> ixxDrop short long' -ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i -ixxInit = coerce (listxInit @(Const i)) +ixxInit :: forall i n sh. IxX (n : sh) i -> IxX (Init (n : sh)) i +ixxInit (i :.% sh@(_ :.% _)) = i :.% ixxInit sh +ixxInit (_ :.% ZIX) = ZIX -ixxLast :: forall n sh i. IxX (n : sh) i -> i -ixxLast = coerce (listxLast @(Const i)) +ixxLast :: forall i n sh. IxX (n : sh) i -> i +ixxLast (_ :.% sh@(_ :.% _)) = ixxLast sh +ixxLast (x :.% ZIX) = x ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j) ixxZip ZIX ZIX = ZIX ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js +{-# INLINE ixxZipWith #-} ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k ixxZipWith _ ZIX ZIX = ZIX ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js -ixxFromLinear :: IShX sh -> Int -> IIxX sh -ixxFromLinear = \sh i -> case go sh i of - (idx, 0) -> idx - _ -> error $ "ixxFromLinear: out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" +ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i +ixxCast ZKX ZIX = ZIX +ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx +ixxCast _ _ = error "ixxCast: ranks don't match" + +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixxToLinear #-} +ixxToLinear :: Num i => IShX sh -> IxX sh i -> i +ixxToLinear = \sh i -> go sh i 0 where - -- returns (index in subarray, remaining index in enclosing array) - go :: IShX sh -> Int -> (IIxX sh, Int) - go ZSX i = (ZIX, i) - go (n :$% sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` fromSMayNat' n - in (locali :.% idx, upi) + go :: Num i => IShX sh -> IxX sh i -> i -> i + go ZSX ZIX !a = a + go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i) -ixxToLinear :: IShX sh -> IIxX sh -> Int -ixxToLinear = \sh i -> fst (go sh i) +{-# INLINEABLE ixxFromLinear #-} +ixxFromLinear :: Num i => IShX sh -> Int -> IxX sh i +ixxFromLinear = \sh -> -- give this function arity 1 so that suffixes is shared when it's called many times + let suffixes = drop 1 (scanr (*) 1 (shxToList sh)) + in fromLin0 sh suffixes where - -- returns (index in subarray, size of subarray) - go :: IShX sh -> IIxX sh -> (Int, Int) - go ZSX ZIX = (0, 1) - go (n :$% sh) (i :.% ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, fromSMayNat' n * sz) + -- Unfold first iteration of fromLin to do the range check. + -- Don't inline this function at first to allow GHC to inline the outer + -- function and realise that 'suffixes' is shared. But then later inline it + -- anyway, to avoid the function call. Removing the pragma makes GHC + -- somehow unable to recognise that 'suffixes' can be shared in a loop. + {-# NOINLINE [0] fromLin0 #-} + fromLin0 :: Num i => IShX sh -> [Int] -> Int -> IxX sh i + fromLin0 sh suffixes i = + if i < 0 then outrange sh i else + case (sh, suffixes) of + (ZSX, _) | i > 0 -> outrange sh i + | otherwise -> ZIX + ((fromSMayNat' -> n) :$% sh', suff : suffs) -> + let (q, r) = i `quotRem` suff + in if q >= n then outrange sh i else + fromIntegral q :.% fromLin sh' suffs r + _ -> error "impossible" + + fromLin :: Num i => IShX sh -> [Int] -> Int -> IxX sh i + fromLin ZSX _ !_ = ZIX + fromLin (_ :$% sh') (suff : suffs) i = + let (q, r) = i `quotRem` suff -- suff == shrSize sh' + in fromIntegral q :.% fromLin sh' suffs r + fromLin _ _ _ = error "impossible" + + {-# NOINLINE outrange #-} + outrange :: IShX sh -> Int -> a + outrange sh i = error $ "ixxFromLinear: out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" + +shxEnum :: IShX sh -> [IIxX sh] +shxEnum = shxEnum' + +{-# INLINABLE shxEnum' #-} -- ensure this can be specialised at use site +shxEnum' :: Num i => IShX sh -> [IxX sh i] +shxEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shxSize sh - 1]] + where + suffixes = drop 1 (scanr (*) 1 (shxToList sh)) + + fromLin :: Num i => IShX sh -> [Int] -> Int# -> IxX sh i + fromLin ZSX _ _ = ZIX + fromLin (_ :$% sh') (I# suff# : suffs) i# = + let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh' + in fromIntegral (I# q#) :.% fromLin sh' suffs r# + fromLin _ _ _ = error "impossible" -- * Mixed shapes -data SMayNat i f n where - SUnknown :: i -> SMayNat i f Nothing - SKnown :: f n -> SMayNat i f (Just n) -deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) -deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) -deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) +data SMayNat i n where + SUnknown :: i -> SMayNat i Nothing + SKnown :: SNat n -> SMayNat i (Just n) +deriving instance Show i => Show (SMayNat i n) +deriving instance Eq i => Eq (SMayNat i n) +deriving instance Ord i => Ord (SMayNat i n) -instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where +instance NFData i => NFData (SMayNat i n) where rnf (SUnknown i) = rnf i - rnf (SKnown x) = rnf x + rnf (SKnown SNat) = () -instance TestEquality f => TestEquality (SMayNat i f) where +instance TestEquality (SMayNat i) where testEquality SUnknown{} SUnknown{} = Just Refl testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl testEquality _ _ = Nothing +{-# INLINE fromSMayNat #-} fromSMayNat :: (n ~ Nothing => i -> r) - -> (forall m. n ~ Just m => f m -> r) - -> SMayNat i f n -> r + -> (forall m. n ~ Just m => SNat m -> r) + -> SMayNat i n -> r fromSMayNat f _ (SUnknown i) = f i fromSMayNat _ g (SKnown s) = g s -fromSMayNat' :: SMayNat Int SNat n -> Int +{-# INLINE fromSMayNat' #-} +fromSMayNat' :: SMayNat Int n -> Int fromSMayNat' = fromSMayNat id fromSNat' type family AddMaybe n m where @@ -299,134 +230,226 @@ type family AddMaybe n m where AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) -smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m) smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) --- | This is a newtype over 'ListX'. type role ShX nominal representational type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) - deriving (Eq, Ord, Generic) - -pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i -pattern ZSX = ShX ZX - -pattern (:$%) - :: forall {sh1} {i}. - forall n sh. (n : sh ~ sh1) - => SMayNat i SNat n -> ShX sh i -> ShX sh1 i -pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) - where i :$% ShX shl = ShX (i ::% shl) -infixr 3 :$% - -{-# COMPLETE ZSX, (:$%) #-} +data ShX sh i where + ZSX :: ShX '[] i + ConsUnknown :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i +-- TODO: bring this UNPACK back when GHC no longer crashes: +-- ConsKnown :: forall n sh i. {-# UNPACK #-} SNat n -> ShX sh i -> ShX (Just n : sh) i + ConsKnown :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i +deriving instance Ord i => Ord (ShX sh i) -type IShX sh = ShX sh Int +-- A manually defined instance and this INLINEABLE is needed to specialize +-- mdot1Inner (otherwise GHC warns specialization breaks down here). +instance Eq i => Eq (ShX sh i) where + {-# INLINEABLE (==) #-} + ZSX == ZSX = True + ConsUnknown i1 sh1 == ConsUnknown i2 sh2 = i1 == i2 && sh1 == sh2 + ConsKnown _ sh1 == ConsKnown _ sh2 = sh1 == sh2 +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ShX sh i) +#else instance Show i => Show (ShX sh i) where - showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + showsPrec _ l = shxShow (fromSMayNat shows (shows . fromSNat)) l +#endif + +instance NFData i => NFData (ShX sh i) where + rnf ZSX = () + rnf (x `ConsUnknown` l) = rnf x `seq` rnf l + rnf (SNat `ConsKnown` l) = rnf l instance Functor (ShX sh) where - fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) + {-# INLINE fmap #-} + fmap f l = shxFmap (fromSMayNat (SUnknown . f) SKnown) l -instance NFData i => NFData (ShX sh i) where - rnf (ShX ZX) = () - rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) - rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) +data UnconsShXRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsShXRes (SMayNat i n) (ShX sh i) +shxUncons :: ShX sh1 i -> Maybe (UnconsShXRes i sh1) +shxUncons (i `ConsUnknown` shl') = Just (UnconsShXRes (SUnknown i) shl') +shxUncons (i `ConsKnown` shl') = Just (UnconsShXRes (SKnown i) shl') +shxUncons ZSX = Nothing --- | This checks only whether the types are equal; unknown dimensions might --- still differ. This corresponds to 'testEquality', except on the penultimate --- type parameter. +-- | This checks only whether the types are equal; if the elements of the list +-- are not singletons, their values may still differ. This corresponds to +-- 'testEquality', except on the penultimate type parameter. shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') shxEqType ZSX ZSX = Just Refl -shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') - | Just Refl <- sameNat n m - , Just Refl <- shxEqType sh sh' - = Just Refl -shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh') +shxEqType (_ `ConsUnknown` sh) (_ `ConsUnknown` sh') | Just Refl <- shxEqType sh sh' = Just Refl +shxEqType (n `ConsKnown` sh) (m `ConsKnown` sh') + | Just Refl <- testEquality n m + , Just Refl <- shxEqType sh sh' + = Just Refl shxEqType _ _ = Nothing --- | This checks whether all dimensions have the same value. This is more than --- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the --- @some@ package (except on the penultimate type parameter). +-- | This checks whether the two lists actually contain equal values. This is +-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ +-- in the @some@ package (except on the penultimate type parameter). shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') shxEqual ZSX ZSX = Just Refl -shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') - | Just Refl <- sameNat n m +shxEqual (n `ConsUnknown` sh) (m `ConsUnknown` sh') + | n == m , Just Refl <- shxEqual sh sh' = Just Refl -shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') - | i == j +shxEqual (n `ConsKnown` sh) (m `ConsKnown` sh') + | Just Refl <- testEquality n m , Just Refl <- shxEqual sh sh' = Just Refl shxEqual _ _ = Nothing +{-# INLINE shxFmap #-} +shxFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ShX sh i -> ShX sh j +shxFmap _ ZSX = ZSX +shxFmap f (x `ConsUnknown` xs) = case f (SUnknown x) of + SUnknown y -> y `ConsUnknown` shxFmap f xs +shxFmap f (x `ConsKnown` xs) = case f (SKnown x) of + SKnown y -> y `ConsKnown` shxFmap f xs + +{-# INLINE shxFoldMap #-} +shxFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ShX sh i -> m +shxFoldMap _ ZSX = mempty +shxFoldMap f (x `ConsUnknown` xs) = f (SUnknown x) <> shxFoldMap f xs +shxFoldMap f (x `ConsKnown` xs) = f (SKnown x) <> shxFoldMap f xs + shxLength :: ShX sh i -> Int -shxLength (ShX l) = listxLength l +shxLength = getSum . shxFoldMap (\_ -> Sum 1) shxRank :: ShX sh i -> SNat (Rank sh) -shxRank (ShX l) = listxRank l +shxRank ZSX = SNat +shxRank (_ `ConsUnknown` l) | SNat <- shxRank l = SNat +shxRank (_ `ConsKnown` l) | SNat <- shxRank l = SNat + +{-# INLINE shxShow #-} +shxShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ShX sh i -> ShowS +shxShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ShX sh' i -> ShowS + go _ ZSX = id + go prefix (x `ConsUnknown` xs) = showString prefix . f (SUnknown x) . go "," xs + go prefix (x `ConsKnown` xs) = showString prefix . f (SKnown x) . go "," xs + +shxHead :: ShX (mn ': sh) i -> SMayNat i mn +shxHead (i `ConsUnknown` _) = SUnknown i +shxHead (i `ConsKnown` _) = SKnown i + +shxTail :: ShX (n : sh) i -> ShX sh i +shxTail (_ `ConsUnknown` sh) = sh +shxTail (_ `ConsKnown` sh) = sh + +shxAppend :: ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i +shxAppend ZSX idx' = idx' +shxAppend (i `ConsUnknown` idx) idx' = i `ConsUnknown` shxAppend idx idx' +shxAppend (i `ConsKnown` idx) idx' = i `ConsKnown` shxAppend idx idx' + +shxDropSh :: forall sh sh' i j. ShX sh j -> ShX (sh ++ sh') i -> ShX sh' i +shxDropSh ZSX long = long +shxDropSh (_ `ConsUnknown` short) long = case long of + _ `ConsUnknown` long' -> shxDropSh short long' +shxDropSh (_ `ConsKnown` short) long = case long of + _ `ConsKnown` long' -> shxDropSh short long' + +shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i +shxDropSSX = coerce (shxDropSh @_ @_ @i @()) + +shxInit :: forall i n sh. ShX (n : sh) i -> ShX (Init (n : sh)) i +shxInit (i `ConsUnknown` sh@(_ `ConsUnknown` _)) = i `ConsUnknown` shxInit sh +shxInit (i `ConsUnknown` sh@(_ `ConsKnown` _)) = i `ConsUnknown` shxInit sh +shxInit (_ `ConsUnknown` ZSX) = ZSX +shxInit (i `ConsKnown` sh@(_ `ConsUnknown` _)) = i `ConsKnown` shxInit sh +shxInit (i `ConsKnown` sh@(_ `ConsKnown` _)) = i `ConsKnown` shxInit sh +shxInit (_ `ConsKnown` ZSX) = ZSX + +shxLast :: forall i n sh. ShX (n : sh) i -> SMayNat i (Last (n : sh)) +shxLast (_ `ConsUnknown` sh@(_ `ConsUnknown` _)) = shxLast sh +shxLast (_ `ConsUnknown` sh@(_ `ConsKnown` _)) = shxLast sh +shxLast (x `ConsUnknown` ZSX) = SUnknown x +shxLast (_ `ConsKnown` sh@(_ `ConsUnknown` _)) = shxLast sh +shxLast (_ `ConsKnown` sh@(_ `ConsKnown` _)) = shxLast sh +shxLast (x `ConsKnown` ZSX) = SKnown x + +pattern (:$%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => SMayNat i n -> ShX sh i -> ShX sh1 i +pattern i :$% shl <- (shxUncons -> Just (UnconsShXRes i shl)) + where i :$% shl = case i of; SUnknown x -> x `ConsUnknown` shl; SKnown x -> x `ConsKnown` shl +infixr 3 :$% + +{-# COMPLETE ZSX, (:$%) #-} + +type IShX sh = ShX sh Int -- | The number of elements in an array described by this shape. shxSize :: IShX sh -> Int shxSize ZSX = 1 shxSize (n :$% sh) = fromSMayNat' n * shxSize sh -shxFromList :: StaticShX sh -> [Int] -> ShX sh Int -shxFromList topssh topl = go topssh topl +-- We don't report the size of the list in case of errors in order not to retain the list. +{-# INLINEABLE shxFromList #-} +shxFromList :: StaticShX sh -> [Int] -> IShX sh +shxFromList (StaticShX topssh) topl = go topssh topl where - go :: StaticShX sh' -> [Int] -> ShX sh' Int - go ZKX [] = ZSX - go (SKnown sn :!% sh) (i : is) - | i == fromSNat' sn = SKnown sn :$% go sh is - | otherwise = error $ "shxFromList: Value does not match typing (type says " - ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" - go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is - go _ _ = error $ "shxFromList: Mismatched list length (type says " - ++ show (ssxLength topssh) ++ ", list has length " - ++ show (length topl) ++ ")" + go :: ShX sh' () -> [Int] -> ShX sh' Int + go ZSX [] = ZSX + go ZSX _ = error $ "shxFromList: List too long (type says " ++ show (shxLength topssh) ++ ")" + go (ConsKnown sn sh) (i : is) + | i == fromSNat' sn = ConsKnown sn (go sh is) + | otherwise = error "shxFromList: Value does not match typing" + go (ConsUnknown () sh) (i : is) = ConsUnknown i (go sh is) + go _ _ = error $ "shxFromList: List too short (type says " ++ show (shxLength topssh) ++ ")" +{-# INLINEABLE shxToList #-} shxToList :: IShX sh -> [Int] -shxToList ZSX = [] -shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh - -shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) - -shxHead :: ShX (n : sh) i -> SMayNat i SNat n -shxHead (ShX list) = listxHead list - -shxTail :: ShX (n : sh) i -> ShX sh i -shxTail (ShX list) = ShX (listxTail list) +shxToList l = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ShX sh Int -> is + go ZSX = nil + go (ConsUnknown i rest) = i `cons` go rest + go (ConsKnown sn rest) = fromSNat' sn `cons` go rest + in go l) -shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i -shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) +-- If it ever matters for performance, this is unsafeCoercible. +shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i +shxFromSSX ZKX = ZSX +shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh)) + | Refl <- lemMapJustCons @sh Refl + = SKnown n :$% shxFromSSX sh +shxFromSSX (SUnknown _ :!% _) = error "unreachable" -shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i -shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) +-- | This may fail if @sh@ has @Nothing@s in it. +shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i) +shxFromSSX2 ZKX = Just ZSX +shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh +shxFromSSX2 (SUnknown _ :!% _) = Nothing -shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i -shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) +shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSSX _ ZKX _ = ZSX +shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh -shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i -shxInit = coerce (listxInit @(SMayNat i SNat)) +shxTakeSh :: forall sh sh' i proxy. proxy sh' -> ShX sh i -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSh _ ZSX _ = ZSX +shxTakeSh p (_ :$% ssh1) (n :$% sh) = n :$% shxTakeSh p ssh1 sh -shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) -shxLast = coerce (listxLast @(SMayNat i SNat)) +{-# INLINEABLE shxTakeIx #-} +shxTakeIx :: forall sh sh' i j. Proxy sh' -> IxX sh j -> ShX (sh ++ sh') i -> ShX sh i +shxTakeIx _ (IxX ZX) _ = ZSX +shxTakeIx proxy (IxX (_ ::% long)) short = case short of i :$% short' -> i :$% shxTakeIx proxy (IxX long) short' -shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i -shxTakeSSX _ = flip go - where - go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i - go ZKX _ = ZSX - go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh +{-# INLINEABLE shxDropIx #-} +shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i +shxDropIx ZIX long = long +shxDropIx (_ :.% short) long = case long of _ :$% long' -> shxDropIx short long' -shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) +{-# INLINE shxZipWith #-} +shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n) -> ShX sh i -> ShX sh j -> ShX sh k shxZipWith _ ZSX ZSX = ZSX shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js @@ -437,115 +460,115 @@ shxCompleteZeros ZKX = ZSX shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh -shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) +shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) shxSplitApp _ ZKX idx = (ZSX, idx) shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) -shxEnum :: IShX sh -> [IIxX sh] -shxEnum = \sh -> go sh id [] - where - go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] - go ZSX f = (f ZIX :) - go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] - -shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh') -shxCast ZSX ZKX = Just ZSX -shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh -shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh -shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh +shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh') +shxCast ZKX ZSX = Just ZSX +shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh +shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh shxCast _ _ = Nothing -- | Partial version of 'shxCast'. -shxCast' :: IShX sh -> StaticShX sh' -> IShX sh' -shxCast' sh ssh = case shxCast sh ssh of +shxCast' :: StaticShX sh' -> IShX sh -> IShX sh' +shxCast' ssh sh = case shxCast ssh sh of Just sh' -> sh' Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")" -- * Static mixed shapes --- | The part of a shape that is statically known. (A newtype over 'ListX'.) +-- | The part of a shape that is statically known. (A newtype over 'ShX'.) type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) - deriving (Eq, Ord) +newtype StaticShX sh = StaticShX (ShX sh ()) + deriving (NFData) + +instance Eq (StaticShX sh) where _ == _ = True +instance Ord (StaticShX sh) where compare _ _ = EQ pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh -pattern ZKX = StaticShX ZX +pattern ZKX = StaticShX ZSX pattern (:!%) :: forall {sh1}. forall n sh. (n : sh ~ sh1) - => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 -pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) - where i :!% StaticShX shl = StaticShX (i ::% shl) + => SMayNat () n -> StaticShX sh -> StaticShX sh1 +pattern i :!% shl <- StaticShX (shxUncons -> Just (UnconsShXRes i (StaticShX -> shl))) + where i :!% StaticShX shl = case i of; SUnknown () -> StaticShX (() `ConsUnknown` shl); SKnown x -> StaticShX (x `ConsKnown` shl) + infixr 3 :!% {-# COMPLETE ZKX, (:!%) #-} +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (StaticShX sh) +#else instance Show (StaticShX sh) where - showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l - -instance NFData (StaticShX sh) where - rnf (StaticShX ZX) = () - rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l) - rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l) + showsPrec _ (StaticShX l) = shxShow (fromSMayNat shows (shows . fromSNat)) l +#endif instance TestEquality StaticShX where - testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2 + testEquality (StaticShX l1) (StaticShX l2) = shxEqType l1 l2 ssxLength :: StaticShX sh -> Int -ssxLength (StaticShX l) = listxLength l +ssxLength (StaticShX l) = shxLength l ssxRank :: StaticShX sh -> SNat (Rank sh) -ssxRank (StaticShX l) = listxRank l +ssxRank (StaticShX l) = shxRank l -- | @ssxEqType = 'testEquality'@. Provided for consistency. ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') ssxEqType = testEquality ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') -ssxAppend ZKX sh' = sh' -ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' +ssxAppend = coerce (shxAppend @_ @()) -ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n -ssxHead (StaticShX list) = listxHead list +ssxHead :: StaticShX (n : sh) -> SMayNat () n +ssxHead (StaticShX list) = shxHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh -ssxTail (_ :!% ssh) = ssh +ssxTail (StaticShX list) = StaticShX (shxTail list) -ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' -ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) +ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh +ssxTakeIx _ (IxX ZX) _ = ZKX +ssxTakeIx proxy (IxX (_ ::% long)) short = case short of i :!% short' -> i :!% ssxTakeIx proxy (IxX long) short' -ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) -ssxInit = coerce (listxInit @(SMayNat () SNat)) +ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropIx (IxX ZX) long = long +ssxDropIx (IxX (_ ::% short)) long = case long of _ :!% long' -> ssxDropIx (IxX short) long' -ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) -ssxLast = coerce (listxLast @(SMayNat () SNat)) +ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSh = coerce (shxDropSh @_ @_ @() @i) --- | This may fail if @sh@ has @Nothing@s in it. -ssxToShX' :: StaticShX sh -> Maybe (IShX sh) -ssxToShX' ZKX = Just ZSX -ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh -ssxToShX' (SUnknown _ :!% _) = Nothing +ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSSX = coerce (shxDropSh @_ @_ @() @()) + +ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) +ssxInit = coerce (shxInit @()) + +ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh)) +ssxLast = coerce (shxLast @()) ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX ssxReplicate (SS (n :: SNat n')) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n' + | Refl <- lemReplicateSucc @(Nothing @Nat) n = SUnknown () :!% ssxReplicate n -ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKX = [] -ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom :: StaticShX sh -> Int -> [Int] +ssxIotaFrom ZKX _ = [] +ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i + 1) -ssxFromShape :: IShX sh -> StaticShX sh -ssxFromShape ZSX = ZKX -ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh +ssxFromShX :: ShX sh i -> StaticShX sh +ssxFromShX ZSX = ZKX +ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) ssxFromSNat SZ = ZKX -ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n +ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) n = SUnknown () :!% ssxFromSNat n -- | Evidence for the static part of a shape. This pops up only when you are @@ -557,7 +580,7 @@ instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SK instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r -withKnownShX k = withDict @(KnownShX sh) k +withKnownShX = withDict @(KnownShX sh) -- * Flattening @@ -570,18 +593,18 @@ type family Flatten' acc sh where Flatten' acc (Just n : sh) = Flatten' (acc * n) sh -- This function is currently unused -ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) +ssxFlatten :: StaticShX sh -> SMayNat () (Flatten sh) ssxFlatten = go (SNat @1) where - go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) + go :: SNat acc -> StaticShX sh -> SMayNat () (Flatten' acc sh) go acc ZKX = SKnown acc go _ (SUnknown () :!% _) = SUnknown () go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh -shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) +shxFlatten :: IShX sh -> SMayNat Int (Flatten sh) shxFlatten = go (SNat @1) where - go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go :: SNat acc -> IShX sh -> SMayNat Int (Flatten' acc sh) go acc ZSX = SKnown acc go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh @@ -592,20 +615,14 @@ shxFlatten = go (SNat @1) goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh --- | Very untyped: only length is checked (at runtime). -instance KnownShX sh => IsList (ListX sh (Const i)) where - type Item (ListX sh (Const i)) = i - fromList = listxFromList (knownShX @sh) - toList = listxToList - -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. -instance KnownShX sh => IsList (IxX sh i) where +instance IsList (IxX sh i) where type Item (IxX sh i) = i fromList = IxX . IsList.fromList toList = Foldable.toList -- | Untyped: length and known dimensions are checked (at runtime). -instance KnownShX sh => IsList (ShX sh Int) where - type Item (ShX sh Int) = Int +instance KnownShX sh => IsList (IShX sh) where + type Item (IShX sh) = Int fromList = shxFromList (knownShX @sh) toList = shxToList |
