aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-14 16:55:45 +0100
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-14 20:42:00 +0100
commitb0cc8caff4ccf5df85f3bea743be1f03ddde01c6 (patch)
treed1e0489ccf1b28a8e67b834820836af9cf3c6a0e /src
parent87e656c5cfebdbd2966494e8ef3f5504d328232a (diff)
Fix f in SMayNat to always be SNat and UNPACK it
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Nested/Mixed.hs2
-rw-r--r--src/Data/Array/Nested/Mixed/Shape.hs73
-rw-r--r--src/Data/Array/Nested/Permutation.hs14
3 files changed, 45 insertions, 44 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
index 54b2a9f..eb05eaa 100644
--- a/src/Data/Array/Nested/Mixed.hs
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -816,7 +816,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
sn :$% sh = mshape arr1
sm :$% _ = mshape arr2
ssh = ssxFromShX sh
- snm :: SMayNat () SNat (AddMaybe n m)
+ snm :: SMayNat () (AddMaybe n m)
snm = case (sn, sm) of
(SUnknown{}, _) -> SUnknown ()
(SKnown{}, SUnknown{}) -> SUnknown ()
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs
index b1b4f81..3f4ee9a 100644
--- a/src/Data/Array/Nested/Mixed/Shape.hs
+++ b/src/Data/Array/Nested/Mixed/Shape.hs
@@ -2,6 +2,7 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImportQualifiedPost #-}
@@ -297,30 +298,30 @@ ixxToLinear = \sh i -> go sh i 0
-- * Mixed shapes
-data SMayNat i f n where
- SUnknown :: i -> SMayNat i f Nothing
- SKnown :: f n -> SMayNat i f (Just n)
-deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n)
-deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n)
-deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n)
+data SMayNat i n where
+ SUnknown :: i -> SMayNat i Nothing
+ SKnown :: {-# UNPACK #-} SNat n -> SMayNat i (Just n)
+deriving instance Show i => Show (SMayNat i n)
+deriving instance Eq i => Eq (SMayNat i n)
+deriving instance Ord i => Ord (SMayNat i n)
-instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where
+instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where
rnf (SUnknown i) = rnf i
rnf (SKnown x) = rnf x
-instance TestEquality f => TestEquality (SMayNat i f) where
+instance TestEquality (SMayNat i) where
testEquality SUnknown{} SUnknown{} = Just Refl
testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl
testEquality _ _ = Nothing
{-# INLINE fromSMayNat #-}
fromSMayNat :: (n ~ Nothing => i -> r)
- -> (forall m. n ~ Just m => f m -> r)
- -> SMayNat i f n -> r
+ -> (forall m. n ~ Just m => SNat m -> r)
+ -> SMayNat i n -> r
fromSMayNat f _ (SUnknown i) = f i
fromSMayNat _ g (SKnown s) = g s
-fromSMayNat' :: SMayNat Int SNat n -> Int
+fromSMayNat' :: SMayNat Int n -> Int
fromSMayNat' = fromSMayNat id fromSNat'
type family AddMaybe n m where
@@ -328,7 +329,7 @@ type family AddMaybe n m where
AddMaybe (Just _) Nothing = Nothing
AddMaybe (Just n) (Just m) = Just (n + m)
-smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m)
+smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m)
smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m)
smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m)
smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
@@ -337,7 +338,7 @@ smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m)
-- | This is a newtype over 'ListX'.
type role ShX nominal representational
type ShX :: [Maybe Nat] -> Type -> Type
-newtype ShX sh i = ShX (ListX sh (SMayNat i SNat))
+newtype ShX sh i = ShX (ListX sh (SMayNat i))
deriving (Eq, Ord, Generic)
pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i
@@ -346,7 +347,7 @@ pattern ZSX = ShX ZX
pattern (:$%)
:: forall {sh1} {i}.
forall n sh. (n : sh ~ sh1)
- => SMayNat i SNat n -> ShX sh i -> ShX sh1 i
+ => SMayNat i n -> ShX sh i -> ShX sh1 i
pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i))
where i :$% ShX shl = ShX (i ::% shl)
infixr 3 :$%
@@ -447,35 +448,35 @@ shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh
shxFromSSX2 (SUnknown _ :!% _) = Nothing
shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i
-shxAppend = coerce (listxAppend @_ @(SMayNat i SNat))
+shxAppend = coerce (listxAppend @_ @(SMayNat i))
-shxHead :: ShX (n : sh) i -> SMayNat i SNat n
+shxHead :: ShX (n : sh) i -> SMayNat i n
shxHead (ShX list) = listxHead list
shxTail :: ShX (n : sh) i -> ShX sh i
shxTail (ShX list) = ShX (listxTail list)
shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i
-shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat))
+shxDropSSX = coerce (listxDrop @(SMayNat i) @(SMayNat ()))
shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i
-shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j))
+shxDropIx = coerce (listxDrop @(SMayNat i) @(Const j))
shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i
-shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat))
+shxDropSh = coerce (listxDrop @(SMayNat i) @(SMayNat i))
shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i
-shxInit = coerce (listxInit @(SMayNat i SNat))
+shxInit = coerce (listxInit @(SMayNat i))
-shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh))
-shxLast = coerce (listxLast @(SMayNat i SNat))
+shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh))
+shxLast = coerce (listxLast @(SMayNat i))
shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i
shxTakeSSX _ ZKX _ = ZSX
shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh
{-# INLINE shxZipWith #-}
-shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n)
+shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n)
-> ShX sh i -> ShX sh j -> ShX sh k
shxZipWith _ ZSX ZSX = ZSX
shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js
@@ -525,7 +526,7 @@ shxCast' ssh sh = case shxCast ssh sh of
-- | The part of a shape that is statically known. (A newtype over 'ListX'.)
type StaticShX :: [Maybe Nat] -> Type
-newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat))
+newtype StaticShX sh = StaticShX (ListX sh (SMayNat ()))
deriving (Eq, Ord)
pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh
@@ -534,7 +535,7 @@ pattern ZKX = StaticShX ZX
pattern (:!%)
:: forall {sh1}.
forall n sh. (n : sh ~ sh1)
- => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1
+ => SMayNat () n -> StaticShX sh -> StaticShX sh1
pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i))
where i :!% StaticShX shl = StaticShX (i ::% shl)
infixr 3 :!%
@@ -570,26 +571,26 @@ ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
ssxAppend ZKX sh' = sh'
ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh'
-ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n
+ssxHead :: StaticShX (n : sh) -> SMayNat () n
ssxHead (StaticShX list) = listxHead list
ssxTail :: StaticShX (n : sh) -> StaticShX sh
ssxTail (_ :!% ssh) = ssh
ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh'
-ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat))
+ssxDropSSX = coerce (listxDrop @(SMayNat ()) @(SMayNat ()))
ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
-ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i))
+ssxDropIx = coerce (listxDrop @(SMayNat ()) @(Const i))
ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh'
-ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat))
+ssxDropSh = coerce (listxDrop @(SMayNat ()) @(SMayNat i))
ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh))
-ssxInit = coerce (listxInit @(SMayNat () SNat))
+ssxInit = coerce (listxInit @(SMayNat ()))
-ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh))
-ssxLast = coerce (listxLast @(SMayNat () SNat))
+ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh))
+ssxLast = coerce (listxLast @(SMayNat ()))
ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
ssxReplicate SZ = ZKX
@@ -632,18 +633,18 @@ type family Flatten' acc sh where
Flatten' acc (Just n : sh) = Flatten' (acc * n) sh
-- This function is currently unused
-ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh)
+ssxFlatten :: StaticShX sh -> SMayNat () (Flatten sh)
ssxFlatten = go (SNat @1)
where
- go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh)
+ go :: SNat acc -> StaticShX sh -> SMayNat () (Flatten' acc sh)
go acc ZKX = SKnown acc
go _ (SUnknown () :!% _) = SUnknown ()
go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh
-shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh)
+shxFlatten :: IShX sh -> SMayNat Int (Flatten sh)
shxFlatten = go (SNat @1)
where
- go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh)
+ go :: SNat acc -> IShX sh -> SMayNat Int (Flatten' acc sh)
go acc ZSX = SKnown acc
go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh)
go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh
diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs
index 065c9fd..6bebcfb 100644
--- a/src/Data/Array/Nested/Permutation.hs
+++ b/src/Data/Array/Nested/Permutation.hs
@@ -201,22 +201,22 @@ ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is
ixxPermutePrefix = coerce (listxPermutePrefix @(Const i))
ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh)
-ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat))
+ssxTakeLen = coerce (listxTakeLen @(SMayNat ()))
ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh)
-ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
+ssxDropLen = coerce (listxDropLen @(SMayNat ()))
ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
-ssxPermute = coerce (listxPermute @(SMayNat () SNat))
+ssxPermute = coerce (listxPermute @(SMayNat ()))
-ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)
-ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i)
+ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () (Index i sh)
+ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat ()) p1 p2 i)
ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
-ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat))
+ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat ()))
shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh)
-shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat))
+shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int))
-- * Operations on permutations