diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-15 02:05:11 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-15 02:05:11 +0100 |
| commit | 20fbbc417952a2740ba2e423581c4c481f61bc54 (patch) | |
| tree | 5fe7bb0d2d388feb5dad314711510dd7464d30cf /src | |
| parent | 38d043ad64c88e1403839bba915651843ab51503 (diff) | |
Inline SMayNat in ListH
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 181 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 36 |
2 files changed, 97 insertions, 120 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index debe5ec..c8c9a7b 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -36,6 +36,7 @@ import Data.Functor.Const import Data.Functor.Product import Data.Kind (Constraint, Type) import Data.Monoid (Sum(..)) +import Data.Proxy import Data.Type.Equality import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict) import GHC.Generics (Generic) @@ -298,36 +299,73 @@ ixxToLinear = \sh i -> go sh i 0 -- * Mixed shape-like lists to be used for ShX and StaticShX +data SMayNat i n where + SUnknown :: i -> SMayNat i Nothing + SKnown :: {-# UNPACK #-} SNat n -> SMayNat i (Just n) +deriving instance Show i => Show (SMayNat i n) +deriving instance Eq i => Eq (SMayNat i n) +deriving instance Ord i => Ord (SMayNat i n) + +instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where + rnf (SUnknown i) = rnf i + rnf (SKnown x) = rnf x + +instance TestEquality (SMayNat i) where + testEquality SUnknown{} SUnknown{} = Just Refl + testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl + testEquality _ _ = Nothing + +{-# INLINE fromSMayNat #-} +fromSMayNat :: (n ~ Nothing => i -> r) + -> (forall m. n ~ Just m => SNat m -> r) + -> SMayNat i n -> r +fromSMayNat f _ (SUnknown i) = f i +fromSMayNat _ g (SKnown s) = g s + +fromSMayNat' :: SMayNat Int n -> Int +fromSMayNat' = fromSMayNat id fromSNat' + +type family AddMaybe n m where + AddMaybe Nothing _ = Nothing + AddMaybe (Just _) Nothing = Nothing + AddMaybe (Just n) (Just m) = Just (n + m) + +smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m) +smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) +smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) +smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) + + type role ListH nominal representational -type ListH :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type -data ListH sh f where - ZH :: ListH '[] f - (::#) :: forall n sh {f}. f n -> ListH sh f -> ListH (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListH sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListH sh f) +type ListH :: [Maybe Nat] -> Type -> Type +data ListH sh i where + ZH :: ListH '[] i + (::#) :: forall n sh i. SMayNat i n -> ListH sh i -> ListH (n : sh) i +deriving instance Eq i => Eq (ListH sh i) +deriving instance Ord i => Ord (ListH sh i) infixr 3 ::# #ifdef OXAR_DEFAULT_SHOW_INSTANCES -deriving instance (forall n. Show (f n)) => Show (ListH sh f) +deriving instance Show i => Show (ListH sh i) #else -instance (forall n. Show (f n)) => Show (ListH sh f) where +instance Show i => Show (ListH sh i) where showsPrec _ = listhShow shows #endif -instance (forall n. NFData (f n)) => NFData (ListH sh f) where +instance (forall n. NFData (SMayNat i n)) => NFData (ListH sh i) where rnf ZH = () rnf (x ::# l) = rnf x `seq` rnf l -data UnconsListHRes f sh1 = - forall n sh. (n : sh ~ sh1) => UnconsListHRes (ListH sh f) (f n) -listhUncons :: ListH sh1 f -> Maybe (UnconsListHRes f sh1) +data UnconsListHRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListHRes (ListH sh i) (SMayNat i n) +listhUncons :: ListH sh1 i -> Maybe (UnconsListHRes i sh1) listhUncons (i ::# shl') = Just (UnconsListHRes shl' i) listhUncons ZH = Nothing -- | This checks only whether the types are equal; if the elements of the list -- are not singletons, their values may still differ. This corresponds to -- 'testEquality', except on the penultimate type parameter. -listhEqType :: TestEquality f => ListH sh f -> ListH sh' f -> Maybe (sh :~: sh') +listhEqType :: ListH sh i -> ListH sh' i -> Maybe (sh :~: sh') listhEqType ZH ZH = Just Refl listhEqType (n ::# sh) (m ::# sh') | Just Refl <- testEquality n m @@ -338,7 +376,7 @@ listhEqType _ _ = Nothing -- | This checks whether the two lists actually contain equal values. This is -- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ -- in the @some@ package (except on the penultimate type parameter). -listhEqual :: (TestEquality f, forall n. Eq (f n)) => ListH sh f -> ListH sh' f -> Maybe (sh :~: sh') +listhEqual :: Eq i => ListH sh i -> ListH sh' i -> Maybe (sh :~: sh') listhEqual ZH ZH = Just Refl listhEqual (n ::# sh) (m ::# sh') | Just Refl <- testEquality n m @@ -348,123 +386,58 @@ listhEqual (n ::# sh) (m ::# sh') listhEqual _ _ = Nothing {-# INLINE listhFmap #-} -listhFmap :: (forall n. f n -> g n) -> ListH sh f -> ListH sh g +listhFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ListH sh i -> ListH sh j listhFmap _ ZH = ZH listhFmap f (x ::# xs) = f x ::# listhFmap f xs {-# INLINE listhFoldMap #-} -listhFoldMap :: Monoid m => (forall n. f n -> m) -> ListH sh f -> m +listhFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ListH sh i -> m listhFoldMap _ ZH = mempty listhFoldMap f (x ::# xs) = f x <> listhFoldMap f xs -listhLength :: ListH sh f -> Int +listhLength :: ListH sh i -> Int listhLength = getSum . listhFoldMap (\_ -> Sum 1) -listhRank :: ListH sh f -> SNat (Rank sh) +listhRank :: ListH sh i -> SNat (Rank sh) listhRank ZH = SNat listhRank (_ ::# l) | SNat <- listhRank l = SNat {-# INLINE listhShow #-} -listhShow :: forall sh f. (forall n. f n -> ShowS) -> ListH sh f -> ShowS +listhShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ListH sh i -> ShowS listhShow f l = showString "[" . go "" l . showString "]" where - go :: String -> ListH sh' f -> ShowS + go :: String -> ListH sh' i -> ShowS go _ ZH = id go prefix (x ::# xs) = showString prefix . f x . go "," xs -listhFromList :: StaticShX sh -> [i] -> ListH sh (Const i) -listhFromList topssh topl = go topssh topl - where - go :: StaticShX sh' -> [i] -> ListH sh' (Const i) - go ZKX [] = ZH - go (_ :!% sh) (i : is) = Const i ::# go sh is - go _ _ = error $ "listhFromList: Mismatched list length (type says " - ++ show (ssxLength topssh) ++ ", list has length " - ++ show (length topl) ++ ")" - -{-# INLINEABLE listhToList #-} -listhToList :: ListH sh (Const i) -> [i] -listhToList list = build (\(cons :: i -> is -> is) (nil :: is) -> - let go :: ListH sh (Const i) -> is - go ZH = nil - go (Const i ::# is) = i `cons` go is - in go list) - -listhHead :: ListH (mn ': sh) f -> f mn +listhHead :: ListH (mn ': sh) i -> SMayNat i mn listhHead (i ::# _) = i listhTail :: ListH (n : sh) i -> ListH sh i listhTail (_ ::# sh) = sh -listhAppend :: ListH sh f -> ListH sh' f -> ListH (sh ++ sh') f +listhAppend :: ListH sh i -> ListH sh' i -> ListH (sh ++ sh') i listhAppend ZH idx' = idx' listhAppend (i ::# idx) idx' = i ::# listhAppend idx idx' -listhDrop :: forall f g sh sh'. ListH sh g -> ListH (sh ++ sh') f -> ListH sh' f +listhDrop :: forall i j sh sh'. ListH sh j -> ListH (sh ++ sh') i -> ListH sh' i listhDrop ZH long = long listhDrop (_ ::# short) long = case long of _ ::# long' -> listhDrop short long' -listhInit :: forall f n sh. ListH (n : sh) f -> ListH (Init (n : sh)) f +listhInit :: forall i n sh. ListH (n : sh) i -> ListH (Init (n : sh)) i listhInit (i ::# sh@(_ ::# _)) = i ::# listhInit sh listhInit (_ ::# ZH) = ZH -listhLast :: forall f n sh. ListH (n : sh) f -> f (Last (n : sh)) +listhLast :: forall i n sh. ListH (n : sh) i -> SMayNat i (Last (n : sh)) listhLast (_ ::# sh@(_ ::# _)) = listhLast sh listhLast (x ::# ZH) = x -listhZip :: ListH sh f -> ListH sh g -> ListH sh (Product f g) -listhZip ZH ZH = ZH -listhZip (i ::# irest) (j ::# jrest) = Pair i j ::# listhZip irest jrest - -{-# INLINE listhZipWith #-} -listhZipWith :: (forall a. f a -> g a -> h a) -> ListH sh f -> ListH sh g - -> ListH sh h -listhZipWith _ ZH ZH = ZH -listhZipWith f (i ::# is) (j ::# js) = f i j ::# listhZipWith f is js - -- * Mixed shapes -data SMayNat i n where - SUnknown :: i -> SMayNat i Nothing - SKnown :: {-# UNPACK #-} SNat n -> SMayNat i (Just n) -deriving instance Show i => Show (SMayNat i n) -deriving instance Eq i => Eq (SMayNat i n) -deriving instance Ord i => Ord (SMayNat i n) - -instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where - rnf (SUnknown i) = rnf i - rnf (SKnown x) = rnf x - -instance TestEquality (SMayNat i) where - testEquality SUnknown{} SUnknown{} = Just Refl - testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl - testEquality _ _ = Nothing - -{-# INLINE fromSMayNat #-} -fromSMayNat :: (n ~ Nothing => i -> r) - -> (forall m. n ~ Just m => SNat m -> r) - -> SMayNat i n -> r -fromSMayNat f _ (SUnknown i) = f i -fromSMayNat _ g (SKnown s) = g s - -fromSMayNat' :: SMayNat Int n -> Int -fromSMayNat' = fromSMayNat id fromSNat' - -type family AddMaybe n m where - AddMaybe Nothing _ = Nothing - AddMaybe (Just _) Nothing = Nothing - AddMaybe (Just n) (Just m) = Just (n + m) - -smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m) -smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) -smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) -smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) - - -- | This is a newtype over 'ListH'. type role ShX nominal representational type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListH sh (SMayNat i)) +newtype ShX sh i = ShX (ListH sh i) deriving (Eq, Ord, Generic) pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i @@ -575,7 +548,7 @@ shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh shxFromSSX2 (SUnknown _ :!% _) = Nothing shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shxAppend = coerce (listhAppend @_ @(SMayNat i)) +shxAppend = coerce (listhAppend @_ @i) shxHead :: ShX (n : sh) i -> SMayNat i n shxHead (ShX list) = listhHead list @@ -584,20 +557,20 @@ shxTail :: ShX (n : sh) i -> ShX sh i shxTail (ShX list) = ShX (listhTail list) shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSSX = coerce (listhDrop @(SMayNat i) @(SMayNat ())) +shxDropSSX = coerce (listhDrop @i @()) shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i shxDropIx (IxX ZX) long = long shxDropIx (IxX (_ ::% short)) long = case long of ShX (_ ::# long') -> shxDropIx (IxX short) (ShX long') shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSh = coerce (listhDrop @(SMayNat i) @(SMayNat i)) +shxDropSh = coerce (listhDrop @i @i) shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i -shxInit = coerce (listhInit @(SMayNat i)) +shxInit = coerce (listhInit @i) shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh)) -shxLast = coerce (listhLast @(SMayNat i)) +shxLast = coerce (listhLast @i) shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i shxTakeSSX _ ZKX _ = ZSX @@ -654,7 +627,7 @@ shxCast' ssh sh = case shxCast ssh sh of -- | The part of a shape that is statically known. (A newtype over 'ListH'.) type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListH sh (SMayNat ())) +newtype StaticShX sh = StaticShX (ListH sh ()) deriving (Eq, Ord) pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh @@ -705,21 +678,25 @@ ssxHead (StaticShX list) = listhHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh -ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSSX = coerce (listhDrop @(SMayNat ()) @(SMayNat ())) +ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh +ssxTakeIx _ (IxX ZX) _ = ZKX +ssxTakeIx proxy (IxX (_ ::% long)) short = case short of StaticShX (i ::# short') -> i :!% ssxTakeIx proxy (IxX long) (StaticShX short') ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' ssxDropIx (IxX ZX) long = long ssxDropIx (IxX (_ ::% short)) long = case long of StaticShX (_ ::# long') -> ssxDropIx (IxX short) (StaticShX long') ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSh = coerce (listhDrop @(SMayNat ()) @(SMayNat i)) +ssxDropSh = coerce (listhDrop @() @i) + +ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSSX = coerce (listhDrop @() @()) ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) -ssxInit = coerce (listhInit @(SMayNat ())) +ssxInit = coerce (listhInit @()) ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh)) -ssxLast = coerce (listhLast @(SMayNat ())) +ssxLast = coerce (listhLast @()) ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index e520e0f..f1485b3 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -172,60 +172,60 @@ type family DropLen ref l where DropLen '[] l = l DropLen (_ : ref) (_ : xs) = DropLen ref xs -listhTakeLen :: forall f is sh. Perm is -> ListH sh f -> ListH (TakeLen is sh) f +listhTakeLen :: forall i is sh. Perm is -> ListH sh i -> ListH (TakeLen is sh) i listhTakeLen PNil _ = ZH listhTakeLen (_ `PCons` is) (n ::# sh) = n ::# listhTakeLen is sh listhTakeLen (_ `PCons` _) ZH = error "Permutation longer than shape" -listhDropLen :: forall f is sh. Perm is -> ListH sh f -> ListH (DropLen is sh) f +listhDropLen :: forall i is sh. Perm is -> ListH sh i -> ListH (DropLen is sh) i listhDropLen PNil sh = sh listhDropLen (_ `PCons` is) (_ ::# sh) = listhDropLen is sh listhDropLen (_ `PCons` _) ZH = error "Permutation longer than shape" -listhPermute :: forall f is sh. Perm is -> ListH sh f -> ListH (Permute is sh) f +listhPermute :: forall i is sh. Perm is -> ListH sh i -> ListH (Permute is sh) i listhPermute PNil _ = ZH -listhPermute (i `PCons` (is :: Perm is')) (sh :: ListH sh f) = +listhPermute (i `PCons` (is :: Perm is')) (sh :: ListH sh i) = listhIndex (Proxy @is') (Proxy @sh) i sh ::# listhPermute is sh -listhIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListH sh f -> f (Index i sh) +listhIndex :: forall i is shT k sh. Proxy is -> Proxy shT -> SNat k -> ListH sh i -> SMayNat i (Index k sh) listhIndex _ _ SZ (n ::# _) = n -listhIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::# (sh :: ListH sh' f)) - | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') +listhIndex p pT (SS (i :: SNat k')) ((_ :: SMayNat i n) ::# (sh :: ListH sh' i)) + | Refl <- lemIndexSucc (Proxy @k') (Proxy @n) (Proxy @sh') = listhIndex p pT i sh listhIndex _ _ _ ZH = error "Index into empty shape" -listhPermutePrefix :: forall f is sh. Perm is -> ListH sh f -> ListH (PermutePrefix is sh) f +listhPermutePrefix :: forall i is sh. Perm is -> ListH sh i -> ListH (PermutePrefix is sh) i listhPermutePrefix perm sh = listhAppend (listhPermute perm (listhTakeLen perm sh)) (listhDropLen perm sh) ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen = coerce (listhTakeLen @(SMayNat ())) +ssxTakeLen = coerce (listhTakeLen @()) ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen = coerce (listhDropLen @(SMayNat ())) +ssxDropLen = coerce (listhDropLen @()) ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute = coerce (listhPermute @(SMayNat ())) +ssxPermute = coerce (listhPermute @()) ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () (Index i sh) -ssxIndex p1 p2 i = coerce (listhIndex @(SMayNat ()) p1 p2 i) +ssxIndex p1 p2 i = coerce (listhIndex @() p1 p2 i) ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) -ssxPermutePrefix = coerce (listhPermutePrefix @(SMayNat ())) +ssxPermutePrefix = coerce (listhPermutePrefix @()) shxTakeLen :: forall is sh. Perm is -> IShX sh -> IShX (TakeLen is sh) -shxTakeLen = coerce (listhTakeLen @(SMayNat Int)) +shxTakeLen = coerce (listhTakeLen @Int) shxDropLen :: Perm is -> IShX sh -> IShX (DropLen is sh) -shxDropLen = coerce (listhDropLen @(SMayNat Int)) +shxDropLen = coerce (listhDropLen @Int) shxPermute :: Perm is -> IShX sh -> IShX (Permute is sh) -shxPermute = coerce (listhPermute @(SMayNat Int)) +shxPermute = coerce (listhPermute @Int) shxIndex :: Proxy is -> Proxy shT -> SNat i -> IShX sh -> SMayNat Int (Index i sh) -shxIndex p1 p2 i = coerce (listhIndex @(SMayNat Int) p1 p2 i) +shxIndex p1 p2 i = coerce (listhIndex @Int p1 p2 i) shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) -shxPermutePrefix = coerce (listhPermutePrefix @(SMayNat Int)) +shxPermutePrefix = coerce (listhPermutePrefix @Int) listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f |
