diff options
| -rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 206 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 86 | ||||
| -rw-r--r-- | src/Data/Array/XArray.hs | 2 |
3 files changed, 224 insertions, 70 deletions
diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs index f08b8be..debe5ec 100644 --- a/src/Data/Array/Nested/Mixed/Shape.hs +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -57,7 +57,7 @@ type family Rank sh where Rank (_ : sh) = Rank sh + 1 --- * Mixed lists +-- * Mixed lists to be used IxX and shaped and ranked lists and indexes type role ListX nominal representational type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type @@ -296,6 +296,132 @@ ixxToLinear = \sh i -> go sh i 0 go (n :$% sh) (i :.% ix) a = go sh ix (fromIntegral (fromSMayNat' n) * a + i) +-- * Mixed shape-like lists to be used for ShX and StaticShX + +type role ListH nominal representational +type ListH :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type +data ListH sh f where + ZH :: ListH '[] f + (::#) :: forall n sh {f}. f n -> ListH sh f -> ListH (n : sh) f +deriving instance (forall n. Eq (f n)) => Eq (ListH sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListH sh f) +infixr 3 ::# + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance (forall n. Show (f n)) => Show (ListH sh f) +#else +instance (forall n. Show (f n)) => Show (ListH sh f) where + showsPrec _ = listhShow shows +#endif + +instance (forall n. NFData (f n)) => NFData (ListH sh f) where + rnf ZH = () + rnf (x ::# l) = rnf x `seq` rnf l + +data UnconsListHRes f sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListHRes (ListH sh f) (f n) +listhUncons :: ListH sh1 f -> Maybe (UnconsListHRes f sh1) +listhUncons (i ::# shl') = Just (UnconsListHRes shl' i) +listhUncons ZH = Nothing + +-- | This checks only whether the types are equal; if the elements of the list +-- are not singletons, their values may still differ. This corresponds to +-- 'testEquality', except on the penultimate type parameter. +listhEqType :: TestEquality f => ListH sh f -> ListH sh' f -> Maybe (sh :~: sh') +listhEqType ZH ZH = Just Refl +listhEqType (n ::# sh) (m ::# sh') + | Just Refl <- testEquality n m + , Just Refl <- listhEqType sh sh' + = Just Refl +listhEqType _ _ = Nothing + +-- | This checks whether the two lists actually contain equal values. This is +-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ +-- in the @some@ package (except on the penultimate type parameter). +listhEqual :: (TestEquality f, forall n. Eq (f n)) => ListH sh f -> ListH sh' f -> Maybe (sh :~: sh') +listhEqual ZH ZH = Just Refl +listhEqual (n ::# sh) (m ::# sh') + | Just Refl <- testEquality n m + , n == m + , Just Refl <- listhEqual sh sh' + = Just Refl +listhEqual _ _ = Nothing + +{-# INLINE listhFmap #-} +listhFmap :: (forall n. f n -> g n) -> ListH sh f -> ListH sh g +listhFmap _ ZH = ZH +listhFmap f (x ::# xs) = f x ::# listhFmap f xs + +{-# INLINE listhFoldMap #-} +listhFoldMap :: Monoid m => (forall n. f n -> m) -> ListH sh f -> m +listhFoldMap _ ZH = mempty +listhFoldMap f (x ::# xs) = f x <> listhFoldMap f xs + +listhLength :: ListH sh f -> Int +listhLength = getSum . listhFoldMap (\_ -> Sum 1) + +listhRank :: ListH sh f -> SNat (Rank sh) +listhRank ZH = SNat +listhRank (_ ::# l) | SNat <- listhRank l = SNat + +{-# INLINE listhShow #-} +listhShow :: forall sh f. (forall n. f n -> ShowS) -> ListH sh f -> ShowS +listhShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListH sh' f -> ShowS + go _ ZH = id + go prefix (x ::# xs) = showString prefix . f x . go "," xs + +listhFromList :: StaticShX sh -> [i] -> ListH sh (Const i) +listhFromList topssh topl = go topssh topl + where + go :: StaticShX sh' -> [i] -> ListH sh' (Const i) + go ZKX [] = ZH + go (_ :!% sh) (i : is) = Const i ::# go sh is + go _ _ = error $ "listhFromList: Mismatched list length (type says " + ++ show (ssxLength topssh) ++ ", list has length " + ++ show (length topl) ++ ")" + +{-# INLINEABLE listhToList #-} +listhToList :: ListH sh (Const i) -> [i] +listhToList list = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ListH sh (Const i) -> is + go ZH = nil + go (Const i ::# is) = i `cons` go is + in go list) + +listhHead :: ListH (mn ': sh) f -> f mn +listhHead (i ::# _) = i + +listhTail :: ListH (n : sh) i -> ListH sh i +listhTail (_ ::# sh) = sh + +listhAppend :: ListH sh f -> ListH sh' f -> ListH (sh ++ sh') f +listhAppend ZH idx' = idx' +listhAppend (i ::# idx) idx' = i ::# listhAppend idx idx' + +listhDrop :: forall f g sh sh'. ListH sh g -> ListH (sh ++ sh') f -> ListH sh' f +listhDrop ZH long = long +listhDrop (_ ::# short) long = case long of _ ::# long' -> listhDrop short long' + +listhInit :: forall f n sh. ListH (n : sh) f -> ListH (Init (n : sh)) f +listhInit (i ::# sh@(_ ::# _)) = i ::# listhInit sh +listhInit (_ ::# ZH) = ZH + +listhLast :: forall f n sh. ListH (n : sh) f -> f (Last (n : sh)) +listhLast (_ ::# sh@(_ ::# _)) = listhLast sh +listhLast (x ::# ZH) = x + +listhZip :: ListH sh f -> ListH sh g -> ListH sh (Product f g) +listhZip ZH ZH = ZH +listhZip (i ::# irest) (j ::# jrest) = Pair i j ::# listhZip irest jrest + +{-# INLINE listhZipWith #-} +listhZipWith :: (forall a. f a -> g a -> h a) -> ListH sh f -> ListH sh g + -> ListH sh h +listhZipWith _ ZH ZH = ZH +listhZipWith f (i ::# is) (j ::# js) = f i j ::# listhZipWith f is js + -- * Mixed shapes data SMayNat i n where @@ -335,21 +461,21 @@ smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) --- | This is a newtype over 'ListX'. +-- | This is a newtype over 'ListH'. type role ShX nominal representational type ShX :: [Maybe Nat] -> Type -> Type -newtype ShX sh i = ShX (ListX sh (SMayNat i)) +newtype ShX sh i = ShX (ListH sh (SMayNat i)) deriving (Eq, Ord, Generic) pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i -pattern ZSX = ShX ZX +pattern ZSX = ShX ZH pattern (:$%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => SMayNat i n -> ShX sh i -> ShX sh1 i -pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) - where i :$% ShX shl = ShX (i ::% shl) +pattern i :$% shl <- ShX (listhUncons -> Just (UnconsListHRes (ShX -> shl) i)) + where i :$% ShX shl = ShX (i ::# shl) infixr 3 :$% {-# COMPLETE ZSX, (:$%) #-} @@ -360,17 +486,17 @@ type IShX sh = ShX sh Int deriving instance Show i => Show (ShX sh i) #else instance Show i => Show (ShX sh i) where - showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + showsPrec _ (ShX l) = listhShow (fromSMayNat shows (shows . fromSNat)) l #endif instance Functor (ShX sh) where {-# INLINE fmap #-} - fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) + fmap f (ShX l) = ShX (listhFmap (fromSMayNat (SUnknown . f) SKnown) l) instance NFData i => NFData (ShX sh i) where - rnf (ShX ZX) = () - rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) - rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) + rnf (ShX ZH) = () + rnf (ShX (SUnknown i ::# l)) = rnf i `seq` rnf (ShX l) + rnf (ShX (SKnown SNat ::# l)) = rnf (ShX l) -- | This checks only whether the types are equal; unknown dimensions might -- still differ. This corresponds to 'testEquality', except on the penultimate @@ -402,10 +528,10 @@ shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') shxEqual _ _ = Nothing shxLength :: ShX sh i -> Int -shxLength (ShX l) = listxLength l +shxLength (ShX l) = listhLength l shxRank :: ShX sh i -> SNat (Rank sh) -shxRank (ShX l) = listxRank l +shxRank (ShX l) = listhRank l -- | The number of elements in an array described by this shape. shxSize :: IShX sh -> Int @@ -449,28 +575,29 @@ shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh shxFromSSX2 (SUnknown _ :!% _) = Nothing shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i -shxAppend = coerce (listxAppend @_ @(SMayNat i)) +shxAppend = coerce (listhAppend @_ @(SMayNat i)) shxHead :: ShX (n : sh) i -> SMayNat i n -shxHead (ShX list) = listxHead list +shxHead (ShX list) = listhHead list shxTail :: ShX (n : sh) i -> ShX sh i -shxTail (ShX list) = ShX (listxTail list) +shxTail (ShX list) = ShX (listhTail list) shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSSX = coerce (listxDrop @(SMayNat i) @(SMayNat ())) +shxDropSSX = coerce (listhDrop @(SMayNat i) @(SMayNat ())) shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i -shxDropIx = coerce (listxDrop @(SMayNat i) @(Const j)) +shxDropIx (IxX ZX) long = long +shxDropIx (IxX (_ ::% short)) long = case long of ShX (_ ::# long') -> shxDropIx (IxX short) (ShX long') shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i -shxDropSh = coerce (listxDrop @(SMayNat i) @(SMayNat i)) +shxDropSh = coerce (listhDrop @(SMayNat i) @(SMayNat i)) shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i -shxInit = coerce (listxInit @(SMayNat i)) +shxInit = coerce (listhInit @(SMayNat i)) shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i (Last (n : sh)) -shxLast = coerce (listxLast @(SMayNat i)) +shxLast = coerce (listhLast @(SMayNat i)) shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i shxTakeSSX _ ZKX _ = ZSX @@ -525,20 +652,20 @@ shxCast' ssh sh = case shxCast ssh sh of -- * Static mixed shapes --- | The part of a shape that is statically known. (A newtype over 'ListX'.) +-- | The part of a shape that is statically known. (A newtype over 'ListH'.) type StaticShX :: [Maybe Nat] -> Type -newtype StaticShX sh = StaticShX (ListX sh (SMayNat ())) +newtype StaticShX sh = StaticShX (ListH sh (SMayNat ())) deriving (Eq, Ord) pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh -pattern ZKX = StaticShX ZX +pattern ZKX = StaticShX ZH pattern (:!%) :: forall {sh1}. forall n sh. (n : sh ~ sh1) => SMayNat () n -> StaticShX sh -> StaticShX sh1 -pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) - where i :!% StaticShX shl = StaticShX (i ::% shl) +pattern i :!% shl <- StaticShX (listhUncons -> Just (UnconsListHRes (StaticShX -> shl) i)) + where i :!% StaticShX shl = StaticShX (i ::# shl) infixr 3 :!% {-# COMPLETE ZKX, (:!%) #-} @@ -547,22 +674,22 @@ infixr 3 :!% deriving instance Show (StaticShX sh) #else instance Show (StaticShX sh) where - showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + showsPrec _ (StaticShX l) = listhShow (fromSMayNat shows (shows . fromSNat)) l #endif instance NFData (StaticShX sh) where - rnf (StaticShX ZX) = () - rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l) - rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l) + rnf (StaticShX ZH) = () + rnf (StaticShX (SUnknown () ::# l)) = rnf (StaticShX l) + rnf (StaticShX (SKnown SNat ::# l)) = rnf (StaticShX l) instance TestEquality StaticShX where - testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2 + testEquality (StaticShX l1) (StaticShX l2) = listhEqType l1 l2 ssxLength :: StaticShX sh -> Int -ssxLength (StaticShX l) = listxLength l +ssxLength (StaticShX l) = listhLength l ssxRank :: StaticShX sh -> SNat (Rank sh) -ssxRank (StaticShX l) = listxRank l +ssxRank (StaticShX l) = listhRank l -- | @ssxEqType = 'testEquality'@. Provided for consistency. ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') @@ -573,25 +700,26 @@ ssxAppend ZKX sh' = sh' ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' ssxHead :: StaticShX (n : sh) -> SMayNat () n -ssxHead (StaticShX list) = listxHead list +ssxHead (StaticShX list) = listhHead list ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSSX = coerce (listxDrop @(SMayNat ()) @(SMayNat ())) +ssxDropSSX = coerce (listhDrop @(SMayNat ()) @(SMayNat ())) ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropIx = coerce (listxDrop @(SMayNat ()) @(Const i)) +ssxDropIx (IxX ZX) long = long +ssxDropIx (IxX (_ ::% short)) long = case long of StaticShX (_ ::# long') -> ssxDropIx (IxX short) (StaticShX long') ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' -ssxDropSh = coerce (listxDrop @(SMayNat ()) @(SMayNat i)) +ssxDropSh = coerce (listhDrop @(SMayNat ()) @(SMayNat i)) ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) -ssxInit = coerce (listxInit @(SMayNat ())) +ssxInit = coerce (listhInit @(SMayNat ())) ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () (Last (n : sh)) -ssxLast = coerce (listxLast @(SMayNat ())) +ssxLast = coerce (listhLast @(SMayNat ())) ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs index 8b46d81..e520e0f 100644 --- a/src/Data/Array/Nested/Permutation.hs +++ b/src/Data/Array/Nested/Permutation.hs @@ -172,6 +172,62 @@ type family DropLen ref l where DropLen '[] l = l DropLen (_ : ref) (_ : xs) = DropLen ref xs +listhTakeLen :: forall f is sh. Perm is -> ListH sh f -> ListH (TakeLen is sh) f +listhTakeLen PNil _ = ZH +listhTakeLen (_ `PCons` is) (n ::# sh) = n ::# listhTakeLen is sh +listhTakeLen (_ `PCons` _) ZH = error "Permutation longer than shape" + +listhDropLen :: forall f is sh. Perm is -> ListH sh f -> ListH (DropLen is sh) f +listhDropLen PNil sh = sh +listhDropLen (_ `PCons` is) (_ ::# sh) = listhDropLen is sh +listhDropLen (_ `PCons` _) ZH = error "Permutation longer than shape" + +listhPermute :: forall f is sh. Perm is -> ListH sh f -> ListH (Permute is sh) f +listhPermute PNil _ = ZH +listhPermute (i `PCons` (is :: Perm is')) (sh :: ListH sh f) = + listhIndex (Proxy @is') (Proxy @sh) i sh ::# listhPermute is sh + +listhIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListH sh f -> f (Index i sh) +listhIndex _ _ SZ (n ::# _) = n +listhIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::# (sh :: ListH sh' f)) + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = listhIndex p pT i sh +listhIndex _ _ _ ZH = error "Index into empty shape" + +listhPermutePrefix :: forall f is sh. Perm is -> ListH sh f -> ListH (PermutePrefix is sh) f +listhPermutePrefix perm sh = listhAppend (listhPermute perm (listhTakeLen perm sh)) (listhDropLen perm sh) + +ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLen = coerce (listhTakeLen @(SMayNat ())) + +ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) +ssxDropLen = coerce (listhDropLen @(SMayNat ())) + +ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) +ssxPermute = coerce (listhPermute @(SMayNat ())) + +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () (Index i sh) +ssxIndex p1 p2 i = coerce (listhIndex @(SMayNat ()) p1 p2 i) + +ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) +ssxPermutePrefix = coerce (listhPermutePrefix @(SMayNat ())) + +shxTakeLen :: forall is sh. Perm is -> IShX sh -> IShX (TakeLen is sh) +shxTakeLen = coerce (listhTakeLen @(SMayNat Int)) + +shxDropLen :: Perm is -> IShX sh -> IShX (DropLen is sh) +shxDropLen = coerce (listhDropLen @(SMayNat Int)) + +shxPermute :: Perm is -> IShX sh -> IShX (Permute is sh) +shxPermute = coerce (listhPermute @(SMayNat Int)) + +shxIndex :: Proxy is -> Proxy shT -> SNat i -> IShX sh -> SMayNat Int (Index i sh) +shxIndex p1 p2 i = coerce (listhIndex @(SMayNat Int) p1 p2 i) + +shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) +shxPermutePrefix = coerce (listhPermutePrefix @(SMayNat Int)) + + listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f listxTakeLen PNil _ = ZX listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh @@ -200,36 +256,6 @@ listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm s ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) -ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) -ssxTakeLen = coerce (listxTakeLen @(SMayNat ())) - -ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) -ssxDropLen = coerce (listxDropLen @(SMayNat ())) - -ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) -ssxPermute = coerce (listxPermute @(SMayNat ())) - -ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () (Index i sh) -ssxIndex p1 p2 i = coerce (listxIndex @(SMayNat ()) p1 p2 i) - -ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) -ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat ())) - -shxTakeLen :: forall is sh. Perm is -> IShX sh -> IShX (TakeLen is sh) -shxTakeLen = coerce (listxTakeLen @(SMayNat Int)) - -shxDropLen :: Perm is -> IShX sh -> IShX (DropLen is sh) -shxDropLen = coerce (listxDropLen @(SMayNat Int)) - -shxPermute :: Perm is -> IShX sh -> IShX (Permute is sh) -shxPermute = coerce (listxPermute @(SMayNat Int)) - -shxIndex :: Proxy is -> Proxy shT -> SNat i -> IShX sh -> SMayNat Int (Index i sh) -shxIndex p1 p2 i = coerce (listxIndex @(SMayNat Int) p1 p2 i) - -shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) -shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int)) - -- * Operations on permutations diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index 6389e67..a9fc14c 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -285,7 +285,7 @@ sumInner ssh ssh' arr go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a go (XArray arr') | Refl <- lemRankApp ssh ssh'F - , let sn = listxRank (let StaticShX l = ssh in l) + , let sn = ssxRank ssh = XArray (liftO1 (numEltSum1Inner sn) arr') in go $ |
