diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-30 11:58:40 +0200 |
commit | a65306ba5d80891b20ac86fa3a3242f9497751e6 (patch) | |
tree | 834af370556a46bbeca807a92c31bef098b47a89 /src/Data/Array/Mixed/Shape.hs | |
parent | d8e2fcf4ea979fe272db48fc2889f4c2636c50d7 (diff) |
Refactor Mixed (modules, regular function names)
Diffstat (limited to 'src/Data/Array/Mixed/Shape.hs')
-rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 455 |
1 files changed, 455 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs new file mode 100644 index 0000000..a16da76 --- /dev/null +++ b/src/Data/Array/Mixed/Shape.hs @@ -0,0 +1,455 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed.Shape where + +import Control.DeepSeq (NFData(..)) +import qualified Data.Foldable as Foldable +import Data.Functor.Const +import Data.Kind (Type, Constraint) +import Data.Monoid (Sum(..)) +import Data.Proxy +import Data.Type.Equality +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import qualified GHC.IsList as IsList +import GHC.TypeLits + +import Data.Array.Mixed.Types +import Data.Coerce +import Data.Bifunctor (first) + + +-- | The length of a type-level list. If the argument is a shape, then the +-- result is the rank of that shape. +type family Rank sh where + Rank '[] = 0 + Rank (_ : sh) = Rank sh + 1 + + +-- * Mixed lists + +type role ListX nominal representational +type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type +data ListX sh f where + ZX :: ListX '[] f + (::%) :: f n -> ListX sh f -> ListX (n : sh) f +deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) +infixr 3 ::% + +instance (forall n. Show (f n)) => Show (ListX sh f) where + showsPrec _ = listxShow shows + +instance (forall n. NFData (f n)) => NFData (ListX sh f) where + rnf ZX = () + rnf (x ::% l) = rnf x `seq` rnf l + +data UnconsListXRes f sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) +listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) +listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) +listxUncons ZX = Nothing + +listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g +listxFmap _ ZX = ZX +listxFmap f (x ::% xs) = f x ::% listxFmap f xs + +listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m +listxFold _ ZX = mempty +listxFold f (x ::% xs) = f x <> listxFold f xs + +listxLength :: ListX sh f -> Int +listxLength = getSum . listxFold (\_ -> Sum 1) + +listxLengthSNat :: ListX sh f -> SNat (Rank sh) +listxLengthSNat ZX = SNat +listxLengthSNat (_ ::% l) | SNat <- listxLengthSNat l = SNat + +listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS +listxShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListX sh' f -> ShowS + go _ ZX = id + go prefix (x ::% xs) = showString prefix . f x . go "," xs + +listxToList :: ListX sh' (Const i) -> [i] +listxToList ZX = [] +listxToList (Const i ::% is) = i : listxToList is + +listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f +listxAppend ZX idx' = idx' +listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' + +listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f +listxDrop long ZX = long +listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short + + +-- * Mixed indices + +-- | This is a newtype over 'ListX'. +type role IxX nominal representational +type IxX :: [Maybe Nat] -> Type -> Type +newtype IxX sh i = IxX (ListX sh (Const i)) + deriving (Eq, Ord, Generic) + +pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i +pattern ZIX = IxX ZX + +pattern (:.%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> IxX sh i -> IxX sh1 i +pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) + where i :.% IxX shl = IxX (Const i ::% shl) +infixr 3 :.% + +{-# COMPLETE ZIX, (:.%) #-} + +type IIxX sh = IxX sh Int + +instance Show i => Show (IxX sh i) where + showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l + +instance Functor (IxX sh) where + fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) + +instance Foldable (IxX sh) where + foldMap f (IxX l) = listxFold (f . getConst) l + +instance NFData i => NFData (IxX sh i) + +ixxZero :: StaticShX sh -> IIxX sh +ixxZero ZKX = ZIX +ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh + +ixxZero' :: IShX sh -> IIxX sh +ixxZero' ZSX = ZIX +ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh + +ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i +ixxAppend = coerce (listxAppend @_ @(Const i)) + +ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i +ixxDrop = coerce (listxDrop @(Const i) @(Const i)) + +ixxFromLinear :: IShX sh -> Int -> IIxX sh +ixxFromLinear = \sh i -> case go sh i of + (idx, 0) -> idx + _ -> error $ "ixxFromLinear: out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" + where + -- returns (index in subarray, remaining index in enclosing array) + go :: IShX sh -> Int -> (IIxX sh, Int) + go ZSX i = (ZIX, i) + go (n :$% sh) i = + let (idx, i') = go sh i + (upi, locali) = i' `quotRem` fromSMayNat' n + in (locali :.% idx, upi) + +ixxToLinear :: IShX sh -> IIxX sh -> Int +ixxToLinear = \sh i -> fst (go sh i) + where + -- returns (index in subarray, size of subarray) + go :: IShX sh -> IIxX sh -> (Int, Int) + go ZSX ZIX = (0, 1) + go (n :$% sh) (i :.% ix) = + let (lidx, sz) = go sh ix + in (sz * i + lidx, fromSMayNat' n * sz) + + +-- * Mixed shapes + +data SMayNat i f n where + SUnknown :: i -> SMayNat i f Nothing + SKnown :: f n -> SMayNat i f (Just n) +deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) +deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) +deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) + +instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where + rnf (SUnknown i) = rnf i + rnf (SKnown x) = rnf x + +fromSMayNat :: (n ~ Nothing => i -> r) + -> (forall m. n ~ Just m => f m -> r) + -> SMayNat i f n -> r +fromSMayNat f _ (SUnknown i) = f i +fromSMayNat _ g (SKnown s) = g s + +fromSMayNat' :: SMayNat Int SNat n -> Int +fromSMayNat' = fromSMayNat id fromSNat' + +type family AddMaybe n m where + AddMaybe Nothing _ = Nothing + AddMaybe (Just _) Nothing = Nothing + AddMaybe (Just n) (Just m) = Just (n + m) + +smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) +smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) +smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) + + +-- | This is a newtype over 'ListX'. +type role ShX nominal representational +type ShX :: [Maybe Nat] -> Type -> Type +newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) + deriving (Eq, Ord, Generic) + +pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i +pattern ZSX = ShX ZX + +pattern (:$%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => SMayNat i SNat n -> ShX sh i -> ShX sh1 i +pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) + where i :$% ShX shl = ShX (i ::% shl) +infixr 3 :$% + +{-# COMPLETE ZSX, (:$%) #-} + +type IShX sh = ShX sh Int + +instance Show i => Show (ShX sh i) where + showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + +instance Functor (ShX sh) where + fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) + +instance NFData i => NFData (ShX sh i) where + rnf (ShX ZX) = () + rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) + rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) + +shxLength :: ShX sh i -> Int +shxLength (ShX l) = listxLength l + +-- | This is more than @geq@: it also checks that the integers (the unknown +-- dimensions) are the same. +shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') +shxEqual ZSX ZSX = Just Refl +shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') + | Just Refl <- sameNat n m + , Just Refl <- shxEqual sh sh' + = Just Refl +shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') + | i == j + , Just Refl <- shxEqual sh sh' + = Just Refl +shxEqual _ _ = Nothing + +-- | The number of elements in an array described by this shape. +shxSize :: IShX sh -> Int +shxSize ZSX = 1 +shxSize (n :$% sh) = fromSMayNat' n * shxSize sh + +shxToList :: IShX sh -> [Int] +shxToList ZSX = [] +shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh + +shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i +shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) + +shxTail :: ShX (n : sh) i -> ShX sh i +shxTail (_ :$% sh) = sh + +shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i +shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) + +shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i +shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) + +shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i +shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) + +shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i +shxTakeSSX _ = flip go + where + go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i + go ZKX _ = ZSX + go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh + +-- This is a weird operation, so it has a long name +shxCompleteZeros :: StaticShX sh -> IShX sh +shxCompleteZeros ZKX = ZSX +shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh +shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh + +shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) +shxSplitApp _ ZKX idx = (ZSX, idx) +shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) + +shxEnum :: IShX sh -> [IIxX sh] +shxEnum = \sh -> go sh id [] + where + go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] + go ZSX f = (f ZIX :) + go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] + + +-- * Static mixed shapes + +-- | The part of a shape that is statically known. (A newtype over 'ListX'.) +type StaticShX :: [Maybe Nat] -> Type +newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) + deriving (Eq, Ord) + +pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh +pattern ZKX = StaticShX ZX + +pattern (:!%) + :: forall {sh1}. + forall n sh. (n : sh ~ sh1) + => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 +pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) + where i :!% StaticShX shl = StaticShX (i ::% shl) +infixr 3 :!% + +{-# COMPLETE ZKX, (:!%) #-} + +instance Show (StaticShX sh) where + showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + +ssxLength :: StaticShX sh -> Int +ssxLength (StaticShX l) = listxLength l + +-- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@ +-- class of the @some@ package. +ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') +ssxGeq ZKX ZKX = Just Refl +ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh') + | Just Refl <- sameNat n m + , Just Refl <- ssxGeq sh sh' + = Just Refl +ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh') + | Just Refl <- ssxGeq sh sh' + = Just Refl +ssxGeq _ _ = Nothing + +ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') +ssxAppend ZKX sh' = sh' +ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' + +ssxTail :: StaticShX (n : sh) -> StaticShX sh +ssxTail (_ :!% ssh) = ssh + +ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' +ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) + +-- | This may fail if @sh@ has @Nothing@s in it. +ssxToShX' :: StaticShX sh -> Maybe (IShX sh) +ssxToShX' ZKX = Just ZSX +ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh +ssxToShX' (SUnknown _ :!% _) = Nothing + +ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) +ssxReplicate SZ = ZKX +ssxReplicate (SS (n :: SNat n')) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n' + = SUnknown () :!% ssxReplicate n + +ssxIotaFrom :: Int -> StaticShX sh -> [Int] +ssxIotaFrom _ ZKX = [] +ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh + +ssxFromShape :: IShX sh -> StaticShX sh +ssxFromShape ZSX = ZKX +ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh + + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShX :: [Maybe Nat] -> Constraint +class KnownShX sh where knownShX :: StaticShX sh +instance KnownShX '[] where knownShX = ZKX +instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX +instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX + + +-- * Flattening + +type Flatten sh = Flatten' 1 sh + +type family Flatten' acc sh where + Flatten' acc '[] = Just acc + Flatten' acc (Nothing : sh) = Nothing + Flatten' acc (Just n : sh) = Flatten' (acc * n) sh + +-- This function is currently unused +ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) +ssxFlatten = go (SNat @1) + where + go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) + go acc ZKX = SKnown acc + go _ (SUnknown () :!% _) = SUnknown () + go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh + +shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) +shxFlatten = go (SNat @1) + where + go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go acc ZSX = SKnown acc + go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) + go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh + + goUnknown :: Int -> IShX sh -> Int + goUnknown acc ZSX = acc + goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh + goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh + + +-- | Very untyped: only length is checked (at runtime). +instance KnownShX sh => IsList (ListX sh (Const i)) where + type Item (ListX sh (Const i)) = i + fromList topl = go (knownShX @sh) topl + where + go :: StaticShX sh' -> [i] -> ListX sh' (Const i) + go ZKX [] = ZX + go (_ :!% sh) (i : is) = Const i ::% go sh is + go _ _ = error $ "IsList(ListX): Mismatched list length (type says " + ++ show (ssxLength (knownShX @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = listxToList + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShX sh => IsList (IxX sh i) where + type Item (IxX sh i) = i + fromList = IxX . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length and known dimensions are checked (at runtime). +instance KnownShX sh => IsList (ShX sh Int) where + type Item (ShX sh Int) = Int + fromList topl = ShX (go (knownShX @sh) topl) + where + go :: StaticShX sh' -> [Int] -> ListX sh' (SMayNat Int SNat) + go ZKX [] = ZX + go (SKnown sn :!% sh) (i : is) + | i == fromSNat' sn = SKnown sn ::% go sh is + | otherwise = error $ "IsList(ShX): Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go (SUnknown () :!% sh) (i : is) = SUnknown i ::% go sh is + go _ _ = error $ "IsList(ShX): Mismatched list length (type says " + ++ show (ssxLength (knownShX @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = shxToList |