diff options
Diffstat (limited to 'src/Data/Array/Nested/Mixed/Shape.hs')
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 181 |
1 files changed, 79 insertions, 102 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index debe5ec..c8c9a7b 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -36,6 +36,7 @@ import Data.Functor.Const import Data.Functor.Product import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) +import Data.Proxy import Data.Type.Equality import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) import GHC.Generics (Generic) @@ -298,36 +299,73 @@ ixxToLinear = \sh i -> go sh i 0 -- * Mixed shape-like lists to be used for ShX and StaticShX +data SMayNat i n where + SUnknown :: i -> SMayNat i Nothing + SKnown :: {-# UNPACK #-} SNat n -> SMayNat i (Just n) +deriving instance Show i => Show (SMayNat i n) +deriving instance Eq i => Eq (SMayNat i n) +deriving instance Ord i => Ord (SMayNat i n) + +instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where + rnf (SUnknown i) = rnf i + rnf (SKnown x) = rnf x + +instance TestEquality (SMayNat i) where + testEquality SUnknown{} SUnknown{} = Just Refl + testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl + testEquality _ _ = Nothing + +{-# INLINE fromSMayNat #-} +fromSMayNat :: (n ~ Nothing => i -> r) + -> (forall m. n ~ Just m => SNat m -> r) + -> SMayNat i n -> r +fromSMayNat f _ (SUnknown i) = f i +fromSMayNat _ g (SKnown s) = g s + +fromSMayNat' :: SMayNat Int n -> Int +fromSMayNat' = fromSMayNat id fromSNat' + +type family AddMaybe n m where + AddMaybe Nothing _ = Nothing + AddMaybe (Just _) Nothing = Nothing + AddMaybe (Just n) (Just m) = Just (n + m) + +smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m) +smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) +smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) +smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) + + type role ListH nominal representational -type ListH :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type -data ListH sh f where - ZH :: ListH '[] f - (::#) :: forall n sh {f}. f n -> ListH sh f -> ListH (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListH sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListH sh f) +type ListH :: [Maybe Nat] -> Type -> Type +data ListH sh i where + ZH :: ListH '[] i + (::#) :: forall n sh i. SMayNat i n -> ListH sh i -> ListH (n : sh) i +deriving instance Eq i => Eq (ListH sh i) +deriving instance Ord i => Ord (ListH sh i) infixr 3 ::# #ifdef OXAR_DEFAULT_SHOW_INSTANCES -deriving instance (forall n. Show (f n)) => Show (ListH sh f) +deriving instance Show i => Show (ListH sh i) #else -instance (forall n. Show (f n)) => Show (ListH sh f) where +instance Show i => Show (ListH sh i) where showsPrec _ = listhShow shows #endif -instance (forall n. NFData (f n)) => NFData (ListH sh f) where +instance (forall n. NFData (SMayNat i n)) => NFData (ListH sh i) where rnf ZH = () rnf (x ::# l) = rnf x `seq` rnf l -data UnconsListHRes f sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListHRes (ListH sh f) (f n) -listhUncons :: ListH sh1 f -> Maybe (UnconsListHRes f sh1) +data UnconsListHRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListHRes (ListH sh i) (SMayNat i n) +listhUncons :: ListH sh1 i -> Maybe (UnconsListHRes i sh1) listhUncons (i ::# shl') = Just (UnconsListHRes shl' i) listhUncons ZH = Nothing -- | This checks only whether the types are equal; if the elements of the list -- are not singletons, their values may still differ. This corresponds to -- 'testEquality', except on the penultimate type parameter. -listhEqType :: TestEquality f => ListH sh f -> ListH sh' f -> Maybe (sh :~: sh') +listhEqType :: ListH sh i -> ListH sh' i -> Maybe (sh :~: sh') listhEqType ZH ZH = Just Refl listhEqType (n ::# sh) (m ::# sh') | Just Refl <- testEquality n m @@ -338,7 +376,7 @@ listhEqType _ _ = Nothing -- | This checks whether the two lists actually contain equal values. This is -- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ -- in the @some@ package (except on the penultimate type parameter). -listhEqual :: (TestEquality f, forall n. Eq (f n)) => ListH sh f -> ListH sh' f -> Maybe (sh :~: sh') +listhEqual :: Eq i => ListH sh i -> ListH sh' i -> Maybe (sh :~: sh') listhEqual ZH ZH = Just Refl listhEqual (n ::# sh) (m ::# sh') | Just Refl <- testEquality n m @@ -348,123 +386,58 @@ listhEqual (n ::# sh) (m ::# sh') listhEqual _ _ = Nothing {-# INLINE listhFmap #-} -listhFmap :: (forall n. f n -> g n) -> ListH sh f -> ListH sh g +listhFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ListH sh i -> ListH sh j listhFmap _ ZH = ZH listhFmap f (x ::# xs) = f x ::# listhFmap f xs {-# INLINE listhFoldMap #-} -listhFoldMap :: Monoid m => (forall n. f n -> m) -> ListH sh f -> m +listhFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ListH sh i -> m listhFoldMap _ ZH = mempty listhFoldMap f (x ::# xs) = f x <> listhFoldMap f xs -listhLength :: ListH sh f -> Int +listhLength :: ListH sh i -> Int listhLength = getSum . listhFoldMap (\_ -> Sum 1) -listhRank :: ListH sh f -> SNat (Rank sh) +listhRank :: ListH sh i -> SNat (Rank sh) listhRank ZH = SNat listhRank (_ ::# l) | SNat <- listhRank l = SNat {-# INLINE listhShow #-} -listhShow :: forall sh f. (forall n. f n -> ShowS) -> ListH sh f -> ShowS +listhShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ListH sh i -> ShowS listhShow f l = showString "[" . go "" l . showString "]" where - go :: String -> ListH sh' f -> ShowS + go :: String -> ListH sh' i -> ShowS go _ ZH = id go prefix (x ::# xs) = showString prefix . f x . go "," xs -listhFromList :: StaticShX sh -> [i] -> ListH sh (Const i) -listhFromList topssh topl = go topssh topl - where - go :: StaticShX sh' -> [i] -> ListH sh' (Const i) - go ZKX [] = ZH - go (_ :!% sh) (i : is) = Const i ::# go sh is - go _ _ = error $ "listhFromList: Mismatched list length (type says " - ++ show (ssxLength topssh) ++ ", list has length " - ++ show (length topl) ++ ")" - -{-# INLINEABLE listhToList #-} -listhToList :: ListH sh (Const i) -> [i] -listhToList list = build (\(cons :: i -> is -> is) (nil :: is) -> - let go :: ListH sh (Const i) -> is - go ZH = nil - go (Const i ::# is) = i `cons` go is - in go list) - -listhHead :: ListH (mn ': sh) f -> f mn +listhHead :: ListH (mn ': sh) i -> SMayNat i mn listhHead (i ::# _) = i listhTail :: ListH (n : sh) i -> ListH sh i listhTail (_ ::# sh) = sh -listhAppend :: ListH sh f -> ListH sh' f -> ListH (sh ++ sh') f +listhAppend :: ListH sh i -> ListH sh' i -> ListH (sh ++ sh') i listhAppend ZH idx' = idx' listhAppend (i ::# idx) idx' = i ::# listhAppend idx idx' -listhDrop :: forall f g sh sh'. ListH sh g -> ListH (sh ++ sh') f -> ListH sh' f +listhDrop :: forall i j sh sh'. ListH sh j -> ListH (sh ++ sh') i -> ListH sh' i listhDrop ZH long = long listhDrop (_ ::# short) long = case long of _ ::# long' -> listhDrop short long' -listhInit :: forall f n sh. ListH (n : sh) f -> ListH (Init (n : sh)) f +listhInit :: forall i n sh. ListH (n : sh) i -> ListH (Init (n : sh)) i listhInit (i ::# sh@(_ ::# _)) = i ::# listhInit sh listhInit (_ ::# ZH) = ZH -listhLast :: forall f n sh. ListH (n : sh) f -> f (Last (n : sh)) +listhLast :: forall i n sh. ListH (n : sh) i -> SMayNat i (Last (n : sh)) listhLast (_ ::# sh@(_ ::# _)) = listhLast sh listhLast (x ::# ZH) = x -listhZip :: ListH sh f -> ListH sh g -> ListH sh (Product f g) -listhZip ZH ZH = ZH -listhZip (i ::# irest) (j ::# jrest) = Pair i j ::# listhZip irest jrest - -{-# INLINE listhZipWith #-} -listhZipWith :: (forall a. f a -> g a -> h a) -> ListH sh f -> ListH sh g - -> ListH sh h -listhZipWith _ ZH ZH = ZH -listhZipWith f (i ::# is) (j ::# js) = f i j ::# listhZipWith f is js - -- * Mixed shapes -data SMayNat i n where - SUnknown :: i -> SMayNat i Nothing - SKnown :: {-# UNPACK #-} SNat n -> SMayNat i (Just n) -deriving instance Show i => Show (SMayNat i n) -deriving instance Eq i => Eq (SMayNat i n) -deriving instance Ord i => Ord (SMayNat i n) - -instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where - rnf (SUnknown i) = rnf i - rnf (SKnown x) = rnf x - -instance TestEquality (SMayNat i) where - testEquality SUnknown{} SUnknown{} = Just Refl - testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl - testEquality _ _ = Nothing - -{-# INLINE fromSMayNat #-} -fromSMayNat :: (n ~ Nothing => i -> r) - -> (forall m. n ~ Just m => SNat m -> r) - -> SMayNat i n -> r -fromSMayNat f _ (SUnknown i) = f i -fromSMayNat _ g (SKnown s) = g s - -fromSMayNat' :: SMayNat Int n -> Int -fromSMayNat' = fromSMayNat id fromSNat' - -type family AddMaybe n m where - AddMaybe Nothing _ = Nothing - AddMaybe (Just _) Nothing = Nothing - AddMaybe (Just n) (Just m) = Just (n + m) - -smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m) -smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) -smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) -smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) - - -- | This is a newtype over 'ListH'. type role ShX nominal representational type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListH sh (SMayNat i)) +newtype ShX sh i = ShX (ListH sh i) deriving (Eq, Ord, Generic) pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i @@ -575,7 +548,7 @@ shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh shxFromSSX2 (SUnknown _ :!% _) = Nothing shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shxAppend = coerce (listhAppend @_ @(SMayNat i)) +shxAppend = coerce (listhAppend @_ @i) shxHead :: ShX (n : sh) i -> SMayNat i n shxHead (ShX list) = listhHead list @@ -584,20 +557,20 @@ shxTail :: ShX (n : sh) i -> ShX sh i shxTail (ShX list) = ShX (listhTail list) shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSSX = coerce (listhDrop @(SMayNat i) @(SMayNat ())) +shxDropSSX = coerce (listhDrop @i @()) shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i shxDropIx (IxX ZX) long = long shxDropIx (IxX (_ ::% short)) long = case long of ShX (_ ::# long') -> shxDropIx (IxX short) (ShX long') shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSh = coerce (listhDrop @(SMayNat i) @(SMayNat i)) +shxDropSh = coerce (listhDrop @i @i) shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i -shxInit = coerce (listhInit @(SMayNat i)) +shxInit = coerce (listhInit @i) shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh)) -shxLast = coerce (listhLast @(SMayNat i)) +shxLast = coerce (listhLast @i) shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i shxTakeSSX _ ZKX _ = ZSX @@ -654,7 +627,7 @@ shxCast' ssh sh = case shxCast ssh sh of -- | The part of a shape that is statically known. (A newtype over 'ListH'.) type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListH sh (SMayNat ())) +newtype StaticShX sh = StaticShX (ListH sh ()) deriving (Eq, Ord) pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh @@ -705,21 +678,25 @@ ssxHead (StaticShX list) = listhHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh -ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSSX = coerce (listhDrop @(SMayNat ()) @(SMayNat ())) +ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh +ssxTakeIx _ (IxX ZX) _ = ZKX +ssxTakeIx proxy (IxX (_ ::% long)) short = case short of StaticShX (i ::# short') -> i :!% ssxTakeIx proxy (IxX long) (StaticShX short') ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' ssxDropIx (IxX ZX) long = long ssxDropIx (IxX (_ ::% short)) long = case long of StaticShX (_ ::# long') -> ssxDropIx (IxX short) (StaticShX long') ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSh = coerce (listhDrop @(SMayNat ()) @(SMayNat i)) +ssxDropSh = coerce (listhDrop @() @i) + +ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSSX = coerce (listhDrop @() @()) ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) -ssxInit = coerce (listhInit @(SMayNat ())) +ssxInit = coerce (listhInit @()) ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh)) -ssxLast = coerce (listhLast @(SMayNat ())) +ssxLast = coerce (listhLast @()) ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX |
