diff options
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs (renamed from src/Data/Array/Nested/Internal/Convert.hs) | 12 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Lemmas.hs | 4 | ||||
-rw-r--r-- | src/Data/Array/Nested/Mixed.hs (renamed from src/Data/Array/Nested/Internal/Mixed.hs) | 10 | ||||
-rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 611 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked.hs (renamed from src/Data/Array/Nested/Internal/Ranked.hs) | 12 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 363 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped.hs (renamed from src/Data/Array/Nested/Internal/Shaped.hs) | 12 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs (renamed from src/Data/Array/Nested/Internal/Shape.hs) | 325 |
8 files changed, 1001 insertions, 348 deletions
diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Convert.hs index c316161..639f5fd 100644 --- a/src/Data/Array/Nested/Internal/Convert.hs +++ b/src/Data/Array/Nested/Convert.hs @@ -5,20 +5,20 @@ {-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -module Data.Array.Nested.Internal.Convert where +module Data.Array.Nested.Convert where import Control.Category import Data.Proxy import Data.Type.Equality import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Nested.Internal.Lemmas -import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Internal.Ranked -import Data.Array.Nested.Internal.Shape -import Data.Array.Nested.Internal.Shaped +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked +import Data.Array.Nested.Shaped +import Data.Array.Nested.Shaped.Shape stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs index f894f78..f4bad70 100644 --- a/src/Data/Array/Nested/Internal/Lemmas.hs +++ b/src/Data/Array/Nested/Internal/Lemmas.hs @@ -11,9 +11,9 @@ import GHC.TypeLits import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape import Data.Array.Mixed.Types -import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Shaped.Shape lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index a2f9737..ec19c21 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -16,7 +16,7 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} -module Data.Array.Nested.Internal.Mixed where +module Data.Array.Nested.Mixed where import Prelude hiding (mconcat) @@ -42,13 +42,13 @@ import GHC.Generics (Generic) import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) -import Data.Array.Mixed.Internal.Arith +import Data.Array.Arith import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray(..)) -import Data.Array.Mixed.XArray qualified as X +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X +import Data.Array.Nested.Mixed.Shape import Data.Bag diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs new file mode 100644 index 0000000..5f4775c --- /dev/null +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -0,0 +1,611 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Mixed.Shape where + +import Control.DeepSeq (NFData(..)) +import Data.Bifunctor (first) +import Data.Coerce +import Data.Foldable qualified as Foldable +import Data.Functor.Const +import Data.Functor.Product +import Data.Kind (Constraint, Type) +import Data.Monoid (Sum(..)) +import Data.Proxy +import Data.Type.Equality +import GHC.Exts (withDict) +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits + +import Data.Array.Mixed.Types + + +-- | The length of a type-level list. If the argument is a shape, then the +-- result is the rank of that shape. +type family Rank sh where + Rank '[] = 0 + Rank (_ : sh) = Rank sh + 1 + + +-- * Mixed lists + +type role ListX nominal representational +type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type +data ListX sh f where + ZX :: ListX '[] f + (::%) :: f n -> ListX sh f -> ListX (n : sh) f +deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) +infixr 3 ::% + +instance (forall n. Show (f n)) => Show (ListX sh f) where + showsPrec _ = listxShow shows + +instance (forall n. NFData (f n)) => NFData (ListX sh f) where + rnf ZX = () + rnf (x ::% l) = rnf x `seq` rnf l + +data UnconsListXRes f sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) +listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) +listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) +listxUncons ZX = Nothing + +-- | This checks only whether the types are equal; if the elements of the list +-- are not singletons, their values may still differ. This corresponds to +-- 'testEquality', except on the penultimate type parameter. +listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') +listxEqType ZX ZX = Just Refl +listxEqType (n ::% sh) (m ::% sh') + | Just Refl <- testEquality n m + , Just Refl <- listxEqType sh sh' + = Just Refl +listxEqType _ _ = Nothing + +-- | This checks whether the two lists actually contain equal values. This is +-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ +-- in the @some@ package (except on the penultimate type parameter). +listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') +listxEqual ZX ZX = Just Refl +listxEqual (n ::% sh) (m ::% sh') + | Just Refl <- testEquality n m + , n == m + , Just Refl <- listxEqual sh sh' + = Just Refl +listxEqual _ _ = Nothing + +listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g +listxFmap _ ZX = ZX +listxFmap f (x ::% xs) = f x ::% listxFmap f xs + +listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m +listxFold _ ZX = mempty +listxFold f (x ::% xs) = f x <> listxFold f xs + +listxLength :: ListX sh f -> Int +listxLength = getSum . listxFold (\_ -> Sum 1) + +listxRank :: ListX sh f -> SNat (Rank sh) +listxRank ZX = SNat +listxRank (_ ::% l) | SNat <- listxRank l = SNat + +listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS +listxShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListX sh' f -> ShowS + go _ ZX = id + go prefix (x ::% xs) = showString prefix . f x . go "," xs + +listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i) +listxFromList topssh topl = go topssh topl + where + go :: StaticShX sh' -> [i] -> ListX sh' (Const i) + go ZKX [] = ZX + go (_ :!% sh) (i : is) = Const i ::% go sh is + go _ _ = error $ "listxFromList: Mismatched list length (type says " + ++ show (ssxLength topssh) ++ ", list has length " + ++ show (length topl) ++ ")" + +listxToList :: ListX sh' (Const i) -> [i] +listxToList ZX = [] +listxToList (Const i ::% is) = i : listxToList is + +listxHead :: ListX (mn ': sh) f -> f mn +listxHead (i ::% _) = i + +listxTail :: ListX (n : sh) i -> ListX sh i +listxTail (_ ::% sh) = sh + +listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f +listxAppend ZX idx' = idx' +listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' + +listxDrop :: forall f g sh sh'. ListX (sh ++ sh') f -> ListX sh g -> ListX sh' f +listxDrop long ZX = long +listxDrop long (_ ::% short) = case long of _ ::% long' -> listxDrop long' short + +listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f +listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh +listxInit (_ ::% ZX) = ZX + +listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh)) +listxLast (_ ::% sh@(_ ::% _)) = listxLast sh +listxLast (x ::% ZX) = x + +listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g) +listxZip ZX ZX = ZX +listxZip (i ::% irest) (j ::% jrest) = + Pair i j ::% listxZip irest jrest + +listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g + -> ListX sh h +listxZipWith _ ZX ZX = ZX +listxZipWith f (i ::% is) (j ::% js) = + f i j ::% listxZipWith f is js + + +-- * Mixed indices + +-- | This is a newtype over 'ListX'. +type role IxX nominal representational +type IxX :: [Maybe Nat] -> Type -> Type +newtype IxX sh i = IxX (ListX sh (Const i)) + deriving (Eq, Ord, Generic) + +pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i +pattern ZIX = IxX ZX + +pattern (:.%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> IxX sh i -> IxX sh1 i +pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) + where i :.% IxX shl = IxX (Const i ::% shl) +infixr 3 :.% + +{-# COMPLETE ZIX, (:.%) #-} + +type IIxX sh = IxX sh Int + +instance Show i => Show (IxX sh i) where + showsPrec _ (IxX l) = listxShow (\(Const i) -> shows i) l + +instance Functor (IxX sh) where + fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) + +instance Foldable (IxX sh) where + foldMap f (IxX l) = listxFold (f . getConst) l + +instance NFData i => NFData (IxX sh i) + +ixxLength :: IxX sh i -> Int +ixxLength (IxX l) = listxLength l + +ixxRank :: IxX sh i -> SNat (Rank sh) +ixxRank (IxX l) = listxRank l + +ixxZero :: StaticShX sh -> IIxX sh +ixxZero ZKX = ZIX +ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh + +ixxZero' :: IShX sh -> IIxX sh +ixxZero' ZSX = ZIX +ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh + +ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i +ixxFromList = coerce (listxFromList @_ @i) + +ixxHead :: IxX (n : sh) i -> i +ixxHead (IxX list) = getConst (listxHead list) + +ixxTail :: IxX (n : sh) i -> IxX sh i +ixxTail (IxX list) = IxX (listxTail list) + +ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i +ixxAppend = coerce (listxAppend @_ @(Const i)) + +ixxDrop :: forall sh sh' i. IxX (sh ++ sh') i -> IxX sh i -> IxX sh' i +ixxDrop = coerce (listxDrop @(Const i) @(Const i)) + +ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i +ixxInit = coerce (listxInit @(Const i)) + +ixxLast :: forall n sh i. IxX (n : sh) i -> i +ixxLast = coerce (listxLast @(Const i)) + +ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j) +ixxZip ZIX ZIX = ZIX +ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js + +ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k +ixxZipWith _ ZIX ZIX = ZIX +ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js + +ixxFromLinear :: IShX sh -> Int -> IIxX sh +ixxFromLinear = \sh i -> case go sh i of + (idx, 0) -> idx + _ -> error $ "ixxFromLinear: out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" + where + -- returns (index in subarray, remaining index in enclosing array) + go :: IShX sh -> Int -> (IIxX sh, Int) + go ZSX i = (ZIX, i) + go (n :$% sh) i = + let (idx, i') = go sh i + (upi, locali) = i' `quotRem` fromSMayNat' n + in (locali :.% idx, upi) + +ixxToLinear :: IShX sh -> IIxX sh -> Int +ixxToLinear = \sh i -> fst (go sh i) + where + -- returns (index in subarray, size of subarray) + go :: IShX sh -> IIxX sh -> (Int, Int) + go ZSX ZIX = (0, 1) + go (n :$% sh) (i :.% ix) = + let (lidx, sz) = go sh ix + in (sz * i + lidx, fromSMayNat' n * sz) + + +-- * Mixed shapes + +data SMayNat i f n where + SUnknown :: i -> SMayNat i f Nothing + SKnown :: f n -> SMayNat i f (Just n) +deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) +deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) +deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) + +instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where + rnf (SUnknown i) = rnf i + rnf (SKnown x) = rnf x + +instance TestEquality f => TestEquality (SMayNat i f) where + testEquality SUnknown{} SUnknown{} = Just Refl + testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl + testEquality _ _ = Nothing + +fromSMayNat :: (n ~ Nothing => i -> r) + -> (forall m. n ~ Just m => f m -> r) + -> SMayNat i f n -> r +fromSMayNat f _ (SUnknown i) = f i +fromSMayNat _ g (SKnown s) = g s + +fromSMayNat' :: SMayNat Int SNat n -> Int +fromSMayNat' = fromSMayNat id fromSNat' + +type family AddMaybe n m where + AddMaybe Nothing _ = Nothing + AddMaybe (Just _) Nothing = Nothing + AddMaybe (Just n) (Just m) = Just (n + m) + +smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) +smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) +smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) + + +-- | This is a newtype over 'ListX'. +type role ShX nominal representational +type ShX :: [Maybe Nat] -> Type -> Type +newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) + deriving (Eq, Ord, Generic) + +pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i +pattern ZSX = ShX ZX + +pattern (:$%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => SMayNat i SNat n -> ShX sh i -> ShX sh1 i +pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) + where i :$% ShX shl = ShX (i ::% shl) +infixr 3 :$% + +{-# COMPLETE ZSX, (:$%) #-} + +type IShX sh = ShX sh Int + +instance Show i => Show (ShX sh i) where + showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + +instance Functor (ShX sh) where + fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) + +instance NFData i => NFData (ShX sh i) where + rnf (ShX ZX) = () + rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) + rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) + +-- | This checks only whether the types are equal; unknown dimensions might +-- still differ. This corresponds to 'testEquality', except on the penultimate +-- type parameter. +shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') +shxEqType ZSX ZSX = Just Refl +shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') + | Just Refl <- sameNat n m + , Just Refl <- shxEqType sh sh' + = Just Refl +shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh') + | Just Refl <- shxEqType sh sh' + = Just Refl +shxEqType _ _ = Nothing + +-- | This checks whether all dimensions have the same value. This is more than +-- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the +-- @some@ package (except on the penultimate type parameter). +shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') +shxEqual ZSX ZSX = Just Refl +shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') + | Just Refl <- sameNat n m + , Just Refl <- shxEqual sh sh' + = Just Refl +shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') + | i == j + , Just Refl <- shxEqual sh sh' + = Just Refl +shxEqual _ _ = Nothing + +shxLength :: ShX sh i -> Int +shxLength (ShX l) = listxLength l + +shxRank :: ShX sh i -> SNat (Rank sh) +shxRank (ShX l) = listxRank l + +-- | The number of elements in an array described by this shape. +shxSize :: IShX sh -> Int +shxSize ZSX = 1 +shxSize (n :$% sh) = fromSMayNat' n * shxSize sh + +shxFromList :: StaticShX sh -> [Int] -> ShX sh Int +shxFromList topssh topl = go topssh topl + where + go :: StaticShX sh' -> [Int] -> ShX sh' Int + go ZKX [] = ZSX + go (SKnown sn :!% sh) (i : is) + | i == fromSNat' sn = SKnown sn :$% go sh is + | otherwise = error $ "shxFromList: Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is + go _ _ = error $ "shxFromList: Mismatched list length (type says " + ++ show (ssxLength topssh) ++ ", list has length " + ++ show (length topl) ++ ")" + +shxToList :: IShX sh -> [Int] +shxToList ZSX = [] +shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh + +shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i +shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) + +shxHead :: ShX (n : sh) i -> SMayNat i SNat n +shxHead (ShX list) = listxHead list + +shxTail :: ShX (n : sh) i -> ShX sh i +shxTail (ShX list) = ShX (listxTail list) + +shxDropSSX :: forall sh sh' i. ShX (sh ++ sh') i -> StaticShX sh -> ShX sh' i +shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) + +shxDropIx :: forall sh sh' i j. ShX (sh ++ sh') i -> IxX sh j -> ShX sh' i +shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) + +shxDropSh :: forall sh sh' i. ShX (sh ++ sh') i -> ShX sh i -> ShX sh' i +shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) + +shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i +shxInit = coerce (listxInit @(SMayNat i SNat)) + +shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) +shxLast = coerce (listxLast @(SMayNat i SNat)) + +shxTakeSSX :: forall sh sh' i. Proxy sh' -> ShX (sh ++ sh') i -> StaticShX sh -> ShX sh i +shxTakeSSX _ = flip go + where + go :: StaticShX sh1 -> ShX (sh1 ++ sh') i -> ShX sh1 i + go ZKX _ = ZSX + go (_ :!% ssh1) (n :$% sh) = n :$% go ssh1 sh + +shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) + -> ShX sh i -> ShX sh j -> ShX sh k +shxZipWith _ ZSX ZSX = ZSX +shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js + +-- This is a weird operation, so it has a long name +shxCompleteZeros :: StaticShX sh -> IShX sh +shxCompleteZeros ZKX = ZSX +shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh +shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh + +shxSplitApp :: Proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) +shxSplitApp _ ZKX idx = (ZSX, idx) +shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) + +shxEnum :: IShX sh -> [IIxX sh] +shxEnum = \sh -> go sh id [] + where + go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] + go ZSX f = (f ZIX :) + go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] + +shxCast :: IShX sh -> StaticShX sh' -> Maybe (IShX sh') +shxCast ZSX ZKX = Just ZSX +shxCast (SKnown n :$% sh) (SKnown m :!% ssh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast sh ssh +shxCast (SUnknown n :$% sh) (SKnown m :!% ssh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast sh ssh +shxCast (SKnown n :$% sh) (SUnknown () :!% ssh) = (SUnknown (fromSNat' n) :$%) <$> shxCast sh ssh +shxCast (SUnknown n :$% sh) (SUnknown () :!% ssh) = (SUnknown n :$%) <$> shxCast sh ssh +shxCast _ _ = Nothing + +-- | Partial version of 'shxCast'. +shxCast' :: IShX sh -> StaticShX sh' -> IShX sh' +shxCast' sh ssh = case shxCast sh ssh of + Just sh' -> sh' + Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")" + + +-- * Static mixed shapes + +-- | The part of a shape that is statically known. (A newtype over 'ListX'.) +type StaticShX :: [Maybe Nat] -> Type +newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) + deriving (Eq, Ord) + +pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh +pattern ZKX = StaticShX ZX + +pattern (:!%) + :: forall {sh1}. + forall n sh. (n : sh ~ sh1) + => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 +pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) + where i :!% StaticShX shl = StaticShX (i ::% shl) +infixr 3 :!% + +{-# COMPLETE ZKX, (:!%) #-} + +instance Show (StaticShX sh) where + showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l + +instance NFData (StaticShX sh) where + rnf (StaticShX ZX) = () + rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l) + rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l) + +instance TestEquality StaticShX where + testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2 + +ssxLength :: StaticShX sh -> Int +ssxLength (StaticShX l) = listxLength l + +ssxRank :: StaticShX sh -> SNat (Rank sh) +ssxRank (StaticShX l) = listxRank l + +-- | @ssxEqType = 'testEquality'@. Provided for consistency. +ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') +ssxEqType = testEquality + +ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') +ssxAppend ZKX sh' = sh' +ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' + +ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n +ssxHead (StaticShX list) = listxHead list + +ssxTail :: StaticShX (n : sh) -> StaticShX sh +ssxTail (_ :!% ssh) = ssh + +ssxDropIx :: forall sh sh' i. StaticShX (sh ++ sh') -> IxX sh i -> StaticShX sh' +ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) + +ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) +ssxInit = coerce (listxInit @(SMayNat () SNat)) + +ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) +ssxLast = coerce (listxLast @(SMayNat () SNat)) + +-- | This may fail if @sh@ has @Nothing@s in it. +ssxToShX' :: StaticShX sh -> Maybe (IShX sh) +ssxToShX' ZKX = Just ZSX +ssxToShX' (SKnown n :!% sh) = (SKnown n :$%) <$> ssxToShX' sh +ssxToShX' (SUnknown _ :!% _) = Nothing + +ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) +ssxReplicate SZ = ZKX +ssxReplicate (SS (n :: SNat n')) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n' + = SUnknown () :!% ssxReplicate n + +ssxIotaFrom :: Int -> StaticShX sh -> [Int] +ssxIotaFrom _ ZKX = [] +ssxIotaFrom i (_ :!% ssh) = i : ssxIotaFrom (i+1) ssh + +ssxFromShape :: IShX sh -> StaticShX sh +ssxFromShape ZSX = ZKX +ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh + +ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) +ssxFromSNat SZ = ZKX +ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n + + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShX :: [Maybe Nat] -> Constraint +class KnownShX sh where knownShX :: StaticShX sh +instance KnownShX '[] where knownShX = ZKX +instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX +instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX + +withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r +withKnownShX k = withDict @(KnownShX sh) k + + +-- * Flattening + +type Flatten sh = Flatten' 1 sh + +type family Flatten' acc sh where + Flatten' acc '[] = Just acc + Flatten' acc (Nothing : sh) = Nothing + Flatten' acc (Just n : sh) = Flatten' (acc * n) sh + +-- This function is currently unused +ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) +ssxFlatten = go (SNat @1) + where + go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) + go acc ZKX = SKnown acc + go _ (SUnknown () :!% _) = SUnknown () + go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh + +shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) +shxFlatten = go (SNat @1) + where + go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go acc ZSX = SKnown acc + go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) + go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh + + goUnknown :: Int -> IShX sh -> Int + goUnknown acc ZSX = acc + goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh + goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh + + +-- | Very untyped: only length is checked (at runtime). +instance KnownShX sh => IsList (ListX sh (Const i)) where + type Item (ListX sh (Const i)) = i + fromList = listxFromList (knownShX @sh) + toList = listxToList + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShX sh => IsList (IxX sh i) where + type Item (IxX sh i) = i + fromList = IxX . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length and known dimensions are checked (at runtime). +instance KnownShX sh => IsList (ShX sh Int) where + type Item (ShX sh Int) = Int + fromList = shxFromList (knownShX @sh) + toList = shxToList diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Ranked.hs index daf0374..e2074ac 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Ranked.hs @@ -17,7 +17,7 @@ {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Ranked where +module Data.Array.Nested.Ranked where import Prelude hiding (mappend, mconcat) @@ -40,12 +40,12 @@ import GHC.TypeNats qualified as TN import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray(..)) -import Data.Array.Mixed.XArray qualified as X -import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Internal.Shape +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked.Shape import Data.Array.Strided.Arith diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs new file mode 100644 index 0000000..1c0b9eb --- /dev/null +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -0,0 +1,363 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Ranked.Shape where + +import Control.DeepSeq (NFData(..)) +import Data.Array.Mixed.Types +import Data.Coerce (coerce) +import Data.Foldable qualified as Foldable +import Data.Kind (Type) +import Data.Proxy +import Data.Type.Equality +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Mixed.Lemmas +import Data.Array.Nested.Mixed.Shape + + +type role ListR nominal representational +type ListR :: Nat -> Type -> Type +data ListR n i where + ZR :: ListR 0 i + (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i +deriving instance Eq i => Eq (ListR n i) +deriving instance Ord i => Ord (ListR n i) +deriving instance Functor (ListR n) +deriving instance Foldable (ListR n) +infixr 3 ::: + +instance Show i => Show (ListR n i) where + showsPrec _ = listrShow shows + +instance NFData i => NFData (ListR n i) where + rnf ZR = () + rnf (x ::: l) = rnf x `seq` rnf l + +data UnconsListRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i +listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) +listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) +listrUncons ZR = Nothing + +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') +listrEqRank ZR ZR = Just Refl +listrEqRank (_ ::: sh) (_ ::: sh') + | Just Refl <- listrEqRank sh sh' + = Just Refl +listrEqRank _ _ = Nothing + +-- | This compares the lists for value equality. +listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n') +listrEqual ZR ZR = Just Refl +listrEqual (i ::: sh) (j ::: sh') + | Just Refl <- listrEqual sh sh' + , i == j + = Just Refl +listrEqual _ _ = Nothing + +listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS +listrShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListR n' i -> ShowS + go _ ZR = id + go prefix (x ::: xs) = showString prefix . f x . go "," xs + +listrLength :: ListR n i -> Int +listrLength = length + +listrRank :: ListR n i -> SNat n +listrRank ZR = SNat +listrRank (_ ::: sh) = snatSucc (listrRank sh) + +listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i +listrAppend ZR sh = sh +listrAppend (x ::: xs) sh = x ::: listrAppend xs sh + +listrFromList :: [i] -> (forall n. ListR n i -> r) -> r +listrFromList [] k = k ZR +listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) + +listrHead :: ListR (n + 1) i -> i +listrHead (i ::: _) = i +listrHead ZR = error "unreachable" + +listrTail :: ListR (n + 1) i -> ListR n i +listrTail (_ ::: sh) = sh +listrTail ZR = error "unreachable" + +listrInit :: ListR (n + 1) i -> ListR n i +listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh +listrInit (_ ::: ZR) = ZR +listrInit ZR = error "unreachable" + +listrLast :: ListR (n + 1) i -> i +listrLast (_ ::: sh@(_ ::: _)) = listrLast sh +listrLast (n ::: ZR) = n +listrLast ZR = error "unreachable" + +listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i +listrIndex SZ (x ::: _) = x +listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs +listrIndex _ ZR = error "k + 1 <= 0" + +listrZip :: ListR n i -> ListR n j -> ListR n (i, j) +listrZip ZR ZR = ZR +listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest +listrZip _ _ = error "listrZip: impossible pattern needlessly required" + +listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k +listrZipWith _ ZR ZR = ZR +listrZipWith f (i ::: irest) (j ::: jrest) = + f i j ::: listrZipWith f irest jrest +listrZipWith _ _ _ = + error "listrZipWith: impossible pattern needlessly required" + +listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i +listrPermutePrefix = \perm sh -> + listrFromList perm $ \sperm -> + case (listrRank sperm, listrRank sh) of + (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of + LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + where + listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) + listrSplitAt SZ sh = (ZR, sh) + listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) + listrSplitAt SS{} ZR = error "m' + 1 <= 0" + + applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i + applyPermRFull _ ZR _ = ZR + applyPermRFull sm@SNat (i ::: perm) l = + TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> + case cmpNat (SNat @(idx + 1)) sm of + LTI -> listrIndex si l ::: applyPermRFull sm perm l + EQI -> listrIndex si l ::: applyPermRFull sm perm l + GTI -> error "listrPermutePrefix: Index in permutation out of range" + + +-- | An index into a rank-typed array. +type role IxR nominal representational +type IxR :: Nat -> Type -> Type +newtype IxR n i = IxR (ListR n i) + deriving (Eq, Ord, Generic) + deriving newtype (Functor, Foldable) + +pattern ZIR :: forall n i. () => n ~ 0 => IxR n i +pattern ZIR = IxR ZR + +pattern (:.:) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> IxR n i -> IxR n1 i +pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) + where i :.: IxR sh = IxR (i ::: sh) +infixr 3 :.: + +{-# COMPLETE ZIR, (:.:) #-} + +type IIxR n = IxR n Int + +instance Show i => Show (IxR n i) where + showsPrec _ (IxR l) = listrShow shows l + +instance NFData i => NFData (IxR sh i) + +ixrLength :: IxR sh i -> Int +ixrLength (IxR l) = listrLength l + +ixrRank :: IxR n i -> SNat n +ixrRank (IxR sh) = listrRank sh + +ixrZero :: SNat n -> IIxR n +ixrZero SZ = ZIR +ixrZero (SS n) = 0 :.: ixrZero n + +ixCvtXR :: IIxX sh -> IIxR (Rank sh) +ixCvtXR ZIX = ZIR +ixCvtXR (n :.% idx) = n :.: ixCvtXR idx + +ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) +ixCvtRX ZIR = ZIX +ixCvtRX (n :.: (idx :: IxR m Int)) = + castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) + (n :.% ixCvtRX idx) + +ixrHead :: IxR (n + 1) i -> i +ixrHead (IxR list) = listrHead list + +ixrTail :: IxR (n + 1) i -> IxR n i +ixrTail (IxR list) = IxR (listrTail list) + +ixrInit :: IxR (n + 1) i -> IxR n i +ixrInit (IxR list) = IxR (listrInit list) + +ixrLast :: IxR (n + 1) i -> i +ixrLast (IxR list) = listrLast list + +ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i +ixrAppend = coerce (listrAppend @_ @i) + +ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) +ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 + +ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k +ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 + +ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i +ixrPermutePrefix = coerce (listrPermutePrefix @i) + + +type role ShR nominal representational +type ShR :: Nat -> Type -> Type +newtype ShR n i = ShR (ListR n i) + deriving (Eq, Ord, Generic) + deriving newtype (Functor, Foldable) + +pattern ZSR :: forall n i. () => n ~ 0 => ShR n i +pattern ZSR = ShR ZR + +pattern (:$:) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> ShR n i -> ShR n1 i +pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) + where i :$: ShR sh = ShR (i ::: sh) +infixr 3 :$: + +{-# COMPLETE ZSR, (:$:) #-} + +type IShR n = ShR n Int + +instance Show i => Show (ShR n i) where + showsPrec _ (ShR l) = listrShow shows l + +instance NFData i => NFData (ShR sh i) + +shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n +shCvtXR' ZSX = + castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) + ZSR +shCvtXR' (n :$% (idx :: IShX sh)) + | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = + castWith (subst2 (lem1 @sh Refl)) + (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) + where + lem1 :: forall sh' n' k. + k : sh' :~: Replicate n' Nothing + -> Rank sh' + 1 :~: n' + lem1 Refl = unsafeCoerceRefl + + lem2 :: k : sh :~: Replicate n Nothing + -> sh :~: Replicate (Rank sh) Nothing + lem2 Refl = unsafeCoerceRefl + +shCvtRX :: IShR n -> IShX (Replicate n Nothing) +shCvtRX ZSR = ZSX +shCvtRX (n :$: (idx :: ShR m Int)) = + castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) + (SUnknown n :$% shCvtRX idx) + +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') +shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' + +-- | This compares the shapes for value equality. +shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') +shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' + +shrLength :: ShR sh i -> Int +shrLength (ShR l) = listrLength l + +-- | This function can also be used to conjure up a 'KnownNat' dictionary; +-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern +-- synonym yields 'KnownNat' evidence. +shrRank :: ShR n i -> SNat n +shrRank (ShR sh) = listrRank sh + +-- | The number of elements in an array described by this shape. +shrSize :: IShR n -> Int +shrSize ZSR = 1 +shrSize (n :$: sh) = n * shrSize sh + +shrHead :: ShR (n + 1) i -> i +shrHead (ShR list) = listrHead list + +shrTail :: ShR (n + 1) i -> ShR n i +shrTail (ShR list) = ShR (listrTail list) + +shrInit :: ShR (n + 1) i -> ShR n i +shrInit (ShR list) = ShR (listrInit list) + +shrLast :: ShR (n + 1) i -> i +shrLast (ShR list) = listrLast list + +shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i +shrAppend = coerce (listrAppend @_ @i) + +shrZip :: ShR n i -> ShR n j -> ShR n (i, j) +shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 + +shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k +shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 + +shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i +shrPermutePrefix = coerce (listrPermutePrefix @i) + + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ListR n i) where + type Item (ListR n i) = i + fromList topl = go (SNat @n) topl + where + go :: SNat n' -> [i] -> ListR n' i + go SZ [] = ZR + go (SS n) (i : is) = i ::: go n is + go _ _ = error $ "IsList(ListR): Mismatched list length (type says " + ++ show (fromSNat (SNat @n)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = Foldable.toList + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (IxR n i) where + type Item (IxR n i) = i + fromList = IxR . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ShR n i) where + type Item (ShR n i) = i + fromList = ShR . IsList.fromList + toList = Foldable.toList diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Shaped.hs index 372439f..4bccbc4 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Shaped.hs @@ -16,7 +16,7 @@ {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Shaped where +module Data.Array.Nested.Shaped where import Prelude hiding (mappend, mconcat) @@ -40,13 +40,13 @@ import GHC.TypeLits import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray) -import Data.Array.Mixed.XArray qualified as X +import Data.Array.XArray (XArray) +import Data.Array.XArray qualified as X import Data.Array.Nested.Internal.Lemmas -import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Internal.Shape +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Shaped.Shape import Data.Array.Strided.Arith diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 97b9456..6c43fa7 100644 --- a/src/Data/Array/Nested/Internal/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -24,7 +24,7 @@ {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Shape where +module Data.Array.Nested.Shaped.Shape where import Control.DeepSeq (NFData(..)) import Data.Array.Mixed.Types @@ -42,331 +42,10 @@ import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits -import GHC.TypeNats qualified as TN import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape - - -type role ListR nominal representational -type ListR :: Nat -> Type -> Type -data ListR n i where - ZR :: ListR 0 i - (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i -deriving instance Eq i => Eq (ListR n i) -deriving instance Ord i => Ord (ListR n i) -deriving instance Functor (ListR n) -deriving instance Foldable (ListR n) -infixr 3 ::: - -instance Show i => Show (ListR n i) where - showsPrec _ = listrShow shows - -instance NFData i => NFData (ListR n i) where - rnf ZR = () - rnf (x ::: l) = rnf x `seq` rnf l - -data UnconsListRRes i n1 = - forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i -listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) -listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) -listrUncons ZR = Nothing - --- | This checks only whether the ranks are equal, not whether the actual --- values are. -listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') -listrEqRank ZR ZR = Just Refl -listrEqRank (_ ::: sh) (_ ::: sh') - | Just Refl <- listrEqRank sh sh' - = Just Refl -listrEqRank _ _ = Nothing - --- | This compares the lists for value equality. -listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n') -listrEqual ZR ZR = Just Refl -listrEqual (i ::: sh) (j ::: sh') - | Just Refl <- listrEqual sh sh' - , i == j - = Just Refl -listrEqual _ _ = Nothing - -listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS -listrShow f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListR n' i -> ShowS - go _ ZR = id - go prefix (x ::: xs) = showString prefix . f x . go "," xs - -listrLength :: ListR n i -> Int -listrLength = length - -listrRank :: ListR n i -> SNat n -listrRank ZR = SNat -listrRank (_ ::: sh) = snatSucc (listrRank sh) - -listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i -listrAppend ZR sh = sh -listrAppend (x ::: xs) sh = x ::: listrAppend xs sh - -listrFromList :: [i] -> (forall n. ListR n i -> r) -> r -listrFromList [] k = k ZR -listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) - -listrHead :: ListR (n + 1) i -> i -listrHead (i ::: _) = i -listrHead ZR = error "unreachable" - -listrTail :: ListR (n + 1) i -> ListR n i -listrTail (_ ::: sh) = sh -listrTail ZR = error "unreachable" - -listrInit :: ListR (n + 1) i -> ListR n i -listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh -listrInit (_ ::: ZR) = ZR -listrInit ZR = error "unreachable" - -listrLast :: ListR (n + 1) i -> i -listrLast (_ ::: sh@(_ ::: _)) = listrLast sh -listrLast (n ::: ZR) = n -listrLast ZR = error "unreachable" - -listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i -listrIndex SZ (x ::: _) = x -listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs -listrIndex _ ZR = error "k + 1 <= 0" - -listrZip :: ListR n i -> ListR n j -> ListR n (i, j) -listrZip ZR ZR = ZR -listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest -listrZip _ _ = error "listrZip: impossible pattern needlessly required" - -listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k -listrZipWith _ ZR ZR = ZR -listrZipWith f (i ::: irest) (j ::: jrest) = - f i j ::: listrZipWith f irest jrest -listrZipWith _ _ _ = - error "listrZipWith: impossible pattern needlessly required" - -listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i -listrPermutePrefix = \perm sh -> - listrFromList perm $ \sperm -> - case (listrRank sperm, listrRank sh) of - (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of - LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" - ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" - where - listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) - listrSplitAt SZ sh = (ZR, sh) - listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) - listrSplitAt SS{} ZR = error "m' + 1 <= 0" - - applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i - applyPermRFull _ ZR _ = ZR - applyPermRFull sm@SNat (i ::: perm) l = - TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> - case cmpNat (SNat @(idx + 1)) sm of - LTI -> listrIndex si l ::: applyPermRFull sm perm l - EQI -> listrIndex si l ::: applyPermRFull sm perm l - GTI -> error "listrPermutePrefix: Index in permutation out of range" - - --- | An index into a rank-typed array. -type role IxR nominal representational -type IxR :: Nat -> Type -> Type -newtype IxR n i = IxR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) - -pattern ZIR :: forall n i. () => n ~ 0 => IxR n i -pattern ZIR = IxR ZR - -pattern (:.:) - :: forall {n1} {i}. - forall n. (n + 1 ~ n1) - => i -> IxR n i -> IxR n1 i -pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) - where i :.: IxR sh = IxR (i ::: sh) -infixr 3 :.: - -{-# COMPLETE ZIR, (:.:) #-} - -type IIxR n = IxR n Int - -instance Show i => Show (IxR n i) where - showsPrec _ (IxR l) = listrShow shows l - -instance NFData i => NFData (IxR sh i) - -ixrLength :: IxR sh i -> Int -ixrLength (IxR l) = listrLength l - -ixrRank :: IxR n i -> SNat n -ixrRank (IxR sh) = listrRank sh - -ixrZero :: SNat n -> IIxR n -ixrZero SZ = ZIR -ixrZero (SS n) = 0 :.: ixrZero n - -ixCvtXR :: IIxX sh -> IIxR (Rank sh) -ixCvtXR ZIX = ZIR -ixCvtXR (n :.% idx) = n :.: ixCvtXR idx - -ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) -ixCvtRX ZIR = ZIX -ixCvtRX (n :.: (idx :: IxR m Int)) = - castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (n :.% ixCvtRX idx) - -ixrHead :: IxR (n + 1) i -> i -ixrHead (IxR list) = listrHead list - -ixrTail :: IxR (n + 1) i -> IxR n i -ixrTail (IxR list) = IxR (listrTail list) - -ixrInit :: IxR (n + 1) i -> IxR n i -ixrInit (IxR list) = IxR (listrInit list) - -ixrLast :: IxR (n + 1) i -> i -ixrLast (IxR list) = listrLast list - -ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i -ixrAppend = coerce (listrAppend @_ @i) - -ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) -ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 - -ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k -ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 - -ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i -ixrPermutePrefix = coerce (listrPermutePrefix @i) - - -type role ShR nominal representational -type ShR :: Nat -> Type -> Type -newtype ShR n i = ShR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) - -pattern ZSR :: forall n i. () => n ~ 0 => ShR n i -pattern ZSR = ShR ZR - -pattern (:$:) - :: forall {n1} {i}. - forall n. (n + 1 ~ n1) - => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) - where i :$: ShR sh = ShR (i ::: sh) -infixr 3 :$: - -{-# COMPLETE ZSR, (:$:) #-} - -type IShR n = ShR n Int - -instance Show i => Show (ShR n i) where - showsPrec _ (ShR l) = listrShow shows l - -instance NFData i => NFData (ShR sh i) - -shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n -shCvtXR' ZSX = - castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) - ZSR -shCvtXR' (n :$% (idx :: IShX sh)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = - castWith (subst2 (lem1 @sh Refl)) - (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) - where - lem1 :: forall sh' n' k. - k : sh' :~: Replicate n' Nothing - -> Rank sh' + 1 :~: n' - lem1 Refl = unsafeCoerceRefl - - lem2 :: k : sh :~: Replicate n Nothing - -> sh :~: Replicate (Rank sh) Nothing - lem2 Refl = unsafeCoerceRefl - -shCvtRX :: IShR n -> IShX (Replicate n Nothing) -shCvtRX ZSR = ZSX -shCvtRX (n :$: (idx :: ShR m Int)) = - castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (SUnknown n :$% shCvtRX idx) - --- | This checks only whether the ranks are equal, not whether the actual --- values are. -shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' - --- | This compares the shapes for value equality. -shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' - -shrLength :: ShR sh i -> Int -shrLength (ShR l) = listrLength l - --- | This function can also be used to conjure up a 'KnownNat' dictionary; --- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern --- synonym yields 'KnownNat' evidence. -shrRank :: ShR n i -> SNat n -shrRank (ShR sh) = listrRank sh - --- | The number of elements in an array described by this shape. -shrSize :: IShR n -> Int -shrSize ZSR = 1 -shrSize (n :$: sh) = n * shrSize sh - -shrHead :: ShR (n + 1) i -> i -shrHead (ShR list) = listrHead list - -shrTail :: ShR (n + 1) i -> ShR n i -shrTail (ShR list) = ShR (listrTail list) - -shrInit :: ShR (n + 1) i -> ShR n i -shrInit (ShR list) = ShR (listrInit list) - -shrLast :: ShR (n + 1) i -> i -shrLast (ShR list) = listrLast list - -shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i -shrAppend = coerce (listrAppend @_ @i) - -shrZip :: ShR n i -> ShR n j -> ShR n (i, j) -shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 - -shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k -shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 - -shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i -shrPermutePrefix = coerce (listrPermutePrefix @i) - - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ListR n i) where - type Item (ListR n i) = i - fromList topl = go (SNat @n) topl - where - go :: SNat n' -> [i] -> ListR n' i - go SZ [] = ZR - go (SS n) (i : is) = i ::: go n is - go _ _ = error $ "IsList(ListR): Mismatched list length (type says " - ++ show (fromSNat (SNat @n)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (IxR n i) where - type Item (IxR n i) = i - fromList = IxR . IsList.fromList - toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ShR n i) where - type Item (ShR n i) = i - fromList = ShR . IsList.fromList - toList = Foldable.toList +import Data.Array.Nested.Mixed.Shape type role ListS nominal representational |