diff options
Diffstat (limited to 'src/Data/Array/Nested')
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 333 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Convert.hs | 86 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Lemmas.hs | 59 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 559 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 495 | ||||
-rw-r--r-- | src/Data/Array/Nested/Lemmas.hs | 162 | ||||
-rw-r--r-- | src/Data/Array/Nested/Mixed.hs (renamed from src/Data/Array/Nested/Internal/Mixed.hs) | 185 | ||||
-rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 644 | ||||
-rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 283 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 323 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 268 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 369 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 272 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 255 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs (renamed from src/Data/Array/Nested/Internal/Shape.hs) | 402 | ||||
-rw-r--r-- | src/Data/Array/Nested/Trace.hs | 8 | ||||
-rw-r--r-- | src/Data/Array/Nested/Types.hs | 152 |
17 files changed, 3194 insertions, 1661 deletions
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs new file mode 100644 index 0000000..2438f68 --- /dev/null +++ b/src/Data/Array/Nested/Convert.hs @@ -0,0 +1,333 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +module Data.Array.Nested.Convert ( + -- * Shape\/index\/list casting functions + -- ** To ranked + ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2, + listrCast, ixrCast, shrCast, + -- ** To shaped + ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX, + ixsCast, + -- ** To mixed + ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS, + ixxCast, shxCast, shxCast', + + -- * Array conversions + convert, + Conversion(..), + + -- * Special cases of array conversions + -- + -- | These functions can all be implemented using 'convert' in some way, + -- but some have fewer constraints. + rtoMixed, rcastToMixed, rcastToShaped, + stoMixed, scastToMixed, stoRanked, + mcast, mcastToShaped, mtoRanked, +) where + +import Control.Category +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked.Base +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Shaped.Base +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types + +-- * Shape or index or list casting functions + +-- * To ranked + +ixrFromIxS :: IxS sh i -> IxR (Rank sh) i +ixrFromIxS ZIS = ZIR +ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix + +ixrFromIxX :: IxX sh i -> IxR (Rank sh) i +ixrFromIxX ZIX = ZIR +ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx + +shrFromShS :: ShS sh -> IShR (Rank sh) +shrFromShS ZSS = ZSR +shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh + +-- shrFromShX re-exported +-- shrFromShX2 re-exported +-- listrCast re-exported +-- ixrCast re-exported +-- shrCast re-exported + +-- * To shaped + +-- TODO: these take a ShS because there are KnownNats inside IxS. + +ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i +ixsFromIxR ZSS ZIR = ZIS +ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx +ixsFromIxR _ _ = error "unreachable" + +-- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the +-- following, but more efficient: +-- +-- > ixsFromIxR' sh idx = ixsFromIxR sh (ixrCast (shsRank sh) idx) +ixsFromIxR' :: ShS sh -> IxR n i -> IxS sh i +ixsFromIxR' ZSS ZIR = ZIS +ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx +ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank" + +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i +ixsFromIxX ZSS ZIX = ZIS +ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx + +-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to +-- the following, but more efficient: +-- +-- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx) +ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i +ixsFromIxX' ZSS ZIX = ZIS +ixsFromIxX' (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX' sh idx +ixsFromIxX' _ _ = error "ixsFromIxX': index rank does not match shape rank" + +-- | Produce an existential 'ShS' from an 'IShR'. +withShsFromShR :: IShR n -> (forall sh. Rank sh ~ n => ShS sh -> r) -> r +withShsFromShR ZSR k = k ZSS +withShsFromShR (n :$: sh) k = + withShsFromShR sh $ \sh' -> + withSomeSNat (fromIntegral @Int @Integer n) $ \case + Just sn@SNat -> k (sn :$$ sh') + Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")" + +-- shsFromShX re-exported + +-- | Produce an existential 'ShS' from an 'IShX'. If you already know that +-- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead. +withShsFromShX :: IShX sh' -> (forall sh. Rank sh ~ Rank sh' => ShS sh -> r) -> r +withShsFromShX ZSX k = k ZSS +withShsFromShX (SKnown sn@SNat :$% sh) k = + withShsFromShX sh $ \sh' -> + k (sn :$$ sh') +withShsFromShX (SUnknown n :$% sh) k = + withShsFromShX sh $ \sh' -> + withSomeSNat (fromIntegral @Int @Integer n) $ \case + Just sn@SNat -> k (sn :$$ sh') + Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")" + +shsFromSSX :: StaticShX (MapJust sh) -> ShS sh +shsFromSSX = shsFromShX Prelude.. shxFromSSX + +-- ixsCast re-exported + +-- * To mixed + +ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i +ixxFromIxR ZIR = ZIX +ixxFromIxR (n :.: (idx :: IxR m i)) = + castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m)) + (n :.% ixxFromIxR idx) + +ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i +ixxFromIxS ZIS = ZIX +ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh + +shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i +shxFromShR ZSR = ZSX +shxFromShR (n :$: (idx :: ShR m i)) = + castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m)) + (SUnknown n :$% shxFromShR idx) + +shxFromShS :: ShS sh -> IShX (MapJust sh) +shxFromShS ZSS = ZSX +shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh + +-- ixxCast re-exported +-- shxCast re-exported +-- shxCast' re-exported + + +-- * Array conversions + +-- | The constructors that perform runtime shape checking are marked with a +-- tick (@'@): 'ConvXS'' and 'ConvXX''. For the other constructors, the types +-- ensure that the shapes are already compatible. To convert between 'Ranked' +-- and 'Shaped', go via 'Mixed'. +-- +-- The guiding principle behind 'Conversion' is that it should represent the +-- array restructurings, or perhaps re-presentations, that do not change the +-- underlying 'XArray's. This leads to the inclusion of some operations that do +-- not look like simple conversions (casts) at first glance, like 'ConvZip'. +-- +-- /Note/: Haddock gleefully renames type variables in constructors so that +-- they match the data type head as much as possible. See the source for a more +-- readable presentation of this data type. +data Conversion a b where + ConvId :: Conversion a a + ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c + + ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a) + ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a) + + ConvXR :: Elt a + => Conversion (Mixed sh a) (Ranked (Rank sh) a) + ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a) + ConvXS' :: (Rank sh ~ Rank sh', Elt a) + => ShS sh' + -> Conversion (Mixed sh a) (Shaped sh' a) + + ConvXX' :: (Rank sh ~ Rank sh', Elt a) + => StaticShX sh' + -> Conversion (Mixed sh a) (Mixed sh' a) + + ConvRR :: Conversion a b + -> Conversion (Ranked n a) (Ranked n b) + ConvSS :: Conversion a b + -> Conversion (Shaped sh a) (Shaped sh b) + ConvXX :: Conversion a b + -> Conversion (Mixed sh a) (Mixed sh b) + ConvT2 :: Conversion a a' + -> Conversion b b' + -> Conversion (a, b) (a', b') + + Conv0X :: Elt a + => Conversion a (Mixed '[] a) + ConvX0 :: Conversion (Mixed '[] a) a + + ConvNest :: Elt a => StaticShX sh + -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) + ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) + + ConvZip :: (Elt a, Elt b) + => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b)) + ConvUnzip :: (Elt a, Elt b) + => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b) +deriving instance Show (Conversion a b) + +instance Category Conversion where + id = ConvId + (.) = ConvCmp + +convert :: (Elt a, Elt b) => Conversion a b -> a -> b +convert = \c x -> munScalar (go c (mscalar x)) + where + -- The 'esh' is the extension shape: the conversion happens under a whole + -- bunch of additional dimensions that it does not touch. These dimensions + -- are 'esh'. + -- The strategy is to unwind step-by-step to a large Mixed array, and to + -- perform the required checks and conversions when re-nesting back up. + go :: Conversion a b -> Mixed esh a -> Mixed esh b + go ConvId x = x + go (ConvCmp c1 c2) x = go c1 (go c2 x) + go ConvRX (M_Ranked x) = x + go ConvSX (M_Shaped x) = x + go (ConvXR @_ @sh) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) + = let ssx' = ssxAppend (ssxFromShX esh) + (ssxReplicate (shxRank (shxDropSSX @esh @sh (ssxFromShX esh) (mshape x)))) + in M_Ranked (M_Nest esh (mcast ssx' x)) + go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x) + go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) + x)) + go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x + go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) + go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) + go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x) + go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) + go Conv0X (x :: Mixed esh a) + | Refl <- lemAppNil @esh + = M_Nest (mshape x) x + go ConvX0 (M_Nest @esh _ x) + | Refl <- lemAppNil @esh + = x + go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (ssxFromShX esh `ssxAppend` ssh) (mshape x)) x) + go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh x + go ConvZip x = + -- no need to check that the two esh's are equal because they were zipped previously + let (M_Nest esh x1, M_Nest _ x2) = munzip x + in M_Nest esh (mzip x1 x2) + go ConvUnzip (M_Nest esh x) = + let (x1, x2) = munzip x + in mzip (M_Nest esh x1) (M_Nest esh x2) + + lemRankAppRankEq :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ sh') + lemRankAppRankEq _ _ _ = unsafeCoerceRefl + + lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh + -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) + lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl + + lemRankAppRankEqMapJust :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') + lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl + + +-- * Special cases of array conversions + +mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) + => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a +mcast ssh2 arr + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr + +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a +mtoRanked = convert ConvXR + +rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a +rtoMixed (Ranked arr) = arr + +-- | A more weakly-typed version of 'rtoMixed' that does a runtime shape +-- compatibility check. +rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a +rcastToMixed sshx rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank rarr) + = mcast sshx arr + +mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => ShS sh' -> Mixed sh a -> Shaped sh' a +mcastToShaped targetsh = convert (ConvXS' targetsh) + +stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a +stoMixed (Shaped arr) = arr + +-- | A more weakly-typed version of 'stoMixed' that does a runtime shape +-- compatibility check. +scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => StaticShX sh' -> Shaped sh a -> Mixed sh' a +scastToMixed sshx sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mcast sshx arr + +stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a +stoRanked sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mtoRanked arr + +rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped (Ranked arr) targetsh + | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh)) + , Refl <- lemRankMapJust targetsh + = mcastToShaped targetsh arr diff --git a/src/Data/Array/Nested/Internal/Convert.hs b/src/Data/Array/Nested/Internal/Convert.hs deleted file mode 100644 index c316161..0000000 --- a/src/Data/Array/Nested/Internal/Convert.hs +++ /dev/null @@ -1,86 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Data.Array.Nested.Internal.Convert where - -import Control.Category -import Data.Proxy -import Data.Type.Equality - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types -import Data.Array.Nested.Internal.Lemmas -import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Internal.Ranked -import Data.Array.Nested.Internal.Shape -import Data.Array.Nested.Internal.Shaped - - -stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a -stoRanked sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - = mtoRanked arr - -rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a -rcastToShaped (Ranked arr) targetsh - | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh)) - , Refl <- lemRankMapJust targetsh - = mcastToShaped arr targetsh - --- | The only constructor that performs runtime shape checking is 'CastXS''. --- For the other construtors, the types ensure that the shapes are already --- compatible. To convert between 'Ranked' and 'Shaped', go via 'Mixed'. -data Castable a b where - CastId :: Castable a a - CastCmp :: Castable b c -> Castable a b -> Castable a c - - CastRX :: Castable a b -> Castable (Ranked n a) (Mixed (Replicate n Nothing) b) - CastSX :: Castable a b -> Castable (Shaped sh a) (Mixed (MapJust sh) b) - - CastXR :: Castable a b -> Castable (Mixed sh a) (Ranked (Rank sh) b) - CastXS :: Castable a b -> Castable (Mixed (MapJust sh) a) (Shaped sh b) - CastXS' :: (Rank sh ~ Rank sh', Elt b) => ShS sh' - -> Castable a b -> Castable (Mixed sh a) (Shaped sh' b) - - CastRR :: Castable a b -> Castable (Ranked n a) (Ranked n b) - CastSS :: Castable a b -> Castable (Shaped sh a) (Shaped sh b) - CastXX :: Castable a b -> Castable (Mixed sh a) (Mixed sh b) - -instance Category Castable where - id = CastId - (.) = CastCmp - -castCastable :: (Elt a, Elt b) => Castable a b -> a -> b -castCastable = \c x -> munScalar (go c (mscalar x)) - where - -- The 'esh' is the extension shape: the casting happens under a whole - -- bunch of additional dimensions that it does not touch. These dimensions - -- are 'esh'. - -- The strategy is to unwind step-by-step to a large Mixed array, and to - -- perform the required checks and castings when re-nesting back up. - go :: Castable a b -> Mixed esh a -> Mixed esh b - go CastId x = x - go (CastCmp c1 c2) x = go c1 (go c2 x) - go (CastRX c) (M_Ranked (M_Nest esh x)) = M_Nest esh (go c x) - go (CastSX c) (M_Shaped (M_Nest esh x)) = M_Nest esh (go c x) - go (CastXR @_ @_ @sh c) (M_Nest @esh esh x) = - M_Ranked (M_Nest esh (mcastSafe @(MCastApp esh sh esh (Replicate (Rank sh) Nothing) MCastId MCastForget) Proxy - (go c x))) - go (CastXS c) (M_Nest esh x) = M_Shaped (M_Nest esh (go c x)) - go (CastXS' @sh @sh' sh' c) (M_Nest @esh esh x) - | Refl <- lemRankAppMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') - = M_Shaped (M_Nest esh (mcast (ssxFromShape (shxAppend esh (shCvtSX sh'))) - (go c x))) - go (CastRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) - go (CastSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) - go (CastXX c) (M_Nest esh x) = M_Nest esh (go c x) - - lemRankAppMapJust :: Rank sh ~ Rank sh' - => Proxy esh -> Proxy sh -> Proxy sh' - -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') - lemRankAppMapJust _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Nested/Internal/Lemmas.hs b/src/Data/Array/Nested/Internal/Lemmas.hs deleted file mode 100644 index f894f78..0000000 --- a/src/Data/Array/Nested/Internal/Lemmas.hs +++ /dev/null @@ -1,59 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Data.Array.Nested.Internal.Lemmas where - -import Data.Proxy -import Data.Type.Equality -import GHC.TypeLits - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types -import Data.Array.Nested.Internal.Shape - - -lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh -lemRankMapJust ZSS = Refl -lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl - -lemMapJustApp :: ShS sh1 -> Proxy sh2 - -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemMapJustApp ZSS _ = Refl -lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl - -lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) -lemTakeLenMapJust PNil _ = Refl -lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl -lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty" - -lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) -lemDropLenMapJust PNil _ = Refl -lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl -lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty" - -lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) -lemIndexMapJust SZ (_ :$$ _) = Refl -lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) - | Refl <- lemIndexMapJust i sh - , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) - , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') - = Refl -lemIndexMapJust _ ZSS = error "Index of empty" - -lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) -lemPermuteMapJust PNil _ = Refl -lemPermuteMapJust (i `PCons` is) sh - | Refl <- lemPermuteMapJust is sh - , Refl <- lemIndexMapJust i sh - = Refl - -lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) -lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) - where - go :: ShS sh' -> StaticShX (MapJust sh') - go ZSS = ZKX - go (n :$$ sh) = SKnown n :!% go sh diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs deleted file mode 100644 index daf0374..0000000 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ /dev/null @@ -1,559 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Ranked where - -import Prelude hiding (mappend, mconcat) - -import Control.DeepSeq (NFData(..)) -import Control.Monad.ST -import Data.Array.RankedS qualified as S -import Data.Bifunctor (first) -import Data.Coerce (coerce) -import Data.Foldable (toList) -import Data.Kind (Type) -import Data.List.NonEmpty (NonEmpty) -import Data.Proxy -import Data.Type.Equality -import Data.Vector.Storable qualified as VS -import Foreign.Storable (Storable) -import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) -import GHC.Generics (Generic) -import GHC.TypeLits -import GHC.TypeNats qualified as TN - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray(..)) -import Data.Array.Mixed.XArray qualified as X -import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Internal.Shape -import Data.Array.Strided.Arith - - --- | A rank-typed array: the number of dimensions of the array (its /rank/) is --- represented on the type level as a 'Nat'. --- --- Valid elements of a ranked arrays are described by the 'Elt' type class. --- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are --- supported (and are represented as a single, flattened, struct-of-arrays --- array internally). --- --- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. -type Ranked :: Nat -> Type -> Type -newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) -deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) -deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) - -instance (Show a, Elt a) => Show (Ranked n a) where - showsPrec d arr@(Ranked marr) = - let sh = show (toList (rshape arr)) - in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr - -instance Elt a => NFData (Ranked n a) where - rnf (Ranked arr) = rnf arr - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) - deriving (Generic) - -deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a)) - -newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) - --- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; --- these instances allow them to also be used as elements of arrays, thus --- making them first-class in the API. -instance Elt a => Elt (Ranked n a) where - mshape (M_Ranked arr) = mshape arr - mindex (M_Ranked arr) i = Ranked (mindex arr i) - - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) - mindexPartial (M_Ranked arr) i = - coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ - mindexPartial arr i - - mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) - - mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) - mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) - - mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] - mtoListOuter (M_Ranked arr) = - coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) - mlift ssh2 f (M_Ranked arr) = - coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ - mlift ssh2 f arr - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) - mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = - coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ - mlift2 ssh3 f arr1 arr2 - - mliftL :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) - -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a)) - mliftL ssh2 f l = - coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a))) - @(NonEmpty (Mixed sh2 (Ranked n a))) $ - mliftL ssh2 f (coerce l) - - mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr) - - mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) - - mconcat l = M_Ranked (mconcat (coerce l)) - - mrnf (M_Ranked arr) = mrnf arr - - type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) - - mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - marrayStrides (M_Ranked arr) = marrayStrides arr - - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs = - mvecsWrite sh idx arr - (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsWritePartial :: forall sh sh' s. - IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx - (coerce @(Mixed sh' (Ranked n a)) - @(Mixed sh' (Mixed (Replicate n Nothing) a)) - arr) - (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) - @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) - mvecsFreeze sh vecs = - coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) - @(Mixed sh (Ranked n a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh (Ranked n a)) - @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - -instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where - memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) - memptyArrayUnsafe i - | Dict <- lemKnownReplicate (SNat @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ - memptyArrayUnsafe i - - mvecsUnsafeNew idx (Ranked arr) - | Dict <- lemKnownReplicate (SNat @n) - = MV_Ranked <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownReplicate (SNat @n) - = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) - - -liftRanked1 :: forall n a b. - (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b) - -> Ranked n a -> Ranked n b -liftRanked1 = coerce - -liftRanked2 :: forall n a b c. - (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c) - -> Ranked n a -> Ranked n b -> Ranked n c -liftRanked2 = coerce - -instance (NumElt a, PrimElt a) => Num (Ranked n a) where - (+) = liftRanked2 (+) - (-) = liftRanked2 (-) - (*) = liftRanked2 (*) - negate = liftRanked1 negate - abs = liftRanked1 abs - signum = liftRanked1 signum - fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal" - -instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where - fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal" - recip = liftRanked1 recip - (/) = liftRanked2 (/) - -instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where - pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal" - exp = liftRanked1 exp - log = liftRanked1 log - sqrt = liftRanked1 sqrt - (**) = liftRanked2 (**) - logBase = liftRanked2 logBase - sin = liftRanked1 sin - cos = liftRanked1 cos - tan = liftRanked1 tan - asin = liftRanked1 asin - acos = liftRanked1 acos - atan = liftRanked1 atan - sinh = liftRanked1 sinh - cosh = liftRanked1 cosh - tanh = liftRanked1 tanh - asinh = liftRanked1 asinh - acosh = liftRanked1 acosh - atanh = liftRanked1 atanh - log1p = liftRanked1 GHC.Float.log1p - expm1 = liftRanked1 GHC.Float.expm1 - log1pexp = liftRanked1 GHC.Float.log1pexp - log1mexp = liftRanked1 GHC.Float.log1mexp - -rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a -rquotArray = liftRanked2 mquotArray -rremArray = liftRanked2 mremArray - -ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a -ratan2Array = liftRanked2 matan2Array - - -remptyArray :: KnownElt a => Ranked 1 a -remptyArray = mtoRanked (memptyArray ZSX) - -rshape :: Elt a => Ranked n a -> IShR n -rshape (Ranked arr) = shCvtXR' (mshape arr) - -rrank :: Elt a => Ranked n a -> SNat n -rrank = shrRank . rshape - --- | The total number of elements in the array. -rsize :: Elt a => Ranked n a -> Int -rsize = shrSize . rshape - -rindex :: Elt a => Ranked n a -> IIxR n -> a -rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) - -rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a -rindexPartial (Ranked arr) idx = - Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) - (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr) - (ixCvtRX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a -rgenerate sh f - | sn@SNat <- shrRank sh - , Dict <- lemKnownReplicate sn - , Refl <- lemRankReplicate sn - = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) - --- | See the documentation of 'mlift'. -rlift :: forall n1 n2 a. Elt a - => SNat n2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) - -> Ranked n1 a -> Ranked n2 a -rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) - --- | See the documentation of 'mlift2'. -rlift2 :: forall n1 n2 n3 a. Elt a - => SNat n3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) - -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a -rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) - -rsumOuter1P :: forall n a. - (Storable a, NumElt a) - => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = Ranked (msumOuter1P arr) - -rsumOuter1 :: forall n a. (NumElt a, PrimElt a) - => Ranked (n + 1) a -> Ranked n a -rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive - -rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a -rsumAllPrim (Ranked arr) = msumAllPrim arr - -rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a -rtranspose perm arr - | sn@SNat <- rrank arr - , Dict <- lemKnownReplicate sn - , length perm <= fromIntegral (natVal (Proxy @n)) - = rlift sn - (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm) - arr - | otherwise - = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" - -rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a -rconcat - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = coerce mconcat - -rappend :: forall n a. Elt a - => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a -rappend arr1 arr2 - | sn@SNat <- rrank arr1 - , Dict <- lemKnownReplicate sn - , Refl <- lemReplicateSucc @(Nothing @Nat) @n - = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) - arr1 arr2 - -rscalar :: Elt a => a -> Ranked 0 a -rscalar x = Ranked (mscalar x) - -rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) -rfromVectorP sh v - | Dict <- lemKnownReplicate (shrRank sh) - = Ranked (mfromVectorP (shCvtRX sh) v) - -rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a -rfromVector sh v = rfromPrimitive (rfromVectorP sh v) - -rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a -rtoVectorP = coerce mtoVectorP - -rtoVector :: PrimElt a => Ranked n a -> VS.Vector a -rtoVector = coerce mtoVector - -rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a -rfromListOuter l - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) - -rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a -rfromList1 l = Ranked (mfromList1 l) - -rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a -rfromList1Prim l = Ranked (mfromList1Prim l) - -rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] -rtoListOuter (Ranked arr) - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) - -rtoList1 :: Elt a => Ranked 1 a -> [a] -rtoList1 = map runScalar . rtoListOuter - -rfromListPrim :: PrimElt a => [a] -> Ranked 1 a -rfromListPrim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a -rfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr) - -rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a -rfromListLinear sh l = rreshape sh (rfromList1 l) - -rtoListLinear :: Elt a => Ranked n a -> [a] -rtoListLinear (Ranked arr) = mtoListLinear arr - -rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a -rfromOrthotope sn arr - | Refl <- lemRankReplicate sn - = let xarr = XArray arr - in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) - -rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a -rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) - | Refl <- lemRankReplicate (shrRank $ shCvtXR' sh) - = arr - -runScalar :: Elt a => Ranked 0 a -> a -runScalar arr = rindex arr ZIR - -rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a) -rnest n arr - | Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat)) - = coerce (mnest (ssxFromSNat n) (coerce arr)) - -runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a -runNest rarr@(Ranked (M_Ranked (M_Nest _ arr))) - | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat)) - = Ranked arr - -rzip :: Ranked n a -> Ranked n b -> Ranked n (a, b) -rzip = coerce mzip - -runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b) -runzip = coerce munzip - -rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) - => SNat n -> IShR n2 - -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) - -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) -rrerankP sn sh2 f (Ranked arr) - | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) - , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) - = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2) - (\a -> let Ranked r = f (Ranked a) in r) - arr) - --- | If there is a zero-sized dimension in the @n@-prefix of the shape of the --- input array, then there is no way to deduce the full shape of the output --- array (more precisely, the @n2@ part): that could only come from calling --- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in --- this case; we choose to fill the @n2@ part of the output shape with zeros. --- --- For example, if: --- --- @ --- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] --- f :: Ranked 2 Int -> Ranked 3 Float --- @ --- --- then: --- --- @ --- rrerank _ _ _ f arr :: Ranked 5 Float --- @ --- --- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the --- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended --- to return an array with shape all-0 here (it probably didn't), but there is --- no better number to put here absent a subarray of the input to pass to @f@. -rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) - => SNat n -> IShR n2 - -> (Ranked n1 a -> Ranked n2 b) - -> Ranked (n + n1) a -> Ranked (n + n2) b -rrerank sn sh2 f (rtoPrimitive -> arr) = - rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr - -rreplicate :: forall n m a. Elt a - => IShR n -> Ranked m a -> Ranked (n + m) a -rreplicate sh (Ranked arr) - | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat)) - = Ranked (mreplicate (shCvtRX sh) arr) - -rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) -rreplicateScalP sh x - | Dict <- lemKnownReplicate (shrRank sh) - = Ranked (mreplicateScalP (shCvtRX sh) x) - -rreplicateScal :: forall n a. PrimElt a - => IShR n -> a -> Ranked n a -rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) - -rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a -rslice i n arr - | Refl <- lemReplicateSucc @(Nothing @Nat) @n - = rlift (rrank arr) - (\_ -> X.sliceU i n) - arr - -rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -rrev1 arr = - rlift (rrank arr) - (\(_ :: StaticShX sh') -> - case lemReplicateSucc @(Nothing @Nat) @n of - Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) - arr - -rreshape :: forall n n' a. Elt a - => IShR n' -> Ranked n a -> Ranked n' a -rreshape sh' rarr@(Ranked arr) - | Dict <- lemKnownReplicate (rrank rarr) - , Dict <- lemKnownReplicate (shrRank sh') - = Ranked (mreshape (shCvtRX sh') arr) - -rflatten :: Elt a => Ranked n a -> Ranked 1 a -rflatten (Ranked arr) = mtoRanked (mflatten arr) - -riota :: (Enum a, PrimElt a) => Int -> Ranked 1 a -riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota - --- | Throws if the array is empty. -rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n -rminIndexPrim rarr@(Ranked arr) - | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) - = ixCvtXR (mminIndexPrim arr) - --- | Throws if the array is empty. -rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n -rmaxIndexPrim rarr@(Ranked arr) - | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) - = ixCvtXR (mmaxIndexPrim arr) - -rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a -rdot1Inner arr1 arr2 - | SNat <- rrank arr1 - , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat)) - = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2 - --- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'rdot1Inner' if applicable. -rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a -rdot = coerce mdot - -rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrimP (Ranked arr) = first shCvtXR' (mtoXArrayPrimP arr) - -rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) -rtoXArrayPrim (Ranked arr) = first shCvtXR' (mtoXArrayPrim arr) - -rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) -rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) - -rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a -rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) - -rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a -rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) - -rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) -rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) - -mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked arr - | Refl <- lemRankReplicate (shxRank (mshape arr)) - = Ranked (mcast (ssxFromShape (convSh (mshape arr))) arr) - where - convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) - convSh ZSX = ZSX - convSh (smn :$% (sh :: IShX sh'T)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) - = SUnknown (fromSMayNat' smn) :$% convSh sh - -rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a -rtoMixed (Ranked arr) = arr - --- | A more weakly-typed version of 'rtoMixed' that does a runtime shape --- compatibility check. -rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a -rcastToMixed sshx rarr@(Ranked arr) - | Refl <- lemRankReplicate (rrank rarr) - = mcast sshx arr diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs deleted file mode 100644 index 372439f..0000000 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ /dev/null @@ -1,495 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE ImportQualifiedPost #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Shaped where - -import Prelude hiding (mappend, mconcat) - -import Control.DeepSeq (NFData(..)) -import Control.Monad.ST -import Data.Array.Internal.RankedG qualified as RG -import Data.Array.Internal.RankedS qualified as RS -import Data.Array.Internal.ShapedG qualified as SG -import Data.Array.Internal.ShapedS qualified as SS -import Data.Bifunctor (first) -import Data.Coerce (coerce) -import Data.Kind (Type) -import Data.List.NonEmpty (NonEmpty) -import Data.Proxy -import Data.Type.Equality -import Data.Vector.Storable qualified as VS -import Foreign.Storable (Storable) -import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) -import GHC.Generics (Generic) -import GHC.TypeLits - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray) -import Data.Array.Mixed.XArray qualified as X -import Data.Array.Nested.Internal.Lemmas -import Data.Array.Nested.Internal.Mixed -import Data.Array.Nested.Internal.Shape -import Data.Array.Strided.Arith - - --- | A shape-typed array: the full shape of the array (the sizes of its --- dimensions) is represented on the type level as a list of 'Nat's. Note that --- these are "GHC.TypeLits" naturals, because we do not need induction over --- them and we want very large arrays to be possible. --- --- Like for 'Ranked', the valid elements are described by the 'Elt' type class, --- and 'Shaped' itself is again an instance of 'Elt' as well. --- --- 'Shaped' is a newtype around a 'Mixed' of 'Just's. -type Shaped :: [Nat] -> Type -> Type -newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) -deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) -deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) - -instance (Show a, Elt a) => Show (Shaped n a) where - showsPrec d arr@(Shaped marr) = - let sh = show (shsToList (sshape arr)) - in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr - -instance Elt a => NFData (Shaped sh a) where - rnf (Shaped arr) = rnf arr - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) - deriving (Generic) - -deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a)) - -newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) - -instance Elt a => Elt (Shaped sh a) where - mshape (M_Shaped arr) = mshape arr - mindex (M_Shaped arr) i = Shaped (mindex arr i) - - mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - mindexPartial (M_Shaped arr) i = - coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mindexPartial arr i - - mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) - - mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) - mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) - - mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] - mtoListOuter (M_Shaped arr) - = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) - - mlift :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) - mlift ssh2 f (M_Shaped arr) = - coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mlift ssh2 f arr - - mlift2 :: forall sh1 sh2 sh3. - StaticShX sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) - mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = - coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ - mlift2 ssh3 f arr1 arr2 - - mliftL :: forall sh1 sh2. - StaticShX sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) - -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) - mliftL ssh2 f l = - coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) - @(NonEmpty (Mixed sh2 (Shaped sh a))) $ - mliftL ssh2 f (coerce l) - - mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr) - - mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) - - mconcat l = M_Shaped (mconcat (coerce l)) - - mrnf (M_Shaped arr) = mrnf arr - - type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - - mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - marrayStrides (M_Shaped arr) = marrayStrides arr - - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs = - mvecsWrite sh idx arr - (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - - mvecsWritePartial :: forall sh1 sh2 s. - IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs = - mvecsWritePartial sh idx - (coerce @(Mixed sh2 (Shaped sh a)) - @(Mixed sh2 (Mixed (MapJust sh) a)) - arr) - (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) - @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) - vecs) - - mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) - mvecsFreeze sh vecs = - coerce @(Mixed sh' (Mixed (MapJust sh) a)) - @(Mixed sh' (Shaped sh a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh' (Shaped sh a)) - @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - -instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where - memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) - memptyArrayUnsafe i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ - memptyArrayUnsafe i - - mvecsUnsafeNew idx (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) - - -liftShaped1 :: forall sh a b. - (Mixed (MapJust sh) a -> Mixed (MapJust sh) b) - -> Shaped sh a -> Shaped sh b -liftShaped1 = coerce - -liftShaped2 :: forall sh a b c. - (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c) - -> Shaped sh a -> Shaped sh b -> Shaped sh c -liftShaped2 = coerce - -instance (NumElt a, PrimElt a) => Num (Shaped sh a) where - (+) = liftShaped2 (+) - (-) = liftShaped2 (-) - (*) = liftShaped2 (*) - negate = liftShaped1 negate - abs = liftShaped1 abs - signum = liftShaped1 signum - fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal" - -instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where - fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" - recip = liftShaped1 recip - (/) = liftShaped2 (/) - -instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where - pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" - exp = liftShaped1 exp - log = liftShaped1 log - sqrt = liftShaped1 sqrt - (**) = liftShaped2 (**) - logBase = liftShaped2 logBase - sin = liftShaped1 sin - cos = liftShaped1 cos - tan = liftShaped1 tan - asin = liftShaped1 asin - acos = liftShaped1 acos - atan = liftShaped1 atan - sinh = liftShaped1 sinh - cosh = liftShaped1 cosh - tanh = liftShaped1 tanh - asinh = liftShaped1 asinh - acosh = liftShaped1 acosh - atanh = liftShaped1 atanh - log1p = liftShaped1 GHC.Float.log1p - expm1 = liftShaped1 GHC.Float.expm1 - log1pexp = liftShaped1 GHC.Float.log1pexp - log1mexp = liftShaped1 GHC.Float.log1mexp - -squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a -squotArray = liftShaped2 mquotArray -sremArray = liftShaped2 mremArray - -satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a -satan2Array = liftShaped2 matan2Array - - -semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a -semptyArray sh = Shaped (memptyArray (shCvtSX sh)) - -sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh -sshape (Shaped arr) = shCvtXS' (mshape arr) - -srank :: Elt a => Shaped sh a -> SNat (Rank sh) -srank = shsRank . sshape - --- | The total number of elements in the array. -ssize :: Elt a => Shaped sh a -> Int -ssize = shsSize . sshape - -sindex :: Elt a => Shaped sh a -> IIxS sh -> a -sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) - -shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh -shsTakeIx _ _ ZIS = ZSS -shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx - -sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a -sindexPartial sarr@(Shaped arr) idx = - Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) - (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) - (ixCvtSX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a -sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) - --- | See the documentation of 'mlift'. -slift :: forall sh1 sh2 a. Elt a - => ShS sh2 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) - -> Shaped sh1 a -> Shaped sh2 a -slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) - --- | See the documentation of 'mlift'. -slift2 :: forall sh1 sh2 sh3 a. Elt a - => ShS sh3 - -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) - -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a -slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) - -ssumOuter1P :: forall sh n a. (Storable a, NumElt a) - => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) - -ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) - => Shaped (n : sh) a -> Shaped sh a -ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive - -ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a -ssumAllPrim (Shaped arr) = msumAllPrim arr - -stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) - => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a -stranspose perm sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - , Refl <- lemTakeLenMapJust perm (sshape sarr) - , Refl <- lemDropLenMapJust perm (sshape sarr) - , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr)) - , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) - = Shaped (mtranspose perm arr) - -sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a -sappend = coerce mappend - -sscalar :: Elt a => a -> Shaped '[] a -sscalar x = Shaped (mscalar x) - -sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) -sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) - -sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a -sfromVector sh v = sfromPrimitive (sfromVectorP sh v) - -stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a -stoVectorP = coerce mtoVectorP - -stoVector :: PrimElt a => Shaped sh a -> VS.Vector a -stoVector = coerce mtoVector - -sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) - -sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a -sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1 - -sfromList1Prim :: PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromList1Prim sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1Prim - -stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] -stoListOuter (Shaped arr) = coerce (mtoListOuter arr) - -stoList1 :: Elt a => Shaped '[n] a -> [a] -stoList1 = map sunScalar . stoListOuter - -sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromListPrim sn l - | Refl <- lemAppNil @'[Just n] - = let ssh = SUnknown () :!% ZKX - xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) - in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr - -sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a -sfromListPrimLinear sh l = - let M_Primitive _ xarr = toPrimitive (mfromListPrim l) - in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) - -sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a -sfromListLinear sh l = Shaped (mfromListLinear (shCvtSX sh) l) - -stoListLinear :: Elt a => Shaped sh a -> [a] -stoListLinear (Shaped arr) = mtoListLinear arr - -sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a -sfromOrthotope sh (SS.A (SG.A arr)) = - Shaped (fromPrimitive (M_Primitive (shCvtSX sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) - -stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a -stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr) - -sunScalar :: Elt a => Shaped '[] a -> a -sunScalar arr = sindex arr ZIS - -snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a) -snest sh arr - | Refl <- lemMapJustApp sh (Proxy @sh') - = coerce (mnest (ssxFromShape (shCvtSX sh)) (coerce arr)) - -sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a -sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) - | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh') - = Shaped arr - -szip :: Shaped sh a -> Shaped sh b -> Shaped sh (a, b) -szip = coerce mzip - -sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b) -sunzip = coerce munzip - -srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) - => ShS sh -> ShS sh2 - -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) - -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) -srerankP sh sh2 f sarr@(Shaped arr) - | Refl <- lemMapJustApp sh (Proxy @sh1) - , Refl <- lemMapJustApp sh (Proxy @sh2) - = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) - (shCvtSX sh2) - (\a -> let Shaped r = f (Shaped a) in r) - arr) - -srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) - => ShS sh -> ShS sh2 - -> (Shaped sh1 a -> Shaped sh2 b) - -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b -srerank sh sh2 f (stoPrimitive -> arr) = - sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr - -sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a -sreplicate sh (Shaped arr) - | Refl <- lemMapJustApp sh (Proxy @sh') - = Shaped (mreplicate (shCvtSX sh) arr) - -sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) -sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) - -sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a -sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) - -sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a -sslice i n@SNat arr = - let _ :$$ sh = sshape arr - in slift (n :$$ sh) (\_ -> X.slice i n) arr - -srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a -srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr - -sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a -sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) - -sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a -sflatten arr = - case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff - n@SNat -> sreshape (n :$$ ZSS) arr - -siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a -siota sn = Shaped (miota sn) - --- | Throws if the array is empty. -sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -sminIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mminIndexPrim arr) - --- | Throws if the array is empty. -smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh -smaxIndexPrim sarr@(Shaped arr) = ixCvtXS (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) - -sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) - => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a -sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) - | Refl <- lemInitApp (Proxy @sh) (Proxy @n) - , Refl <- lemLastApp (Proxy @sh) (Proxy @n) - = case sshape sarr1 of - _ :$$ _ - | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n]) - -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) - _ -> error "unreachable" - --- | This has a temporary, suboptimal implementation in terms of 'mflatten'. --- Prefer 'sdot1Inner' if applicable. -sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a -sdot = coerce mdot - -stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) -stoXArrayPrimP (Shaped arr) = first shCvtXS' (mtoXArrayPrimP arr) - -stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) -stoXArrayPrim (Shaped arr) = first shCvtXS' (mtoXArrayPrim arr) - -sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) - -sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) - -sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a -sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) - -stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) -stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) - -mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') - => Mixed sh a -> ShS sh' -> Shaped sh' a -mcastToShaped arr targetsh - | Refl <- lemRankMapJust targetsh - = Shaped (mcast (ssxFromShape (shCvtSX targetsh)) arr) - -stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a -stoMixed (Shaped arr) = arr - --- | A more weakly-typed version of 'stoMixed' that does a runtime shape --- compatibility check. -scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') - => StaticShX sh' -> Shaped sh a -> Mixed sh' a -scastToMixed sshx sarr@(Shaped arr) - | Refl <- lemRankMapJust (sshape sarr) - = mcast sshx arr diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs new file mode 100644 index 0000000..8cac298 --- /dev/null +++ b/src/Data/Array/Nested/Lemmas.hs @@ -0,0 +1,162 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Lemmas where + +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits + +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types + + +-- * Lemmas about numbers and lists + +-- ** Nat + +lemLeqSuccSucc :: k + 1 <= n => Proxy k -> Proxy n -> (k <=? n - 1) :~: True +lemLeqSuccSucc _ _ = unsafeCoerceRefl + +lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True +lemLeqPlus _ _ _ = Refl + +-- ** Append + +lemAppNil :: l ++ '[] :~: l +lemAppNil = unsafeCoerceRefl + +lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) +lemAppAssoc _ _ _ = unsafeCoerceRefl + +lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l +lemAppLeft _ Refl = Refl + +-- ** Simple type families + +lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a + -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +lemReplicatePlusApp sn _ _ = go sn + where + go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a + go SZ = Refl + go (SS (n :: SNat n'm1)) + | Refl <- lemReplicateSucc @a @n'm1 + , Refl <- go n + = sym (lemReplicateSucc @a @(n'm1 + m)) + +lemDropLenApp :: Rank l1 <= Rank l2 + => Proxy l1 -> Proxy l2 -> Proxy rest + -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest) +lemDropLenApp _ _ _ = unsafeCoerceRefl + +lemTakeLenApp :: Rank l1 <= Rank l2 + => Proxy l1 -> Proxy l2 -> Proxy rest + -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest) +lemTakeLenApp _ _ _ = unsafeCoerceRefl + +lemInitApp :: Proxy l -> Proxy x -> Init (l ++ '[x]) :~: l +lemInitApp _ _ = unsafeCoerceRefl + +lemLastApp :: Proxy l -> Proxy x -> Last (l ++ '[x]) :~: x +lemLastApp _ _ = unsafeCoerceRefl + + +-- ** KnownNat + +lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1) +lemKnownNatSucc = Dict + +lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh) +lemKnownNatRank ZSX = Dict +lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict + +lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRankSSX ZKX = Dict +lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict + + +-- * Lemmas about shapes + +-- ** Known shapes + +lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) +lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn) + +lemKnownShX :: StaticShX sh -> Dict KnownShX sh +lemKnownShX ZKX = Dict +lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict +lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict + +lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) +lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) + where + go :: ShS sh' -> StaticShX (MapJust sh') + go ZSS = ZKX + go (n :$$ sh) = SKnown n :!% go sh + +-- ** Rank + +lemRankApp :: forall sh1 sh2. + StaticShX sh1 -> StaticShX sh2 + -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 +lemRankApp ZKX _ = Refl +lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 + = lem (Proxy @(Rank sh1T)) Proxy Proxy $ + sym (lemRankApp ssh1 ssh2) + where + lem :: proxy a -> proxy b -> proxy c + -> (a + b :~: c) + -> c + 1 :~: (a + 1 + b) + lem _ _ _ Refl = Refl + +lemRankAppComm :: proxy sh1 -> proxy sh2 + -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) +lemRankAppComm _ _ = unsafeCoerceRefl + +lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = unsafeCoerceRefl + +lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh +lemRankMapJust ZSS = Refl +lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl + +-- ** Related to MapJust and/or Permutation + +lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) +lemTakeLenMapJust PNil _ = Refl +lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl +lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty" + +lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) +lemDropLenMapJust PNil _ = Refl +lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl +lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty" + +lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) +lemIndexMapJust SZ (_ :$$ _) = Refl +lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) + | Refl <- lemIndexMapJust i sh + , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) + , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = Refl +lemIndexMapJust _ ZSS = error "Index of empty" + +lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) +lemPermuteMapJust PNil _ = Refl +lemPermuteMapJust (i `PCons` is) sh + | Refl <- lemPermuteMapJust is sh + , Refl <- lemIndexMapJust i sh + = Refl + +lemMapJustApp :: ShS sh1 -> Proxy sh2 + -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 +lemMapJustApp ZSS _ = Refl +lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index a2f9737..144230e 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE DeriveGeneric #-} @@ -16,7 +17,7 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} -module Data.Array.Nested.Internal.Mixed where +module Data.Array.Nested.Mixed where import Prelude hiding (mconcat) @@ -28,7 +29,7 @@ import Data.Bifunctor (bimap) import Data.Coerce import Data.Foldable (toList) import Data.Int -import Data.Kind (Constraint, Type) +import Data.Kind (Type) import Data.List.NonEmpty (NonEmpty(..)) import Data.List.NonEmpty qualified as NE import Data.Proxy @@ -40,15 +41,14 @@ import Foreign.Storable (Storable) import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) import GHC.Generics (Generic) import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) - -import Data.Array.Mixed.Internal.Arith -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Types -import Data.Array.Mixed.XArray (XArray(..)) -import Data.Array.Mixed.XArray qualified as X + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Types +import Data.Array.Strided.Orthotope +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X import Data.Bag @@ -140,27 +140,39 @@ data family Mixed sh a -- sizes of the elements of an empty array, which is information that should -- ostensibly not exist; the full array is still empty. +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +#define ANDSHOW , Show +#else +#define ANDSHOW +#endif + data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) - deriving (Eq, Ord, Generic) + deriving (Eq, Ord, Generic ANDSHOW) -- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic) -newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Eq, Ord, Generic) -- no content, orthotope optimises this (via Vector) +newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Eq, Ord, Generic ANDSHOW) -- no content, orthotope optimises this (via Vector) -- etc. data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) +#endif -- etc., larger tuples (perhaps use generics to allow arbitrary product types) deriving instance (Eq (Mixed sh a), Eq (Mixed sh b)) => Eq (Mixed sh (a, b)) deriving instance (Ord (Mixed sh a), Ord (Mixed sh b)) => Ord (Mixed sh (a, b)) data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance (Show (Mixed (sh1 ++ sh2) a)) => Show (Mixed sh1 (Mixed sh2 a)) +#endif deriving instance Eq (Mixed (sh1 ++ sh2) a) => Eq (Mixed sh1 (Mixed sh2 a)) deriving instance Ord (Mixed (sh1 ++ sh2) a) => Ord (Mixed sh1 (Mixed sh2 a)) @@ -204,10 +216,12 @@ showsMixedArray fromlistPrefix replicatePrefix d arr = _ -> showString fromlistPrefix . showString " " . shows (mtoListLinear arr) +#ifndef OXAR_DEFAULT_SHOW_INSTANCES instance (Show a, Elt a) => Show (Mixed sh a) where showsPrec d arr = let sh = show (shxToList (mshape arr)) in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr +#endif instance Elt a => NFData (Mixed sh a) where rnf = mrnf @@ -366,9 +380,7 @@ class Elt a where -- of this class with those of 'Elt': some instances have an additional -- "known-shape" constraint. -- --- This class is (currently) only required for 'mgenerate', --- 'Data.Array.Nested.Ranked.rgenerate' and --- 'Data.Array.Nested.Shaped.sgenerate'. +-- This class is (currently) only required for `memptyArray` and 'mgenerate'. class Elt a => KnownElt a where -- | Create an empty array. The given shape must have size zero; this may or may not be checked. memptyArrayUnsafe :: IShX sh -> Mixed sh a @@ -384,11 +396,11 @@ class Elt a => KnownElt a where instance Storable a => Elt (Primitive a) where mshape (M_Primitive sh _) = sh mindex (M_Primitive _ a) i = Primitive (X.index a i) - mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) + mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i) mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) mfromListOuter l@(arr1 :| _) = let sh = SUnknown (length l) :$% mshape arr1 - in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) + in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l))) mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) mlift :: forall sh1 sh2. @@ -426,17 +438,17 @@ instance Storable a => Elt (Primitive a) where => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' - sh2 = shxCast' sh1 ssh2 - in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) + sh2 = shxCast' ssh2 sh1 + in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr) mtranspose perm (M_Primitive sh arr) = M_Primitive (shxPermutePrefix perm sh) - (X.transpose (ssxFromShape sh) perm arr) + (X.transpose (ssxFromShX sh) perm arr) mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a) mconcat l@(M_Primitive (_ :$% sh) _ :| _) = - let result = X.concat (ssxFromShape sh) (fmap (\(M_Primitive _ arr) -> arr) l) - in M_Primitive (X.shape (SUnknown () :!% ssxFromShape sh) result) result + let result = X.concat (ssxFromShX sh) (fmap (\(M_Primitive _ arr) -> arr) l) + in M_Primitive (X.shape (SUnknown () :!% ssxFromShX sh) result) result mrnf (M_Primitive sh a) = rnf sh `seq` rnf a @@ -453,7 +465,7 @@ instance Storable a => Elt (Primitive a) where :: forall sh' sh s. IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do - let arrsh = X.shape (ssxFromShape sh') arr + let arrsh = X.shape (ssxFromShX sh') arr offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) @@ -540,16 +552,16 @@ instance Elt a => Elt (Mixed sh' a) where -- moverlongShape method, a prefix of which is mshape. mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh mshape (M_Nest sh arr) - = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) + = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr)) mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a - mindex (M_Nest _ arr) i = mindexPartial arr i + mindex (M_Nest _ arr) = mindexPartial arr mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest sh arr) i | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) mscalar = M_Nest ZSX @@ -569,7 +581,7 @@ instance Elt a => Elt (Mixed sh' a) where (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) in M_Nest sh2 result where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) + ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr))) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b f' sshT @@ -586,7 +598,7 @@ instance Elt a => Elt (Mixed sh' a) where (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) in M_Nest sh3 result where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) + ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b f' sshT @@ -604,7 +616,7 @@ instance Elt a => Elt (Mixed sh' a) where (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result)) in fmap (M_Nest sh2) result where - ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) + ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b) f' sshT @@ -618,15 +630,15 @@ instance Elt a => Elt (Mixed sh' a) where | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T - sh2 = shxCast' sh1 ssh2 + sh2 = shxCast' ssh2 sh1 in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr) mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) => Perm is -> Mixed sh (Mixed sh' a) -> Mixed (PermutePrefix is sh) (Mixed sh' a) mtranspose perm (M_Nest sh arr) - | let sh' = shxDropSh @sh @sh' (mshape arr) sh - , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') + | let sh' = shxDropSh @sh @sh' sh (mshape arr) + , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh') , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') @@ -637,14 +649,14 @@ instance Elt a => Elt (Mixed sh' a) where mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) mconcat l@(M_Nest sh1 _ :| _) = let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) - in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape result))) result + in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) - mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) + mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr))))) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -671,7 +683,7 @@ instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where mvecsUnsafeNew sh example | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShX sh'))) where sh' = mshape example @@ -729,14 +741,14 @@ msumOuter1P :: forall sh n a. (Storable a, NumElt a) => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) msumOuter1P (M_Primitive (n :$% sh) arr) = let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX - in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) + in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr) msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) => Mixed (n : sh) a -> Mixed sh a msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a -msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShape sh) arr +msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr mappend :: forall n m sh a. Elt a => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a @@ -744,7 +756,7 @@ mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 where sn :$% sh = mshape arr1 sm :$% _ = mshape arr2 - ssh = ssxFromShape sh + ssh = ssxFromShX sh snm :: SMayNat () SNat (AddMaybe n m) snm = case (sn, sm) of (SUnknown{}, _) -> SUnknown () @@ -770,14 +782,10 @@ mtoVector arr = mtoVectorP (toPrimitive arr) mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? -mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromList1Prim l = - let ssh = SUnknown () :!% ZKX - xarr = X.fromList1 ssh l - in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mtoList1 :: Elt a => Mixed '[n] a -> [a] -mtoList1 = map munScalar . mtoListOuter +-- This forall is there so that a simple type application can constrain the +-- shape, in case the user wants to use OverloadedLists for the shape. +mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a +mfromListLinear sh l = mreshape sh (mfromList1 l) mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a mfromListPrim l = @@ -790,10 +798,8 @@ mfromListPrimLinear sh l = let M_Primitive _ xarr = toPrimitive (mfromListPrim l) in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) --- This forall is there so that a simple type application can constrain the --- shape, in case the user wants to use OverloadedLists for the shape. -mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a -mfromListLinear sh l = mreshape sh (mfromList1 l) +mtoList :: Elt a => Mixed '[n] a -> [a] +mtoList = map munScalar . mtoListOuter mtoListLinear :: Elt a => Mixed sh a -> [a] mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise @@ -807,8 +813,11 @@ mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a munNest (M_Nest _ arr) = arr -mzip :: Mixed sh a -> Mixed sh b -> Mixed sh (a, b) -mzip = M_Tup2 +-- | The arguments must have equal shapes. If they do not, an error is raised. +mzip :: (Elt a, Elt b) => Mixed sh a -> Mixed sh b -> Mixed sh (a, b) +mzip a b + | Just Refl <- shxEqual (mshape a) (mshape b) = M_Tup2 a b + | otherwise = error "mzip: unequal shapes" munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b) munzip (M_Tup2 a b) = (a, b) @@ -818,13 +827,13 @@ mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) mrerankP ssh sh2 f (M_Primitive sh arr) = - let sh1 = shxDropSSX sh ssh - in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) - (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) + let sh1 = shxDropSSX ssh sh + in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) ssh sh) sh2) + (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2) (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) arr) --- | See the caveats at @X.rerank@. +-- | See the caveats at 'Data.Array.XArray.rerank'. mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) => StaticShX sh -> IShX sh2 -> (Mixed sh1 a -> Mixed sh2 b) @@ -835,8 +844,8 @@ mrerank ssh sh2 f (toPrimitive -> arr) = mreplicate :: forall sh sh' a. Elt a => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a mreplicate sh arr = - let ssh' = ssxFromShape (mshape arr) - in mlift (ssxAppend (ssxFromShape sh) ssh') + let ssh' = ssxFromShX (mshape arr) + in mlift (ssxAppend (ssxFromShX sh) ssh') (\(sshT :: StaticShX shT) -> case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of Refl -> X.replicate sh (ssxAppend ssh' sshT)) @@ -852,18 +861,18 @@ mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a mslice i n arr = let _ :$% sh = mshape arr - in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr + in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr +msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr +mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a mreshape sh' arr = - mlift (ssxFromShape sh') - (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') + mlift (ssxFromShX sh') + (\sshIn -> X.reshapePartial (ssxFromShX (mshape arr)) sshIn sh') arr mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a @@ -875,12 +884,12 @@ miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) -- | Throws if the array is empty. mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr)) + ixxFromList (ssxFromShX sh) (numEltMinIndex (shxRank sh) (fromO arr)) -- | Throws if the array is empty. mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr)) + ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr)) mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a @@ -890,7 +899,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi = case sh1 of _ :$% _ | sh1 == sh2 - , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) -> + , Refl <- lemRankApp (ssxInit (ssxFromShX sh1)) (ssxLast (ssxFromShX sh1) :!% ZKX) -> fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b)) | otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")" ZSX -> error "unreachable" @@ -925,31 +934,3 @@ mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) -> Mixed sh a -> Mixed sh b -> Mixed sh c mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) - -mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) - => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a -mcast ssh2 arr - | Refl <- lemAppNil @sh1 - , Refl <- lemAppNil @sh2 - = mcastPartial (ssxFromShape (mshape arr)) ssh2 (Proxy @'[]) arr - --- TODO: This should be `type data` but a bug in GHC 9.10 means that that throws linker errors -data SafeMCastSpec - = MCastId - | MCastApp [Maybe Nat] [Maybe Nat] [Maybe Nat] [Maybe Nat] SafeMCastSpec SafeMCastSpec - | MCastForget - -type SafeMCast :: SafeMCastSpec -> [Maybe Nat] -> [Maybe Nat] -> Constraint -type family SafeMCast spec sh1 sh2 where - SafeMCast MCastId sh sh = () - SafeMCast (MCastApp sh1A sh1B sh2A sh2B specA specB) sh1 sh2 = (sh1 ~ sh1A ++ sh1B, sh2 ~ sh2A ++ sh2B, SafeMCast specA sh1A sh2A, SafeMCast specB sh1B sh2B) - SafeMCast MCastForget sh1 sh2 = sh2 ~ Replicate (Rank sh1) Nothing - --- | This is an O(1) operation: the 'SafeMCast' constraint ensures that --- type-level shape information can only be forgotten, not introduced, and thus --- that no runtime shape checks are required. The @spec@ describes to --- 'SafeMCast' how exactly you intend @sh2@ to be a weakening of @sh1@. --- --- To see how to construct the spec, read the equations of 'SafeMCast' closely. -mcastSafe :: forall spec sh1 sh2 a proxy. SafeMCast spec sh1 sh2 => proxy spec -> Mixed sh1 a -> Mixed sh2 a -mcastSafe _ = unsafeCoerce @(Mixed sh1 a) @(Mixed sh2 a) diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs new file mode 100644 index 0000000..852dd5e --- /dev/null +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -0,0 +1,644 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Mixed.Shape where + +import Control.DeepSeq (NFData(..)) +import Data.Bifunctor (first) +import Data.Coerce +import Data.Foldable qualified as Foldable +import Data.Functor.Const +import Data.Functor.Product +import Data.Kind (Constraint, Type) +import Data.Monoid (Sum(..)) +import Data.Type.Equality +import GHC.Exts (withDict) +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits + +import Data.Array.Nested.Types + + +-- | The length of a type-level list. If the argument is a shape, then the +-- result is the rank of that shape. +type family Rank sh where + Rank '[] = 0 + Rank (_ : sh) = Rank sh + 1 + + +-- * Mixed lists + +type role ListX nominal representational +type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type +data ListX sh f where + ZX :: ListX '[] f + (::%) :: f n -> ListX sh f -> ListX (n : sh) f +deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) +infixr 3 ::% + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance (forall n. Show (f n)) => Show (ListX sh f) +#else +instance (forall n. Show (f n)) => Show (ListX sh f) where + showsPrec _ = listxShow shows +#endif + +instance (forall n. NFData (f n)) => NFData (ListX sh f) where + rnf ZX = () + rnf (x ::% l) = rnf x `seq` rnf l + +data UnconsListXRes f sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) +listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) +listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) +listxUncons ZX = Nothing + +-- | This checks only whether the types are equal; if the elements of the list +-- are not singletons, their values may still differ. This corresponds to +-- 'testEquality', except on the penultimate type parameter. +listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') +listxEqType ZX ZX = Just Refl +listxEqType (n ::% sh) (m ::% sh') + | Just Refl <- testEquality n m + , Just Refl <- listxEqType sh sh' + = Just Refl +listxEqType _ _ = Nothing + +-- | This checks whether the two lists actually contain equal values. This is +-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ +-- in the @some@ package (except on the penultimate type parameter). +listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') +listxEqual ZX ZX = Just Refl +listxEqual (n ::% sh) (m ::% sh') + | Just Refl <- testEquality n m + , n == m + , Just Refl <- listxEqual sh sh' + = Just Refl +listxEqual _ _ = Nothing + +listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g +listxFmap _ ZX = ZX +listxFmap f (x ::% xs) = f x ::% listxFmap f xs + +listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m +listxFold _ ZX = mempty +listxFold f (x ::% xs) = f x <> listxFold f xs + +listxLength :: ListX sh f -> Int +listxLength = getSum . listxFold (\_ -> Sum 1) + +listxRank :: ListX sh f -> SNat (Rank sh) +listxRank ZX = SNat +listxRank (_ ::% l) | SNat <- listxRank l = SNat + +listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS +listxShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListX sh' f -> ShowS + go _ ZX = id + go prefix (x ::% xs) = showString prefix . f x . go "," xs + +listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i) +listxFromList topssh topl = go topssh topl + where + go :: StaticShX sh' -> [i] -> ListX sh' (Const i) + go ZKX [] = ZX + go (_ :!% sh) (i : is) = Const i ::% go sh is + go _ _ = error $ "listxFromList: Mismatched list length (type says " + ++ show (ssxLength topssh) ++ ", list has length " + ++ show (length topl) ++ ")" + +listxToList :: ListX sh' (Const i) -> [i] +listxToList ZX = [] +listxToList (Const i ::% is) = i : listxToList is + +listxHead :: ListX (mn ': sh) f -> f mn +listxHead (i ::% _) = i + +listxTail :: ListX (n : sh) i -> ListX sh i +listxTail (_ ::% sh) = sh + +listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f +listxAppend ZX idx' = idx' +listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' + +listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f +listxDrop ZX long = long +listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long' + +listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f +listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh +listxInit (_ ::% ZX) = ZX + +listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh)) +listxLast (_ ::% sh@(_ ::% _)) = listxLast sh +listxLast (x ::% ZX) = x + +listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g) +listxZip ZX ZX = ZX +listxZip (i ::% irest) (j ::% jrest) = + Pair i j ::% listxZip irest jrest + +listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g + -> ListX sh h +listxZipWith _ ZX ZX = ZX +listxZipWith f (i ::% is) (j ::% js) = + f i j ::% listxZipWith f is js + + +-- * Mixed indices + +-- | An index into a mixed-typed array. +type role IxX nominal representational +type IxX :: [Maybe Nat] -> Type -> Type +newtype IxX sh i = IxX (ListX sh (Const i)) + deriving (Eq, Ord, Generic) + +pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i +pattern ZIX = IxX ZX + +pattern (:.%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> IxX sh i -> IxX sh1 i +pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) + where i :.% IxX shl = IxX (Const i ::% shl) +infixr 3 :.% + +{-# COMPLETE ZIX, (:.%) #-} + +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). +type IIxX sh = IxX sh Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (IxX sh i) +#else +instance Show i => Show (IxX sh i) where + showsPrec _ (IxX l) = listxShow (shows . getConst) l +#endif + +instance Functor (IxX sh) where + fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) + +instance Foldable (IxX sh) where + foldMap f (IxX l) = listxFold (f . getConst) l + +instance NFData i => NFData (IxX sh i) + +ixxLength :: IxX sh i -> Int +ixxLength (IxX l) = listxLength l + +ixxRank :: IxX sh i -> SNat (Rank sh) +ixxRank (IxX l) = listxRank l + +ixxZero :: StaticShX sh -> IIxX sh +ixxZero ZKX = ZIX +ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh + +ixxZero' :: IShX sh -> IIxX sh +ixxZero' ZSX = ZIX +ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh + +ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i +ixxFromList = coerce (listxFromList @_ @i) + +ixxHead :: IxX (n : sh) i -> i +ixxHead (IxX list) = getConst (listxHead list) + +ixxTail :: IxX (n : sh) i -> IxX sh i +ixxTail (IxX list) = IxX (listxTail list) + +ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i +ixxAppend = coerce (listxAppend @_ @(Const i)) + +ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i +ixxDrop = coerce (listxDrop @(Const i) @(Const i)) + +ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i +ixxInit = coerce (listxInit @(Const i)) + +ixxLast :: forall n sh i. IxX (n : sh) i -> i +ixxLast = coerce (listxLast @(Const i)) + +ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i +ixxCast ZKX ZIX = ZIX +ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx +ixxCast _ _ = error "ixxCast: ranks don't match" + +ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j) +ixxZip ZIX ZIX = ZIX +ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js + +ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k +ixxZipWith _ ZIX ZIX = ZIX +ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js + +ixxFromLinear :: IShX sh -> Int -> IIxX sh +ixxFromLinear = \sh i -> case go sh i of + (idx, 0) -> idx + _ -> error $ "ixxFromLinear: out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" + where + -- returns (index in subarray, remaining index in enclosing array) + go :: IShX sh -> Int -> (IIxX sh, Int) + go ZSX i = (ZIX, i) + go (n :$% sh) i = + let (idx, i') = go sh i + (upi, locali) = i' `quotRem` fromSMayNat' n + in (locali :.% idx, upi) + +ixxToLinear :: IShX sh -> IIxX sh -> Int +ixxToLinear = \sh i -> fst (go sh i) + where + -- returns (index in subarray, size of subarray) + go :: IShX sh -> IIxX sh -> (Int, Int) + go ZSX ZIX = (0, 1) + go (n :$% sh) (i :.% ix) = + let (lidx, sz) = go sh ix + in (sz * i + lidx, fromSMayNat' n * sz) + + +-- * Mixed shapes + +data SMayNat i f n where + SUnknown :: i -> SMayNat i f Nothing + SKnown :: f n -> SMayNat i f (Just n) +deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) +deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) +deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) + +instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where + rnf (SUnknown i) = rnf i + rnf (SKnown x) = rnf x + +instance TestEquality f => TestEquality (SMayNat i f) where + testEquality SUnknown{} SUnknown{} = Just Refl + testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl + testEquality _ _ = Nothing + +fromSMayNat :: (n ~ Nothing => i -> r) + -> (forall m. n ~ Just m => f m -> r) + -> SMayNat i f n -> r +fromSMayNat f _ (SUnknown i) = f i +fromSMayNat _ g (SKnown s) = g s + +fromSMayNat' :: SMayNat Int SNat n -> Int +fromSMayNat' = fromSMayNat id fromSNat' + +type family AddMaybe n m where + AddMaybe Nothing _ = Nothing + AddMaybe (Just _) Nothing = Nothing + AddMaybe (Just n) (Just m) = Just (n + m) + +smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) +smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) +smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) + + +-- | This is a newtype over 'ListX'. +type role ShX nominal representational +type ShX :: [Maybe Nat] -> Type -> Type +newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) + deriving (Eq, Ord, Generic) + +pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i +pattern ZSX = ShX ZX + +pattern (:$%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => SMayNat i SNat n -> ShX sh i -> ShX sh1 i +pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) + where i :$% ShX shl = ShX (i ::% shl) +infixr 3 :$% + +{-# COMPLETE ZSX, (:$%) #-} + +type IShX sh = ShX sh Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ShX sh i) +#else +instance Show i => Show (ShX sh i) where + showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l +#endif + +instance Functor (ShX sh) where + fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) + +instance NFData i => NFData (ShX sh i) where + rnf (ShX ZX) = () + rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) + rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) + +-- | This checks only whether the types are equal; unknown dimensions might +-- still differ. This corresponds to 'testEquality', except on the penultimate +-- type parameter. +shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') +shxEqType ZSX ZSX = Just Refl +shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') + | Just Refl <- sameNat n m + , Just Refl <- shxEqType sh sh' + = Just Refl +shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh') + | Just Refl <- shxEqType sh sh' + = Just Refl +shxEqType _ _ = Nothing + +-- | This checks whether all dimensions have the same value. This is more than +-- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the +-- @some@ package (except on the penultimate type parameter). +shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') +shxEqual ZSX ZSX = Just Refl +shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') + | Just Refl <- sameNat n m + , Just Refl <- shxEqual sh sh' + = Just Refl +shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') + | i == j + , Just Refl <- shxEqual sh sh' + = Just Refl +shxEqual _ _ = Nothing + +shxLength :: ShX sh i -> Int +shxLength (ShX l) = listxLength l + +shxRank :: ShX sh i -> SNat (Rank sh) +shxRank (ShX l) = listxRank l + +-- | The number of elements in an array described by this shape. +shxSize :: IShX sh -> Int +shxSize ZSX = 1 +shxSize (n :$% sh) = fromSMayNat' n * shxSize sh + +shxFromList :: StaticShX sh -> [Int] -> IShX sh +shxFromList topssh topl = go topssh topl + where + go :: StaticShX sh' -> [Int] -> IShX sh' + go ZKX [] = ZSX + go (SKnown sn :!% sh) (i : is) + | i == fromSNat' sn = SKnown sn :$% go sh is + | otherwise = error $ "shxFromList: Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is + go _ _ = error $ "shxFromList: Mismatched list length (type says " + ++ show (ssxLength topssh) ++ ", list has length " + ++ show (length topl) ++ ")" + +shxToList :: IShX sh -> [Int] +shxToList ZSX = [] +shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh + +shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i +shxFromSSX ZKX = ZSX +shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh)) + | Refl <- lemMapJustCons @sh Refl + = SKnown n :$% shxFromSSX sh +shxFromSSX (SUnknown _ :!% _) = error "unreachable" + +-- | This may fail if @sh@ has @Nothing@s in it. +shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i) +shxFromSSX2 ZKX = Just ZSX +shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh +shxFromSSX2 (SUnknown _ :!% _) = Nothing + +shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i +shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) + +shxHead :: ShX (n : sh) i -> SMayNat i SNat n +shxHead (ShX list) = listxHead list + +shxTail :: ShX (n : sh) i -> ShX sh i +shxTail (ShX list) = ShX (listxTail list) + +shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i +shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) + +shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i +shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) + +shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i +shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) + +shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i +shxInit = coerce (listxInit @(SMayNat i SNat)) + +shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) +shxLast = coerce (listxLast @(SMayNat i SNat)) + +shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSSX _ ZKX _ = ZSX +shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh + +shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) + -> ShX sh i -> ShX sh j -> ShX sh k +shxZipWith _ ZSX ZSX = ZSX +shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js + +-- This is a weird operation, so it has a long name +shxCompleteZeros :: StaticShX sh -> IShX sh +shxCompleteZeros ZKX = ZSX +shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh +shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh + +shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) +shxSplitApp _ ZKX idx = (ZSX, idx) +shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) + +shxEnum :: IShX sh -> [IIxX sh] +shxEnum = \sh -> go sh id [] + where + go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] + go ZSX f = (f ZIX :) + go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] + +shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh') +shxCast ZKX ZSX = Just ZSX +shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh +shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh +shxCast _ _ = Nothing + +-- | Partial version of 'shxCast'. +shxCast' :: StaticShX sh' -> IShX sh -> IShX sh' +shxCast' ssh sh = case shxCast ssh sh of + Just sh' -> sh' + Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")" + + +-- * Static mixed shapes + +-- | The part of a shape that is statically known. (A newtype over 'ListX'.) +type StaticShX :: [Maybe Nat] -> Type +newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) + deriving (Eq, Ord) + +pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh +pattern ZKX = StaticShX ZX + +pattern (:!%) + :: forall {sh1}. + forall n sh. (n : sh ~ sh1) + => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 +pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) + where i :!% StaticShX shl = StaticShX (i ::% shl) +infixr 3 :!% + +{-# COMPLETE ZKX, (:!%) #-} + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (StaticShX sh) +#else +instance Show (StaticShX sh) where + showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l +#endif + +instance NFData (StaticShX sh) where + rnf (StaticShX ZX) = () + rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l) + rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l) + +instance TestEquality StaticShX where + testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2 + +ssxLength :: StaticShX sh -> Int +ssxLength (StaticShX l) = listxLength l + +ssxRank :: StaticShX sh -> SNat (Rank sh) +ssxRank (StaticShX l) = listxRank l + +-- | @ssxEqType = 'testEquality'@. Provided for consistency. +ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') +ssxEqType = testEquality + +ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') +ssxAppend ZKX sh' = sh' +ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' + +ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n +ssxHead (StaticShX list) = listxHead list + +ssxTail :: StaticShX (n : sh) -> StaticShX sh +ssxTail (_ :!% ssh) = ssh + +ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat)) + +ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) + +ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat)) + +ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) +ssxInit = coerce (listxInit @(SMayNat () SNat)) + +ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) +ssxLast = coerce (listxLast @(SMayNat () SNat)) + +ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) +ssxReplicate SZ = ZKX +ssxReplicate (SS (n :: SNat n')) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n' + = SUnknown () :!% ssxReplicate n + +ssxIotaFrom :: StaticShX sh -> Int -> [Int] +ssxIotaFrom ZKX _ = [] +ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i+1) + +ssxFromShX :: ShX sh i -> StaticShX sh +ssxFromShX ZSX = ZKX +ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh + +ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) +ssxFromSNat SZ = ZKX +ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n + + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShX :: [Maybe Nat] -> Constraint +class KnownShX sh where knownShX :: StaticShX sh +instance KnownShX '[] where knownShX = ZKX +instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX +instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX + +withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r +withKnownShX = withDict @(KnownShX sh) + + +-- * Flattening + +type Flatten sh = Flatten' 1 sh + +type family Flatten' acc sh where + Flatten' acc '[] = Just acc + Flatten' acc (Nothing : sh) = Nothing + Flatten' acc (Just n : sh) = Flatten' (acc * n) sh + +-- This function is currently unused +ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) +ssxFlatten = go (SNat @1) + where + go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) + go acc ZKX = SKnown acc + go _ (SUnknown () :!% _) = SUnknown () + go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh + +shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) +shxFlatten = go (SNat @1) + where + go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go acc ZSX = SKnown acc + go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) + go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh + + goUnknown :: Int -> IShX sh -> Int + goUnknown acc ZSX = acc + goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh + goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh + + +-- | Very untyped: only length is checked (at runtime). +instance KnownShX sh => IsList (ListX sh (Const i)) where + type Item (ListX sh (Const i)) = i + fromList = listxFromList (knownShX @sh) + toList = listxToList + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShX sh => IsList (IxX sh i) where + type Item (IxX sh i) = i + fromList = IxX . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length and known dimensions are checked (at runtime). +instance KnownShX sh => IsList (ShX sh Int) where + type Item (ShX sh Int) = Int + fromList = shxFromList (knownShX @sh) + toList = shxToList diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs new file mode 100644 index 0000000..03d1640 --- /dev/null +++ b/src/Data/Array/Nested/Permutation.hs @@ -0,0 +1,283 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Permutation where + +import Data.Coerce (coerce) +import Data.Functor.Const +import Data.List (sort) +import Data.Maybe (fromMaybe) +import Data.Proxy +import Data.Type.Bool +import Data.Type.Equality +import Data.Type.Ord +import GHC.Exts (withDict) +import GHC.TypeError +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Types + + +-- * Permutations + +-- | A "backward" permutation of a dimension list. The operation on the +-- dimension list is most similar to @backpermute@ in the @vector@ package; see +-- 'Permute' for code that implements this. +data Perm list where + PNil :: Perm '[] + PCons :: SNat a -> Perm l -> Perm (a : l) +infixr 5 `PCons` +deriving instance Show (Perm list) +deriving instance Eq (Perm list) + +instance TestEquality Perm where + testEquality PNil PNil = Just Refl + testEquality (x `PCons` xs) (y `PCons` ys) + | Just Refl <- testEquality x y + , Just Refl <- testEquality xs ys = Just Refl + testEquality _ _ = Nothing + +permRank :: Perm list -> SNat (Rank list) +permRank PNil = SNat +permRank (_ `PCons` l) | SNat <- permRank l = SNat + +permFromList :: [Int] -> (forall list. Perm list -> r) -> r +permFromList [] k = k PNil +permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case + Just sn -> permFromList xs $ \list -> k (sn `PCons` list) + Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x + +permToList :: Perm list -> [Natural] +permToList PNil = mempty +permToList (x `PCons` l) = TN.fromSNat x : permToList l + +permToList' :: Perm list -> [Int] +permToList' = map fromIntegral . permToList + +-- | When called as @permCheckPermutation p k@, if @p@ is a permutation of +-- @[0 .. 'length' ('permToList' p) - 1]@, @Just k@ is returned. If it isn't, +-- then @Nothing@ is returned. +permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r +permCheckPermutation = \p k -> + let n = permRank p + in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of + (Just Refl, Just Refl) -> Just k + _ -> Nothing + where + lemElemCount :: (0 <= n, Compare n m ~ LT) + => proxy n -> proxy m -> Elem n (Count 0 m) :~: True + lemElemCount _ _ = unsafeCoerceRefl + + lemCount :: (OrdCond (Compare i n) True False True ~ True) + => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n + lemCount _ _ = unsafeCoerceRefl + + lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True + lemElem _ _ = unsafeCoerceRefl + + provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' + -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) + provePerm1 _ _ PNil = Just Refl + provePerm1 p rtop@SNat (PCons sn@SNat perm) + | Just Refl <- provePerm1 p rtop perm + = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of + (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + _ -> Nothing + | otherwise + = Nothing + + provePerm2 :: SNat i -> SNat n -> Perm is' + -> Maybe (AllElem' (Count i n) is' :~: True) + provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> + case cmpNat i n of + EQI -> Just Refl + LTI | Refl <- lemCount i n + , Just Refl <- provePerm2 (SNat @(i + 1)) n perm + -> checkElem i perm + | otherwise -> Nothing + GTI -> error "unreachable" + where + checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) + checkElem _ PNil = Nothing + checkElem i@SNat (PCons k@SNat perm :: Perm is') = + case sameNat i k of + Just Refl -> Just Refl + Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl + | otherwise -> Nothing + +-- | Utility class for generating permutations from type class information. +class KnownPerm l where makePerm :: Perm l +instance KnownPerm '[] where makePerm = PNil +instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm + +withKnownPerm :: forall l r. Perm l -> (KnownPerm l => r) -> r +withKnownPerm = withDict @(KnownPerm l) + +-- | Untyped permutations for ranked arrays +type PermR = [Int] + + +-- ** Applying permutations + +type family Elem x l where + Elem x '[] = 'False + Elem x (x : _) = 'True + Elem x (_ : ys) = Elem x ys + +type family AllElem' as bs where + AllElem' '[] bs = 'True + AllElem' (a : as) bs = Elem a bs && AllElem' as bs + +type AllElem as bs = Assert (AllElem' as bs) + (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) + +type family Count i n where + Count n n = '[] + Count i n = i : Count (i + 1) n + +type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) + +type family Index i sh where + Index 0 (n : sh) = n + Index i (_ : sh) = Index (i - 1) sh + +type family Permute is sh where + Permute '[] sh = '[] + Permute (i : is) sh = Index i sh : Permute is sh + +type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh + +type family TakeLen ref l where + TakeLen '[] l = '[] + TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs + +type family DropLen ref l where + DropLen '[] l = l + DropLen (_ : ref) (_ : xs) = DropLen ref xs + +listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f +listxTakeLen PNil _ = ZX +listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh +listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape" + +listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f +listxDropLen PNil sh = sh +listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh +listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape" + +listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f +listxPermute PNil _ = ZX +listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = + listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh + +listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) +listxIndex _ _ SZ (n ::% _) = n +listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = listxIndex p pT i sh +listxIndex _ _ _ ZX = error "Index into empty shape" + +listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f +listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) + +ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i +ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) + +ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) + +ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) +ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) + +ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) +ssxPermute = coerce (listxPermute @(SMayNat () SNat)) + +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) +ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) + +ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) +ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) + +shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) +shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) + + +-- * Operations on permutations + +permInverse :: Perm is + -> (forall is'. + IsPermutation is' + => Perm is' + -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) + -> r) + -> r +permInverse = \perm k -> + genPerm perm $ \(invperm :: Perm is') -> + fromMaybe + (error $ "permInverse: did not generate permutation? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm) + (permCheckPermutation invperm + (k invperm + (\ssh -> case permCheckInverse perm invperm ssh of + Just eq -> eq + Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm))) + where + genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r + genPerm perm = + let permList = permToList' perm + in toHList $ map snd (sort (zip permList [0..])) + where + toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r + toHList [] k = k PNil + toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) + + permCheckInverse :: Perm is -> Perm is' -> StaticShX sh + -> Maybe (Permute is' (Permute is sh) :~: sh) + permCheckInverse perm perminv ssh = + ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh + +type family MapSucc is where + MapSucc '[] = '[] + MapSucc (i : is) = i + 1 : MapSucc is + +permShift1 :: Perm l -> Perm (0 : MapSucc l) +permShift1 = (SNat @0 `PCons`) . permMapSucc + where + permMapSucc :: Perm l -> Perm (MapSucc l) + permMapSucc PNil = PNil + permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns + + +-- * Lemmas + +lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is +lemRankPermute _ PNil = Refl +lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl + +lemRankDropLen :: forall is sh. (Rank is <= Rank sh) + => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is +lemRankDropLen ZKX PNil = Refl +lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!% _) PNil = Refl +lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" + +lemIndexSucc :: Proxy i -> Proxy a -> Proxy l + -> Index (i + 1) (a : l) :~: Index i l +lemIndexSucc _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs new file mode 100644 index 0000000..9778c54 --- /dev/null +++ b/src/Data/Array/Nested/Ranked.hs @@ -0,0 +1,323 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Ranked ( + Ranked(Ranked), + rquotArray, rremArray, ratan2Array, + rshape, rrank, + module Data.Array.Nested.Ranked, + liftRanked1, liftRanked2, +) where + +import Prelude hiding (mappend, mconcat) + +import Data.Array.RankedS qualified as S +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Nested.Convert +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Ranked.Base +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X + + +remptyArray :: KnownElt a => Ranked 1 a +remptyArray = mtoRanked (memptyArray ZSX) + +-- | The total number of elements in the array. +rsize :: Elt a => Ranked n a -> Int +rsize = shrSize . rshape + +rindex :: Elt a => Ranked n a -> IIxR n -> a +rindex (Ranked arr) idx = mindex arr (ixxFromIxR idx) + +rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a +rindexPartial (Ranked arr) idx = + Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) + (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr) + (ixxFromIxR idx)) + +-- | __WARNING__: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details. +rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a +rgenerate sh f + | sn@SNat <- shrRank sh + , Dict <- lemKnownReplicate sn + , Refl <- lemRankReplicate sn + = Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX)) + +-- | See the documentation of 'mlift'. +rlift :: forall n1 n2 a. Elt a + => SNat n2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) + -> Ranked n1 a -> Ranked n2 a +rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) + +-- | See the documentation of 'mlift2'. +rlift2 :: forall n1 n2 n3 a. Elt a + => SNat n3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) + -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a +rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) + +rsumOuter1P :: forall n a. + (Storable a, NumElt a) + => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) +rsumOuter1P (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = Ranked (msumOuter1P arr) + +rsumOuter1 :: forall n a. (NumElt a, PrimElt a) + => Ranked (n + 1) a -> Ranked n a +rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive + +rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a +rsumAllPrim (Ranked arr) = msumAllPrim arr + +rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a +rtranspose perm arr + | sn@SNat <- rrank arr + , Dict <- lemKnownReplicate sn + , length perm <= fromIntegral (natVal (Proxy @n)) + = rlift sn + (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm) + arr + | otherwise + = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" + +rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a +rconcat + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce mconcat + +rappend :: forall n a. Elt a + => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a +rappend arr1 arr2 + | sn@SNat <- rrank arr1 + , Dict <- lemKnownReplicate sn + , Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) + arr1 arr2 + +rscalar :: Elt a => a -> Ranked 0 a +rscalar x = Ranked (mscalar x) + +rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVectorP sh v + | Dict <- lemKnownReplicate (shrRank sh) + = Ranked (mfromVectorP (shxFromShR sh) v) + +rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a +rfromVector sh v = rfromPrimitive (rfromVectorP sh v) + +rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a +rtoVectorP = coerce mtoVectorP + +rtoVector :: PrimElt a => Ranked n a -> VS.Vector a +rtoVector = coerce mtoVector + +rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a +rfromList1 l = Ranked (mfromList1 l) + +rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a +rfromListOuter l + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) + +rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a +rfromListLinear sh l = rreshape sh (rfromList1 l) + +rfromListPrim :: PrimElt a => [a] -> Ranked 1 a +rfromListPrim l = Ranked (mfromListPrim l) + +rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a +rfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr) + +rtoList :: Elt a => Ranked 1 a -> [a] +rtoList = map runScalar . rtoListOuter + +rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoListOuter (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) + +rtoListLinear :: Elt a => Ranked n a -> [a] +rtoListLinear (Ranked arr) = mtoListLinear arr + +rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a +rfromOrthotope sn arr + | Refl <- lemRankReplicate sn + = let xarr = XArray arr + in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) + +rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a +rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) + | Refl <- lemRankReplicate (shrRank $ shrFromShX2 sh) + = arr + +runScalar :: Elt a => Ranked 0 a -> a +runScalar arr = rindex arr ZIR + +rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a) +rnest n arr + | Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat)) + = coerce (mnest (ssxFromSNat n) (coerce arr)) + +runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a +runNest rarr@(Ranked (M_Ranked (M_Nest _ arr))) + | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat)) + = Ranked arr + +rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b) +rzip = coerce mzip + +runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b) +runzip = coerce munzip + +rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) + => SNat n -> IShR n2 + -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) + -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) +rrerankP sn sh2 f (Ranked arr) + | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) + , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) + = Ranked (mrerankP (ssxFromSNat sn) (shxFromShR sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr) + +-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the +-- input array, then there is no way to deduce the full shape of the output +-- array (more precisely, the @n2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the @n2@ part of the output shape with zeros. +-- +-- For example, if: +-- +-- @ +-- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] +-- f :: Ranked 2 Int -> Ranked 3 Float +-- @ +-- +-- then: +-- +-- @ +-- rrerank _ _ _ f arr :: Ranked 6 Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the +-- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended +-- to return an array with shape all-0 here (it probably didn't), but there is +-- no better number to put here absent a subarray of the input to pass to @f@. +rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) + => SNat n -> IShR n2 + -> (Ranked n1 a -> Ranked n2 b) + -> Ranked (n + n1) a -> Ranked (n + n2) b +rrerank sn sh2 f (rtoPrimitive -> arr) = + rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr + +rreplicate :: forall n m a. Elt a + => IShR n -> Ranked m a -> Ranked (n + m) a +rreplicate sh (Ranked arr) + | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat)) + = Ranked (mreplicate (shxFromShR sh) arr) + +rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) +rreplicateScalP sh x + | Dict <- lemKnownReplicate (shrRank sh) + = Ranked (mreplicateScalP (shxFromShR sh) x) + +rreplicateScal :: forall n a. PrimElt a + => IShR n -> a -> Ranked n a +rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) + +rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a +rslice i n arr + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = rlift (rrank arr) + (\_ -> X.sliceU i n) + arr + +rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a +rrev1 arr = + rlift (rrank arr) + (\(_ :: StaticShX sh') -> + case lemReplicateSucc @(Nothing @Nat) @n of + Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) + arr + +rreshape :: forall n n' a. Elt a + => IShR n' -> Ranked n a -> Ranked n' a +rreshape sh' rarr@(Ranked arr) + | Dict <- lemKnownReplicate (rrank rarr) + , Dict <- lemKnownReplicate (shrRank sh') + = Ranked (mreshape (shxFromShR sh') arr) + +rflatten :: Elt a => Ranked n a -> Ranked 1 a +rflatten (Ranked arr) = mtoRanked (mflatten arr) + +riota :: (Enum a, PrimElt a) => Int -> Ranked 1 a +riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota + +-- | Throws if the array is empty. +rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n +rminIndexPrim rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) + = ixrFromIxX (mminIndexPrim arr) + +-- | Throws if the array is empty. +rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n +rmaxIndexPrim rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) + = ixrFromIxX (mmaxIndexPrim arr) + +rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a +rdot1Inner arr1 arr2 + | SNat <- rrank arr1 + , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat)) + = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2 + +-- | This has a temporary, suboptimal implementation in terms of 'mflatten'. +-- Prefer 'rdot1Inner' if applicable. +rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a +rdot = coerce mdot + +rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) +rtoXArrayPrimP (Ranked arr) = first shrFromShX2 (mtoXArrayPrimP arr) + +rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) +rtoXArrayPrim (Ranked arr) = first shrFromShX2 (mtoXArrayPrim arr) + +rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) +rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr) + +rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a +rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr) + +rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a +rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) + +rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) +rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs new file mode 100644 index 0000000..babc809 --- /dev/null +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -0,0 +1,268 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_HADDOCK not-home #-} +module Data.Array.Nested.Ranked.Base where + +import Prelude hiding (mappend, mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad.ST +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +#ifndef OXAR_DEFAULT_SHOW_INSTANCES +import Data.Foldable (toList) +#endif + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith +import Data.Array.XArray (XArray(..)) + + +-- | A rank-typed array: the number of dimensions of the array (its /rank/) is +-- represented on the type level as a 'Nat'. +-- +-- Valid elements of a ranked arrays are described by the 'Elt' type class. +-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are +-- supported (and are represented as a single, flattened, struct-of-arrays +-- array internally). +-- +-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. +type Ranked :: Nat -> Type -> Type +newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) +#endif +deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) +deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) + +#ifndef OXAR_DEFAULT_SHOW_INSTANCES +instance (Show a, Elt a) => Show (Ranked n a) where + showsPrec d arr@(Ranked marr) = + let sh = show (toList (rshape arr)) + in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr +#endif + +instance Elt a => NFData (Ranked n a) where + rnf (Ranked arr) = rnf arr + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) + deriving (Generic) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a)) +#endif + +deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a)) + +newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) + +-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; +-- these instances allow them to also be used as elements of arrays, thus +-- making them first-class in the API. +instance Elt a => Elt (Ranked n a) where + mshape (M_Ranked arr) = mshape arr + mindex (M_Ranked arr) i = Ranked (mindex arr i) + + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) + mindexPartial (M_Ranked arr) i = + coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ + mindexPartial arr i + + mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) + + mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) + mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) + + mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] + mtoListOuter (M_Ranked arr) = + coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) + mlift ssh2 f (M_Ranked arr) = + coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ + mlift ssh2 f arr + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) + mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = + coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ + mlift2 ssh3 f arr1 arr2 + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a)) + mliftL ssh2 f l = + coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a))) + @(NonEmpty (Mixed sh2 (Ranked n a))) $ + mliftL ssh2 f (coerce l) + + mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr) + + mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) + + mconcat l = M_Ranked (mconcat (coerce l)) + + mrnf (M_Ranked arr) = mrnf arr + + type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) + + mshapeTree (Ranked arr) = first shrFromShX2 (mshapeTree arr) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Ranked arr) = marrayStrides arr + + mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWrite sh idx (Ranked arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsWritePartial :: forall sh sh' s. + IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) + -> MixedVecs s (sh ++ sh') (Ranked n a) + -> ST s () + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh' (Ranked n a)) + @(Mixed sh' (Mixed (Replicate n Nothing) a)) + arr) + (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) + @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsFreeze sh vecs = + coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) + @(Mixed sh (Ranked n a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh (Ranked n a)) + @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + +instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where + memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) + memptyArrayUnsafe i + | Dict <- lemKnownReplicate (SNat @n) + = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ + memptyArrayUnsafe i + + mvecsUnsafeNew idx (Ranked arr) + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsUnsafeNew idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) + + +liftRanked1 :: forall n a b. + (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b) + -> Ranked n a -> Ranked n b +liftRanked1 = coerce + +liftRanked2 :: forall n a b c. + (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c) + -> Ranked n a -> Ranked n b -> Ranked n c +liftRanked2 = coerce + +instance (NumElt a, PrimElt a) => Num (Ranked n a) where + (+) = liftRanked2 (+) + (-) = liftRanked2 (-) + (*) = liftRanked2 (*) + negate = liftRanked1 negate + abs = liftRanked1 abs + signum = liftRanked1 signum + fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where + fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal" + recip = liftRanked1 recip + (/) = liftRanked2 (/) + +instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where + pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal" + exp = liftRanked1 exp + log = liftRanked1 log + sqrt = liftRanked1 sqrt + (**) = liftRanked2 (**) + logBase = liftRanked2 logBase + sin = liftRanked1 sin + cos = liftRanked1 cos + tan = liftRanked1 tan + asin = liftRanked1 asin + acos = liftRanked1 acos + atan = liftRanked1 atan + sinh = liftRanked1 sinh + cosh = liftRanked1 cosh + tanh = liftRanked1 tanh + asinh = liftRanked1 asinh + acosh = liftRanked1 acosh + atanh = liftRanked1 atanh + log1p = liftRanked1 GHC.Float.log1p + expm1 = liftRanked1 GHC.Float.expm1 + log1pexp = liftRanked1 GHC.Float.log1pexp + log1mexp = liftRanked1 GHC.Float.log1mexp + +rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a +rquotArray = liftRanked2 mquotArray +rremArray = liftRanked2 mremArray + +ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a +ratan2Array = liftRanked2 matan2Array + + +rshape :: Elt a => Ranked n a -> IShR n +rshape (Ranked arr) = shrFromShX2 (mshape arr) + +rrank :: Elt a => Ranked n a -> SNat n +rrank = shrRank . rshape + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) +shrFromShX ZSX = ZSR +shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +-- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. +shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n +shrFromShX2 sh + | Refl <- lemRankReplicate (Proxy @n) + = shrFromShX sh diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs new file mode 100644 index 0000000..8b670e5 --- /dev/null +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -0,0 +1,369 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Ranked.Shape where + +import Control.DeepSeq (NFData(..)) +import Data.Coerce (coerce) +import Data.Foldable qualified as Foldable +import Data.Kind (Type) +import Data.Proxy +import Data.Type.Equality +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Types + + +-- * Ranked lists + +type role ListR nominal representational +type ListR :: Nat -> Type -> Type +data ListR n i where + ZR :: ListR 0 i + (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i +deriving instance Eq i => Eq (ListR n i) +deriving instance Ord i => Ord (ListR n i) +deriving instance Functor (ListR n) +deriving instance Foldable (ListR n) +infixr 3 ::: + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ListR n i) +#else +instance Show i => Show (ListR n i) where + showsPrec _ = listrShow shows +#endif + +instance NFData i => NFData (ListR n i) where + rnf ZR = () + rnf (x ::: l) = rnf x `seq` rnf l + +data UnconsListRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i +listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) +listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) +listrUncons ZR = Nothing + +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') +listrEqRank ZR ZR = Just Refl +listrEqRank (_ ::: sh) (_ ::: sh') + | Just Refl <- listrEqRank sh sh' + = Just Refl +listrEqRank _ _ = Nothing + +-- | This compares the lists for value equality. +listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n') +listrEqual ZR ZR = Just Refl +listrEqual (i ::: sh) (j ::: sh') + | Just Refl <- listrEqual sh sh' + , i == j + = Just Refl +listrEqual _ _ = Nothing + +listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS +listrShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListR n' i -> ShowS + go _ ZR = id + go prefix (x ::: xs) = showString prefix . f x . go "," xs + +listrLength :: ListR n i -> Int +listrLength = length + +listrRank :: ListR n i -> SNat n +listrRank ZR = SNat +listrRank (_ ::: sh) = snatSucc (listrRank sh) + +listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i +listrAppend ZR sh = sh +listrAppend (x ::: xs) sh = x ::: listrAppend xs sh + +listrFromList :: [i] -> (forall n. ListR n i -> r) -> r +listrFromList [] k = k ZR +listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) + +listrHead :: ListR (n + 1) i -> i +listrHead (i ::: _) = i +listrHead ZR = error "unreachable" + +listrTail :: ListR (n + 1) i -> ListR n i +listrTail (_ ::: sh) = sh +listrTail ZR = error "unreachable" + +listrInit :: ListR (n + 1) i -> ListR n i +listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh +listrInit (_ ::: ZR) = ZR +listrInit ZR = error "unreachable" + +listrLast :: ListR (n + 1) i -> i +listrLast (_ ::: sh@(_ ::: _)) = listrLast sh +listrLast (n ::: ZR) = n +listrLast ZR = error "unreachable" + +-- | Performs a runtime check that the lengths are identical. +listrCast :: SNat n' -> ListR n i -> ListR n' i +listrCast = listrCastWithName "listrCast" + +listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i +listrIndex SZ (x ::: _) = x +listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs +listrIndex _ ZR = error "k + 1 <= 0" + +listrZip :: ListR n i -> ListR n j -> ListR n (i, j) +listrZip ZR ZR = ZR +listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest +listrZip _ _ = error "listrZip: impossible pattern needlessly required" + +listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k +listrZipWith _ ZR ZR = ZR +listrZipWith f (i ::: irest) (j ::: jrest) = + f i j ::: listrZipWith f irest jrest +listrZipWith _ _ _ = + error "listrZipWith: impossible pattern needlessly required" + +listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i +listrPermutePrefix = \perm sh -> + listrFromList perm $ \sperm -> + case (listrRank sperm, listrRank sh) of + (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of + LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + where + listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) + listrSplitAt SZ sh = (ZR, sh) + listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) + listrSplitAt SS{} ZR = error "m' + 1 <= 0" + + applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i + applyPermRFull _ ZR _ = ZR + applyPermRFull sm@SNat (i ::: perm) l = + TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> + case cmpNat (SNat @(idx + 1)) sm of + LTI -> listrIndex si l ::: applyPermRFull sm perm l + EQI -> listrIndex si l ::: applyPermRFull sm perm l + GTI -> error "listrPermutePrefix: Index in permutation out of range" + + +-- * Ranked indices + +-- | An index into a rank-typed array. +type role IxR nominal representational +type IxR :: Nat -> Type -> Type +newtype IxR n i = IxR (ListR n i) + deriving (Eq, Ord, Generic) + deriving newtype (Functor, Foldable) + +pattern ZIR :: forall n i. () => n ~ 0 => IxR n i +pattern ZIR = IxR ZR + +pattern (:.:) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> IxR n i -> IxR n1 i +pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) + where i :.: IxR sh = IxR (i ::: sh) +infixr 3 :.: + +{-# COMPLETE ZIR, (:.:) #-} + +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). +type IIxR n = IxR n Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (IxR n i) +#else +instance Show i => Show (IxR n i) where + showsPrec _ (IxR l) = listrShow shows l +#endif + +instance NFData i => NFData (IxR sh i) + +ixrLength :: IxR sh i -> Int +ixrLength (IxR l) = listrLength l + +ixrRank :: IxR n i -> SNat n +ixrRank (IxR sh) = listrRank sh + +ixrZero :: SNat n -> IIxR n +ixrZero SZ = ZIR +ixrZero (SS n) = 0 :.: ixrZero n + +ixrHead :: IxR (n + 1) i -> i +ixrHead (IxR list) = listrHead list + +ixrTail :: IxR (n + 1) i -> IxR n i +ixrTail (IxR list) = IxR (listrTail list) + +ixrInit :: IxR (n + 1) i -> IxR n i +ixrInit (IxR list) = IxR (listrInit list) + +ixrLast :: IxR (n + 1) i -> i +ixrLast (IxR list) = listrLast list + +-- | Performs a runtime check that the lengths are identical. +ixrCast :: SNat n' -> IxR n i -> IxR n' i +ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx) + +ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i +ixrAppend = coerce (listrAppend @_ @i) + +ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) +ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 + +ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k +ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 + +ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i +ixrPermutePrefix = coerce (listrPermutePrefix @i) + + +-- * Ranked shapes + +type role ShR nominal representational +type ShR :: Nat -> Type -> Type +newtype ShR n i = ShR (ListR n i) + deriving (Eq, Ord, Generic) + deriving newtype (Functor, Foldable) + +pattern ZSR :: forall n i. () => n ~ 0 => ShR n i +pattern ZSR = ShR ZR + +pattern (:$:) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> ShR n i -> ShR n1 i +pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) + where i :$: ShR sh = ShR (i ::: sh) +infixr 3 :$: + +{-# COMPLETE ZSR, (:$:) #-} + +type IShR n = ShR n Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ShR n i) +#else +instance Show i => Show (ShR n i) where + showsPrec _ (ShR l) = listrShow shows l +#endif + +instance NFData i => NFData (ShR sh i) + +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') +shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' + +-- | This compares the shapes for value equality. +shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') +shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' + +shrLength :: ShR sh i -> Int +shrLength (ShR l) = listrLength l + +-- | This function can also be used to conjure up a 'KnownNat' dictionary; +-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern +-- synonym yields 'KnownNat' evidence. +shrRank :: ShR n i -> SNat n +shrRank (ShR sh) = listrRank sh + +-- | The number of elements in an array described by this shape. +shrSize :: IShR n -> Int +shrSize ZSR = 1 +shrSize (n :$: sh) = n * shrSize sh + +shrHead :: ShR (n + 1) i -> i +shrHead (ShR list) = listrHead list + +shrTail :: ShR (n + 1) i -> ShR n i +shrTail (ShR list) = ShR (listrTail list) + +shrInit :: ShR (n + 1) i -> ShR n i +shrInit (ShR list) = ShR (listrInit list) + +shrLast :: ShR (n + 1) i -> i +shrLast (ShR list) = listrLast list + +-- | Performs a runtime check that the lengths are identical. +shrCast :: SNat n' -> ShR n i -> ShR n' i +shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh) + +shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i +shrAppend = coerce (listrAppend @_ @i) + +shrZip :: ShR n i -> ShR n j -> ShR n (i, j) +shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 + +shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k +shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 + +shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i +shrPermutePrefix = coerce (listrPermutePrefix @i) + + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ListR n i) where + type Item (ListR n i) = i + fromList topl = go (SNat @n) topl + where + go :: SNat n' -> [i] -> ListR n' i + go SZ [] = ZR + go (SS n) (i : is) = i ::: go n is + go _ _ = error $ "IsList(ListR): Mismatched list length (type says " + ++ show (fromSNat (SNat @n)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = Foldable.toList + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (IxR n i) where + type Item (IxR n i) = i + fromList = IxR . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ShR n i) where + type Item (ShR n i) = i + fromList = ShR . IsList.fromList + toList = Foldable.toList + + +-- * Internal helper functions + +listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i +listrCastWithName _ SZ ZR = ZR +listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx +listrCastWithName name _ _ = error $ name ++ ": ranks don't match" diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs new file mode 100644 index 0000000..198a068 --- /dev/null +++ b/src/Data/Array/Nested/Shaped.hs @@ -0,0 +1,272 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +module Data.Array.Nested.Shaped ( + Shaped(Shaped), + squotArray, sremArray, satan2Array, + sshape, + module Data.Array.Nested.Shaped, + liftShaped1, liftShaped2, +) where + +import Prelude hiding (mappend, mconcat) + +import Data.Array.Internal.RankedG qualified as RG +import Data.Array.Internal.RankedS qualified as RS +import Data.Array.Internal.ShapedG qualified as SG +import Data.Array.Internal.ShapedS qualified as SS +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.TypeLits + +import Data.Array.Nested.Convert +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Shaped.Base +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) +import Data.Array.XArray qualified as X + + +semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a +semptyArray sh = Shaped (memptyArray (shxFromShS sh)) + +srank :: Elt a => Shaped sh a -> SNat (Rank sh) +srank = shsRank . sshape + +-- | The total number of elements in the array. +ssize :: Elt a => Shaped sh a -> Int +ssize = shsSize . sshape + +sindex :: Elt a => Shaped sh a -> IIxS sh -> a +sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx) + +shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh +shsTakeIx _ _ ZIS = ZSS +shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx + +sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a +sindexPartial sarr@(Shaped arr) idx = + Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) + (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) + (ixxFromIxS idx)) + +-- | __WARNING__: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details. +sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a +sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh)) + +-- | See the documentation of 'mlift'. +slift :: forall sh1 sh2 a. Elt a + => ShS sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) + -> Shaped sh1 a -> Shaped sh2 a +slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr) + +-- | See the documentation of 'mlift'. +slift2 :: forall sh1 sh2 sh3 a. Elt a + => ShS sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) + -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a +slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2) + +ssumOuter1P :: forall sh n a. (Storable a, NumElt a) + => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) +ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) + +ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) + => Shaped (n : sh) a -> Shaped sh a +ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive + +ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a +ssumAllPrim (Shaped arr) = msumAllPrim arr + +stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) + => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a +stranspose perm sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + , Refl <- lemTakeLenMapJust perm (sshape sarr) + , Refl <- lemDropLenMapJust perm (sshape sarr) + , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr)) + , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) + = Shaped (mtranspose perm arr) + +sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a +sappend = coerce mappend + +sscalar :: Elt a => a -> Shaped '[] a +sscalar x = Shaped (mscalar x) + +sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) +sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v) + +sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a +sfromVector sh v = sfromPrimitive (sfromVectorP sh v) + +stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a +stoVectorP = coerce mtoVectorP + +stoVector :: PrimElt a => Shaped sh a -> VS.Vector a +stoVector = coerce mtoVector + +sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a +sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1 + +sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) + +sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a +sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l) + +sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a +sfromListPrim sn l + | Refl <- lemAppNil @'[Just n] + = let ssh = SUnknown () :!% ZKX + xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) + in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr + +sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a +sfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr) + +stoList :: Elt a => Shaped '[n] a -> [a] +stoList = map sunScalar . stoListOuter + +stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] +stoListOuter (Shaped arr) = coerce (mtoListOuter arr) + +stoListLinear :: Elt a => Shaped sh a -> [a] +stoListLinear (Shaped arr) = mtoListLinear arr + +sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a +sfromOrthotope sh (SS.A (SG.A arr)) = + Shaped (fromPrimitive (M_Primitive (shxFromShS sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) + +stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a +stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr) + +sunScalar :: Elt a => Shaped '[] a -> a +sunScalar arr = sindex arr ZIS + +snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a) +snest sh arr + | Refl <- lemMapJustApp sh (Proxy @sh') + = coerce (mnest (ssxFromShX (shxFromShS sh)) (coerce arr)) + +sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a +sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) + | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh') + = Shaped arr + +szip :: (Elt a, Elt b) => Shaped sh a -> Shaped sh b -> Shaped sh (a, b) +szip = coerce mzip + +sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b) +sunzip = coerce munzip + +srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => ShS sh -> ShS sh2 + -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) + -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) +srerankP sh sh2 f sarr@(Shaped arr) + | Refl <- lemMapJustApp sh (Proxy @sh1) + , Refl <- lemMapJustApp sh (Proxy @sh2) + = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (ssxFromShX (shxFromShS sh)) (shxFromShS (sshape sarr)))) + (shxFromShS sh2) + (\a -> let Shaped r = f (Shaped a) in r) + arr) + +-- | See the caveats at 'Data.Array.XArray.rerank'. +srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => ShS sh -> ShS sh2 + -> (Shaped sh1 a -> Shaped sh2 b) + -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b +srerank sh sh2 f (stoPrimitive -> arr) = + sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr + +sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a +sreplicate sh (Shaped arr) + | Refl <- lemMapJustApp sh (Proxy @sh') + = Shaped (mreplicate (shxFromShS sh) arr) + +sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) +sreplicateScalP sh x = Shaped (mreplicateScalP (shxFromShS sh) x) + +sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a +sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) + +sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a +sslice i n@SNat arr = + let _ :$$ sh = sshape arr + in slift (n :$$ sh) (\_ -> X.slice i n) arr + +srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a +srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr + +sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a +sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr) + +sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a +sflatten arr = + case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff + n@SNat -> sreshape (n :$$ ZSS) arr + +siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a +siota sn = Shaped (miota sn) + +-- | Throws if the array is empty. +sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh +sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr) + +-- | Throws if the array is empty. +smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh +smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) + +sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) + => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a +sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) + | Refl <- lemInitApp (Proxy @sh) (Proxy @n) + , Refl <- lemLastApp (Proxy @sh) (Proxy @n) + = case sshape sarr1 of + _ :$$ _ + | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n]) + -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) + _ -> error "unreachable" + +-- | This has a temporary, suboptimal implementation in terms of 'mflatten'. +-- Prefer 'sdot1Inner' if applicable. +sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a +sdot = coerce mdot + +stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) +stoXArrayPrimP (Shaped arr) = first shsFromShX (mtoXArrayPrimP arr) + +stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) +stoXArrayPrim (Shaped arr) = first shsFromShX (mtoXArrayPrim arr) + +sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShX (shxFromShS sh)) arr) + +sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShX (shxFromShS sh)) arr) + +sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a +sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) + +stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) +stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs new file mode 100644 index 0000000..879e6b5 --- /dev/null +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -0,0 +1,255 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_HADDOCK not-home #-} +module Data.Array.Nested.Shaped.Base where + +import Prelude hiding (mappend, mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad.ST +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) + + +-- | A shape-typed array: the full shape of the array (the sizes of its +-- dimensions) is represented on the type level as a list of 'Nat's. Note that +-- these are "GHC.TypeLits" naturals, because we do not need induction over +-- them and we want very large arrays to be possible. +-- +-- Like for 'Ranked', the valid elements are described by the 'Elt' type class, +-- and 'Shaped' itself is again an instance of 'Elt' as well. +-- +-- 'Shaped' is a newtype around a 'Mixed' of 'Just's. +type Shaped :: [Nat] -> Type -> Type +newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) +#endif +deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) +deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) + +#ifndef OXAR_DEFAULT_SHOW_INSTANCES +instance (Show a, Elt a) => Show (Shaped n a) where + showsPrec d arr@(Shaped marr) = + let sh = show (shsToList (sshape arr)) + in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr +#endif + +instance Elt a => NFData (Shaped sh a) where + rnf (Shaped arr) = rnf arr + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) + deriving (Generic) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a)) +#endif + +deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a)) + +newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) + +instance Elt a => Elt (Shaped sh a) where + mshape (M_Shaped arr) = mshape arr + mindex (M_Shaped arr) i = Shaped (mindex arr i) + + mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + mindexPartial (M_Shaped arr) i = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mindexPartial arr i + + mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) + + mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) + mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) + + mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] + mtoListOuter (M_Shaped arr) + = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) + mlift ssh2 f (M_Shaped arr) = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mlift ssh2 f arr + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) + mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = + coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ + mlift2 ssh3 f arr1 arr2 + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) + mliftL ssh2 f l = + coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) + @(NonEmpty (Mixed sh2 (Shaped sh a))) $ + mliftL ssh2 f (coerce l) + + mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr) + + mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) + + mconcat l = M_Shaped (mconcat (coerce l)) + + mrnf (M_Shaped arr) = mrnf arr + + type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) + + mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Shaped arr) = marrayStrides arr + + mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWrite sh idx (Shaped arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + + mvecsWritePartial :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh2 (Shaped sh a)) + @(Mixed sh2 (Mixed (MapJust sh) a)) + arr) + (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) + @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) + vecs) + + mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + +instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where + memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) + memptyArrayUnsafe i + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ + memptyArrayUnsafe i + + mvecsUnsafeNew idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsUnsafeNew idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) + + +liftShaped1 :: forall sh a b. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b) + -> Shaped sh a -> Shaped sh b +liftShaped1 = coerce + +liftShaped2 :: forall sh a b c. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c) + -> Shaped sh a -> Shaped sh b -> Shaped sh c +liftShaped2 = coerce + +instance (NumElt a, PrimElt a) => Num (Shaped sh a) where + (+) = liftShaped2 (+) + (-) = liftShaped2 (-) + (*) = liftShaped2 (*) + negate = liftShaped1 negate + abs = liftShaped1 abs + signum = liftShaped1 signum + fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where + fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" + recip = liftShaped1 recip + (/) = liftShaped2 (/) + +instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" + exp = liftShaped1 exp + log = liftShaped1 log + sqrt = liftShaped1 sqrt + (**) = liftShaped2 (**) + logBase = liftShaped2 logBase + sin = liftShaped1 sin + cos = liftShaped1 cos + tan = liftShaped1 tan + asin = liftShaped1 asin + acos = liftShaped1 acos + atan = liftShaped1 atan + sinh = liftShaped1 sinh + cosh = liftShaped1 cosh + tanh = liftShaped1 tanh + asinh = liftShaped1 asinh + acosh = liftShaped1 acosh + atanh = liftShaped1 atanh + log1p = liftShaped1 GHC.Float.log1p + expm1 = liftShaped1 GHC.Float.expm1 + log1pexp = liftShaped1 GHC.Float.log1pexp + log1mexp = liftShaped1 GHC.Float.log1mexp + +squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +squotArray = liftShaped2 mquotArray +sremArray = liftShaped2 mremArray + +satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +satan2Array = liftShaped2 matan2Array + + +sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh +sshape (Shaped arr) = shsFromShX (mshape arr) + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh +shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS +shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) = + castWith (subst1 (sym (lemMapJustCons Refl))) $ + n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) + idx) +shsFromShX (SUnknown _ :$% _) = error "impossible" diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 97b9456..5f9ba79 100644 --- a/src/Data/Array/Nested/Internal/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -1,11 +1,9 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} @@ -24,10 +22,9 @@ {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Nested.Internal.Shape where +module Data.Array.Nested.Shaped.Shape where import Control.DeepSeq (NFData(..)) -import Data.Array.Mixed.Types import Data.Array.Shape qualified as O import Data.Coerce (coerce) import Data.Foldable qualified as Foldable @@ -42,333 +39,16 @@ import GHC.Generics (Generic) import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits -import GHC.TypeNats qualified as TN - -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape - - -type role ListR nominal representational -type ListR :: Nat -> Type -> Type -data ListR n i where - ZR :: ListR 0 i - (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i -deriving instance Eq i => Eq (ListR n i) -deriving instance Ord i => Ord (ListR n i) -deriving instance Functor (ListR n) -deriving instance Foldable (ListR n) -infixr 3 ::: - -instance Show i => Show (ListR n i) where - showsPrec _ = listrShow shows - -instance NFData i => NFData (ListR n i) where - rnf ZR = () - rnf (x ::: l) = rnf x `seq` rnf l - -data UnconsListRRes i n1 = - forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i -listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) -listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) -listrUncons ZR = Nothing - --- | This checks only whether the ranks are equal, not whether the actual --- values are. -listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') -listrEqRank ZR ZR = Just Refl -listrEqRank (_ ::: sh) (_ ::: sh') - | Just Refl <- listrEqRank sh sh' - = Just Refl -listrEqRank _ _ = Nothing - --- | This compares the lists for value equality. -listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n') -listrEqual ZR ZR = Just Refl -listrEqual (i ::: sh) (j ::: sh') - | Just Refl <- listrEqual sh sh' - , i == j - = Just Refl -listrEqual _ _ = Nothing - -listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS -listrShow f l = showString "[" . go "" l . showString "]" - where - go :: String -> ListR n' i -> ShowS - go _ ZR = id - go prefix (x ::: xs) = showString prefix . f x . go "," xs - -listrLength :: ListR n i -> Int -listrLength = length - -listrRank :: ListR n i -> SNat n -listrRank ZR = SNat -listrRank (_ ::: sh) = snatSucc (listrRank sh) - -listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i -listrAppend ZR sh = sh -listrAppend (x ::: xs) sh = x ::: listrAppend xs sh - -listrFromList :: [i] -> (forall n. ListR n i -> r) -> r -listrFromList [] k = k ZR -listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) - -listrHead :: ListR (n + 1) i -> i -listrHead (i ::: _) = i -listrHead ZR = error "unreachable" - -listrTail :: ListR (n + 1) i -> ListR n i -listrTail (_ ::: sh) = sh -listrTail ZR = error "unreachable" - -listrInit :: ListR (n + 1) i -> ListR n i -listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh -listrInit (_ ::: ZR) = ZR -listrInit ZR = error "unreachable" - -listrLast :: ListR (n + 1) i -> i -listrLast (_ ::: sh@(_ ::: _)) = listrLast sh -listrLast (n ::: ZR) = n -listrLast ZR = error "unreachable" - -listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i -listrIndex SZ (x ::: _) = x -listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs -listrIndex _ ZR = error "k + 1 <= 0" - -listrZip :: ListR n i -> ListR n j -> ListR n (i, j) -listrZip ZR ZR = ZR -listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest -listrZip _ _ = error "listrZip: impossible pattern needlessly required" - -listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k -listrZipWith _ ZR ZR = ZR -listrZipWith f (i ::: irest) (j ::: jrest) = - f i j ::: listrZipWith f irest jrest -listrZipWith _ _ _ = - error "listrZipWith: impossible pattern needlessly required" - -listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i -listrPermutePrefix = \perm sh -> - listrFromList perm $ \sperm -> - case (listrRank sperm, listrRank sh) of - (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of - LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post - GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" - ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" - where - listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) - listrSplitAt SZ sh = (ZR, sh) - listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) - listrSplitAt SS{} ZR = error "m' + 1 <= 0" - - applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i - applyPermRFull _ ZR _ = ZR - applyPermRFull sm@SNat (i ::: perm) l = - TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> - case cmpNat (SNat @(idx + 1)) sm of - LTI -> listrIndex si l ::: applyPermRFull sm perm l - EQI -> listrIndex si l ::: applyPermRFull sm perm l - GTI -> error "listrPermutePrefix: Index in permutation out of range" - - --- | An index into a rank-typed array. -type role IxR nominal representational -type IxR :: Nat -> Type -> Type -newtype IxR n i = IxR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) - -pattern ZIR :: forall n i. () => n ~ 0 => IxR n i -pattern ZIR = IxR ZR - -pattern (:.:) - :: forall {n1} {i}. - forall n. (n + 1 ~ n1) - => i -> IxR n i -> IxR n1 i -pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) - where i :.: IxR sh = IxR (i ::: sh) -infixr 3 :.: - -{-# COMPLETE ZIR, (:.:) #-} - -type IIxR n = IxR n Int - -instance Show i => Show (IxR n i) where - showsPrec _ (IxR l) = listrShow shows l - -instance NFData i => NFData (IxR sh i) - -ixrLength :: IxR sh i -> Int -ixrLength (IxR l) = listrLength l - -ixrRank :: IxR n i -> SNat n -ixrRank (IxR sh) = listrRank sh - -ixrZero :: SNat n -> IIxR n -ixrZero SZ = ZIR -ixrZero (SS n) = 0 :.: ixrZero n - -ixCvtXR :: IIxX sh -> IIxR (Rank sh) -ixCvtXR ZIX = ZIR -ixCvtXR (n :.% idx) = n :.: ixCvtXR idx - -ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) -ixCvtRX ZIR = ZIX -ixCvtRX (n :.: (idx :: IxR m Int)) = - castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (n :.% ixCvtRX idx) - -ixrHead :: IxR (n + 1) i -> i -ixrHead (IxR list) = listrHead list - -ixrTail :: IxR (n + 1) i -> IxR n i -ixrTail (IxR list) = IxR (listrTail list) - -ixrInit :: IxR (n + 1) i -> IxR n i -ixrInit (IxR list) = IxR (listrInit list) - -ixrLast :: IxR (n + 1) i -> i -ixrLast (IxR list) = listrLast list - -ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i -ixrAppend = coerce (listrAppend @_ @i) - -ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) -ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 - -ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k -ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 - -ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i -ixrPermutePrefix = coerce (listrPermutePrefix @i) - - -type role ShR nominal representational -type ShR :: Nat -> Type -> Type -newtype ShR n i = ShR (ListR n i) - deriving (Eq, Ord, Generic) - deriving newtype (Functor, Foldable) - -pattern ZSR :: forall n i. () => n ~ 0 => ShR n i -pattern ZSR = ShR ZR - -pattern (:$:) - :: forall {n1} {i}. - forall n. (n + 1 ~ n1) - => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) - where i :$: ShR sh = ShR (i ::: sh) -infixr 3 :$: - -{-# COMPLETE ZSR, (:$:) #-} - -type IShR n = ShR n Int - -instance Show i => Show (ShR n i) where - showsPrec _ (ShR l) = listrShow shows l - -instance NFData i => NFData (ShR sh i) - -shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n -shCvtXR' ZSX = - castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) - ZSR -shCvtXR' (n :$% (idx :: IShX sh)) - | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = - castWith (subst2 (lem1 @sh Refl)) - (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) - where - lem1 :: forall sh' n' k. - k : sh' :~: Replicate n' Nothing - -> Rank sh' + 1 :~: n' - lem1 Refl = unsafeCoerceRefl - - lem2 :: k : sh :~: Replicate n Nothing - -> sh :~: Replicate (Rank sh) Nothing - lem2 Refl = unsafeCoerceRefl - -shCvtRX :: IShR n -> IShX (Replicate n Nothing) -shCvtRX ZSR = ZSX -shCvtRX (n :$: (idx :: ShR m Int)) = - castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) - (SUnknown n :$% shCvtRX idx) --- | This checks only whether the ranks are equal, not whether the actual --- values are. -shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Types --- | This compares the shapes for value equality. -shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') -shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' - -shrLength :: ShR sh i -> Int -shrLength (ShR l) = listrLength l - --- | This function can also be used to conjure up a 'KnownNat' dictionary; --- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern --- synonym yields 'KnownNat' evidence. -shrRank :: ShR n i -> SNat n -shrRank (ShR sh) = listrRank sh - --- | The number of elements in an array described by this shape. -shrSize :: IShR n -> Int -shrSize ZSR = 1 -shrSize (n :$: sh) = n * shrSize sh - -shrHead :: ShR (n + 1) i -> i -shrHead (ShR list) = listrHead list - -shrTail :: ShR (n + 1) i -> ShR n i -shrTail (ShR list) = ShR (listrTail list) - -shrInit :: ShR (n + 1) i -> ShR n i -shrInit (ShR list) = ShR (listrInit list) - -shrLast :: ShR (n + 1) i -> i -shrLast (ShR list) = listrLast list - -shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i -shrAppend = coerce (listrAppend @_ @i) - -shrZip :: ShR n i -> ShR n j -> ShR n (i, j) -shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 - -shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k -shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 - -shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i -shrPermutePrefix = coerce (listrPermutePrefix @i) - - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ListR n i) where - type Item (ListR n i) = i - fromList topl = go (SNat @n) topl - where - go :: SNat n' -> [i] -> ListR n' i - go SZ [] = ZR - go (SS n) (i : is) = i ::: go n is - go _ _ = error $ "IsList(ListR): Mismatched list length (type says " - ++ show (fromSNat (SNat @n)) ++ ", list has length " - ++ show (length topl) ++ ")" - toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (IxR n i) where - type Item (IxR n i) = i - fromList = IxR . IsList.fromList - toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ShR n i) where - type Item (ShR n i) = i - fromList = ShR . IsList.fromList - toList = Foldable.toList +-- * Shaped lists +-- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be +-- removed in a future release. type role ListS nominal representational type ListS :: [Nat] -> (Nat -> Type) -> Type data ListS sh f where @@ -379,8 +59,12 @@ deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) infixr 3 ::$ +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance (forall n. Show (f n)) => Show (ListS sh f) +#else instance (forall n. Show (f n)) => Show (ListS sh f) where showsPrec _ = listsShow shows +#endif instance (forall m. NFData (f m)) => NFData (ListS n f) where rnf ZS = () @@ -497,11 +181,9 @@ listsIndex _ _ _ ZS = error "Index into empty shape" listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) +-- * Shaped indices -- | An index into a shape-typed array. --- --- For convenience, this contains regular 'Int's instead of bounded integers --- (traditionally called \"@Fin@\"). type role IxS nominal representational type IxS :: [Nat] -> Type -> Type newtype IxS sh i = IxS (ListS sh (Const i)) @@ -510,6 +192,8 @@ newtype IxS sh i = IxS (ListS sh (Const i)) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS +-- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be +-- removed in a future release. pattern (:.$) :: forall {sh1} {i}. forall n sh. (KnownNat n, n : sh ~ sh1) @@ -520,10 +204,16 @@ infixr 3 :.$ {-# COMPLETE ZIS, (:.$) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxS sh = IxS sh Int +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (IxS sh i) +#else instance Show i => Show (IxS sh i) where showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l +#endif instance Functor (IxS sh) where fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) @@ -543,14 +233,6 @@ ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh -ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh -ixCvtXS ZSS ZIX = ZIS -ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx - -ixCvtSX :: IIxS sh -> IIxX (MapJust sh) -ixCvtSX ZIS = ZIX -ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh - ixsHead :: IxS (n : sh) i -> i ixsHead (IxS list) = getConst (listsHead list) @@ -563,6 +245,12 @@ ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i ixsLast (IxS list) = getConst (listsLast list) +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i +ixsCast ZSS ZIS = ZIS +ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx +ixsCast _ _ = error "ixsCast: ranks don't match" + ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = coerce (listsAppend @_ @(Const i)) @@ -578,6 +266,8 @@ ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) +-- * Shaped shapes + -- | The shape of a shape-typed array given as a list of 'SNat' values. -- -- Note that because the shape of a shape-typed array is known statically, you @@ -601,8 +291,12 @@ infixr 3 :$$ {-# COMPLETE ZSS, (:$$) #-} +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (ShS sh) +#else instance Show (ShS sh) where showsPrec _ (ShS l) = listsShow (shows . fromSNat) l +#endif instance NFData (ShS sh) where rnf (ShS ZS) = () @@ -630,23 +324,6 @@ shsToList :: ShS sh -> [Int] shsToList ZSS = [] shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh -shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh -shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) = - castWith (subst1 (lem Refl)) $ - n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) - idx) - where - lem :: forall sh1 sh' n. - Just n : sh1 :~: MapJust sh' - -> n : Tail sh' :~: sh' - lem Refl = unsafeCoerceRefl -shCvtXS' (SUnknown _ :$% _) = error "impossible" - -shCvtSX :: ShS sh -> IShX (MapJust sh) -shCvtSX ZSS = ZSX -shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh - shsHead :: ShS (n : sh) -> SNat n shsHead (ShS list) = listsHead list @@ -690,7 +367,7 @@ instance KnownShS '[] where knownShS = ZSS instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r -withKnownShS k = withDict @(KnownShS sh) k +withKnownShS = withDict @(KnownShS sh) shsKnownShS :: ShS sh -> Dict KnownShS sh shsKnownShS ZSS = Dict @@ -700,6 +377,17 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh shsOrthotopeShape ZSS = Dict shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict +-- | This function is a hack made possible by the 'KnownNat' inside 'ListS'. +-- This function may be removed in a future release. +shsFromListS :: ListS sh f -> ShS sh +shsFromListS ZS = ZSS +shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l + +-- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This +-- function may be removed in a future release. +shsFromIxS :: IxS sh i -> ShS sh +shsFromIxS (IxS l) = shsFromListS l + -- | Untyped: length is checked at runtime. instance KnownShS sh => IsList (ListS sh (Const i)) where diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs index 838e2b0..8a29aa5 100644 --- a/src/Data/Array/Nested/Trace.hs +++ b/src/Data/Array/Nested/Trace.hs @@ -37,10 +37,12 @@ module Data.Array.Nested.Trace ( ShS(..), KnownShS(..), Mixed, + ListX(ZX, (::%)), IxX(..), IIxX, - ShX(..), KnownShX(..), + ShX(..), KnownShX(..), IShX, StaticShX(..), SMayNat(..), + Conversion(..), Elt, PrimElt, @@ -54,7 +56,7 @@ module Data.Array.Nested.Trace ( Perm(..), IsPermutation, KnownPerm(..), - NumElt, FloatElt, + NumElt, IntElt, FloatElt, Rank, Product, Replicate, MapJust, @@ -67,4 +69,4 @@ import Data.Array.Nested.Trace.TH $(concat <$> mapM convertFun - ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromListOuter, 'rfromList1, 'rfromList1Prim, 'rtoListOuter, 'rtoList1, 'rfromListLinear, 'rfromListPrimLinear, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rcastToShaped, 'rtoMixed, 'rfromOrthotope, 'rtoOrthotope, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromListOuter, 'sfromList1, 'sfromList1Prim, 'stoListOuter, 'stoList1, 'sfromListLinear, 'sfromListPrimLinear, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoRanked, 'stoMixed, 'sfromOrthotope, 'stoOrthotope, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromListOuter, 'mfromList1, 'mfromList1Prim, 'mtoListOuter, 'mtoList1, 'mfromListLinear, 'mfromListPrimLinear, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mtoRanked, 'mcastToShaped]) + ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromListOuter, 'rfromListLinear, 'rfromListPrim, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromListPrim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromListOuter, 'mfromListLinear, 'mfromListPrim, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array]) diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs new file mode 100644 index 0000000..4444acd --- /dev/null +++ b/src/Data/Array/Nested/Types.hs @@ -0,0 +1,152 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilyDependencies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Types ( + -- * Reasoning helpers + subst1, subst2, + + -- * Reified evidence of a type class + Dict(..), + + -- * Type-level naturals + pattern SZ, pattern SS, + fromSNat', sameNat', + snatPlus, snatMinus, snatMul, + snatSucc, + + -- * Type-level lists + type (++), + Replicate, + lemReplicateSucc, + MapJust, + lemMapJustEmpty, lemMapJustCons, + Head, + Tail, + Init, + Last, + + -- * Unsafe + unsafeCoerceRefl, +) where + +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits +import GHC.TypeNats qualified as TN +import Unsafe.Coerce qualified + + +-- Reasoning helpers + +subst1 :: forall f a b. a :~: b -> f a :~: f b +subst1 Refl = Refl + +subst2 :: forall f c a b. a :~: b -> f a c :~: f b c +subst2 Refl = Refl + +-- | Evidence for the constraint @c a@. +data Dict c a where + Dict :: c a => Dict c a + +fromSNat' :: SNat n -> Int +fromSNat' = fromIntegral . fromSNat + +sameNat' :: SNat n -> SNat m -> Maybe (n :~: m) +sameNat' n@SNat m@SNat = sameNat n m + +pattern SZ :: () => (n ~ 0) => SNat n +pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) + where SZ = SNat + +pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 +pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) + where SS = snatSucc + +{-# COMPLETE SZ, SS #-} + +snatSucc :: SNat n -> SNat (n + 1) +snatSucc SNat = SNat + +data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) +snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) +snatPred snp1 = + withKnownNat snp1 $ + case cmpNat (Proxy @1) (Proxy @np1) of + LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + GTI -> Nothing + +-- This should be a function in base +snatPlus :: SNat n -> SNat m -> SNat (n + m) +snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce + +-- This should be a function in base +snatMinus :: SNat n -> SNat m -> SNat (n - m) +snatMinus n m = let res = TN.fromSNat n - TN.fromSNat m in res `seq` TN.withSomeSNat res Unsafe.Coerce.unsafeCoerce + +-- This should be a function in base +snatMul :: SNat n -> SNat m -> SNat (n * m) +snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce + + +-- | Type-level list append. +type family l1 ++ l2 where + '[] ++ l2 = l2 + (x : xs) ++ l2 = x : xs ++ l2 + +type family Replicate n a where + Replicate 0 a = '[] + Replicate n a = a : Replicate (n - 1) a + +lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc = unsafeCoerceRefl + +type family MapJust l = r | r -> l where + MapJust '[] = '[] + MapJust (x : xs) = Just x : MapJust xs + +lemMapJustEmpty :: MapJust sh :~: '[] -> sh :~: '[] +lemMapJustEmpty Refl = unsafeCoerceRefl + +lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh +lemMapJustCons Refl = unsafeCoerceRefl + +type family Head l where + Head (x : _) = x + +type family Tail l where + Tail (_ : xs) = xs + +type family Init l where + Init (x : y : xs) = x : Init (y : xs) + Init '[x] = '[] + +type family Last l where + Last (x : y : xs) = Last (y : xs) + Last '[x] = x + + +-- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to +-- only typecheck for actual type equalities. One cannot, e.g. accidentally +-- write this: +-- +-- @ +-- foo :: Proxy a -> Proxy b -> a :~: b +-- foo = unsafeCoerceRefl +-- @ +-- +-- which would have been permitted with normal 'Unsafe.Coerce.unsafeCoerce', +-- but would have resulted in interesting memory errors at runtime. +unsafeCoerceRefl :: a :~: b +unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl |