{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed.Shape where import Control.DeepSeq (NFData(..)) import Data.Bifunctor (first) import Data.Coerce import Data.Foldable qualified as Foldable import Data.Functor.Const import Data.Kind (Type, Constraint) import Data.Monoid (Sum(..)) import Data.Proxy import Data.Type.Equality import GHC.Exts (withDict) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import Data.Array.Mixed.Types -- | The length of a type-level list. If the argument is a shape, then the -- result is the rank of that shape. type family Rank sh where Rank '[] = 0 Rank (_ : sh) = Rank sh + 1 -- * Mixed lists type role ListX nominal representational type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type data ListX sh f where ZX :: ListX '[] f (::%) :: f n -> ListX sh f -> ListX (n : sh) f deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) infixr 3 ::% instance (forall n. Show (f n)) => Show (ListX sh f) where showsPrec _ = listxShow shows instance (forall n. NFData (f n)) => NFData (ListX sh f) where rnf ZX = () rnf (x ::% l) = rnf x `seq` rnf l data UnconsListXRes f sh1 = forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) listxUncons ZX = Nothing listxEq :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') listxEq ZX ZX = Just Refl listxEq (n ::% sh) (m ::% sh') | Just Refl <- testEquality n m , Just Refl <- listxEq sh sh' = Just Refl listxEq _ _ = Nothing listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g listxFmap _ ZX = ZX listxFmap f (x ::% xs) = f x ::% listxFmap f xs listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m listxFold _ ZX = mempty listxFold f (x ::% xs) = f x <> listxFold f xs listxLength :: ListX sh f -> Int listxLength = getSum . listxFold (\_ -> Sum 1) listxRank :: ListX sh f -> SNat (Rank sh) listxRank ZX = SNat listxRank (_ ::% l) | SNat <- listxRank l = SNat listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS listxShow f l = showString "[" . go "" l . showString "]" where go :: String -> ListX sh' f -> ShowS go _ ZX = id go prefix (x ::% xs) = showString prefix . f x . go "," xs listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i) listxFromList topssh topl = go topssh topl where go :: StaticShX sh' -> [i] -> ListX sh' (Const i) go ZKX [] = ZX go (_ :!% sh) (i : is) = Const i ::% go sh is go _ _ = error $ "listxFromList: Mismatched list length (type says " ++ show (ssxLength topssh) ++ ", list has length " ++ show (length topl) ++ ")" listxToList :: ListX sh' (Const i) -> [i] listxToList ZX = [] listxToList (Const i ::% is) = i : listxToList is listxTail :: ListX (n : sh) i -> ListX sh i listxTail (_ ::% sh) = sh listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f listxAppend ZX idx' = idx' listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f listxDrop long ZX = long listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh listxInit (_ ::% ZX) = ZX listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh)) listxLast (_ ::% sh@(_ ::% _)) = listxLast sh listxLast (x ::% ZX) = x -- * Mixed indices -- | This is a newtype over 'ListX'. type role IxX nominal representational type IxX :: [Maybe Nat] -> Type -> Type newtype IxX sh i = IxX (ListX sh (Const i)) deriving (Eq, Ord, Generic) pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i pattern ZIX = IxX ZX pattern (:.%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => i -> IxX sh i -> IxX sh1 i pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) where i :.% IxX shl = IxX (Const i ::% shl) infixr 3 :.% {-# COMPLETE ZIX, (:.%) #-} type IIxX sh = IxX sh Int instance Show i => Show (IxX sh i) where showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l instance Functor (IxX sh) where fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) instance Foldable (IxX sh) where foldMap f (IxX l) = listxFold (f . getConst) l instance NFData i => NFData (IxX sh i) ixxZero :: StaticShX sh -> IIxX sh ixxZero ZKX = ZIX ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh ixxZero' :: IShX sh -> IIxX sh ixxZero' ZSX = ZIX ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i ixxFromList = coerce (listxFromList @_ @i) ixxTail :: IxX (n : sh) i -> IxX sh i ixxTail (IxX list) = IxX (listxTail list) ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i ixxAppend = coerce (listxAppend @_ @(Const i)) ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i ixxDrop = coerce (listxDrop @(Const i) @(Const i)) ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i ixxInit = coerce (listxInit @(Const i)) ixxLast :: forall n sh i. IxX (n : sh) i -> i ixxLast = coerce (listxLast @(Const i)) ixxFromLinear :: IShX sh -> Int -> IIxX sh ixxFromLinear = \sh i -> case go sh i of (idx, 0) -> idx _ -> error $ "ixxFromLinear: out of range (" ++ show i ++ " in array of shape " ++ show sh ++ ")" where -- returns (index in subarray, remaining index in enclosing array) go :: IShX sh -> Int -> (IIxX sh, Int) go ZSX i = (ZIX, i) go (n :$% sh) i = let (idx, i') = go sh i (upi, locali) = i' `quotRem` fromSMayNat' n in (locali :.% idx, upi) ixxToLinear :: IShX sh -> IIxX sh -> Int ixxToLinear = \sh i -> fst (go sh i) where -- returns (index in subarray, size of subarray) go :: IShX sh -> IIxX sh -> (Int, Int) go ZSX ZIX = (0, 1) go (n :$% sh) (i :.% ix) = let (lidx, sz) = go sh ix in (sz * i + lidx, fromSMayNat' n * sz) -- * Mixed shapes data SMayNat i f n where SUnknown :: i -> SMayNat i f Nothing SKnown :: f n -> SMayNat i f (Just n) deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where rnf (SUnknown i) = rnf i rnf (SKnown x) = rnf x instance TestEquality f => TestEquality (SMayNat i f) where testEquality SUnknown{} SUnknown{} = Just Refl testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl testEquality _ _ = Nothing fromSMayNat :: (n ~ Nothing => i -> r) -> (forall m. n ~ Just m => f m -> r) -> SMayNat i f n -> r fromSMayNat f _ (SUnknown i) = f i fromSMayNat _ g (SKnown s) = g s fromSMayNat' :: SMayNat Int SNat n -> Int fromSMayNat' = fromSMayNat id fromSNat' type family AddMaybe n m where AddMaybe Nothing _ = Nothing AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) -- | This is a newtype over 'ListX'. type role ShX nominal representational type ShX :: [Maybe Nat] -> Type -> Type newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) deriving (Eq, Ord, Generic) pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i pattern ZSX = ShX ZX pattern (:$%) :: forall {sh1} {i}. forall n sh. (n : sh ~ sh1) => SMayNat i SNat n -> ShX sh i -> ShX sh1 i pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) where i :$% ShX shl = ShX (i ::% shl) infixr 3 :$% {-# COMPLETE ZSX, (:$%) #-} type IShX sh = ShX sh Int instance Show i => Show (ShX sh i) where showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l instance Functor (ShX sh) where fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) instance NFData i => NFData (ShX sh i) where rnf (ShX ZX) = () rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) shxLength :: ShX sh i -> Int shxLength (ShX l) = listxLength l shxRank :: ShX sh f -> SNat (Rank sh) shxRank (ShX list) = listxRank list -- | This is more than @geq@: it also checks that the integers (the unknown -- dimensions) are the same. shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') shxEqual ZSX ZSX = Just Refl shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') | Just Refl <- sameNat n m , Just Refl <- shxEqual sh sh' = Just Refl shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') | i == j , Just Refl <- shxEqual sh sh' = Just Refl shxEqual _ _ = Nothing -- | The number of elements in an array described by this shape. shxSize :: IShX sh -> Int shxSize ZSX = 1 shxSize (n :$% sh) = fromSMayNat' n * shxSize sh shxFromList :: StaticShX sh -> [Int] -> ShX sh Int shxFromList topssh topl = go topssh topl where go :: StaticShX sh' -> [Int] -> ShX sh' Int go ZKX [] = ZSX go (SKnown sn :!% sh) (i : is) | i == fromSNat' sn = SKnown sn :$% go sh is | otherwise = error $ "shxFromList: Value does not match typing (type says " ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is go _ _ = error $ "shxFromList: Mismatched list length (type says " ++ show (ssxLength topssh) ++ ", list has length " ++ show (length topl) ++ ")" shxToList :: IShX sh -> [Int] shxToList ZSX = [] shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) shxTail :: ShX (n : sh) i -> ShX sh i shxTail (ShX list) = ShX (listxTail list) shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i shxInit = coerce (listxInit @(SMayNat i SNat)) shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) shxLast = coerce (listxLast @(SMayNat i SNat)) shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i shxTakeSSX _ = flip go where go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i go ZKX _ = ZSX go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh -- This is a weird operation, so it has a long name shxCompleteZeros :: StaticShX sh -> IShX sh shxCompleteZeros ZKX = ZSX shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) shxSplitApp _ ZKX idx = (ZSX, idx) shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) shxEnum :: IShX sh -> [IIxX sh] shxEnum = \sh -> go sh id [] where go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] go ZSX f = (f ZIX :) go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] -- * Static mixed shapes -- | The part of a shape that is statically known. (A newtype over 'ListX'.) type StaticShX :: [Maybe Nat] -> Type newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) deriving (Eq, Ord) pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh pattern ZKX = StaticShX ZX pattern (:!%) :: forall {sh1}. forall n sh. (n : sh ~ sh1) => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) where i :!% StaticShX shl = StaticShX (i ::% shl) infixr 3 :!% {-# COMPLETE ZKX, (:!%) #-} instance Show (StaticShX sh) where showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l instance TestEquality StaticShX where testEquality (StaticShX l1) (StaticShX l2) = listxEq l1 l2 ssxLength :: StaticShX sh -> Int ssxLength (StaticShX l) = listxLength l -- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@ -- class of the @some@ package. ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') ssxGeq ZKX ZKX = Just Refl ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh') | Just Refl <- sameNat n m , Just Refl <- ssxGeq sh sh' = Just Refl ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh') | Just Refl <- ssxGeq sh sh' = Just Refl ssxGeq _ _ = Nothing ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') ssxAppend ZKX sh' = sh' ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' ssxTail :: StaticShX (n : sh) -> StaticShX sh ssxTail (_ :!% ssh) = ssh ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) ssxInit = coerce (listxInit @(SMayNat () SNat)) ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) ssxLast = coerce (listxLast @(SMayNat () SNat)) -- | This may fail if @sh@ has @Nothing@s in it. ssxToShX' :: StaticShX sh -> Maybe (IShX sh) ssxToShX' ZKX = Just ZSX ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh ssxToShX' (SUnknown _ :!% _) = Nothing ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) ssxReplicate SZ = ZKX ssxReplicate (SS (n :: SNat n')) | Refl <- lemReplicateSucc @(Nothing @Nat) @n' = SUnknown () :!% ssxReplicate n ssxIotaFrom :: Int -> StaticShX sh -> [Int] ssxIotaFrom _ ZKX = [] ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh ssxFromShape :: IShX sh -> StaticShX sh ssxFromShape ZSX = ZKX ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) ssxFromSNat SZ = ZKX ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n -- | Evidence for the static part of a shape. This pops up only when you are -- polymorphic in the element type of an array. type KnownShX :: [Maybe Nat] -> Constraint class KnownShX sh where knownShX :: StaticShX sh instance KnownShX '[] where knownShX = ZKX instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r withKnownShX sh = withDict @(KnownShX sh) sh -- * Flattening type Flatten sh = Flatten' 1 sh type family Flatten' acc sh where Flatten' acc '[] = Just acc Flatten' acc (Nothing : sh) = Nothing Flatten' acc (Just n : sh) = Flatten' (acc * n) sh -- This function is currently unused ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) ssxFlatten = go (SNat @1) where go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) go acc ZKX = SKnown acc go _ (SUnknown () :!% _) = SUnknown () go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) shxFlatten = go (SNat @1) where go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) go acc ZSX = SKnown acc go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh goUnknown :: Int -> IShX sh -> Int goUnknown acc ZSX = acc goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh -- | Very untyped: only length is checked (at runtime). instance KnownShX sh => IsList (ListX sh (Const i)) where type Item (ListX sh (Const i)) = i fromList = listxFromList (knownShX @sh) toList = listxToList -- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. instance KnownShX sh => IsList (IxX sh i) where type Item (IxX sh i) = i fromList = IxX . IsList.fromList toList = Foldable.toList -- | Untyped: length and known dimensions are checked (at runtime). instance KnownShX sh => IsList (ShX sh Int) where type Item (ShX sh Int) = Int fromList = shxFromList (knownShX @sh) toList = shxToList