diff options
Diffstat (limited to 'src/Data/Array/Nested/Shaped')
-rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 255 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 425 |
2 files changed, 680 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs new file mode 100644 index 0000000..879e6b5 --- /dev/null +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -0,0 +1,255 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_HADDOCK not-home #-} +module Data.Array.Nested.Shaped.Base where + +import Prelude hiding (mappend, mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad.ST +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) + + +-- | A shape-typed array: the full shape of the array (the sizes of its +-- dimensions) is represented on the type level as a list of 'Nat's. Note that +-- these are "GHC.TypeLits" naturals, because we do not need induction over +-- them and we want very large arrays to be possible. +-- +-- Like for 'Ranked', the valid elements are described by the 'Elt' type class, +-- and 'Shaped' itself is again an instance of 'Elt' as well. +-- +-- 'Shaped' is a newtype around a 'Mixed' of 'Just's. +type Shaped :: [Nat] -> Type -> Type +newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) +#endif +deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) +deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) + +#ifndef OXAR_DEFAULT_SHOW_INSTANCES +instance (Show a, Elt a) => Show (Shaped n a) where + showsPrec d arr@(Shaped marr) = + let sh = show (shsToList (sshape arr)) + in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr +#endif + +instance Elt a => NFData (Shaped sh a) where + rnf (Shaped arr) = rnf arr + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) + deriving (Generic) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a)) +#endif + +deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a)) + +newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) + +instance Elt a => Elt (Shaped sh a) where + mshape (M_Shaped arr) = mshape arr + mindex (M_Shaped arr) i = Shaped (mindex arr i) + + mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + mindexPartial (M_Shaped arr) i = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mindexPartial arr i + + mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) + + mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) + mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) + + mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] + mtoListOuter (M_Shaped arr) + = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) + mlift ssh2 f (M_Shaped arr) = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mlift ssh2 f arr + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) + mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = + coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ + mlift2 ssh3 f arr1 arr2 + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) + mliftL ssh2 f l = + coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) + @(NonEmpty (Mixed sh2 (Shaped sh a))) $ + mliftL ssh2 f (coerce l) + + mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr) + + mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) + + mconcat l = M_Shaped (mconcat (coerce l)) + + mrnf (M_Shaped arr) = mrnf arr + + type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) + + mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Shaped arr) = marrayStrides arr + + mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWrite sh idx (Shaped arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + + mvecsWritePartial :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh2 (Shaped sh a)) + @(Mixed sh2 (Mixed (MapJust sh) a)) + arr) + (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) + @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) + vecs) + + mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + +instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where + memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) + memptyArrayUnsafe i + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ + memptyArrayUnsafe i + + mvecsUnsafeNew idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsUnsafeNew idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) + + +liftShaped1 :: forall sh a b. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b) + -> Shaped sh a -> Shaped sh b +liftShaped1 = coerce + +liftShaped2 :: forall sh a b c. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c) + -> Shaped sh a -> Shaped sh b -> Shaped sh c +liftShaped2 = coerce + +instance (NumElt a, PrimElt a) => Num (Shaped sh a) where + (+) = liftShaped2 (+) + (-) = liftShaped2 (-) + (*) = liftShaped2 (*) + negate = liftShaped1 negate + abs = liftShaped1 abs + signum = liftShaped1 signum + fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where + fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" + recip = liftShaped1 recip + (/) = liftShaped2 (/) + +instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" + exp = liftShaped1 exp + log = liftShaped1 log + sqrt = liftShaped1 sqrt + (**) = liftShaped2 (**) + logBase = liftShaped2 logBase + sin = liftShaped1 sin + cos = liftShaped1 cos + tan = liftShaped1 tan + asin = liftShaped1 asin + acos = liftShaped1 acos + atan = liftShaped1 atan + sinh = liftShaped1 sinh + cosh = liftShaped1 cosh + tanh = liftShaped1 tanh + asinh = liftShaped1 asinh + acosh = liftShaped1 acosh + atanh = liftShaped1 atanh + log1p = liftShaped1 GHC.Float.log1p + expm1 = liftShaped1 GHC.Float.expm1 + log1pexp = liftShaped1 GHC.Float.log1pexp + log1mexp = liftShaped1 GHC.Float.log1mexp + +squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +squotArray = liftShaped2 mquotArray +sremArray = liftShaped2 mremArray + +satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +satan2Array = liftShaped2 matan2Array + + +sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh +sshape (Shaped arr) = shsFromShX (mshape arr) + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh +shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS +shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) = + castWith (subst1 (sym (lemMapJustCons Refl))) $ + n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) + idx) +shsFromShX (SUnknown _ :$% _) = error "impossible" diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs new file mode 100644 index 0000000..5f9ba79 --- /dev/null +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -0,0 +1,425 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Shaped.Shape where + +import Control.DeepSeq (NFData(..)) +import Data.Array.Shape qualified as O +import Data.Coerce (coerce) +import Data.Foldable qualified as Foldable +import Data.Functor.Const +import Data.Functor.Product qualified as Fun +import Data.Kind (Constraint, Type) +import Data.Monoid (Sum(..)) +import Data.Proxy +import Data.Type.Equality +import GHC.Exts (withDict) +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits + +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Types + + +-- * Shaped lists + +-- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be +-- removed in a future release. +type role ListS nominal representational +type ListS :: [Nat] -> (Nat -> Type) -> Type +data ListS sh f where + ZS :: ListS '[] f + -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity + (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f +deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) +infixr 3 ::$ + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance (forall n. Show (f n)) => Show (ListS sh f) +#else +instance (forall n. Show (f n)) => Show (ListS sh f) where + showsPrec _ = listsShow shows +#endif + +instance (forall m. NFData (f m)) => NFData (ListS n f) where + rnf ZS = () + rnf (x ::$ l) = rnf x `seq` rnf l + +data UnconsListSRes f sh1 = + forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) +listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) +listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) +listsUncons ZS = Nothing + +-- | This checks only whether the types are equal; if the elements of the list +-- are not singletons, their values may still differ. This corresponds to +-- 'testEquality', except on the penultimate type parameter. +listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') +listsEqType ZS ZS = Just Refl +listsEqType (n ::$ sh) (m ::$ sh') + | Just Refl <- testEquality n m + , Just Refl <- listsEqType sh sh' + = Just Refl +listsEqType _ _ = Nothing + +-- | This checks whether the two lists actually contain equal values. This is +-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ +-- in the @some@ package (except on the penultimate type parameter). +listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') +listsEqual ZS ZS = Just Refl +listsEqual (n ::$ sh) (m ::$ sh') + | Just Refl <- testEquality n m + , n == m + , Just Refl <- listsEqual sh sh' + = Just Refl +listsEqual _ _ = Nothing + +listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g +listsFmap _ ZS = ZS +listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs + +listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m +listsFold _ ZS = mempty +listsFold f (x ::$ xs) = f x <> listsFold f xs + +listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS +listsShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListS sh' f -> ShowS + go _ ZS = id + go prefix (x ::$ xs) = showString prefix . f x . go "," xs + +listsLength :: ListS sh f -> Int +listsLength = getSum . listsFold (\_ -> Sum 1) + +listsRank :: ListS sh f -> SNat (Rank sh) +listsRank ZS = SNat +listsRank (_ ::$ sh) = snatSucc (listsRank sh) + +listsToList :: ListS sh (Const i) -> [i] +listsToList ZS = [] +listsToList (Const i ::$ is) = i : listsToList is + +listsHead :: ListS (n : sh) f -> f n +listsHead (i ::$ _) = i + +listsTail :: ListS (n : sh) f -> ListS sh f +listsTail (_ ::$ sh) = sh + +listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f +listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh +listsInit (_ ::$ ZS) = ZS + +listsLast :: ListS (n : sh) f -> f (Last (n : sh)) +listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh +listsLast (n ::$ ZS) = n + +listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f +listsAppend ZS idx' = idx' +listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' + +listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) +listsZip ZS ZS = ZS +listsZip (i ::$ is) (j ::$ js) = + Fun.Pair i j ::$ listsZip is js + +listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g + -> ListS sh h +listsZipWith _ ZS ZS = ZS +listsZipWith f (i ::$ is) (j ::$ js) = + f i j ::$ listsZipWith f is js + +listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f +listsTakeLenPerm PNil _ = ZS +listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh +listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" + +listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f +listsDropLenPerm PNil sh = sh +listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh +listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" + +listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f +listsPermute PNil _ = ZS +listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = + case listsIndex (Proxy @is') (Proxy @sh) i sh of + (item, SNat) -> item ::$ listsPermute is sh + +-- TODO: remove this SNat when the KnownNat constaint in ListS is removed +listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) +listsIndex _ _ SZ (n ::$ _) = (n, SNat) +listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = listsIndex p pT i sh +listsIndex _ _ _ ZS = error "Index into empty shape" + +listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f +listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) + +-- * Shaped indices + +-- | An index into a shape-typed array. +type role IxS nominal representational +type IxS :: [Nat] -> Type -> Type +newtype IxS sh i = IxS (ListS sh (Const i)) + deriving (Eq, Ord, Generic) + +pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i +pattern ZIS = IxS ZS + +-- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be +-- removed in a future release. +pattern (:.$) + :: forall {sh1} {i}. + forall n sh. (KnownNat n, n : sh ~ sh1) + => i -> IxS sh i -> IxS sh1 i +pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) + where i :.$ IxS shl = IxS (Const i ::$ shl) +infixr 3 :.$ + +{-# COMPLETE ZIS, (:.$) #-} + +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). +type IIxS sh = IxS sh Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (IxS sh i) +#else +instance Show i => Show (IxS sh i) where + showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l +#endif + +instance Functor (IxS sh) where + fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) + +instance Foldable (IxS sh) where + foldMap f (IxS l) = listsFold (f . getConst) l + +instance NFData i => NFData (IxS sh i) + +ixsLength :: IxS sh i -> Int +ixsLength (IxS l) = listsLength l + +ixsRank :: IxS sh i -> SNat (Rank sh) +ixsRank (IxS l) = listsRank l + +ixsZero :: ShS sh -> IIxS sh +ixsZero ZSS = ZIS +ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh + +ixsHead :: IxS (n : sh) i -> i +ixsHead (IxS list) = getConst (listsHead list) + +ixsTail :: IxS (n : sh) i -> IxS sh i +ixsTail (IxS list) = IxS (listsTail list) + +ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i +ixsInit (IxS list) = IxS (listsInit list) + +ixsLast :: IxS (n : sh) i -> i +ixsLast (IxS list) = getConst (listsLast list) + +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i +ixsCast ZSS ZIS = ZIS +ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx +ixsCast _ _ = error "ixsCast: ranks don't match" + +ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i +ixsAppend = coerce (listsAppend @_ @(Const i)) + +ixsZip :: IxS n i -> IxS n j -> IxS n (i, j) +ixsZip ZIS ZIS = ZIS +ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js + +ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k +ixsZipWith _ ZIS ZIS = ZIS +ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js + +ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i +ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) + + +-- * Shaped shapes + +-- | The shape of a shape-typed array given as a list of 'SNat' values. +-- +-- Note that because the shape of a shape-typed array is known statically, you +-- can also retrieve the array shape from a 'KnownShS' dictionary. +type role ShS nominal +type ShS :: [Nat] -> Type +newtype ShS sh = ShS (ListS sh SNat) + deriving (Eq, Ord, Generic) + +pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh +pattern ZSS = ShS ZS + +pattern (:$$) + :: forall {sh1}. + forall n sh. (KnownNat n, n : sh ~ sh1) + => SNat n -> ShS sh -> ShS sh1 +pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) + where i :$$ ShS shl = ShS (i ::$ shl) + +infixr 3 :$$ + +{-# COMPLETE ZSS, (:$$) #-} + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (ShS sh) +#else +instance Show (ShS sh) where + showsPrec _ (ShS l) = listsShow (shows . fromSNat) l +#endif + +instance NFData (ShS sh) where + rnf (ShS ZS) = () + rnf (ShS (SNat ::$ l)) = rnf (ShS l) + +instance TestEquality ShS where + testEquality (ShS l1) (ShS l2) = listsEqType l1 l2 + +-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are +-- equal if and only if values are equal.) +shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') +shsEqual = testEquality + +shsLength :: ShS sh -> Int +shsLength (ShS l) = listsLength l + +shsRank :: ShS sh -> SNat (Rank sh) +shsRank (ShS l) = listsRank l + +shsSize :: ShS sh -> Int +shsSize ZSS = 1 +shsSize (n :$$ sh) = fromSNat' n * shsSize sh + +shsToList :: ShS sh -> [Int] +shsToList ZSS = [] +shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh + +shsHead :: ShS (n : sh) -> SNat n +shsHead (ShS list) = listsHead list + +shsTail :: ShS (n : sh) -> ShS sh +shsTail (ShS list) = ShS (listsTail list) + +shsInit :: ShS (n : sh) -> ShS (Init (n : sh)) +shsInit (ShS list) = ShS (listsInit list) + +shsLast :: ShS (n : sh) -> SNat (Last (n : sh)) +shsLast (ShS list) = listsLast list + +shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') +shsAppend = coerce (listsAppend @_ @SNat) + +shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) +shsTakeLen = coerce (listsTakeLenPerm @SNat) + +shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) +shsPermute = coerce (listsPermute @SNat) + +shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) +shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) + +shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) +shsPermutePrefix = coerce (listsPermutePrefix @SNat) + +type family Product sh where + Product '[] = 1 + Product (n : ns) = n * Product ns + +shsProduct :: ShS sh -> SNat (Product sh) +shsProduct ZSS = SNat +shsProduct (n :$$ sh) = n `snatMul` shsProduct sh + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShS :: [Nat] -> Constraint +class KnownShS sh where knownShS :: ShS sh +instance KnownShS '[] where knownShS = ZSS +instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS + +withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r +withKnownShS = withDict @(KnownShS sh) + +shsKnownShS :: ShS sh -> Dict KnownShS sh +shsKnownShS ZSS = Dict +shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict + +shsOrthotopeShape :: ShS sh -> Dict O.Shape sh +shsOrthotopeShape ZSS = Dict +shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict + +-- | This function is a hack made possible by the 'KnownNat' inside 'ListS'. +-- This function may be removed in a future release. +shsFromListS :: ListS sh f -> ShS sh +shsFromListS ZS = ZSS +shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l + +-- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This +-- function may be removed in a future release. +shsFromIxS :: IxS sh i -> ShS sh +shsFromIxS (IxS l) = shsFromListS l + + +-- | Untyped: length is checked at runtime. +instance KnownShS sh => IsList (ListS sh (Const i)) where + type Item (ListS sh (Const i)) = i + fromList topl = go (knownShS @sh) topl + where + go :: ShS sh' -> [i] -> ListS sh' (Const i) + go ZSS [] = ZS + go (_ :$$ sh) (i : is) = Const i ::$ go sh is + go _ _ = error $ "IsList(ListS): Mismatched list length (type says " + ++ show (shsLength (knownShS @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = listsToList + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShS sh => IsList (IxS sh i) where + type Item (IxS sh i) = i + fromList = IxS . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length and values are checked at runtime. +instance KnownShS sh => IsList (ShS sh) where + type Item (ShS sh) = Int + fromList topl = ShS (go (knownShS @sh) topl) + where + go :: ShS sh' -> [Int] -> ListS sh' SNat + go ZSS [] = ZS + go (sn :$$ sh) (i : is) + | i == fromSNat' sn = sn ::$ go sh is + | otherwise = error $ "IsList(ShS): Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go _ _ = error $ "IsList(ShS): Mismatched list length (type says " + ++ show (shsLength (knownShS @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = shsToList |