aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-15 02:05:11 +0100
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-15 02:05:11 +0100
commit20fbbc417952a2740ba2e423581c4c481f61bc54 (patch)
tree5fe7bb0d2d388feb5dad314711510dd7464d30cf /src
parent38d043ad64c88e1403839bba915651843ab51503 (diff)
Inline SMayNat in ListH
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs181
-rw-r--r--src/Data/Array/Nested/Permutation.hs36
2 files changed, 97 insertions, 120 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index debe5ec..c8c9a7b 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -36,6 +36,7 @@ import Data.Functor.Const
import Data.Functor.Product
import Data.Kind (Constraint, Type)
import Data.Monoid (Sum(..))
+import Data.Proxy
import Data.Type.Equality
import GHC.Exts (Int(..), Int#, build, quotRemInt#, withDict)
import GHC.Generics (Generic)
@@ -298,36 +299,73 @@ ixxToLinear = \sh i -> go sh i 0
-- * Mixed shape-like lists to be used for ShX and StaticShX
+data SMayNat i n where
+ SUnknown :: i -> SMayNat i Nothing
+ SKnown :: {-# UNPACK #-} SNat n -> SMayNat i (Just n)
+deriving instance Show i => Show (SMayNat i n)
+deriving instance Eq i => Eq (SMayNat i n)
+deriving instance Ord i => Ord (SMayNat i n)
+
+instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where
+ rnf (SUnknown i) = rnf i
+ rnf (SKnown x) = rnf x
+
+instance TestEquality (SMayNat i) where
+ testEquality SUnknown{} SUnknown{} = Just Refl
+ testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl
+ testEquality _ _ = Nothing
+
+{-# INLINE fromSMayNat #-}
+fromSMayNat :: (n ~ Nothing => i -> r)
+ -> (forall m. n ~ Just m => SNat m -> r)
+ -> SMayNat i n -> r
+fromSMayNat f _ (SUnknown i) = f i
+fromSMayNat _ g (SKnown s) = g s
+
+fromSMayNat' :: SMayNat Int n -> Int
+fromSMayNat' = fromSMayNat id fromSNat'
+
+type family AddMaybe n m where
+ AddMaybe Nothing _ = Nothing
+ AddMaybe (Just _) Nothing = Nothing
+ AddMaybe (Just n) (Just m) = Just (n + m)
+
+smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m)
+smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
+smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
+smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
+
+
type role ListH nominal representational
-type ListH :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type
-data ListH sh f where
- ZH :: ListH '[] f
- (::#) :: forall n sh {f}. f n -> ListH sh f -> ListH (n : sh) f
-deriving instance (forall n. Eq (f n)) => Eq (ListH sh f)
-deriving instance (forall n. Ord (f n)) => Ord (ListH sh f)
+type ListH :: [Maybe Nat] -> Type -> Type
+data ListH sh i where
+ ZH :: ListH '[] i
+ (::#) :: forall n sh i. SMayNat i n -> ListH sh i -> ListH (n : sh) i
+deriving instance Eq i => Eq (ListH sh i)
+deriving instance Ord i => Ord (ListH sh i)
infixr 3 ::#
#ifdef OXAR_DEFAULT_SHOW_INSTANCES
-deriving instance (forall n. Show (f n)) => Show (ListH sh f)
+deriving instance Show i => Show (ListH sh i)
#else
-instance (forall n. Show (f n)) => Show (ListH sh f) where
+instance Show i => Show (ListH sh i) where
showsPrec _ = listhShow shows
#endif
-instance (forall n. NFData (f n)) => NFData (ListH sh f) where
+instance (forall n. NFData (SMayNat i n)) => NFData (ListH sh i) where
rnf ZH = ()
rnf (x ::# l) = rnf x `seq` rnf l
-data UnconsListHRes f sh1 =
- forall n sh. (n : sh ~ sh1) => UnconsListHRes (ListH sh f) (f n)
-listhUncons :: ListH sh1 f -> Maybe (UnconsListHRes f sh1)
+data UnconsListHRes i sh1 =
+ forall n sh. (n : sh ~ sh1) => UnconsListHRes (ListH sh i) (SMayNat i n)
+listhUncons :: ListH sh1 i -> Maybe (UnconsListHRes i sh1)
listhUncons (i ::# shl') = Just (UnconsListHRes shl' i)
listhUncons ZH = Nothing
-- | This checks only whether the types are equal; if the elements of the list
-- are not singletons, their values may still differ. This corresponds to
-- 'testEquality', except on the penultimate type parameter.
-listhEqType :: TestEquality f => ListH sh f -> ListH sh' f -> Maybe (sh :~: sh')
+listhEqType :: ListH sh i -> ListH sh' i -> Maybe (sh :~: sh')
listhEqType ZH ZH = Just Refl
listhEqType (n ::# sh) (m ::# sh')
| Just Refl <- testEquality n m
@@ -338,7 +376,7 @@ listhEqType _ _ = Nothing
-- | This checks whether the two lists actually contain equal values. This is
-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
-- in the @some@ package (except on the penultimate type parameter).
-listhEqual :: (TestEquality f, forall n. Eq (f n)) => ListH sh f -> ListH sh' f -> Maybe (sh :~: sh')
+listhEqual :: Eq i => ListH sh i -> ListH sh' i -> Maybe (sh :~: sh')
listhEqual ZH ZH = Just Refl
listhEqual (n ::# sh) (m ::# sh')
| Just Refl <- testEquality n m
@@ -348,123 +386,58 @@ listhEqual (n ::# sh) (m ::# sh')
listhEqual _ _ = Nothing
{-# INLINE listhFmap #-}
-listhFmap :: (forall n. f n -> g n) -> ListH sh f -> ListH sh g
+listhFmap :: (forall n. SMayNat i n -> SMayNat j n) -> ListH sh i -> ListH sh j
listhFmap _ ZH = ZH
listhFmap f (x ::# xs) = f x ::# listhFmap f xs
{-# INLINE listhFoldMap #-}
-listhFoldMap :: Monoid m => (forall n. f n -> m) -> ListH sh f -> m
+listhFoldMap :: Monoid m => (forall n. SMayNat i n -> m) -> ListH sh i -> m
listhFoldMap _ ZH = mempty
listhFoldMap f (x ::# xs) = f x <> listhFoldMap f xs
-listhLength :: ListH sh f -> Int
+listhLength :: ListH sh i -> Int
listhLength = getSum . listhFoldMap (\_ -> Sum 1)
-listhRank :: ListH sh f -> SNat (Rank sh)
+listhRank :: ListH sh i -> SNat (Rank sh)
listhRank ZH = SNat
listhRank (_ ::# l) | SNat <- listhRank l = SNat
{-# INLINE listhShow #-}
-listhShow :: forall sh f. (forall n. f n -> ShowS) -> ListH sh f -> ShowS
+listhShow :: forall sh i. (forall n. SMayNat i n -> ShowS) -> ListH sh i -> ShowS
listhShow f l = showString "[" . go "" l . showString "]"
where
- go :: String -> ListH sh' f -> ShowS
+ go :: String -> ListH sh' i -> ShowS
go _ ZH = id
go prefix (x ::# xs) = showString prefix . f x . go "," xs
-listhFromList :: StaticShX sh -> [i] -> ListH sh (Const i)
-listhFromList topssh topl = go topssh topl
- where
- go :: StaticShX sh' -> [i] -> ListH sh' (Const i)
- go ZKX [] = ZH
- go (_ :!% sh) (i : is) = Const i ::# go sh is
- go _ _ = error $ "listhFromList: Mismatched list length (type says "
- ++ show (ssxLength topssh) ++ ", list has length "
- ++ show (length topl) ++ ")"
-
-{-# INLINEABLE listhToList #-}
-listhToList :: ListH sh (Const i) -> [i]
-listhToList list = build (\(cons :: i -> is -> is) (nil :: is) ->
- let go :: ListH sh (Const i) -> is
- go ZH = nil
- go (Const i ::# is) = i `cons` go is
- in go list)
-
-listhHead :: ListH (mn ': sh) f -> f mn
+listhHead :: ListH (mn ': sh) i -> SMayNat i mn
listhHead (i ::# _) = i
listhTail :: ListH (n : sh) i -> ListH sh i
listhTail (_ ::# sh) = sh
-listhAppend :: ListH sh f -> ListH sh' f -> ListH (sh ++ sh') f
+listhAppend :: ListH sh i -> ListH sh' i -> ListH (sh ++ sh') i
listhAppend ZH idx' = idx'
listhAppend (i ::# idx) idx' = i ::# listhAppend idx idx'
-listhDrop :: forall f g sh sh'. ListH sh g -> ListH (sh ++ sh') f -> ListH sh' f
+listhDrop :: forall i j sh sh'. ListH sh j -> ListH (sh ++ sh') i -> ListH sh' i
listhDrop ZH long = long
listhDrop (_ ::# short) long = case long of _ ::# long' -> listhDrop short long'
-listhInit :: forall f n sh. ListH (n : sh) f -> ListH (Init (n : sh)) f
+listhInit :: forall i n sh. ListH (n : sh) i -> ListH (Init (n : sh)) i
listhInit (i ::# sh@(_ ::# _)) = i ::# listhInit sh
listhInit (_ ::# ZH) = ZH
-listhLast :: forall f n sh. ListH (n : sh) f -> f (Last (n : sh))
+listhLast :: forall i n sh. ListH (n : sh) i -> SMayNat i (Last (n : sh))
listhLast (_ ::# sh@(_ ::# _)) = listhLast sh
listhLast (x ::# ZH) = x
-listhZip :: ListH sh f -> ListH sh g -> ListH sh (Product f g)
-listhZip ZH ZH = ZH
-listhZip (i ::# irest) (j ::# jrest) = Pair i j ::# listhZip irest jrest
-
-{-# INLINE listhZipWith #-}
-listhZipWith :: (forall a. f a -> g a -> h a) -> ListH sh f -> ListH sh g
- -> ListH sh h
-listhZipWith _ ZH ZH = ZH
-listhZipWith f (i ::# is) (j ::# js) = f i j ::# listhZipWith f is js
-
-- * Mixed shapes
-data SMayNat i n where
- SUnknown :: i -> SMayNat i Nothing
- SKnown :: {-# UNPACK #-} SNat n -> SMayNat i (Just n)
-deriving instance Show i => Show (SMayNat i n)
-deriving instance Eq i => Eq (SMayNat i n)
-deriving instance Ord i => Ord (SMayNat i n)
-
-instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where
- rnf (SUnknown i) = rnf i
- rnf (SKnown x) = rnf x
-
-instance TestEquality (SMayNat i) where
- testEquality SUnknown{} SUnknown{} = Just Refl
- testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl
- testEquality _ _ = Nothing
-
-{-# INLINE fromSMayNat #-}
-fromSMayNat :: (n ~ Nothing => i -> r)
- -> (forall m. n ~ Just m => SNat m -> r)
- -> SMayNat i n -> r
-fromSMayNat f _ (SUnknown i) = f i
-fromSMayNat _ g (SKnown s) = g s
-
-fromSMayNat' :: SMayNat Int n -> Int
-fromSMayNat' = fromSMayNat id fromSNat'
-
-type family AddMaybe n m where
- AddMaybe Nothing _ = Nothing
- AddMaybe (Just _) Nothing = Nothing
- AddMaybe (Just n) (Just m) = Just (n + m)
-
-smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m)
-smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
-smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
-smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
-
-
-- | This is a newtype over 'ListH'.
type role ShX nominal representational
type ShX :: [Maybe Nat] -> Type -> Type
-newtype ShX sh i = ShX (ListH sh (SMayNat i))
+newtype ShX sh i = ShX (ListH sh i)
deriving (Eq, Ord, Generic)
pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
@@ -575,7 +548,7 @@ shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh
shxFromSSX2 (SUnknown _ :!% _) = Nothing
shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
-shxAppend = coerce (listhAppend @_ @(SMayNat i))
+shxAppend = coerce (listhAppend @_ @i)
shxHead :: ShX (n : sh) i -> SMayNat i n
shxHead (ShX list) = listhHead list
@@ -584,20 +557,20 @@ shxTail :: ShX (n : sh) i -> ShX sh i
shxTail (ShX list) = ShX (listhTail list)
shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i
-shxDropSSX = coerce (listhDrop @(SMayNat i) @(SMayNat ()))
+shxDropSSX = coerce (listhDrop @i @())
shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i
shxDropIx (IxX ZX) long = long
shxDropIx (IxX (_ ::% short)) long = case long of ShX (_ ::# long') -> shxDropIx (IxX short) (ShX long')
shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i
-shxDropSh = coerce (listhDrop @(SMayNat i) @(SMayNat i))
+shxDropSh = coerce (listhDrop @i @i)
shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
-shxInit = coerce (listhInit @(SMayNat i))
+shxInit = coerce (listhInit @i)
shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh))
-shxLast = coerce (listhLast @(SMayNat i))
+shxLast = coerce (listhLast @i)
shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i
shxTakeSSX _ ZKX _ = ZSX
@@ -654,7 +627,7 @@ shxCast' ssh sh = case shxCast ssh sh of
-- | The part of a shape that is statically known. (A newtype over 'ListH'.)
type StaticShX :: [Maybe Nat] -> Type
-newtype StaticShX sh = StaticShX (ListH sh (SMayNat ()))
+newtype StaticShX sh = StaticShX (ListH sh ())
deriving (Eq, Ord)
pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
@@ -705,21 +678,25 @@ ssxHead (StaticShX list) = listhHead list
ssxTail :: StaticShX (n : sh) -> StaticShX sh
ssxTail (_ :!% ssh) = ssh
-ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'
-ssxDropSSX = coerce (listhDrop @(SMayNat ()) @(SMayNat ()))
+ssxTakeIx :: forall sh sh' i. Proxy sh' -> IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh
+ssxTakeIx _ (IxX ZX) _ = ZKX
+ssxTakeIx proxy (IxX (_ ::% long)) short = case short of StaticShX (i ::# short') -> i :!% ssxTakeIx proxy (IxX long) (StaticShX short')
ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
ssxDropIx (IxX ZX) long = long
ssxDropIx (IxX (_ ::% short)) long = case long of StaticShX (_ ::# long') -> ssxDropIx (IxX short) (StaticShX long')
ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
-ssxDropSh = coerce (listhDrop @(SMayNat ()) @(SMayNat i))
+ssxDropSh = coerce (listhDrop @() @i)
+
+ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'
+ssxDropSSX = coerce (listhDrop @() @())
ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
-ssxInit = coerce (listhInit @(SMayNat ()))
+ssxInit = coerce (listhInit @())
ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh))
-ssxLast = coerce (listhLast @(SMayNat ()))
+ssxLast = coerce (listhLast @())
ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
ssxReplicate SZ = ZKX
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index e520e0f..f1485b3 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -172,60 +172,60 @@ type family DropLen ref l where
DropLen '[] l = l
DropLen (_ : ref) (_ : xs) = DropLen ref xs
-listhTakeLen :: forall f is sh. Perm is -> ListH sh f -> ListH (TakeLen is sh) f
+listhTakeLen :: forall i is sh. Perm is -> ListH sh i -> ListH (TakeLen is sh) i
listhTakeLen PNil _ = ZH
listhTakeLen (_ `PCons` is) (n ::# sh) = n ::# listhTakeLen is sh
listhTakeLen (_ `PCons` _) ZH = error "Permutation longer than shape"
-listhDropLen :: forall f is sh. Perm is -> ListH sh f -> ListH (DropLen is sh) f
+listhDropLen :: forall i is sh. Perm is -> ListH sh i -> ListH (DropLen is sh) i
listhDropLen PNil sh = sh
listhDropLen (_ `PCons` is) (_ ::# sh) = listhDropLen is sh
listhDropLen (_ `PCons` _) ZH = error "Permutation longer than shape"
-listhPermute :: forall f is sh. Perm is -> ListH sh f -> ListH (Permute is sh) f
+listhPermute :: forall i is sh. Perm is -> ListH sh i -> ListH (Permute is sh) i
listhPermute PNil _ = ZH
-listhPermute (i `PCons` (is :: Perm is')) (sh :: ListH sh f) =
+listhPermute (i `PCons` (is :: Perm is')) (sh :: ListH sh i) =
listhIndex (Proxy @is') (Proxy @sh) i sh ::# listhPermute is sh
-listhIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListH sh f -> f (Index i sh)
+listhIndex :: forall i is shT k sh. Proxy is -> Proxy shT -> SNat k -> ListH sh i -> SMayNat i (Index k sh)
listhIndex _ _ SZ (n ::# _) = n
-listhIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::# (sh :: ListH sh' f))
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+listhIndex p pT (SS (i :: SNat k')) ((_ :: SMayNat i n) ::# (sh :: ListH sh' i))
+ | Refl <- lemIndexSucc (Proxy @k') (Proxy @n) (Proxy @sh')
= listhIndex p pT i sh
listhIndex _ _ _ ZH = error "Index into empty shape"
-listhPermutePrefix :: forall f is sh. Perm is -> ListH sh f -> ListH (PermutePrefix is sh) f
+listhPermutePrefix :: forall i is sh. Perm is -> ListH sh i -> ListH (PermutePrefix is sh) i
listhPermutePrefix perm sh = listhAppend (listhPermute perm (listhTakeLen perm sh)) (listhDropLen perm sh)
ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
-ssxTakeLen = coerce (listhTakeLen @(SMayNat ()))
+ssxTakeLen = coerce (listhTakeLen @())
ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
-ssxDropLen = coerce (listhDropLen @(SMayNat ()))
+ssxDropLen = coerce (listhDropLen @())
ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
-ssxPermute = coerce (listhPermute @(SMayNat ()))
+ssxPermute = coerce (listhPermute @())
ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () (Index i sh)
-ssxIndex p1 p2 i = coerce (listhIndex @(SMayNat ()) p1 p2 i)
+ssxIndex p1 p2 i = coerce (listhIndex @() p1 p2 i)
ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
-ssxPermutePrefix = coerce (listhPermutePrefix @(SMayNat ()))
+ssxPermutePrefix = coerce (listhPermutePrefix @())
shxTakeLen :: forall is sh. Perm is -> IShX sh -> IShX (TakeLen is sh)
-shxTakeLen = coerce (listhTakeLen @(SMayNat Int))
+shxTakeLen = coerce (listhTakeLen @Int)
shxDropLen :: Perm is -> IShX sh -> IShX (DropLen is sh)
-shxDropLen = coerce (listhDropLen @(SMayNat Int))
+shxDropLen = coerce (listhDropLen @Int)
shxPermute :: Perm is -> IShX sh -> IShX (Permute is sh)
-shxPermute = coerce (listhPermute @(SMayNat Int))
+shxPermute = coerce (listhPermute @Int)
shxIndex :: Proxy is -> Proxy shT -> SNat i -> IShX sh -> SMayNat Int (Index i sh)
-shxIndex p1 p2 i = coerce (listhIndex @(SMayNat Int) p1 p2 i)
+shxIndex p1 p2 i = coerce (listhIndex @Int p1 p2 i)
shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
-shxPermutePrefix = coerce (listhPermutePrefix @(SMayNat Int))
+shxPermutePrefix = coerce (listhPermutePrefix @Int)
listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f