aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Mixed/Shape.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Mixed/Shape.hs')
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs181
1 files changed, 79 insertions, 102 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index debe5ec..c8c9a7b 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -36,6 +36,7 @@ import Data.Functor.Const
import Data.Functor.Product
import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
+import Data.Proxy
import Data.Type.Equality
import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict)
import GHC.Generics (Generic)
@@ -298,36 +299,73 @@ ixxToLinear = \sh i -> go sh i 0
-- * Mixed shape-like lists to be used for ShX and StaticShX
+data SMayNat i n where
+ SUnknown :: i -> SMayNat i Nothing
+ SKnown :: {-# UNPACK #-} SNat n -> SMayNat i (Just n)
+deriving instance Show i => Show (SMayNat i n)
+deriving instance Eq i => Eq (SMayNat i n)
+deriving instance Ord i => Ord (SMayNat i n)
+
+instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where
+ rnf (SUnknown i) = rnf i
+ rnf (SKnown x) = rnf x
+
+instance TestEquality (SMayNat i) where
+ testEquality SUnknown{} SUnknown{} = Just Refl
+ testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl
+ testEquality _ _ = Nothing
+
+{-# INLINE fromSMayNat #-}
+fromSMayNat :: (n ~ Nothing => i -> r)
+ -> (forall m. n ~ Just m => SNat m -> r)
+ -> SMayNat i n -> r
+fromSMayNat f _ (SUnknown i) = f i
+fromSMayNat _ g (SKnown s) = g s
+
+fromSMayNat' :: SMayNat Int n -> Int
+fromSMayNat' = fromSMayNat id fromSNat'
+
+type family AddMaybe n m where
+ AddMaybe Nothing _ = Nothing
+ AddMaybe (Just _) Nothing = Nothing
+ AddMaybe (Just n) (Just m) = Just (n + m)
+
+smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m)
+smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
+smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
+smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
+
+
type role ListH nominal representational
-type ListH :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
-data ListH sh f where
- ZH :: ListH '[] f
- (::#) :: forall n sh {f}. f n -> ListH sh f -> ListH (n : sh) f
-deriving instance (forall n. Eq (f n)) => Eq (ListH sh f)
-deriving instance (forall n. Ord (f n)) => Ord (ListH sh f)
+type ListH :: [Maybe Nat] -> Type -> Type
+data ListH sh i where
+ ZH :: ListH '[] i
+ (::#) :: forall n sh i. SMayNat i n -> ListH sh i -> ListH (n : sh) i
+deriving instance Eq i => Eq (ListH sh i)
+deriving instance Ord i => Ord (ListH sh i)
infixr 3 ::#
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
-deriving instance (forall n. Show (f n)) => Show (ListH sh f)
+deriving instance Show i => Show (ListH sh i)
#else
-instance (forall n. Show (f n)) => Show (ListH sh f) where
+instance Show i => Show (ListH sh i) where
showsPrec _ = listhShow shows
#endif
-instance (forall n. NFData (f n)) => NFData (ListH sh f) where
+instance (forall n. NFData (SMayNat i n)) => NFData (ListH sh i) where
rnf ZH = ()
rnf (x ::# l) = rnf x `seq` rnf l
-data UnconsListHRes f sh1 =
- forall n sh. (n : sh ~ sh1) => UnconsListHRes (ListH sh f) (f n)
-listhUncons :: ListH sh1 f -> Maybe (UnconsListHRes f sh1)
+data UnconsListHRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListHRes (ListH sh i) (SMayNat i n)
+listhUncons :: ListH sh1 i -> Maybe (UnconsListHRes i sh1)
listhUncons (i ::# shl') = Just (UnconsListHRes shl' i)
listhUncons ZH = Nothing
-- | This checks only whether the types are equal; if the elements of the list
-- are not singletons, their values may still differ. This corresponds to
-- 'testEquality', except on the penultimate type parameter.
-listhEqType :: TestEquality f => ListH sh f -> ListH sh' f -> Maybe (sh :~: sh')
+listhEqType :: ListH sh i -> ListH sh' i -> Maybe (sh :~: sh')
listhEqType ZH ZH = Just Refl
listhEqType (n ::# sh) (m ::# sh')
| Just Refl <- testEquality n m
@@ -338,7 +376,7 @@ listhEqType _ _ = Nothing
-- | This checks whether the two lists actually contain equal values. This is
-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
-- in the @some@ package (except on the penultimate type parameter).
-listhEqual :: (TestEquality f, forall n. Eq (f n)) => ListH sh f -> ListH sh' f -> Maybe (sh :~: sh')
+listhEqual :: Eq i => ListH sh i -> ListH sh' i -> Maybe (sh :~: sh')
listhEqual ZH ZH = Just Refl
listhEqual (n ::# sh) (m ::# sh')
| Just Refl <- testEquality n m
@@ -348,123 +386,58 @@ listhEqual (n ::# sh) (m ::# sh')
listhEqual _ _ = Nothing
{-# INLINE listhFmap #-}
-listhFmap :: (forall n. f n -> g n) -> ListH sh f -> ListH sh g
+listhFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ListH sh i -> ListH sh j
listhFmap _ ZH = ZH
listhFmap f (x ::# xs) = f x ::# listhFmap f xs
{-# INLINE listhFoldMap #-}
-listhFoldMap :: Monoid m => (forall n. f n -> m) -> ListH sh f -> m
+listhFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ListH sh i -> m
listhFoldMap _ ZH = mempty
listhFoldMap f (x ::# xs) = f x <> listhFoldMap f xs
-listhLength :: ListH sh f -> Int
+listhLength :: ListH sh i -> Int
listhLength = getSum . listhFoldMap (\_ -> Sum 1)
-listhRank :: ListH sh f -> SNat (Rank sh)
+listhRank :: ListH sh i -> SNat (Rank sh)
listhRank ZH = SNat
listhRank (_ ::# l) | SNat <- listhRank l = SNat
{-# INLINE listhShow #-}
-listhShow :: forall sh f. (forall n. f n -> ShowS) -> ListH sh f -> ShowS
+listhShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ListH sh i -> ShowS
listhShow f l = showString "[" . go "" l . showString "]"
where
- go :: String -> ListH sh' f -> ShowS
+ go :: String -> ListH sh' i -> ShowS
go _ ZH = id
go prefix (x ::# xs) = showString prefix . f x . go "," xs
-listhFromList :: StaticShX sh -> [i] -> ListH sh (Const i)
-listhFromList topssh topl = go topssh topl
- where
- go :: StaticShX sh' -> [i] -> ListH sh' (Const i)
- go ZKX [] = ZH
- go (_ :!% sh) (i : is) = Const i ::# go sh is
- go _ _ = error $ "listhFromList: Mismatched list length (type says "
- ++ show (ssxLength topssh) ++ ", list has length "
- ++ show (length topl) ++ ")"
-
-{-# INLINEABLE listhToList #-}
-listhToList :: ListH sh (Const i) -> [i]
-listhToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
- let go :: ListH sh (Const i) -> is
- go ZH = nil
- go (Const i ::# is) = i `cons` go is
- in go list)
-
-listhHead :: ListH (mn ': sh) f -> f mn
+listhHead :: ListH (mn ': sh) i -> SMayNat i mn
listhHead (i ::# _) = i
listhTail :: ListH (n : sh) i -> ListH sh i
listhTail (_ ::# sh) = sh
-listhAppend :: ListH sh f -> ListH sh' f -> ListH (sh ++ sh') f
+listhAppend :: ListH sh i -> ListH sh' i -> ListH (sh ++ sh') i
listhAppend ZH idx' = idx'
listhAppend (i ::# idx) idx' = i ::# listhAppend idx idx'
-listhDrop :: forall f g sh sh'. ListH sh g -> ListH (sh ++ sh') f -> ListH sh' f
+listhDrop :: forall i j sh sh'. ListH sh j -> ListH (sh ++ sh') i -> ListH sh' i
listhDrop ZH long = long
listhDrop (_ ::# short) long = case long of _ ::# long' -> listhDrop short long'
-listhInit :: forall f n sh. ListH (n : sh) f -> ListH (Init (n : sh)) f
+listhInit :: forall i n sh. ListH (n : sh) i -> ListH (Init (n : sh)) i
listhInit (i ::# sh@(_ ::# _)) = i ::# listhInit sh
listhInit (_ ::# ZH) = ZH
-listhLast :: forall f n sh. ListH (n : sh) f -> f (Last (n : sh))
+listhLast :: forall i n sh. ListH (n : sh) i -> SMayNat i (Last (n : sh))
listhLast (_ ::# sh@(_ ::# _)) = listhLast sh
listhLast (x ::# ZH) = x
-listhZip :: ListH sh f -> ListH sh g -> ListH sh (Product f g)
-listhZip ZH ZH = ZH
-listhZip (i ::# irest) (j ::# jrest) = Pair i j ::# listhZip irest jrest
-
-{-# INLINE listhZipWith #-}
-listhZipWith :: (forall a. f a -> g a -> h a) -> ListH sh f -> ListH sh g
- -> ListH sh h
-listhZipWith _ ZH ZH = ZH
-listhZipWith f (i ::# is) (j ::# js) = f i j ::# listhZipWith f is js
-
-- * Mixed shapes
-data SMayNat i n where
- SUnknown :: i -> SMayNat i Nothing
- SKnown :: {-# UNPACK #-} SNat n -> SMayNat i (Just n)
-deriving instance Show i => Show (SMayNat i n)
-deriving instance Eq i => Eq (SMayNat i n)
-deriving instance Ord i => Ord (SMayNat i n)
-
-instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where
- rnf (SUnknown i) = rnf i
- rnf (SKnown x) = rnf x
-
-instance TestEquality (SMayNat i) where
- testEquality SUnknown{} SUnknown{} = Just Refl
- testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl
- testEquality _ _ = Nothing
-
-{-# INLINE fromSMayNat #-}
-fromSMayNat :: (n ~ Nothing => i -> r)
- -> (forall m. n ~ Just m => SNat m -> r)
- -> SMayNat i n -> r
-fromSMayNat f _ (SUnknown i) = f i
-fromSMayNat _ g (SKnown s) = g s
-
-fromSMayNat' :: SMayNat Int n -> Int
-fromSMayNat' = fromSMayNat id fromSNat'
-
-type family AddMaybe n m where
- AddMaybe Nothing _ = Nothing
- AddMaybe (Just _) Nothing = Nothing
- AddMaybe (Just n) (Just m) = Just (n + m)
-
-smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m)
-smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
-smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
-smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
-
-
-- | This is a newtype over 'ListH'.
type role ShX nominal representational
type ShX :: [Maybe Nat] -> Type -> Type
-newtype ShX sh i = ShX (ListH sh (SMayNat i))
+newtype ShX sh i = ShX (ListH sh i)
deriving (Eq, Ord, Generic)
pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
@@ -575,7 +548,7 @@ shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh
shxFromSSX2 (SUnknown _ :!% _) = Nothing
shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
-shxAppend = coerce (listhAppend @_ @(SMayNat i))
+shxAppend = coerce (listhAppend @_ @i)
shxHead :: ShX (n : sh) i -> SMayNat i n
shxHead (ShX list) = listhHead list
@@ -584,20 +557,20 @@ shxTail :: ShX (n : sh) i -> ShX sh i
shxTail (ShX list) = ShX (listhTail list)
shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i
-shxDropSSX = coerce (listhDrop @(SMayNat i) @(SMayNat ()))
+shxDropSSX = coerce (listhDrop @i @())
shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i
shxDropIx (IxX ZX) long = long
shxDropIx (IxX (_ ::% short)) long = case long of ShX (_ ::# long') -> shxDropIx (IxX short) (ShX long')
shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i
-shxDropSh = coerce (listhDrop @(SMayNat i) @(SMayNat i))
+shxDropSh = coerce (listhDrop @i @i)
shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
-shxInit = coerce (listhInit @(SMayNat i))
+shxInit = coerce (listhInit @i)
shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh))
-shxLast = coerce (listhLast @(SMayNat i))
+shxLast = coerce (listhLast @i)
shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i
shxTakeSSX _ ZKX _ = ZSX
@@ -654,7 +627,7 @@ shxCast' ssh sh = case shxCast ssh sh of
-- | The part of a shape that is statically known. (A newtype over 'ListH'.)
type StaticShX :: [Maybe Nat] -> Type
-newtype StaticShX sh = StaticShX (ListH sh (SMayNat ()))
+newtype StaticShX sh = StaticShX (ListH sh ())
deriving (Eq, Ord)
pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
@@ -705,21 +678,25 @@ ssxHead (StaticShX list) = listhHead list
ssxTail :: StaticShX (n : sh) -> StaticShX sh
ssxTail (_ :!% ssh) = ssh
-ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'
-ssxDropSSX = coerce (listhDrop @(SMayNat ()) @(SMayNat ()))
+ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh
+ssxTakeIx _ (IxX ZX) _ = ZKX
+ssxTakeIx proxy (IxX (_ ::% long)) short = case short of StaticShX (i ::# short') -> i :!% ssxTakeIx proxy (IxX long) (StaticShX short')
ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
ssxDropIx (IxX ZX) long = long
ssxDropIx (IxX (_ ::% short)) long = case long of StaticShX (_ ::# long') -> ssxDropIx (IxX short) (StaticShX long')
ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
-ssxDropSh = coerce (listhDrop @(SMayNat ()) @(SMayNat i))
+ssxDropSh = coerce (listhDrop @() @i)
+
+ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropSSX = coerce (listhDrop @() @())
ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
-ssxInit = coerce (listhInit @(SMayNat ()))
+ssxInit = coerce (listhInit @())
ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh))
-ssxLast = coerce (listhLast @(SMayNat ()))
+ssxLast = coerce (listhLast @())
ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
ssxReplicate SZ = ZKX