diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-14 16:55:45 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-14 20:42:00 +0100 |
| commit | b0cc8caff4ccf5df85f3bea743be1f03ddde01c6 (patch) | |
| tree | d1e0489ccf1b28a8e67b834820836af9cf3c6a0e /src | |
| parent | 87e656c5cfebdbd2966494e8ef3f5504d328232a (diff) | |
Fix f in SMayNat to always be SNat and UNPACK it
Diffstat (limited to 'src')
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 2 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 73 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 14 |
3 files changed, 45 insertions, 44 deletions
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index 54b2a9f..eb05eaa 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -816,7 +816,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 sn :$% sh = mshape arr1 sm :$% _ = mshape arr2 ssh = ssxFromShX sh - snm :: SMayNat () SNat (AddMaybe n m) + snm :: SMayNat () (AddMaybe n m) snm = case (sn, sm) of (SUnknown{}, _) -> SUnknown () (SKnown{}, SUnknown{}) -> SUnknown () diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index b1b4f81..3f4ee9a 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -2,6 +2,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} @@ -297,30 +298,30 @@ ixxToLinear = \sh i -> go sh i 0 -- * Mixed shapes -data SMayNat i f n where - SUnknown :: i -> SMayNat i f Nothing - SKnown :: f n -> SMayNat i f (Just n) -deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) -deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) -deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) +data SMayNat i n where + SUnknown :: i -> SMayNat i Nothing + SKnown :: {-# UNPACK #-} SNat n -> SMayNat i (Just n) +deriving instance Show i => Show (SMayNat i n) +deriving instance Eq i => Eq (SMayNat i n) +deriving instance Ord i => Ord (SMayNat i n) -instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where +instance (NFData i, forall m. NFData (SNat m)) => NFData (SMayNat i n) where rnf (SUnknown i) = rnf i rnf (SKnown x) = rnf x -instance TestEquality f => TestEquality (SMayNat i f) where +instance TestEquality (SMayNat i) where testEquality SUnknown{} SUnknown{} = Just Refl testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl testEquality _ _ = Nothing {-# INLINE fromSMayNat #-} fromSMayNat :: (n ~ Nothing => i -> r) - -> (forall m. n ~ Just m => f m -> r) - -> SMayNat i f n -> r + -> (forall m. n ~ Just m => SNat m -> r) + -> SMayNat i n -> r fromSMayNat f _ (SUnknown i) = f i fromSMayNat _ g (SKnown s) = g s -fromSMayNat' :: SMayNat Int SNat n -> Int +fromSMayNat' :: SMayNat Int n -> Int fromSMayNat' = fromSMayNat id fromSNat' type family AddMaybe n m where @@ -328,7 +329,7 @@ type family AddMaybe n m where AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) -smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe :: SMayNat Int n -> SMayNat Int m -> SMayNat Int (AddMaybe n m) smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) @@ -337,7 +338,7 @@ smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) -- | This is a newtype over 'ListX'. type role ShX nominal representational type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) +newtype ShX sh i = ShX (ListX sh (SMayNat i)) deriving (Eq, Ord, Generic) pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i @@ -346,7 +347,7 @@ pattern ZSX = ShX ZX pattern (:$%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) - => SMayNat i SNat n -> ShX sh i -> ShX sh1 i + => SMayNat i n -> ShX sh i -> ShX sh1 i pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) where i :$% ShX shl = ShX (i ::% shl) infixr 3 :$% @@ -447,35 +448,35 @@ shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh shxFromSSX2 (SUnknown _ :!% _) = Nothing shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) +shxAppend = coerce (listxAppend @_ @(SMayNat i)) -shxHead :: ShX (n : sh) i -> SMayNat i SNat n +shxHead :: ShX (n : sh) i -> SMayNat i n shxHead (ShX list) = listxHead list shxTail :: ShX (n : sh) i -> ShX sh i shxTail (ShX list) = ShX (listxTail list) shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) +shxDropSSX = coerce (listxDrop @(SMayNat i) @(SMayNat ())) shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i -shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) +shxDropIx = coerce (listxDrop @(SMayNat i) @(Const j)) shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) +shxDropSh = coerce (listxDrop @(SMayNat i) @(SMayNat i)) shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i -shxInit = coerce (listxInit @(SMayNat i SNat)) +shxInit = coerce (listxInit @(SMayNat i)) -shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) -shxLast = coerce (listxLast @(SMayNat i SNat)) +shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh)) +shxLast = coerce (listxLast @(SMayNat i)) shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i shxTakeSSX _ ZKX _ = ZSX shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh {-# INLINE shxZipWith #-} -shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) +shxZipWith :: (forall n. SMayNat i n -> SMayNat j n -> SMayNat k n) -> ShX sh i -> ShX sh j -> ShX sh k shxZipWith _ ZSX ZSX = ZSX shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js @@ -525,7 +526,7 @@ shxCast' ssh sh = case shxCast ssh sh of -- | The part of a shape that is statically known. (A newtype over 'ListX'.) type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) +newtype StaticShX sh = StaticShX (ListX sh (SMayNat ())) deriving (Eq, Ord) pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh @@ -534,7 +535,7 @@ pattern ZKX = StaticShX ZX pattern (:!%) :: forall {sh1}. forall n sh. (n : sh ~ sh1) - => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 + => SMayNat () n -> StaticShX sh -> StaticShX sh1 pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) where i :!% StaticShX shl = StaticShX (i ::% shl) infixr 3 :!% @@ -570,26 +571,26 @@ ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend ZKX sh' = sh' ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' -ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n +ssxHead :: StaticShX (n : sh) -> SMayNat () n ssxHead (StaticShX list) = listxHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat)) +ssxDropSSX = coerce (listxDrop @(SMayNat ()) @(SMayNat ())) ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) +ssxDropIx = coerce (listxDrop @(SMayNat ()) @(Const i)) ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat)) +ssxDropSh = coerce (listxDrop @(SMayNat ()) @(SMayNat i)) ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) -ssxInit = coerce (listxInit @(SMayNat () SNat)) +ssxInit = coerce (listxInit @(SMayNat ())) -ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) -ssxLast = coerce (listxLast @(SMayNat () SNat)) +ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh)) +ssxLast = coerce (listxLast @(SMayNat ())) ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX @@ -632,18 +633,18 @@ type family Flatten' acc sh where Flatten' acc (Just n : sh) = Flatten' (acc * n) sh -- This function is currently unused -ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) +ssxFlatten :: StaticShX sh -> SMayNat () (Flatten sh) ssxFlatten = go (SNat @1) where - go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) + go :: SNat acc -> StaticShX sh -> SMayNat () (Flatten' acc sh) go acc ZKX = SKnown acc go _ (SUnknown () :!% _) = SUnknown () go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh -shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) +shxFlatten :: IShX sh -> SMayNat Int (Flatten sh) shxFlatten = go (SNat @1) where - go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go :: SNat acc -> IShX sh -> SMayNat Int (Flatten' acc sh) go acc ZSX = SKnown acc go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index 065c9fd..6bebcfb 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -201,22 +201,22 @@ ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) +ssxTakeLen = coerce (listxTakeLen @(SMayNat ())) ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) +ssxDropLen = coerce (listxDropLen @(SMayNat ())) ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute = coerce (listxPermute @(SMayNat () SNat)) +ssxPermute = coerce (listxPermute @(SMayNat ())) -ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) -ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat () SNat) p1 p2 i) +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () (Index i sh) +ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat ()) p1 p2 i) ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) -ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) +ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat ())) shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) -shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) +shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int)) -- * Operations on permutations |
