diff options
Diffstat (limited to 'src/Data/Array/Nested/Shaped')
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 266 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 399 |
2 files changed, 665 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs new file mode 100644 index 0000000..a5e6247 --- /dev/null +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -0,0 +1,266 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_HADDOCK not-home #-} +module Data.Array.Nested.Shaped.Base where + +import Prelude hiding (mappend, mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad.ST +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) + + +-- | A shape-typed array: the full shape of the array (the sizes of its +-- dimensions) is represented on the type level as a list of 'Nat's. Note that +-- these are "GHC.TypeLits" naturals, because we do not need induction over +-- them and we want very large arrays to be possible. +-- +-- Like for 'Data.Array.Nested.Ranked.Base.Ranked', +-- the valid elements are described by the 'Elt' type class, +-- and 'Shaped' itself is again an instance of 'Elt' as well. +-- +-- 'Shaped' is a newtype around a 'Mixed' of 'Just's. +type Shaped :: [Nat] -> Type -> Type +newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) +#endif +deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) +deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) + +#ifndef OXAR_DEFAULT_SHOW_INSTANCES +instance (Show a, Elt a) => Show (Shaped n a) where + showsPrec d arr@(Shaped marr) = + let sh = show (shsToList (sshape arr)) + in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr +#endif + +instance Elt a => NFData (Shaped sh a) where + rnf (Shaped arr) = rnf arr + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) + deriving (Generic) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a)) +#endif + +deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a)) + +newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) + +instance Elt a => Elt (Shaped sh a) where + {-# INLINE mshape #-} + mshape (M_Shaped arr) = mshape arr + {-# INLINE mindex #-} + mindex (M_Shaped arr) i = Shaped (mindex arr i) + + {-# INLINE mindexPartial #-} + mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + mindexPartial (M_Shaped arr) i = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mindexPartial arr i + + mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) + + mfromListOuterSN :: SNat n -> NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Just n : sh') (Shaped sh a) + mfromListOuterSN sn l = M_Shaped (mfromListOuterSN sn (coerce l)) + + mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] + mtoListOuter (M_Shaped arr) + = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) + + {-# INLINE mlift #-} + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) + mlift ssh2 f (M_Shaped arr) = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mlift ssh2 f arr + + {-# INLINE mlift2 #-} + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) + mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = + coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ + mlift2 ssh3 f arr1 arr2 + + {-# INLINE mliftL #-} + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) + mliftL ssh2 f l = + coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) + @(NonEmpty (Mixed sh2 (Shaped sh a))) $ + mliftL ssh2 f (coerce l) + + mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr) + + mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) + + mconcat l = M_Shaped (mconcat (coerce l)) + + mrnf (M_Shaped arr) = mrnf arr + + type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) + + mshapeTree (Shaped arr) = first coerce (mshapeTree arr) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeIsEmpty _ (sh, t) = shsSize sh == 0 || mshapeTreeIsEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Shaped arr) = marrayStrides arr + + mvecsWriteLinear :: forall sh' s. Int -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWriteLinear idx (Shaped arr) vecs = + mvecsWriteLinear idx arr + (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + + mvecsWritePartialLinear + :: forall sh1 sh2 s. + Proxy sh1 -> Int -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartialLinear proxy idx arr vecs = + mvecsWritePartialLinear proxy idx + (coerce @(Mixed sh2 (Shaped sh a)) + @(Mixed sh2 (Mixed (MapJust sh) a)) + arr) + (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) + @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) + vecs) + + mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + mvecsUnsafeFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsUnsafeFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsUnsafeFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + +instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where + memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) + memptyArrayUnsafe sh + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ + memptyArrayUnsafe sh + + mvecsUnsafeNew idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsUnsafeNew idx arr + + mvecsReplicate idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsReplicate idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) + + +liftShaped1 :: forall sh a b. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b) + -> Shaped sh a -> Shaped sh b +liftShaped1 = coerce + +liftShaped2 :: forall sh a b c. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c) + -> Shaped sh a -> Shaped sh b -> Shaped sh c +liftShaped2 = coerce + +instance (NumElt a, PrimElt a) => Num (Shaped sh a) where + (+) = liftShaped2 (+) + (-) = liftShaped2 (-) + (*) = liftShaped2 (*) + negate = liftShaped1 negate + abs = liftShaped1 abs + signum = liftShaped1 signum + fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicatePrim" + +instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where + fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicatePrim" + recip = liftShaped1 recip + (/) = liftShaped2 (/) + +instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicatePrim" + exp = liftShaped1 exp + log = liftShaped1 log + sqrt = liftShaped1 sqrt + (**) = liftShaped2 (**) + logBase = liftShaped2 logBase + sin = liftShaped1 sin + cos = liftShaped1 cos + tan = liftShaped1 tan + asin = liftShaped1 asin + acos = liftShaped1 acos + atan = liftShaped1 atan + sinh = liftShaped1 sinh + cosh = liftShaped1 cosh + tanh = liftShaped1 tanh + asinh = liftShaped1 asinh + acosh = liftShaped1 acosh + atanh = liftShaped1 atanh + log1p = liftShaped1 GHC.Float.log1p + expm1 = liftShaped1 GHC.Float.expm1 + log1pexp = liftShaped1 GHC.Float.log1pexp + log1mexp = liftShaped1 GHC.Float.log1mexp + +squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +squotArray = liftShaped2 mquotArray +sremArray = liftShaped2 mremArray + +satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +satan2Array = liftShaped2 matan2Array + + +{-# INLINE sshape #-} +sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh +sshape (Shaped arr) = coerce (mshape arr) diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs new file mode 100644 index 0000000..60e0252 --- /dev/null +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -0,0 +1,399 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Shaped.Shape where + +import Control.DeepSeq (NFData(..)) +import Control.Exception (assert) +import Data.Array.Shape qualified as O +import Data.Coerce (coerce) +import Data.Foldable qualified as Foldable +import Data.Kind (Constraint, Type) +import Data.Proxy +import Data.Type.Equality +import GHC.Exts (build, withDict) +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits + +import Data.Array.Nested.Mixed.ListX +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Types + + +-- * Shaped indices + +-- | An index into a shape-typed array. +type role IxS nominal representational +type IxS :: [Nat] -> Type -> Type +newtype IxS sh i = IxS (IxX (MapJust sh) i) + deriving (Eq, Ord, NFData, Functor, Foldable) + +pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i +pattern ZIS <- IxS (matchZIX -> Just Refl) + where ZIS = IxS ZIX + +matchZIX :: forall sh i. IxX (MapJust sh) i -> Maybe (sh :~: '[]) +matchZIX ZIX | Refl <- lemMapJustEmpty @sh Refl = Just Refl +matchZIX _ = Nothing + +pattern (:.$) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> IxS sh i -> IxS sh1 i +pattern i :.$ l <- (ixsUncons -> Just (UnconsIxSRes i l)) + where i :.$ IxS l = IxS (i :.% l) +infixr 3 :.$ + +data UnconsIxSRes i sh1 = + forall n sh. (n : sh ~ sh1) => UnconsIxSRes i (IxS sh i) +ixsUncons :: forall sh1 i. IxS sh1 i -> Maybe (UnconsIxSRes i sh1) +ixsUncons (IxS (i :.% l)) | Refl <- lemMapJustHead (Proxy @sh1) + , Refl <- lemMapJustCons @sh1 Refl = + Just (UnconsIxSRes i (IxS l)) +ixsUncons (IxS _) = Nothing + +{-# COMPLETE ZIS, (:.$) #-} + +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). +type IIxS sh = IxS sh Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (IxS sh i) +#else +instance Show i => Show (IxS sh i) where + showsPrec _ l = ixsShow shows l +#endif + +ixsShow :: forall sh i. (i -> ShowS) -> IxS sh i -> ShowS +ixsShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> IxS sh' i -> ShowS + go _ ZIS = id + go prefix (x :.$ xs) = showString prefix . f x . go "," xs + +ixsRank :: IxS sh i -> SNat (Rank sh) +ixsRank ZIS = SNat +ixsRank (_ :.$ sh) = snatSucc (ixsRank sh) + +{-# INLINE ixsFromList #-} +ixsFromList :: ShS sh -> [i] -> IxS sh i +ixsFromList sh l = assert (shsLength sh == length l) + $ IxS $ IsList.fromList l + +{-# INLINE ixsFromIxS #-} +ixsFromIxS :: IxS sh i0 -> [i] -> IxS sh i +ixsFromIxS sh l = assert (length sh == length l) + $ IxS $ IsList.fromList l + +ixsZero :: ShS sh -> IIxS sh +ixsZero ZSS = ZIS +ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh + +ixsHead :: IxS (n : sh) i -> i +ixsHead (i :.$ _) = i + +ixsTail :: IxS (n : sh) i -> IxS sh i +ixsTail (_ :.$ sh) = sh + +ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i +ixsInit (n :.$ sh@(_ :.$ _)) = n :.$ ixsInit sh +ixsInit (_ :.$ ZIS) = ZIS + +ixsLast :: IxS (n : sh) i -> i +ixsLast (_ :.$ sh@(_ :.$ _)) = ixsLast sh +ixsLast (n :.$ ZIS) = n + +ixsCast :: IxS sh i -> IxS sh i +ixsCast ZIS = ZIS +ixsCast (i :.$ idx) = i :.$ ixsCast idx + +ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i +ixsAppend = gcastWith (unsafeCoerceRefl :: MapJust (sh ++ sh') :~: MapJust sh ++ MapJust sh') $ + coerce (ixxAppend @_ @_ @i) + +ixsZip :: IxS sh i -> IxS sh j -> IxS sh (i, j) +ixsZip ZIS ZIS = ZIS +ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js + +{-# INLINE ixsZipWith #-} +ixsZipWith :: (i -> j -> k) -> IxS sh i -> IxS sh j -> IxS sh k +ixsZipWith _ ZIS ZIS = ZIS +ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js + +ixsTakeLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (TakeLen is sh) i +ixsTakeLenPerm PNil _ = ZIS +ixsTakeLenPerm (_ `PCons` is) (n :.$ sh) = n :.$ ixsTakeLenPerm is sh +ixsTakeLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape" + +ixsDropLenPerm :: forall i is sh. Perm is -> IxS sh i -> IxS (DropLen is sh) i +ixsDropLenPerm PNil sh = sh +ixsDropLenPerm (_ `PCons` is) (_ :.$ sh) = ixsDropLenPerm is sh +ixsDropLenPerm (_ `PCons` _) ZIS = error "Permutation longer than shape" + +ixsPermute :: forall i is sh. Perm is -> IxS sh i -> IxS (Permute is sh) i +ixsPermute PNil _ = ZIS +ixsPermute (i `PCons` (is :: Perm is')) (sh :: IxS sh f) = + case ixsIndex i sh of + item -> item :.$ ixsPermute is sh + +ixsIndex :: forall j i sh. SNat i -> IxS sh j -> j +ixsIndex SZ (n :.$ _) = n +ixsIndex (SS i) (_ :.$ sh) = ixsIndex i sh +ixsIndex _ ZIS = error "Index into empty shape" + +ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i +ixsPermutePrefix perm sh = ixsAppend (ixsPermute perm (ixsTakeLenPerm perm sh)) (ixsDropLenPerm perm sh) + +-- | Given a multidimensional index, get the corresponding linear +-- index into the buffer. +{-# INLINEABLE ixsToLinear #-} +ixsToLinear :: Num i => ShS sh -> IxS sh i -> i +ixsToLinear (ShS sh) ix = ixxToLinear sh (ixxFromIxS ix) + +ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i +ixxFromIxS = coerce + +{-# INLINEABLE ixsFromLinear #-} +ixsFromLinear :: Num i => ShS sh -> Int -> IxS sh i +ixsFromLinear (ShS sh) i = ixsFromIxX $ ixxFromLinear sh i + +ixsFromIxX :: IxX (MapJust sh) i -> IxS sh i +ixsFromIxX = coerce + +shsEnum :: ShS sh -> [IIxS sh] +shsEnum = shsEnum' + +{-# INLINABLE shsEnum' #-} -- ensure this can be specialised at use site +shsEnum' :: Num i => ShS sh -> [IxS sh i] +shsEnum' (ShS sh) = (coerce :: [IxX (MapJust sh) i] -> [IxS sh i]) $ shxEnum' sh + +-- * Shaped shapes + +-- | The shape of a shape-typed array given as a list of 'SNat' values. +-- +-- Note that because the shape of a shape-typed array is known statically, you +-- can also retrieve the array shape from a 'KnownShS' dictionary. +type role ShS nominal +type ShS :: [Nat] -> Type +newtype ShS sh = ShS (ShX (MapJust sh) Int) + deriving (NFData) + +instance Eq (ShS sh) where _ == _ = True +instance Ord (ShS sh) where compare _ _ = EQ + +pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh +pattern ZSS <- ShS (matchZSX -> Just Refl) + where ZSS = ShS ZSX + +matchZSX :: forall sh i. ShX (MapJust sh) i -> Maybe (sh :~: '[]) +matchZSX ZSX | Refl <- lemMapJustEmpty @sh Refl = Just Refl +matchZSX _ = Nothing + +pattern (:$$) + :: forall {sh1}. + forall n sh. (n : sh ~ sh1) + => SNat n -> ShS sh -> ShS sh1 +pattern i :$$ sh <- (shsUncons -> Just (UnconsShSRes i sh)) + where i :$$ ShS sh = ShS (SKnown i :$% sh) +infixr 3 :$$ + +data UnconsShSRes sh1 = + forall n sh. (n : sh ~ sh1) => UnconsShSRes (SNat n) (ShS sh) +shsUncons :: forall sh1. ShS sh1 -> Maybe (UnconsShSRes sh1) +shsUncons (ShS (SKnown x :$% sh')) | Refl <- lemMapJustCons @sh1 Refl + = Just (UnconsShSRes x (ShS sh')) +shsUncons (ShS _) = Nothing + +{-# COMPLETE ZSS, (:$$) #-} + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (ShS sh) +#else +instance Show (ShS sh) where + showsPrec d (ShS shx) = showsPrec d shx +#endif + +instance TestEquality ShS where + testEquality (ShS shx1) (ShS shx2) = case shxEqType shx1 shx2 of + Nothing -> Nothing + Just Refl -> Just unsafeCoerceRefl + +-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are +-- equal if and only if values are equal.) +shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') +shsEqual = testEquality + +shsLength :: ShS sh -> Int +shsLength (ShS shx) = shxLength shx + +shsRank :: forall sh. ShS sh -> SNat (Rank sh) +shsRank (ShS shx) | Refl <- lemRankMapJust (Proxy @sh) = + shxRank shx + +lemRankMapJust :: proxy sh -> Rank (MapJust sh) :~: Rank sh +lemRankMapJust _ = unsafeCoerceRefl + +shsSize :: ShS sh -> Int +shsSize (ShS sh) = shxSize sh + +-- | This is a partial @const@ that fails when the second argument +-- doesn't match the first. We don't report the size of the list +-- in case of errors in order not to retain the list. +{-# INLINEABLE shsFromList #-} +shsFromList :: ShS sh -> [Int] -> ShS sh +shsFromList sh0@(ShS topsh) topl = go topsh topl `seq` sh0 + where + go :: ShX sh' Int -> [Int] -> () + go ZSX [] = () + go ZSX _ = error $ "shsFromList: List too long (type says " ++ show (shxLength topsh) ++ ")" + go (ConsKnown sn sh) (i : is) + | i == fromSNat' sn = go sh is + | otherwise = error "shsFromList: Value does not match typing" + go ConsUnknown{} _ = error "shsFromList: impossible case" + go _ _ = error $ "shsFromList: List too short (type says " ++ show (shxLength topsh) ++ ")" + +-- This is equivalent to but faster than @coerce shxToList@. +{-# INLINEABLE shsToList #-} +shsToList :: ShS sh -> [Int] +shsToList (ShS l) = build (\(cons :: i -> is -> is) (nil :: is) -> + let go :: ShX sh Int -> is + go ZSX = nil + go ConsUnknown{} = error "shsToList: impossible case" + go (ConsKnown snat rest) = fromSNat' snat `cons` go rest + in go l) + +shsHead :: ShS (n : sh) -> SNat n +shsHead (ShS shx) = case shxHead shx of + SKnown SNat -> SNat + +shsTail :: forall n sh. ShS (n : sh) -> ShS sh +shsTail = coerce (shxTail @_ @_ @Int) + +{-# INLINEABLE shsTakeIx #-} +shsTakeIx :: forall sh sh' j. Proxy sh' -> IxS sh j -> ShS (sh ++ sh') -> ShS sh +shsTakeIx _ ZIS _ = ZSS +shsTakeIx p (_ :.$ idx) sh = case sh of n :$$ sh' -> n :$$ shsTakeIx p idx sh' + +{-# INLINEABLE shsDropIx #-} +shsDropIx :: forall sh sh' j. IxS sh j -> ShS (sh ++ sh') -> ShS sh' +shsDropIx ZIS long = long +shsDropIx (_ :.$ short) long = case long of _ :$$ long' -> shsDropIx short long' + +shsInit :: forall n sh. ShS (n : sh) -> ShS (Init (n : sh)) +shsInit = + gcastWith (unsafeCoerceRefl + :: Init (Just n : MapJust sh) :~: MapJust (Init (n : sh))) $ + coerce (shxInit @Int) + +shsLast :: forall n sh. ShS (n : sh) -> SNat (Last (n : sh)) +shsLast (ShS shx) = + gcastWith (unsafeCoerceRefl + :: Last (Just n : MapJust sh) :~: Just (Last (n : sh))) $ + case shxLast shx of + SKnown SNat -> SNat + +shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') +shsAppend = + gcastWith (unsafeCoerceRefl + :: MapJust sh ++ MapJust sh' :~: MapJust (sh ++ sh')) $ + coerce (shxAppend @_ @Int) + +shsTakeLenPerm :: forall is sh. Perm is -> ShS sh -> ShS (TakeLen is sh) +shsTakeLenPerm = + gcastWith (unsafeCoerceRefl + :: TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)) $ + coerce (shxTakeLenPerm @Int) + +shsDropLenPerm :: forall is sh. Perm is -> ShS sh -> ShS (DropLen is sh) +shsDropLenPerm = + gcastWith (unsafeCoerceRefl + :: DropLen is (MapJust sh) :~: MapJust (DropLen is sh)) $ + coerce (shxDropLenPerm @Int) + +shsPermute :: forall is sh. Perm is -> ShS sh -> ShS (Permute is sh) +shsPermute = + gcastWith (unsafeCoerceRefl + :: Permute is (MapJust sh) :~: MapJust (Permute is sh)) $ + coerce (shxPermute @Int) + +shsIndex :: forall i sh. SNat i -> ShS sh -> SNat (Index i sh) +shsIndex i (ShS sh) = + gcastWith (unsafeCoerceRefl + :: Index i (MapJust sh) :~: Just (Index i sh)) $ + case shxIndex @Int i sh of + SKnown SNat -> SNat + +shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) +shsPermutePrefix perm (ShS shx) + {- TODO: here and elsewhere, solve the module dependency cycle and add this: + | Refl <- lemTakeLenMapJust perm sh + , Refl <- lemDropLenMapJust perm sh + , Refl <- lemPermuteMapJust perm sh + , Refl <- lemMapJustApp (shsPermute perm (shsTakeLenPerm perm sh)) (shsDropLenPerm perm sh) -} + = gcastWith (unsafeCoerceRefl + :: Permute is (TakeLen is (MapJust sh)) + ++ DropLen is (MapJust sh) + :~: MapJust (Permute is (TakeLen is sh) ++ DropLen is sh)) $ + ShS (shxPermutePrefix perm shx) + +type family Product sh where + Product '[] = 1 + Product (n : ns) = n * Product ns + +shsProduct :: ShS sh -> SNat (Product sh) +shsProduct ZSS = SNat +shsProduct (n :$$ sh) = n `snatMul` shsProduct sh + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShS :: [Nat] -> Constraint +class KnownShS sh where knownShS :: ShS sh +instance KnownShS '[] where knownShS = ZSS +instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS + +withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r +withKnownShS = withDict @(KnownShS sh) + +shsKnownShS :: ShS sh -> Dict KnownShS sh +shsKnownShS ZSS = Dict +shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict + +shsOrthotopeShape :: ShS sh -> Dict O.Shape sh +shsOrthotopeShape ZSS = Dict +shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict + + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShS sh => IsList (IxS sh i) where + type Item (IxS sh i) = i + fromList = ixsFromList (knownShS @sh) + toList = Foldable.toList + +-- | Untyped: length and values are checked at runtime. +instance KnownShS sh => IsList (ShS sh) where + type Item (ShS sh) = Int + fromList = shsFromList (knownShS @sh) + toList = shsToList |
