diff options
Diffstat (limited to 'src/Data/Array/Nested/Shaped/Shape.hs')
-rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 79 |
1 files changed, 44 insertions, 35 deletions
diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs index 6c43fa7..5f9ba79 100644 --- a/src/Data/Array/Nested/Shaped/Shape.hs +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -1,11 +1,9 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PatternSynonyms #-} @@ -27,7 +25,6 @@ module Data.Array.Nested.Shaped.Shape where import Control.DeepSeq (NFData(..)) -import Data.Array.Mixed.Types import Data.Array.Shape qualified as O import Data.Coerce (coerce) import Data.Foldable qualified as Foldable @@ -43,11 +40,15 @@ import GHC.IsList (IsList) import GHC.IsList qualified as IsList import GHC.TypeLits -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Types +-- * Shaped lists + +-- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be +-- removed in a future release. type role ListS nominal representational type ListS :: [Nat] -> (Nat -> Type) -> Type data ListS sh f where @@ -58,8 +59,12 @@ deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) infixr 3 ::$ +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance (forall n. Show (f n)) => Show (ListS sh f) +#else instance (forall n. Show (f n)) => Show (ListS sh f) where showsPrec _ = listsShow shows +#endif instance (forall m. NFData (f m)) => NFData (ListS n f) where rnf ZS = () @@ -176,11 +181,9 @@ listsIndex _ _ _ ZS = error "Index into empty shape" listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) +-- * Shaped indices -- | An index into a shape-typed array. --- --- For convenience, this contains regular 'Int's instead of bounded integers --- (traditionally called \"@Fin@\"). type role IxS nominal representational type IxS :: [Nat] -> Type -> Type newtype IxS sh i = IxS (ListS sh (Const i)) @@ -189,6 +192,8 @@ newtype IxS sh i = IxS (ListS sh (Const i)) pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i pattern ZIS = IxS ZS +-- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be +-- removed in a future release. pattern (:.$) :: forall {sh1} {i}. forall n sh. (KnownNat n, n : sh ~ sh1) @@ -199,10 +204,16 @@ infixr 3 :.$ {-# COMPLETE ZIS, (:.$) #-} +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). type IIxS sh = IxS sh Int +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (IxS sh i) +#else instance Show i => Show (IxS sh i) where showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l +#endif instance Functor (IxS sh) where fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) @@ -222,14 +233,6 @@ ixsZero :: ShS sh -> IIxS sh ixsZero ZSS = ZIS ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh -ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh -ixCvtXS ZSS ZIX = ZIS -ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx - -ixCvtSX :: IIxS sh -> IIxX (MapJust sh) -ixCvtSX ZIS = ZIX -ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh - ixsHead :: IxS (n : sh) i -> i ixsHead (IxS list) = getConst (listsHead list) @@ -242,6 +245,12 @@ ixsInit (IxS list) = IxS (listsInit list) ixsLast :: IxS (n : sh) i -> i ixsLast (IxS list) = getConst (listsLast list) +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i +ixsCast ZSS ZIS = ZIS +ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx +ixsCast _ _ = error "ixsCast: ranks don't match" + ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i ixsAppend = coerce (listsAppend @_ @(Const i)) @@ -257,6 +266,8 @@ ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) +-- * Shaped shapes + -- | The shape of a shape-typed array given as a list of 'SNat' values. -- -- Note that because the shape of a shape-typed array is known statically, you @@ -280,8 +291,12 @@ infixr 3 :$$ {-# COMPLETE ZSS, (:$$) #-} +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (ShS sh) +#else instance Show (ShS sh) where showsPrec _ (ShS l) = listsShow (shows . fromSNat) l +#endif instance NFData (ShS sh) where rnf (ShS ZS) = () @@ -309,23 +324,6 @@ shsToList :: ShS sh -> [Int] shsToList ZSS = [] shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh -shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh -shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS -shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) = - castWith (subst1 (lem Refl)) $ - n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) - idx) - where - lem :: forall sh1 sh' n. - Just n : sh1 :~: MapJust sh' - -> n : Tail sh' :~: sh' - lem Refl = unsafeCoerceRefl -shCvtXS' (SUnknown _ :$% _) = error "impossible" - -shCvtSX :: ShS sh -> IShX (MapJust sh) -shCvtSX ZSS = ZSX -shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh - shsHead :: ShS (n : sh) -> SNat n shsHead (ShS list) = listsHead list @@ -369,7 +367,7 @@ instance KnownShS '[] where knownShS = ZSS instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r -withKnownShS k = withDict @(KnownShS sh) k +withKnownShS = withDict @(KnownShS sh) shsKnownShS :: ShS sh -> Dict KnownShS sh shsKnownShS ZSS = Dict @@ -379,6 +377,17 @@ shsOrthotopeShape :: ShS sh -> Dict O.Shape sh shsOrthotopeShape ZSS = Dict shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict +-- | This function is a hack made possible by the 'KnownNat' inside 'ListS'. +-- This function may be removed in a future release. +shsFromListS :: ListS sh f -> ShS sh +shsFromListS ZS = ZSS +shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l + +-- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This +-- function may be removed in a future release. +shsFromIxS :: IxS sh i -> ShS sh +shsFromIxS (IxS l) = shsFromListS l + -- | Untyped: length is checked at runtime. instance KnownShS sh => IsList (ListS sh (Const i)) where |