{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE StrictData #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Nested.Ranked.Shape where import Control.DeepSeq (NFData(..)) import Data.Coerce (coerce) import Data.Foldable qualified as Foldable import Data.Kind (Type) import Data.Proxy import Data.Type.Equality import GHC.Exts (Int(..), Int#, quotRemInt#) import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits import GHC.TypeNats qualified as TN import Data.Array.Nested.Lemmas import Data.Array.Nested.Types -- * Ranked lists type role ListR nominal representational type ListR :: Nat -> Type -> Type data ListR n i where ZR :: ListR 0 i (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i deriving instance Eq i => Eq (ListR n i) deriving instance Ord i => Ord (ListR n i) deriving instance Functor (ListR n) deriving instance Foldable (ListR n) infixr 3 ::: #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (ListR n i) #else instance Show i => Show (ListR n i) where showsPrec _ = listrShow shows #endif instance NFData i => NFData (ListR n i) where rnf ZR = () rnf (x ::: l) = rnf x `seq` rnf l data UnconsListRRes i n1 = forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) listrUncons ZR = Nothing -- | This checks only whether the ranks are equal, not whether the actual -- values are. listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') listrEqRank ZR ZR = Just Refl listrEqRank (_ ::: sh) (_ ::: sh') | Just Refl <- listrEqRank sh sh' = Just Refl listrEqRank _ _ = Nothing -- | This compares the lists for value equality. listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n') listrEqual ZR ZR = Just Refl listrEqual (i ::: sh) (j ::: sh') | Just Refl <- listrEqual sh sh' , i == j = Just Refl listrEqual _ _ = Nothing listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS listrShow f l = showString "[" . go "" l . showString "]" where go :: String -> ListR n' i -> ShowS go _ ZR = id go prefix (x ::: xs) = showString prefix . f x . go "," xs listrLength :: ListR n i -> Int listrLength = length listrRank :: ListR n i -> SNat n listrRank ZR = SNat listrRank (_ ::: sh) = snatSucc (listrRank sh) listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i listrAppend ZR sh = sh listrAppend (x ::: xs) sh = x ::: listrAppend xs sh listrFromList :: [i] -> (forall n. ListR n i -> r) -> r listrFromList [] k = k ZR listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) listrHead :: ListR (n + 1) i -> i listrHead (i ::: _) = i listrTail :: ListR (n + 1) i -> ListR n i listrTail (_ ::: sh) = sh listrInit :: ListR (n + 1) i -> ListR n i listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh listrInit (_ ::: ZR) = ZR listrLast :: ListR (n + 1) i -> i listrLast (_ ::: sh@(_ ::: _)) = listrLast sh listrLast (n ::: ZR) = n -- | Performs a runtime check that the lengths are identical. listrCast :: SNat n' -> ListR n i -> ListR n' i listrCast = listrCastWithName "listrCast" listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i listrIndex SZ (x ::: _) = x listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs listrIndex _ ZR = error "k + 1 <= 0" listrZip :: ListR n i -> ListR n j -> ListR n (i, j) listrZip ZR ZR = ZR listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest listrZip _ _ = error "listrZip: impossible pattern needlessly required" listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k listrZipWith _ ZR ZR = ZR listrZipWith f (i ::: irest) (j ::: jrest) = f i j ::: listrZipWith f irest jrest listrZipWith _ _ _ = error "listrZipWith: impossible pattern needlessly required" listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i listrPermutePrefix = \perm sh -> listrFromList perm $ \sperm -> case (listrRank sperm, listrRank sh) of (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" where listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) listrSplitAt SZ sh = (ZR, sh) listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) listrSplitAt SS{} ZR = error "m' + 1 <= 0" applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i applyPermRFull _ ZR _ = ZR applyPermRFull sm@SNat (i ::: perm) l = TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> case cmpNat (SNat @(idx + 1)) sm of LTI -> listrIndex si l ::: applyPermRFull sm perm l EQI -> listrIndex si l ::: applyPermRFull sm perm l GTI -> error "listrPermutePrefix: Index in permutation out of range" -- * Ranked indices -- | An index into a rank-typed array. type role IxR nominal representational type IxR :: Nat -> Type -> Type newtype IxR n i = IxR (ListR n i) deriving (Eq, Ord, Generic) deriving newtype (Functor, Foldable) pattern ZIR :: forall n i. () => n ~ 0 => IxR n i pattern ZIR = IxR ZR pattern (:.:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> IxR n i -> IxR n1 i pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) where i :.: IxR sh = IxR (i ::: sh) infixr 3 :.: {-# COMPLETE ZIR, (:.:) #-} -- For convenience, this contains regular 'Int's instead of bounded integers -- (traditionally called \"@Fin@\"). type IIxR n = IxR n Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (IxR n i) #else instance Show i => Show (IxR n i) where showsPrec _ (IxR l) = listrShow shows l #endif instance NFData i => NFData (IxR sh i) ixrLength :: IxR sh i -> Int ixrLength (IxR l) = listrLength l ixrRank :: IxR n i -> SNat n ixrRank (IxR sh) = listrRank sh ixrZero :: SNat n -> IIxR n ixrZero SZ = ZIR ixrZero (SS n) = 0 :.: ixrZero n ixrHead :: IxR (n + 1) i -> i ixrHead (IxR list) = listrHead list ixrTail :: IxR (n + 1) i -> IxR n i ixrTail (IxR list) = IxR (listrTail list) ixrInit :: IxR (n + 1) i -> IxR n i ixrInit (IxR list) = IxR (listrInit list) ixrLast :: IxR (n + 1) i -> i ixrLast (IxR list) = listrLast list -- | Performs a runtime check that the lengths are identical. ixrCast :: SNat n' -> IxR n i -> IxR n' i ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx) ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i ixrAppend = coerce (listrAppend @_ @i) ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i ixrPermutePrefix = coerce (listrPermutePrefix @i) -- * Ranked shapes type role ShR nominal representational type ShR :: Nat -> Type -> Type newtype ShR n i = ShR (ListR n i) deriving (Eq, Ord, Generic) deriving newtype (Functor, Foldable) pattern ZSR :: forall n i. () => n ~ 0 => ShR n i pattern ZSR = ShR ZR pattern (:$:) :: forall {n1} {i}. forall n. (n + 1 ~ n1) => i -> ShR n i -> ShR n1 i pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) where i :$: ShR sh = ShR (i ::: sh) infixr 3 :$: {-# COMPLETE ZSR, (:$:) #-} type IShR n = ShR n Int #ifdef OXAR_DEFAULT_SHOW_INSTANCES deriving instance Show i => Show (ShR n i) #else instance Show i => Show (ShR n i) where showsPrec _ (ShR l) = listrShow shows l #endif instance NFData i => NFData (ShR sh i) -- | This checks only whether the ranks are equal, not whether the actual -- values are. shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' -- | This compares the shapes for value equality. shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' shrLength :: ShR sh i -> Int shrLength (ShR l) = listrLength l -- | This function can also be used to conjure up a 'KnownNat' dictionary; -- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern -- synonym yields 'KnownNat' evidence. shrRank :: ShR n i -> SNat n shrRank (ShR sh) = listrRank sh -- | The number of elements in an array described by this shape. shrSize :: IShR n -> Int shrSize ZSR = 1 shrSize (n :$: sh) = n * shrSize sh shrHead :: ShR (n + 1) i -> i shrHead (ShR list) = listrHead list shrTail :: ShR (n + 1) i -> ShR n i shrTail (ShR list) = ShR (listrTail list) shrInit :: ShR (n + 1) i -> ShR n i shrInit (ShR list) = ShR (listrInit list) shrLast :: ShR (n + 1) i -> i shrLast (ShR list) = listrLast list -- | Performs a runtime check that the lengths are identical. shrCast :: SNat n' -> ShR n i -> ShR n' i shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh) shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i shrAppend = coerce (listrAppend @_ @i) shrZip :: ShR n i -> ShR n j -> ShR n (i, j) shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i shrPermutePrefix = coerce (listrPermutePrefix @i) shrEnum :: IShR sh -> [IIxR sh] shrEnum = shrEnum' {-# INLINABLE shrEnum' #-} -- ensure this can be specialised at use site shrEnum' :: Num i => IShR sh -> [IxR sh i] shrEnum' sh = [fromLin sh suffixes li# | I# li# <- [0 .. shrSize sh - 1]] where suffixes = drop 1 (scanr (*) 1 (Foldable.toList sh)) fromLin :: Num i => IShR sh -> [Int] -> Int# -> IxR sh i fromLin ZSR _ _ = ZIR fromLin (_ :$: sh') (I# suff# : suffs) i# = let !(# q#, r# #) = i# `quotRemInt#` suff# -- suff == shrSize sh' in fromIntegral (I# q#) :.: fromLin sh' suffs r# fromLin _ _ _ = error "impossible" -- | Untyped: length is checked at runtime. instance KnownNat n => IsList (ListR n i) where type Item (ListR n i) = i fromList topl = go (SNat @n) topl where go :: SNat n' -> [i] -> ListR n' i go SZ [] = ZR go (SS n) (i : is) = i ::: go n is go _ _ = error $ "IsList(ListR): Mismatched list length (type says " ++ show (fromSNat (SNat @n)) ++ ", list has length " ++ show (length topl) ++ ")" toList = Foldable.toList -- | Untyped: length is checked at runtime. instance KnownNat n => IsList (IxR n i) where type Item (IxR n i) = i fromList = IxR . IsList.fromList toList = Foldable.toList -- | Untyped: length is checked at runtime. instance KnownNat n => IsList (ShR n i) where type Item (ShR n i) = i fromList = ShR . IsList.fromList toList = Foldable.toList -- * Internal helper functions listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i listrCastWithName _ SZ ZR = ZR listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx listrCastWithName name _ _ = error $ name ++ ": ranks don't match"