diff options
Diffstat (limited to 'src/Data/Array')
| -rw-r--r-- | src/Data/Array/Mixed.hs | 9 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Internal/Arith.hs | 11 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Lemmas.hs | 86 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Permutation.hs | 17 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Shape.hs | 17 | ||||
| -rw-r--r-- | src/Data/Array/Mixed/Types.hs | 48 | ||||
| -rw-r--r-- | src/Data/Array/Nested.hs | 8 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Convert.hs | 28 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal.hs | 2054 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Lemmas.hs | 59 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 741 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 446 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shape.hs | 467 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 379 | 
14 files changed, 2268 insertions, 2102 deletions
| diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index d5a8b78..4a338a2 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -2,10 +2,11 @@  {-# LANGUAGE DeriveGeneric #-}  {-# LANGUAGE FlexibleInstances #-}  {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE NoStarIsType #-}  {-# LANGUAGE StrictData #-}  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeOperators #-} @@ -14,14 +15,14 @@  module Data.Array.Mixed where  import Control.DeepSeq (NFData(..)) -import qualified Data.Array.RankedS as S -import qualified Data.Array.Ranked as ORB +import Data.Array.Ranked qualified as ORB +import Data.Array.RankedS qualified as S  import Data.Coerce  import Data.Kind  import Data.Proxy  import Data.Type.Equality  import Data.Type.Ord -import qualified Data.Vector.Storable as VS +import Data.Vector.Storable qualified as VS  import Foreign.Storable (Storable)  import GHC.Generics (Generic)  import GHC.TypeLits diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs index cf6820b..bb3ee4a 100644 --- a/src/Data/Array/Mixed/Internal/Arith.hs +++ b/src/Data/Array/Mixed/Internal/Arith.hs @@ -1,4 +1,5 @@  {-# LANGUAGE DataKinds #-} +{-# LANGUAGE ImportQualifiedPost #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TemplateHaskell #-} @@ -9,14 +10,14 @@  module Data.Array.Mixed.Internal.Arith where  import Control.Monad (forM, guard) -import qualified Data.Array.Internal as OI -import qualified Data.Array.Internal.RankedG as RG -import qualified Data.Array.Internal.RankedS as RS +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as RG +import Data.Array.Internal.RankedS qualified as RS  import Data.Bits  import Data.Int  import Data.List (sort) -import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM +import Data.Vector.Storable qualified as VS +import Data.Vector.Storable.Mutable qualified as VSM  import Foreign.C.Types  import Foreign.Ptr  import Foreign.Storable (Storable) diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs index d0e8d24..ff2e45c 100644 --- a/src/Data/Array/Mixed/Lemmas.hs +++ b/src/Data/Array/Mixed/Lemmas.hs @@ -1,5 +1,6 @@  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeOperators #-} @@ -11,10 +12,45 @@ import Data.Proxy  import Data.Type.Equality  import GHC.TypeLits +import Data.Array.Mixed.Permutation  import Data.Array.Mixed.Shape  import Data.Array.Mixed.Types +-- * Reasoning helpers + +subst1 :: forall f a b. a :~: b -> f a :~: f b +subst1 Refl = Refl + +subst2 :: forall f c a b. a :~: b -> f a c :~: f b c +subst2 Refl = Refl + + +-- * Lemmas + +-- ** Nat + +lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True +lemLeqSuccSucc _ _ = unsafeCoerceRefl + +lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True +lemLeqPlus _ _ _ = Refl + + +-- ** Append + +lemAppNil :: l ++ '[] :~: l +lemAppNil = unsafeCoerceRefl + +lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) +lemAppAssoc _ _ _ = unsafeCoerceRefl + +lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l +lemAppLeft _ Refl = Refl + + +-- ** Rank +  lemRankApp :: forall sh1 sh2.                StaticShX sh1 -> StaticShX sh2             -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 @@ -38,10 +74,58 @@ lemRankAppComm :: StaticShX sh1 -> StaticShX sh2                 -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)  lemRankAppComm _ _ = unsafeCoerceRefl  -- TODO improve this -lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) +lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate SZ = Refl +lemRankReplicate (SS (n :: SNat nm1)) +  | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 +  , Refl <- lemRankReplicate n +  = Refl + + +-- ** Various type families + +lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a +                    -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +lemReplicatePlusApp sn _ _ = go sn +  where +    go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a +    go SZ = Refl +    go (SS (n :: SNat n'm1)) +      | Refl <- lemReplicateSucc @a @n'm1 +      , Refl <- go n +      = sym (lemReplicateSucc @a @(n'm1 + m)) + +lemDropLenApp :: Rank l1 <= Rank l2 +              => Proxy l1 -> Proxy l2 -> Proxy rest +              -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest) +lemDropLenApp _ _ _ = unsafeCoerceRefl + +lemTakeLenApp :: Rank l1 <= Rank l2 +              => Proxy l1 -> Proxy l2 -> Proxy rest +              -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest) +lemTakeLenApp _ _ _ = unsafeCoerceRefl + + +-- ** KnownNat + +lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1) +lemKnownNatSucc = Dict + +lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh)  lemKnownNatRank ZSX = Dict  lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict  lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)  lemKnownNatRankSSX ZKX = Dict  lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict + + +-- ** Known shapes + +lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) +lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn) + +lemKnownShX :: StaticShX sh -> Dict KnownShX sh +lemKnownShX ZKX = Dict +lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict +lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs index 1df0ec7..83a5ee4 100644 --- a/src/Data/Array/Mixed/Permutation.hs +++ b/src/Data/Array/Mixed/Permutation.hs @@ -1,6 +1,7 @@  {-# LANGUAGE ConstraintKinds #-}  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE QuantifiedConstraints #-} @@ -25,7 +26,7 @@ import Data.Type.Equality  import Data.Type.Ord  import GHC.TypeError  import GHC.TypeLits -import qualified GHC.TypeNats as TN +import GHC.TypeNats qualified as TN  import Data.Array.Mixed.Shape  import Data.Array.Mixed.Types @@ -112,14 +113,14 @@ listxDropLen (_ `PCons` _) ZX = error "IsPermutation longer than shape"  listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f  listxPermute PNil _ = ZX  listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = -  listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh) +  listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh -listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f -listxIndex _ _ SZ (n ::% _) rest = n ::% rest -listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest +listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) +listxIndex _ _ SZ (n ::% _) = n +listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f))    | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') -  = listxIndex p pT i sh rest -listxIndex _ _ _ ZX _ = error "Index into empty shape" +  = listxIndex p pT i sh +listxIndex _ _ _ ZX = error "Index into empty shape"  listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f  listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) @@ -136,7 +137,7 @@ ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))  ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)  ssxPermute = coerce (listxPermute @(SMayNat () SNat)) -ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)  ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)  ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs index 363b772..a13a176 100644 --- a/src/Data/Array/Mixed/Shape.hs +++ b/src/Data/Array/Mixed/Shape.hs @@ -2,6 +2,7 @@  {-# LANGUAGE DeriveGeneric #-}  {-# LANGUAGE FlexibleInstances #-}  {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-}  {-# LANGUAGE NoStarIsType #-}  {-# LANGUAGE PatternSynonyms #-}  {-# LANGUAGE PolyKinds #-} @@ -22,7 +23,9 @@  module Data.Array.Mixed.Shape where  import Control.DeepSeq (NFData(..)) -import qualified Data.Foldable as Foldable +import Data.Bifunctor (first) +import Data.Coerce +import Data.Foldable qualified as Foldable  import Data.Functor.Const  import Data.Kind (Type, Constraint)  import Data.Monoid (Sum(..)) @@ -30,12 +33,10 @@ import Data.Proxy  import Data.Type.Equality  import GHC.Generics (Generic)  import GHC.IsList (IsList) -import qualified GHC.IsList as IsList +import GHC.IsList qualified as IsList  import GHC.TypeLits  import Data.Array.Mixed.Types -import Data.Coerce -import Data.Bifunctor (first)  -- | The length of a type-level list. If the argument is a shape, then the @@ -307,6 +308,10 @@ shxEnum = \sh -> go sh id []      go ZSX f = (f ZIX :)      go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] +shxRank :: ShX sh f -> SNat (Rank sh) +shxRank ZSX = SNat +shxRank (_ :$% sh) | SNat <- shxRank sh = SNat +  -- * Static mixed shapes @@ -377,6 +382,10 @@ ssxFromShape :: IShX sh -> StaticShX sh  ssxFromShape ZSX = ZKX  ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh +ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) +ssxFromSNat SZ = ZKX +ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n +  -- | Evidence for the static part of a shape. This pops up only when you are  -- polymorphic in the element type of an array. diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs index d77513f..52201df 100644 --- a/src/Data/Array/Mixed/Types.hs +++ b/src/Data/Array/Mixed/Types.hs @@ -1,5 +1,6 @@  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-}  {-# LANGUAGE NoStarIsType #-}  {-# LANGUAGE PatternSynonyms #-}  {-# LANGUAGE PolyKinds #-} @@ -22,10 +23,10 @@ module Data.Array.Mixed.Types (    -- * Type-level lists    type (++), -  lemAppNil, -  lemAppAssoc,    Replicate,    lemReplicateSucc, +  MapJust, +  Tail,    -- * Unsafe    unsafeCoerceRefl, @@ -34,8 +35,8 @@ module Data.Array.Mixed.Types (  import Data.Type.Equality  import Data.Proxy  import GHC.TypeLits -import qualified GHC.TypeNats as TN -import qualified Unsafe.Coerce +import GHC.TypeNats qualified as TN +import Unsafe.Coerce qualified  -- | Evidence for the constraint @c a@. @@ -76,6 +77,26 @@ snatMul :: SNat n -> SNat m -> SNat (n * m)  snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce +-- | Type-level list append. +type family l1 ++ l2 where +  '[] ++ l2 = l2 +  (x : xs) ++ l2 = x : xs ++ l2 + +type family Replicate n a where +  Replicate 0 a = '[] +  Replicate n a = a : Replicate (n - 1) a + +lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc = unsafeCoerceRefl + +type family MapJust l where +  MapJust '[] = '[] +  MapJust (x : xs) = Just x : MapJust xs + +type family Tail l where +  Tail (_ : xs) = xs + +  -- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to  -- only typecheck for actual type equalities. One cannot, e.g. accidentally  -- write this: @@ -89,22 +110,3 @@ snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsa  -- but would have resulted in interesting memory errors at runtime.  unsafeCoerceRefl :: a :~: b  unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl - - --- | Type-level list append. -type family l1 ++ l2 where -  '[] ++ l2 = l2 -  (x : xs) ++ l2 = x : xs ++ l2 - -lemAppNil :: l ++ '[] :~: l -lemAppNil = unsafeCoerceRefl - -lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) -lemAppAssoc _ _ _ = unsafeCoerceRefl - -type family Replicate n a where -  Replicate 0 a = '[] -  Replicate n a = a : Replicate (n - 1) a - -lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a -lemReplicateSucc = unsafeCoerceRefl diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 1a4e094..c982b4d 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -62,8 +62,6 @@ module Data.Array.Nested (    Perm(..),    IsPermutation,    KnownNatList(..), -  listSToList, -  shSToList,    NumElt, FloatElt,  ) where @@ -74,6 +72,10 @@ import Data.Array.Mixed.Internal.Arith  import Data.Array.Mixed.Permutation  import Data.Array.Mixed.Shape  import Data.Array.Mixed.Types -import Data.Array.Nested.Internal +import Data.Array.Nested.Convert +import Data.Array.Nested.Mixed +import Data.Array.Nested.Ranked +import Data.Array.Nested.Shape +import Data.Array.Nested.Shaped  import Foreign.Storable  import GHC.TypeLits diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs new file mode 100644 index 0000000..cb22c32 --- /dev/null +++ b/src/Data/Array/Nested/Convert.hs @@ -0,0 +1,28 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} +module Data.Array.Nested.Convert where + +import Data.Type.Equality + +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Shape +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Ranked +import Data.Array.Nested.Shape +import Data.Array.Nested.Shaped + + +stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a +stoRanked sarr@(Shaped arr) +  | Refl <- lemRankMapJust (sshape sarr) +  = mtoRanked arr + +rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped (Ranked arr) targetsh +  | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh)) +  , Refl <- lemRankMapJust targetsh +  = mcastToShaped arr targetsh diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs deleted file mode 100644 index 0870789..0000000 --- a/src/Data/Array/Nested/Internal.hs +++ /dev/null @@ -1,2054 +0,0 @@ -{-# LANGUAGE ConstraintKinds #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} - -module Data.Array.Nested.Internal where - -import Prelude hiding (mappend) - -import Control.DeepSeq (NFData) -import Control.Monad (forM_, when) -import Control.Monad.ST -import qualified Data.Array.RankedS as S -import Data.Bifunctor (first) -import Data.Coerce (coerce, Coercible) -import Data.Foldable as Foldable (toList) -import Data.Functor.Const -import Data.Int -import Data.Kind -import Data.List.NonEmpty (NonEmpty(..)) -import Data.Monoid (Sum(..)) -import Data.Proxy -import Data.Type.Equality -import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM -import Foreign.C.Types (CInt(..)) -import Foreign.Storable (Storable) -import qualified GHC.Float (log1p, expm1, log1pexp, log1mexp) -import GHC.Generics (Generic) -import GHC.IsList (IsList) -import qualified GHC.IsList as IsList -import GHC.TypeLits -import qualified GHC.TypeNats as TypeNats -import Unsafe.Coerce - -import Data.Array.Mixed -import qualified Data.Array.Mixed as X -import Data.Array.Mixed.Lemmas -import Data.Array.Mixed.Permutation -import Data.Array.Mixed.Shape -import Data.Array.Mixed.Internal.Arith -import Data.Array.Mixed.Types - - --- Invariant in the API --- ==================== --- --- In the underlying XArray, there is some shape for elements of an empty --- array. For example, for this array: --- ---   arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) ---   rshape arr == 0 :.: 0 :.: 0 :.: ZIR --- --- the two underlying XArrays have a shape, and those shapes might be anything. --- The invariant is that these element shapes are unobservable in the API. --- (This is possible because you ought to not be able to get to such an element --- without indexing out of bounds.) --- --- Note, though, that the converse situation may arise: the outer array might --- be nonempty but then the inner arrays might. This is fine, an invariant only --- applies if the _outer_ array is empty. --- --- TODO: can we enforce that the elements of an empty (nested) array have --- all-zero shape? ---   -> no, because mlift and also any kind of internals probing from outsiders - - --- Primitive element types --- ======================= --- --- There are a few primitive element types; arrays containing elements of such --- type are a newtype over an XArray, which it itself a newtype over a Vector. --- Unfortunately, the setup of the library requires us to list these primitive --- element types multiple times; to aid in extending the list, all these lists --- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. - - -type family MapJust l where -  MapJust '[] = '[] -  MapJust (x : xs) = Just x : MapJust xs - - --- Stupid things that the type checker should be able to figure out in-line, but can't - -subst1 :: forall f a b. a :~: b -> f a :~: f b -subst1 Refl = Refl - -subst2 :: forall f c a b. a :~: b -> f a c :~: f b c -subst2 Refl = Refl - -lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l -lemAppLeft _ Refl = Refl - -knownNatSucc :: KnownNat n => Dict KnownNat (n + 1) -knownNatSucc = Dict - - -lemKnownShX :: StaticShX sh -> Dict KnownShX sh -lemKnownShX ZKX = Dict -lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict -lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict - -ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) -ssxFromSNat SZ = ZKX -ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n - -lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) -lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn) - -lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate SZ = Refl -lemRankReplicate (SS (n :: SNat nm1)) -  | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 -  , Refl <- lemRankReplicate n -  = Refl - -lemRankMapJust :: forall sh. ShS sh -> Rank (MapJust sh) :~: Rank sh -lemRankMapJust ZSS = Refl -lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl - -lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a -                    -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a -lemReplicatePlusApp sn _ _ = go sn -  where -    go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a -    go SZ = Refl -    go (SS (n :: SNat n'm1)) -      | Refl <- lemReplicateSucc @a @n'm1 -      , Refl <- go n -      = sym (lemReplicateSucc @a @(n'm1 + m)) - -lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True -lemLeqPlus _ _ _ = Refl - -lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True -lemLeqSuccSucc _ _ = unsafeCoerce Refl - -lemDropLenApp :: Rank l1 <= Rank l2 -              => Proxy l1 -> Proxy l2 -> Proxy rest -              -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest) -lemDropLenApp _ _ _ = unsafeCoerce Refl - -lemTakeLenApp :: Rank l1 <= Rank l2 -              => Proxy l1 -> Proxy l2 -> Proxy rest -              -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest) -lemTakeLenApp _ _ _ = unsafeCoerce Refl - -srankSh :: ShX sh f -> SNat (Rank sh) -srankSh ZSX = SNat -srankSh (_ :$% sh) | SNat <- srankSh sh = SNat - - --- === NEW INDEX TYPES === -- - -type role ListR nominal representational -type ListR :: Nat -> Type -> Type -data ListR n i where -  ZR :: ListR 0 i -  (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i -deriving instance Eq i => Eq (ListR n i) -deriving instance Ord i => Ord (ListR n i) -deriving instance Functor (ListR n) -deriving instance Foldable (ListR n) -infixr 3 ::: - -instance Show i => Show (ListR n i) where -  showsPrec _ = showListR shows - -data UnconsListRRes i n1 = -  forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i -unconsListR :: ListR n1 i -> Maybe (UnconsListRRes i n1) -unconsListR (i ::: sh') = Just (UnconsListRRes sh' i) -unconsListR ZR = Nothing - -showListR :: forall sh i. (i -> ShowS) -> ListR sh i -> ShowS -showListR f l = showString "[" . go "" l . showString "]" -  where -    go :: String -> ListR sh' i -> ShowS -    go _ ZR = id -    go prefix (x ::: xs) = showString prefix . f x . go "," xs - -listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i -listrAppend ZR sh = sh -listrAppend (x ::: xs) sh = x ::: listrAppend xs sh - -listrFromList :: [i] -> (forall n. ListR n i -> r) -> r -listrFromList [] k = k ZR -listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) - -listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i -listrIndex SZ (x ::: _) = x -listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs -listrIndex _ ZR = error "k + 1 <= 0" - - --- | An index into a rank-typed array. -type role IxR nominal representational -type IxR :: Nat -> Type -> Type -newtype IxR n i = IxR (ListR n i) -  deriving (Eq, Ord) -  deriving newtype (Functor, Foldable) - -pattern ZIR :: forall n i. () => n ~ 0 => IxR n i -pattern ZIR = IxR ZR - -pattern (:.:) -  :: forall {n1} {i}. -     forall n. (n + 1 ~ n1) -  => i -> IxR n i -> IxR n1 i -pattern i :.: sh <- IxR (unconsListR -> Just (UnconsListRRes (IxR -> sh) i)) -  where i :.: IxR sh = IxR (i ::: sh) -infixr 3 :.: - -{-# COMPLETE ZIR, (:.:) #-} - -type IIxR n = IxR n Int - -instance Show i => Show (IxR n i) where -  showsPrec _ (IxR l) = showListR shows l - - -type role ShR nominal representational -type ShR :: Nat -> Type -> Type -newtype ShR n i = ShR (ListR n i) -  deriving (Eq, Ord) -  deriving newtype (Functor, Foldable) - -pattern ZSR :: forall n i. () => n ~ 0 => ShR n i -pattern ZSR = ShR ZR - -pattern (:$:) -  :: forall {n1} {i}. -     forall n. (n + 1 ~ n1) -  => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- ShR (unconsListR -> Just (UnconsListRRes (ShR -> sh) i)) -  where i :$: (ShR sh) = ShR (i ::: sh) -infixr 3 :$: - -{-# COMPLETE ZSR, (:$:) #-} - -type IShR n = ShR n Int - -instance Show i => Show (ShR n i) where -  showsPrec _ (ShR l) = showListR shows l - - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ListR n i) where -  type Item (ListR n i) = i -  fromList = go (SNat @n) -    where -      go :: SNat n' -> [i] -> ListR n' i -      go SZ [] = ZR -      go (SS n) (i : is) = i ::: go n is -      go _ _ = error "IsList(ListR): Mismatched list length" -  toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (IxR n i) where -  type Item (IxR n i) = i -  fromList = IxR . IsList.fromList -  toList = Foldable.toList - --- | Untyped: length is checked at runtime. -instance KnownNat n => IsList (ShR n i) where -  type Item (ShR n i) = i -  fromList = ShR . IsList.fromList -  toList = Foldable.toList - - -type role ListS nominal representational -type ListS :: [Nat] -> (Nat -> Type) -> Type -data ListS sh f where -  ZS :: ListS '[] f -  (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f -deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) -deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) -infixr 3 ::$ - -instance (forall n. Show (f n)) => Show (ListS sh f) where -  showsPrec _ = showListS shows - -data UnconsListSRes f sh1 = -  forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) -unconsListS :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) -unconsListS (x ::$ sh') = Just (UnconsListSRes sh' x) -unconsListS ZS = Nothing - -fmapListS :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g -fmapListS _ ZS = ZS -fmapListS f (x ::$ xs) = f x ::$ fmapListS f xs - -foldListS :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m -foldListS _ ZS = mempty -foldListS f (x ::$ xs) = f x <> foldListS f xs - -showListS :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS -showListS f l = showString "[" . go "" l . showString "]" -  where -    go :: String -> ListS sh' f -> ShowS -    go _ ZS = id -    go prefix (x ::$ xs) = showString prefix . f x . go "," xs - -listSToList :: ListS sh (Const i) -> [i] -listSToList ZS = [] -listSToList (Const i ::$ is) = i : listSToList is - - --- | An index into a shape-typed array. --- --- For convenience, this contains regular 'Int's instead of bounded integers --- (traditionally called \"@Fin@\"). Note that because the shape of a --- shape-typed array is known statically, you can also retrieve the array shape --- from a 'KnownShape' dictionary. -type role IxS nominal representational -type IxS :: [Nat] -> Type -> Type -newtype IxS sh i = IxS (ListS sh (Const i)) -  deriving (Eq, Ord) - -pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i -pattern ZIS = IxS ZS - -pattern (:.$) -  :: forall {sh1} {i}. -     forall n sh. (KnownNat n, n : sh ~ sh1) -  => i -> IxS sh i -> IxS sh1 i -pattern i :.$ shl <- IxS (unconsListS -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) -  where i :.$ IxS shl = IxS (Const i ::$ shl) -infixr 3 :.$ - -{-# COMPLETE ZIS, (:.$) #-} - -type IIxS sh = IxS sh Int - -instance Show i => Show (IxS sh i) where -  showsPrec _ (IxS l) = showListS (\(Const i) -> shows i) l - -instance Functor (IxS sh) where -  fmap f (IxS l) = IxS (fmapListS (Const . f . getConst) l) - -instance Foldable (IxS sh) where -  foldMap f (IxS l) = foldListS (f . getConst) l - - --- | The shape of a shape-typed array given as a list of 'SNat' values. -type role ShS nominal -type ShS :: [Nat] -> Type -newtype ShS sh = ShS (ListS sh SNat) -  deriving (Eq, Ord) - -pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh -pattern ZSS = ShS ZS - -pattern (:$$) -  :: forall {sh1}. -     forall n sh. (KnownNat n, n : sh ~ sh1) -  => SNat n -> ShS sh -> ShS sh1 -pattern i :$$ shl <- ShS (unconsListS -> Just (UnconsListSRes (ShS -> shl) i)) -  where i :$$ ShS shl = ShS (i ::$ shl) - -infixr 3 :$$ - -{-# COMPLETE ZSS, (:$$) #-} - -instance Show (ShS sh) where -  showsPrec _ (ShS l) = showListS (shows . fromSNat) l - -lengthShS :: ShS sh -> Int -lengthShS (ShS l) = getSum (foldListS (\_ -> Sum 1) l) - -shSToList :: ShS sh -> [Int] -shSToList ZSS = [] -shSToList (sn :$$ sh) = fromSNat' sn : shSToList sh - - --- | Untyped: length is checked at runtime. -instance KnownShS sh => IsList (ListS sh (Const i)) where -  type Item (ListS sh (Const i)) = i -  fromList topl = go (knownShS @sh) topl -    where -      go :: ShS sh' -> [i] -> ListS sh' (Const i) -      go ZSS [] = ZS -      go (_ :$$ sh) (i : is) = Const i ::$ go sh is -      go _ _ = error $ "IsList(ListS): Mismatched list length (type says " -                         ++ show (lengthShS (knownShS @sh)) ++ ", list has length " -                         ++ show (length topl) ++ ")" -  toList = listSToList - --- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. -instance KnownShS sh => IsList (IxS sh i) where -  type Item (IxS sh i) = i -  fromList = IxS . IsList.fromList -  toList = Foldable.toList - --- | Untyped: length and values are checked at runtime. -instance KnownShS sh => IsList (ShS sh) where -  type Item (ShS sh) = Int -  fromList topl = ShS (go (knownShS @sh) topl) -    where -      go :: ShS sh' -> [Int] -> ListS sh' SNat -      go ZSS [] = ZS -      go (sn :$$ sh) (i : is) -        | i == fromSNat' sn = sn ::$ go sh is -        | otherwise = error $ "IsList(ShS): Value does not match typing (type says " -                                ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" -      go _ _ = error $ "IsList(ShS): Mismatched list length (type says " -                         ++ show (lengthShS (knownShS @sh)) ++ ", list has length " -                         ++ show (length topl) ++ ")" -  toList = shSToList - - --- | Wrapper type used as a tag to attach instances on. The instances on arrays --- of @'Primitive' a@ are more polymorphic than the direct instances for arrays --- of scalars; this means that if @orthotope@ supports an element type @T@ that --- this library does not (directly), it may just work if you use an array of --- @'Primitive' T@ instead. -newtype Primitive a = Primitive a - --- | Element types that are primitive; arrays of these types are just a newtype --- wrapper over an array. -class Storable a => PrimElt a where -  fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a -  toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) - -  default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a -  fromPrimitive = coerce - -  default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) -  toPrimitive = coerce - --- [PRIMITIVE ELEMENT TYPES LIST] -instance PrimElt Int -instance PrimElt Int64 -instance PrimElt Int32 -instance PrimElt CInt -instance PrimElt Float -instance PrimElt Double -instance PrimElt () - - --- | Mixed arrays: some dimensions are size-typed, some are not. Distributes --- over product-typed elements using a data family so that the full array is --- always in struct-of-arrays format. --- --- Built on top of 'XArray' which is built on top of @orthotope@, meaning that --- dimension permutations (e.g. 'mtranspose') are typically free. --- --- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type --- class. -type Mixed :: [Maybe Nat] -> Type -> Type -data family Mixed sh a --- NOTE: When opening up the Mixed abstraction, you might see dimension sizes --- that you're not supposed to see. In particular, you might see (nonempty) --- sizes of the elements of an empty array, which is information that should --- ostensibly not exist; the full array is still empty. - -data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) -  deriving (Show, Eq, Generic) - --- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays. -deriving instance (Ord a, Storable a) => Ord (Mixed '[] (Primitive a)) - -instance NFData a => NFData (Mixed sh (Primitive a)) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show, Eq, Generic) -newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show, Eq, Generic) -newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show, Eq, Generic) -newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show, Eq, Generic)  -- no content, orthotope optimises this (via Vector) --- etc. - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving instance Ord (Mixed '[] Int) ; instance NFData (Mixed sh Int) -deriving instance Ord (Mixed '[] Int64) ; instance NFData (Mixed sh Int64) -deriving instance Ord (Mixed '[] Int32) ; instance NFData (Mixed sh Int32) -deriving instance Ord (Mixed '[] CInt) ; instance NFData (Mixed sh CInt) -deriving instance Ord (Mixed '[] Float) ; instance NFData (Mixed sh Float) -deriving instance Ord (Mixed '[] Double) ; instance NFData (Mixed sh Double) -deriving instance Ord (Mixed '[] ()) ; instance NFData (Mixed sh ()) - -data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic) -deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) -instance (NFData (Mixed sh a), NFData (Mixed sh b)) => NFData (Mixed sh (a, b)) --- etc., larger tuples (perhaps use generics to allow arbitrary product types) - -data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic) -deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) -instance NFData (Mixed (sh1 ++ sh2) a) => NFData (Mixed sh1 (Mixed sh2 a)) - - --- | Internal helper data family mirroring 'Mixed' that consists of mutable --- vectors instead of 'XArray's. -type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type -data family MixedVecs s sh a - -newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) -newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) -newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) -newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) -newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) -newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) -newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ())  -- no content, MVector optimises this --- etc. - -data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) --- etc. - -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) - - --- | Tree giving the shape of every array component. -type family ShapeTree a where -  ShapeTree (a, b) = (ShapeTree a, ShapeTree b) -  ShapeTree (Mixed sh a) = (IShX sh, ShapeTree a) -  ShapeTree (Ranked n a) = (IShR n, ShapeTree a) -  ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - -  -- to avoid having to list all of the primitive types: -  ShapeTree _ = () - - --- | Allowable element types in a mixed array, and by extension in a 'Ranked' or --- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' --- a@; see the documentation for 'Primitive' for more details. -class Elt a where -  -- ====== PUBLIC METHODS ====== -- - -  mshape :: Mixed sh a -> IShX sh -  mindex :: Mixed sh a -> IIxX sh -> a -  mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a -  mscalar :: a -> Mixed '[] a - -  -- | All arrays in the list, even subarrays inside @a@, must have the same -  -- shape; if they do not, a runtime error will be thrown. See the -  -- documentation of 'mgenerate' for more information about this restriction. -  -- Furthermore, the length of the list must correspond with @n@: if @n@ is -  -- @Just m@ and @m@ does not equal the length of the list, a runtime error is -  -- thrown. -  -- -  -- Consider also 'mfromListPrim', which can avoid intermediate arrays. -  mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a - -  mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] - -  -- | Note: this library makes no particular guarantees about the shapes of -  -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the -  -- full 'XArray' and as such you can distinguish different empty arrays by -  -- the "shapes" of their elements. This information is meaningless, so you -  -- should not use it. -  mlift :: forall sh1 sh2. -           StaticShX sh2 -        -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -        -> Mixed sh1 a -> Mixed sh2 a - -  -- | See the documentation for 'mlift'. -  mlift2 :: forall sh1 sh2 sh3. -            StaticShX sh3 -         -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -         -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a - -  mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 -        => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a - -  mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) -             => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a - -  -- ====== PRIVATE METHODS ====== -- - -  mshapeTree :: a -> ShapeTree a - -  mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool - -  mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool - -  mshowShapeTree :: Proxy a -> ShapeTree a -> String - -  -- | Given the shape of this array, an index and a value, write the value at -  -- that index in the vectors. -  mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () - -  -- | Given the shape of this array, an index and a value, write the value at -  -- that index in the vectors. -  mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () - -  -- | Given the shape of this array, finalise the vectors into 'XArray's. -  mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) - - --- | Element types for which we have evidence of the (static part of the) shape --- in a type class constraint. Compare the instance contexts of the instances --- of this class with those of 'Elt': some instances have an additional --- "known-shape" constraint. --- --- This class is (currently) only required for 'mgenerate' / 'rgenerate' / --- 'sgenerate'. -class Elt a => KnownElt a where -  -- | Create an empty array. The given shape must have size zero; this may or may not be checked. -  memptyArray :: IShX sh -> Mixed sh a - -  -- | Create uninitialised vectors for this array type, given the shape of -  -- this vector and an example for the contents. -  mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) - -  mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) - - --- Arrays of scalars are basically just arrays of scalars. -instance Storable a => Elt (Primitive a) where -  mshape (M_Primitive sh _) = sh -  mindex (M_Primitive _ a) i = Primitive (X.index a i) -  mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) -  mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) -  mfromListOuter l@(arr1 :| _) = -    let sh = SUnknown (length l) :$% mshape arr1 -    in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) -  mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) - -  mlift :: forall sh1 sh2. -           StaticShX sh2 -        -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) -        -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -  mlift ssh2 f (M_Primitive _ a) -    | Refl <- lemAppNil @sh1 -    , Refl <- lemAppNil @sh2 -    , let result = f ZKX a -    = M_Primitive (X.shape ssh2 result) result - -  mlift2 :: forall sh1 sh2 sh3. -            StaticShX sh3 -         -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) -         -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) -  mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) -    | Refl <- lemAppNil @sh1 -    , Refl <- lemAppNil @sh2 -    , Refl <- lemAppNil @sh3 -    , let result = f ZKX a b -    = M_Primitive (X.shape ssh3 result) result - -  mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 -        => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) -  mcast ssh1 sh2 _ (M_Primitive sh1' arr) = -    let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' -    in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) - -  mtranspose perm (M_Primitive sh arr) = -    M_Primitive (shxPermutePrefix perm sh) -                (X.transpose (ssxFromShape sh) perm arr) - -  mshapeTree _ = () -  mshapeTreeEq _ () () = True -  mshapeTreeEmpty _ () = False -  mshowShapeTree _ () = "()" -  mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x - -  -- TODO: this use of toVector is suboptimal -  mvecsWritePartial -    :: forall sh' sh s. -       IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () -  mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do -    let arrsh = X.shape (ssxFromShape sh') arr -        offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) -    VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) - -  mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Int instance Elt Int -deriving via Primitive Int64 instance Elt Int64 -deriving via Primitive Int32 instance Elt Int32 -deriving via Primitive CInt instance Elt CInt -deriving via Primitive Double instance Elt Double -deriving via Primitive Float instance Elt Float -deriving via Primitive () instance Elt () - -instance Storable a => KnownElt (Primitive a) where -  memptyArray sh = M_Primitive sh (X.empty sh) -  mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) -  mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Int instance KnownElt Int -deriving via Primitive Int64 instance KnownElt Int64 -deriving via Primitive Int32 instance KnownElt Int32 -deriving via Primitive CInt instance KnownElt CInt -deriving via Primitive Double instance KnownElt Double -deriving via Primitive Float instance KnownElt Float -deriving via Primitive () instance KnownElt () - --- Arrays of pairs are pairs of arrays. -instance (Elt a, Elt b) => Elt (a, b) where -  mshape (M_Tup2 a _) = mshape a -  mindex (M_Tup2 a b) i = (mindex a i, mindex b i) -  mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) -  mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) -  mfromListOuter l = -    M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) -           (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) -  mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) -  mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) -  mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) - -  mcast ssh1 sh2 psh' (M_Tup2 a b) = -    M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) - -  mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) - -  mshapeTree (x, y) = (mshapeTree x, mshapeTree y) -  mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' -  mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 -  mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" -  mvecsWrite sh i (x, y) (MV_Tup2 a b) = do -    mvecsWrite sh i x a -    mvecsWrite sh i y b -  mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do -    mvecsWritePartial sh i x a -    mvecsWritePartial sh i y b -  mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b - -instance (KnownElt a, KnownElt b) => KnownElt (a, b) where -  memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) -  mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y -  mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) - --- Arrays of arrays are just arrays, but with more dimensions. -instance Elt a => Elt (Mixed sh' a) where -  -- TODO: this is quadratic in the nesting depth because it repeatedly -  -- truncates the shape vector to one a little shorter. Fix with a -  -- moverlongShape method, a prefix of which is mshape. -  mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh -  mshape (M_Nest sh arr) -    = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) - -  mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a -  mindex (M_Nest _ arr) i = mindexPartial arr i - -  mindexPartial :: forall sh1 sh2. -                   Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) -  mindexPartial (M_Nest sh arr) i -    | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') -    = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) - -  mscalar = M_Nest ZSX - -  mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) -  mfromListOuter l@(arr :| _) = -    M_Nest (SUnknown (length l) :$% mshape arr) -           (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) - -  mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) - -  mlift :: forall sh1 sh2. -           StaticShX sh2 -        -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) -        -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -  mlift ssh2 f (M_Nest sh1 arr) = -    let result = mlift (ssxAppend ssh2 ssh') f' arr -        (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) -    in M_Nest sh2 result -    where -      ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) - -      f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -      f' sshT -        | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) -        , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) -        = f (ssxAppend ssh' sshT) - -  mlift2 :: forall sh1 sh2 sh3. -            StaticShX sh3 -         -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) -         -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) -  mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = -    let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 -        (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) -    in M_Nest sh3 result -    where -      ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) - -      f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b -      f' sshT -        | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) -        , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) -        , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) -        = f (ssxAppend ssh' sshT) - -  mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 -        => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) -  mcast ssh1 sh2 _ (M_Nest sh1T arr) -    | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') -    , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') -    = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T -      in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) - -  mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) -             => Perm is -> Mixed sh (Mixed sh' a) -             -> Mixed (PermutePrefix is sh) (Mixed sh' a) -  mtranspose perm (M_Nest sh arr) -    | let sh' = shxDropSh @sh @sh' (mshape arr) sh -    , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') -    , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) -    , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') -    , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') -    , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') -    = M_Nest (shxPermutePrefix perm sh) -             (mtranspose perm arr) - -  mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) -  mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) - -  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - -  mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - -  mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - -  mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs - -  mvecsWritePartial :: forall sh1 sh2 s. -                       IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) -                    -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -                    -> ST s () -  mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) -    | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') -    = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs - -  mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs - -instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where -  memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh')))) - -  mvecsUnsafeNew sh example -    | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) -    | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) -    where -      sh' = mshape example - -  mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) - - --- | Create an array given a size and a function that computes the element at a --- given index. --- --- __WARNING__: It is required that every @a@ returned by the argument to --- 'mgenerate' has the same shape. For example, the following will throw a --- runtime error: --- --- > foo :: Mixed [Nothing] (Mixed [Nothing] Double) --- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) -> --- >         mgenerate (i :.: ZIR) $ \(j :.: ZIR) -> --- >           ... --- --- because the size of the inner 'mgenerate' is not always the same (it depends --- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so --- the entire hierarchy (after distributing out tuples) must be a rectangular --- array. The type of 'mgenerate' allows this requirement to be broken very --- easily, hence the runtime check. -mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f = case shxEnum sh of -  [] -> memptyArray sh -  firstidx : restidxs -> -    let firstelem = f (ixxZero' sh) -        shapetree = mshapeTree firstelem -    in if mshapeTreeEmpty (Proxy @a) shapetree -         then memptyArray sh -         else runST $ do -                vecs <- mvecsUnsafeNew sh firstelem -                mvecsWrite sh firstidx firstelem vecs -                -- TODO: This is likely fine if @a@ is big, but if @a@ is a -                -- scalar this array copying inefficient. Should improve this. -                forM_ restidxs $ \idx -> do -                  let val = f idx -                  when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ -                    error "Data.Array.Nested mgenerate: generated values do not have equal shapes" -                  mvecsWrite sh idx val vecs -                mvecsFreeze sh vecs - -msumOuter1P :: forall sh n a. (Storable a, NumElt a) -            => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) -msumOuter1P (M_Primitive (n :$% sh) arr) = -  let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX -  in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) - -msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) -           => Mixed (n : sh) a -> Mixed sh a -msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive - -mappend :: forall n m sh a. Elt a -        => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a -mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 -  where -    sn :$% sh = mshape arr1 -    sm :$% _ = mshape arr2 -    ssh = ssxFromShape sh -    snm :: SMayNat () SNat (AddMaybe n m) -    snm = case (sn, sm) of -            (SUnknown{}, _) -> SUnknown () -            (SKnown{}, SUnknown{}) -> SUnknown () -            (SKnown n, SKnown m) -> SKnown (snatPlus n m) - -    f :: forall sh' b. Storable b -      => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b -    f ssh' = X.append (ssxAppend ssh ssh') - -mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) -mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) - -mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a -mfromVector sh v = fromPrimitive (mfromVectorP sh v) - -mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a -mtoVectorP (M_Primitive _ v) = X.toVector v - -mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a -mtoVector arr = mtoVectorP (toPrimitive arr) - -mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a -mfromList1 = mfromListOuter . fmap mscalar  -- TODO: optimise? - -mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromList1Prim l = -  let ssh = SUnknown () :!% ZKX -      xarr = X.fromList1 ssh l -  in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mtoList1 :: Elt a => Mixed '[n] a -> [a] -mtoList1 = map munScalar . mtoListOuter - -mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a -mfromListPrim l = -  let ssh = SUnknown () :!% ZKX -      xarr = X.fromList1 ssh l -  in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a -mfromListPrimLinear sh l = -  let M_Primitive _ xarr = toPrimitive (mfromListPrim l) -  in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) - -munScalar :: Elt a => Mixed '[] a -> a -munScalar arr = mindex arr ZIX - -mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) -         => StaticShX sh -> IShX sh2 -         -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) -         -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) -mrerankP ssh sh2 f (M_Primitive sh arr) = -  let sh1 = shxDropSSX sh ssh -  in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) -                 (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) -                           (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) -                           arr) - --- | See the caveats at @X.rerank@. -mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) -        => StaticShX sh -> IShX sh2 -        -> (Mixed sh1 a -> Mixed sh2 b) -        -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b -mrerank ssh sh2 f (toPrimitive -> arr) = -  fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr - -mreplicate :: forall sh sh' a. Elt a -           => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a -mreplicate sh arr = -  let ssh' = ssxFromShape (mshape arr) -  in mlift (ssxAppend (ssxFromShape sh) ssh') -           (\(sshT :: StaticShX shT) -> -              case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of -                Refl -> X.replicate sh (ssxAppend ssh' sshT)) -           arr - -mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) -mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) - -mreplicateScal :: forall sh a. PrimElt a -               => IShX sh -> a -> Mixed sh a -mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) - -mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a -mslice i n arr = -  let _ :$% sh = mshape arr -  in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr - -msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a -msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr - -mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr - -mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a -mreshape sh' arr = -  mlift (ssxFromShape sh') -        (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') -        arr - -miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a -miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) - -masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) -masXArrayPrimP (M_Primitive sh arr) = (sh, arr) - -masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a) -masXArrayPrim = masXArrayPrimP . toPrimitive - -mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a) -mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr - -mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a -mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP - -mliftPrim :: PrimElt a -          => (a -> a) -          -> Mixed sh a -> Mixed sh a -mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) - -mliftPrim2 :: PrimElt a -           => (a -> a -> a) -           -> Mixed sh a -> Mixed sh a -> Mixed sh a -mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = -  fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) - -mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a -mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (srankSh sh) arr)) - -mliftNumElt2 :: PrimElt a -             => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -             -> Mixed sh a -> Mixed sh a -> Mixed sh a -mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) -  | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (srankSh sh1) arr1 arr2)) -  | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 - -instance (NumElt a, PrimElt a) => Num (Mixed sh a) where -  (+) = mliftNumElt2 numEltAdd -  (-) = mliftNumElt2 numEltSub -  (*) = mliftNumElt2 numEltMul -  negate = mliftNumElt1 numEltNeg -  abs = mliftNumElt1 numEltAbs -  signum = mliftNumElt1 numEltSignum -  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" - -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where -  fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" -  recip = mliftNumElt1 floatEltRecip -  (/) = mliftNumElt2 floatEltDiv - -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where -  pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" -  exp = mliftNumElt1 floatEltExp -  log = mliftNumElt1 floatEltLog -  sqrt = mliftNumElt1 floatEltSqrt - -  (**) = mliftNumElt2 floatEltPow -  logBase = mliftNumElt2 floatEltLogbase - -  sin = mliftNumElt1 floatEltSin -  cos = mliftNumElt1 floatEltCos -  tan = mliftNumElt1 floatEltTan -  asin = mliftNumElt1 floatEltAsin -  acos = mliftNumElt1 floatEltAcos -  atan = mliftNumElt1 floatEltAtan -  sinh = mliftNumElt1 floatEltSinh -  cosh = mliftNumElt1 floatEltCosh -  tanh = mliftNumElt1 floatEltTanh -  asinh = mliftNumElt1 floatEltAsinh -  acosh = mliftNumElt1 floatEltAcosh -  atanh = mliftNumElt1 floatEltAtanh -  log1p = mliftNumElt1 floatEltLog1p -  expm1 = mliftNumElt1 floatEltExpm1 -  log1pexp = mliftNumElt1 floatEltLog1pexp -  log1mexp = mliftNumElt1 floatEltLog1mexp - -mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a -mtoRanked arr -  | Refl <- lemAppNil @sh -  , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat)) -  , Refl <- lemRankReplicate (srankSh (mshape arr)) -  = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) -  where -    convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) -    convSh ZSX = ZSX -    convSh (smn :$% (sh :: IShX sh'T)) -      | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) -      = SUnknown (fromSMayNat' smn) :$% convSh sh - -mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') -              => Mixed sh a -> ShS sh' -> Shaped sh' a -mcastToShaped arr targetsh -  | Refl <- lemAppNil @sh -  , Refl <- lemAppNil @(MapJust sh') -  , Refl <- lemRankMapJust targetsh -  = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) - - --- | A rank-typed array: the number of dimensions of the array (its /rank/) is --- represented on the type level as a 'Nat'. --- --- Valid elements of a ranked arrays are described by the 'Elt' type class. --- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are --- supported (and are represented as a single, flattened, struct-of-arrays --- array internally). --- --- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. -type Ranked :: Nat -> Type -> Type -newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) -deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) -deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) -deriving instance Ord (Mixed '[] a) => Ord (Ranked 0 a) -deriving instance NFData (Mixed (Replicate n Nothing) a) => NFData (Ranked n a) - --- | A shape-typed array: the full shape of the array (the sizes of its --- dimensions) is represented on the type level as a list of 'Nat's. Note that --- these are "GHC.TypeLits" naturals, because we do not need induction over --- them and we want very large arrays to be possible. --- --- Like for 'Ranked', the valid elements are described by the 'Elt' type class, --- and 'Shaped' itself is again an instance of 'Elt' as well. --- --- 'Shaped' is a newtype around a 'Mixed' of 'Just's. -type Shaped :: [Nat] -> Type -> Type -newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) -deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) -deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) -deriving instance Ord (Mixed '[] a) => Ord (Shaped '[] a) -deriving instance NFData (Mixed (MapJust sh) a) => NFData (Shaped sh a) - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Ranked n   a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) -deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a)) -newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh'        ) a)) -deriving instance Show (Mixed sh (Mixed (MapJust sh'        ) a)) => Show (Mixed sh (Shaped sh' a)) - -newtype instance MixedVecs s sh (Ranked n   a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) -newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh'        ) a)) - - --- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; --- these instances allow them to also be used as elements of arrays, thus --- making them first-class in the API. -instance Elt a => Elt (Ranked n a) where -  mshape (M_Ranked arr) = mshape arr -  mindex (M_Ranked arr) i = Ranked (mindex arr i) - -  mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) -  mindexPartial (M_Ranked arr) i = -    coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ -        mindexPartial arr i - -  mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) - -  mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) -  mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) - -  mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] -  mtoListOuter (M_Ranked arr) = -    coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) - -  mlift :: forall sh1 sh2. -           StaticShX sh2 -        -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -        -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -  mlift ssh2 f (M_Ranked arr) = -    coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ -      mlift ssh2 f arr - -  mlift2 :: forall sh1 sh2 sh3. -            StaticShX sh3 -         -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -         -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) -  mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = -    coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ -        mlift2 ssh3 f arr1 arr2 - -  mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr) - -  mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) - -  mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) - -  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - -  mshapeTreeEmpty _ (sh, t) = shapeSizeR sh == 0 && mshapeTreeEmpty (Proxy @a) t - -  mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - -  mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () -  mvecsWrite sh idx (Ranked arr) vecs = -    mvecsWrite sh idx arr -      (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) -         vecs) - -  mvecsWritePartial :: forall sh sh' s. -                       IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) -                    -> MixedVecs s (sh ++ sh') (Ranked n a) -                    -> ST s () -  mvecsWritePartial sh idx arr vecs = -    mvecsWritePartial sh idx -      (coerce @(Mixed sh' (Ranked n a)) -              @(Mixed sh' (Mixed (Replicate n Nothing) a)) -         arr) -      (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) -              @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) -         vecs) - -  mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) -  mvecsFreeze sh vecs = -    coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) -           @(Mixed sh (Ranked n a)) -      <$> mvecsFreeze sh -            (coerce @(MixedVecs s sh (Ranked n a)) -                    @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) -                    vecs) - -instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where -  memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a) -  memptyArray i -    | Dict <- lemKnownReplicate (SNat @n) -    = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ -        memptyArray i - -  mvecsUnsafeNew idx (Ranked arr) -    | Dict <- lemKnownReplicate (SNat @n) -    = MV_Ranked <$> mvecsUnsafeNew idx arr - -  mvecsNewEmpty _ -    | Dict <- lemKnownReplicate (SNat @n) -    = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) - --- sshapeKnown :: ShS sh -> Dict KnownShape sh --- sshapeKnown ZSS = Dict --- sshapeKnown (SNat :$$ sh) | Dict <- sshapeKnown sh = Dict - -lemCommMapJustApp :: forall sh1 sh2. ShS sh1 -> Proxy sh2 -                  -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemCommMapJustApp ZSS _ = Refl -lemCommMapJustApp (_ :$$ sh) p | Refl <- lemCommMapJustApp sh p = Refl - -instance Elt a => Elt (Shaped sh a) where -  mshape (M_Shaped arr) = mshape arr -  mindex (M_Shaped arr) i = Shaped (mindex arr i) - -  mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) -  mindexPartial (M_Shaped arr) i = -    coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ -      mindexPartial arr i - -  mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) - -  mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) -  mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) - -  mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] -  mtoListOuter (M_Shaped arr) -    = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) - -  mlift :: forall sh1 sh2. -           StaticShX sh2 -        -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -        -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -  mlift ssh2 f (M_Shaped arr) = -    coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ -      mlift ssh2 f arr - -  mlift2 :: forall sh1 sh2 sh3. -            StaticShX sh3 -         -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) -         -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) -  mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = -    coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ -      mlift2 ssh3 f arr1 arr2 - -  mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr) - -  mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) - -  mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) - -  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - -  mshapeTreeEmpty _ (sh, t) = shapeSizeS sh == 0 && mshapeTreeEmpty (Proxy @a) t - -  mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - -  mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () -  mvecsWrite sh idx (Shaped arr) vecs = -    mvecsWrite sh idx arr -      (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) -         vecs) - -  mvecsWritePartial :: forall sh1 sh2 s. -                       IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) -                    -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) -                    -> ST s () -  mvecsWritePartial sh idx arr vecs = -    mvecsWritePartial sh idx -      (coerce @(Mixed sh2 (Shaped sh a)) -              @(Mixed sh2 (Mixed (MapJust sh) a)) -         arr) -      (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) -              @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) -         vecs) - -  mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) -  mvecsFreeze sh vecs = -    coerce @(Mixed sh' (Mixed (MapJust sh) a)) -           @(Mixed sh' (Shaped sh a)) -      <$> mvecsFreeze sh -            (coerce @(MixedVecs s sh' (Shaped sh a)) -                    @(MixedVecs s sh' (Mixed (MapJust sh) a)) -                    vecs) - --- | Evidence for the static part of a shape. This pops up only when you are --- polymorphic in the element type of an array. -type KnownShS :: [Nat] -> Constraint -class KnownShS sh where knownShS :: ShS sh -instance KnownShS '[] where knownShS = ZSS -instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS - -lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) -lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) -  where -    go :: ShS sh' -> StaticShX (MapJust sh') -    go ZSS = ZKX -    go (n :$$ sh) = SKnown n :!% go sh - -instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where -  memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) -  memptyArray i -    | Dict <- lemKnownMapJust (Proxy @sh) -    = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ -        memptyArray i - -  mvecsUnsafeNew idx (Shaped arr) -    | Dict <- lemKnownMapJust (Proxy @sh) -    = MV_Shaped <$> mvecsUnsafeNew idx arr - -  mvecsNewEmpty _ -    | Dict <- lemKnownMapJust (Proxy @sh) -    = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) - - --- ====== API OF RANKED ARRAYS ====== -- - -arithPromoteRanked :: forall n a. PrimElt a -                   => (forall sh. Mixed sh a -> Mixed sh a) -                   -> Ranked n a -> Ranked n a -arithPromoteRanked = coerce - -arithPromoteRanked2 :: forall n a. PrimElt a -                    => (forall sh. Mixed sh a -> Mixed sh a -> Mixed sh a) -                    -> Ranked n a -> Ranked n a -> Ranked n a -arithPromoteRanked2 = coerce - -instance (NumElt a, PrimElt a) => Num (Ranked n a) where -  (+) = arithPromoteRanked2 (+) -  (-) = arithPromoteRanked2 (-) -  (*) = arithPromoteRanked2 (*) -  negate = arithPromoteRanked negate -  abs = arithPromoteRanked abs -  signum = arithPromoteRanked signum -  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal" - -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Ranked n a) where -  fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal" -  recip = arithPromoteRanked recip -  (/) = arithPromoteRanked2 (/) - -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Ranked n a) where -  pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal" -  exp = arithPromoteRanked exp -  log = arithPromoteRanked log -  sqrt = arithPromoteRanked sqrt -  (**) = arithPromoteRanked2 (**) -  logBase = arithPromoteRanked2 logBase -  sin = arithPromoteRanked sin -  cos = arithPromoteRanked cos -  tan = arithPromoteRanked tan -  asin = arithPromoteRanked asin -  acos = arithPromoteRanked acos -  atan = arithPromoteRanked atan -  sinh = arithPromoteRanked sinh -  cosh = arithPromoteRanked cosh -  tanh = arithPromoteRanked tanh -  asinh = arithPromoteRanked asinh -  acosh = arithPromoteRanked acosh -  atanh = arithPromoteRanked atanh -  log1p = arithPromoteRanked GHC.Float.log1p -  expm1 = arithPromoteRanked GHC.Float.expm1 -  log1pexp = arithPromoteRanked GHC.Float.log1pexp -  log1mexp = arithPromoteRanked GHC.Float.log1mexp - -zeroIxR :: SNat n -> IIxR n -zeroIxR SZ = ZIR -zeroIxR (SS n) = 0 :.: zeroIxR n - -ixCvtXR :: IIxX sh -> IIxR (Rank sh) -ixCvtXR ZIX = ZIR -ixCvtXR (n :.% idx) = n :.: ixCvtXR idx - -shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n -shCvtXR' ZSX = -  castWith (subst2 (unsafeCoerce Refl :: 0 :~: n)) -    ZSR -shCvtXR' (n :$% (idx :: IShX sh)) -  | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = -  castWith (subst2 (lem1 @sh Refl)) -    (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) -  where -    lem1 :: forall sh' n' k. -            k : sh' :~: Replicate n' Nothing -         -> Rank sh' + 1 :~: n' -    lem1 Refl = unsafeCoerce Refl - -    lem2 :: k : sh :~: Replicate n Nothing -         -> sh :~: Replicate (Rank sh) Nothing -    lem2 Refl = unsafeCoerce Refl - -ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) -ixCvtRX ZIR = ZIX -ixCvtRX (n :.: (idx :: IxR m Int)) = -  castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) -    (n :.% ixCvtRX idx) - -shCvtRX :: IShR n -> IShX (Replicate n Nothing) -shCvtRX ZSR = ZSX -shCvtRX (n :$: (idx :: ShR m Int)) = -  castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) -    (SUnknown n :$% shCvtRX idx) - -shapeSizeR :: IShR n -> Int -shapeSizeR ZSR = 1 -shapeSizeR (n :$: sh) = n * shapeSizeR sh - - -rshape :: forall n a. Elt a => Ranked n a -> IShR n -rshape (Ranked arr) = shCvtXR' (mshape arr) - -rindex :: Elt a => Ranked n a -> IIxR n -> a -rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) - -snatFromListR :: ListR n i -> SNat n -snatFromListR ZR = SNat -snatFromListR (_ ::: (l :: ListR n i)) | SNat <- snatFromListR l, Dict <- knownNatSucc @n = SNat - -snatFromIxR :: IxR n i -> SNat n -snatFromIxR (IxR sh) = snatFromListR sh - -snatFromShR :: ShR n i -> SNat n -snatFromShR (ShR sh) = snatFromListR sh - -rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a -rindexPartial (Ranked arr) idx = -  Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) -            (castWith (subst2 (lemReplicatePlusApp (snatFromIxR idx) (Proxy @m) (Proxy @Nothing))) arr) -            (ixCvtRX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a -rgenerate sh f -  | sn@SNat <- snatFromShR sh -  , Dict <- lemKnownReplicate sn -  , Refl <- lemRankReplicate sn -  = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) - --- | See the documentation of 'mlift'. -rlift :: forall n1 n2 a. Elt a -      => SNat n2 -      -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) -      -> Ranked n1 a -> Ranked n2 a -rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) - --- | See the documentation of 'mlift2'. -rlift2 :: forall n1 n2 n3 a. Elt a -       => SNat n3 -       -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) -       -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a -rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) - -rsumOuter1P :: forall n a. -               (Storable a, NumElt a) -            => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked arr) -  | Refl <- lemReplicateSucc @(Nothing @Nat) @n -  = Ranked (msumOuter1P arr) - -rsumOuter1 :: forall n a. (NumElt a, PrimElt a) -           => Ranked (n + 1) a -> Ranked n a -rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive - -applyPermR :: forall i n. [Int] -> ListR n i -> ListR n i -applyPermR = \perm sh -> -  listrFromList perm $ \sperm -> -    case (snatFromListR sperm, snatFromListR sh) of -      (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of -        LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post -        EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post -        GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" -                       ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" -  where -    listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) -    listrSplitAt SZ sh = (ZR, sh) -    listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) -    listrSplitAt SS{} ZR = error "m' + 1 <= 0" - -    applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i -    applyPermRFull _ ZR _ = ZR -    applyPermRFull sm@SNat (i ::: perm) l = -      TypeNats.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> -        case cmpNat (SNat @(idx + 1)) sm of -          LTI -> listrIndex si l ::: applyPermRFull sm perm l -          EQI -> listrIndex si l ::: applyPermRFull sm perm l -          GTI -> error "applyPermR: Index in permutation out of range" - -applyPermIxR :: forall n i. [Int] -> IxR n i -> IxR n i -applyPermIxR = coerce (applyPermR @i) - -applyPermShR :: forall n i. [Int] -> ShR n i -> ShR n i -applyPermShR = coerce (applyPermR @i) - -rtranspose :: forall n a. Elt a => [Int] -> Ranked n a -> Ranked n a -rtranspose perm arr -  | sn@SNat <- snatFromShR (rshape arr) -  , Dict <- lemKnownReplicate sn -  , length perm <= fromIntegral (natVal (Proxy @n)) -  = rlift sn -          (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm) -          arr -  | otherwise -  = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" - -rappend :: forall n a. Elt a -        => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a -rappend arr1 arr2 -  | sn@SNat <- snatFromShR (rshape arr1) -  , Dict <- lemKnownReplicate sn -  , Refl <- lemReplicateSucc @(Nothing @Nat) @n -  = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) -      arr1 arr2 - -rscalar :: Elt a => a -> Ranked 0 a -rscalar x = Ranked (mscalar x) - -rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) -rfromVectorP sh v -  | Dict <- lemKnownReplicate (snatFromShR sh) -  = Ranked (mfromVectorP (shCvtRX sh) v) - -rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a -rfromVector sh v = rfromPrimitive (rfromVectorP sh v) - -rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a -rtoVectorP = coerce mtoVectorP - -rtoVector :: PrimElt a => Ranked n a -> VS.Vector a -rtoVector = coerce mtoVector - -rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a -rfromListOuter l -  | Refl <- lemReplicateSucc @(Nothing @Nat) @n -  = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) - -rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a -rfromList1 l = Ranked (mfromList1 l) - -rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a -rfromList1Prim l = Ranked (mfromList1Prim l) - -rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] -rtoListOuter (Ranked arr) -  | Refl <- lemReplicateSucc @(Nothing @Nat) @n -  = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) - -rtoList1 :: Elt a => Ranked 1 a -> [a] -rtoList1 = map runScalar . rtoListOuter - -rfromListPrim :: PrimElt a => [a] -> Ranked 1 a -rfromListPrim l = -  let ssh = SUnknown () :!% ZKX -      xarr = X.fromList1 ssh l -  in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr - -rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a -rfromListPrimLinear sh l = -  let M_Primitive _ xarr = toPrimitive (mfromListPrim l) -  in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr) - -rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a -rfromOrthotope sn arr -  | Refl <- lemRankReplicate sn -  = let xarr = XArray arr -    in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) - -runScalar :: Elt a => Ranked 0 a -> a -runScalar arr = rindex arr ZIR - -rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) -         => SNat n -> IShR n2 -         -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) -         -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) -rrerankP sn sh2 f (Ranked arr) -  | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) -  , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) -  = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2) -                     (\a -> let Ranked r = f (Ranked a) in r) -                     arr) - --- | If there is a zero-sized dimension in the @n@-prefix of the shape of the --- input array, then there is no way to deduce the full shape of the output --- array (more precisely, the @n2@ part): that could only come from calling --- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in --- this case; we choose to fill the @n2@ part of the output shape with zeros. --- --- For example, if: --- --- @ --- arr :: Ranked 5 Int   -- of shape [3, 0, 4, 2, 21] --- f :: Ranked 2 Int -> Ranked 3 Float --- @ --- --- then: --- --- @ --- rrerank _ _ _ f arr :: Ranked 5 Float --- @ --- --- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the --- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended --- to return an array with shape all-0 here (it probably didn't), but there is --- no better number to put here absent a subarray of the input to pass to @f@. -rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) -         => SNat n -> IShR n2 -         -> (Ranked n1 a -> Ranked n2 b) -         -> Ranked (n + n1) a -> Ranked (n + n2) b -rrerank ssh sh2 f (rtoPrimitive -> arr) = -  rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr - -rreplicate :: forall n m a. Elt a -           => IShR n -> Ranked m a -> Ranked (n + m) a -rreplicate sh (Ranked arr) -  | Refl <- lemReplicatePlusApp (snatFromShR sh) (Proxy @m) (Proxy @(Nothing @Nat)) -  = Ranked (mreplicate (shCvtRX sh) arr) - -rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) -rreplicateScalP sh x -  | Dict <- lemKnownReplicate (snatFromShR sh) -  = Ranked (mreplicateScalP (shCvtRX sh) x) - -rreplicateScal :: forall n a. PrimElt a -               => IShR n -> a -> Ranked n a -rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) - -rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a -rslice i n arr -  | Refl <- lemReplicateSucc @(Nothing @Nat) @n -  = rlift (snatFromShR (rshape arr)) -          (\_ -> X.sliceU i n) -          arr - -rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a -rrev1 arr = -  rlift (snatFromShR (rshape arr)) -        (\(_ :: StaticShX sh') -> -          case lemReplicateSucc @(Nothing @Nat) @n of -            Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) -        arr - -rreshape :: forall n n' a. Elt a -         => IShR n' -> Ranked n a -> Ranked n' a -rreshape sh' rarr@(Ranked arr) -  | Dict <- lemKnownReplicate (snatFromShR (rshape rarr)) -  , Dict <- lemKnownReplicate (snatFromShR sh') -  = Ranked (mreshape (shCvtRX sh') arr) - -riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a -riota n = TypeNats.withSomeSNat (fromIntegral n) $ mtoRanked . miota - -rasXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) -rasXArrayPrimP (Ranked arr) = first shCvtXR' (masXArrayPrimP arr) - -rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) -rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr) - -rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) -rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) - -rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a -rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) - -rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a -rcastToShaped (Ranked arr) targetsh -  | Refl <- lemRankReplicate (srankSh (shCvtSX targetsh)) -  , Refl <- lemRankMapJust targetsh -  = mcastToShaped arr targetsh - -rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a -rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) - -rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) -rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) - - --- ====== API OF SHAPED ARRAYS ====== -- - -arithPromoteShaped :: forall sh a. PrimElt a -                   => (forall shx. Mixed shx a -> Mixed shx a) -                   -> Shaped sh a -> Shaped sh a -arithPromoteShaped = coerce - -arithPromoteShaped2 :: forall sh a. PrimElt a -                    => (forall shx. Mixed shx a -> Mixed shx a -> Mixed shx a) -                    -> Shaped sh a -> Shaped sh a -> Shaped sh a -arithPromoteShaped2 = coerce - -instance (NumElt a, PrimElt a) => Num (Shaped sh a) where -  (+) = arithPromoteShaped2 (+) -  (-) = arithPromoteShaped2 (-) -  (*) = arithPromoteShaped2 (*) -  negate = arithPromoteShaped negate -  abs = arithPromoteShaped abs -  signum = arithPromoteShaped signum -  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal" - -instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Shaped sh a) where -  fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" -  recip = arithPromoteShaped recip -  (/) = arithPromoteShaped2 (/) - -instance (FloatElt a, NumElt a, PrimElt a) => Floating (Shaped sh a) where -  pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" -  exp = arithPromoteShaped exp -  log = arithPromoteShaped log -  sqrt = arithPromoteShaped sqrt -  (**) = arithPromoteShaped2 (**) -  logBase = arithPromoteShaped2 logBase -  sin = arithPromoteShaped sin -  cos = arithPromoteShaped cos -  tan = arithPromoteShaped tan -  asin = arithPromoteShaped asin -  acos = arithPromoteShaped acos -  atan = arithPromoteShaped atan -  sinh = arithPromoteShaped sinh -  cosh = arithPromoteShaped cosh -  tanh = arithPromoteShaped tanh -  asinh = arithPromoteShaped asinh -  acosh = arithPromoteShaped acosh -  atanh = arithPromoteShaped atanh -  log1p = arithPromoteShaped GHC.Float.log1p -  expm1 = arithPromoteShaped GHC.Float.expm1 -  log1pexp = arithPromoteShaped GHC.Float.log1pexp -  log1mexp = arithPromoteShaped GHC.Float.log1mexp - -zeroIxS :: ShS sh -> IIxS sh -zeroIxS ZSS = ZIS -zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh - -ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh -ixCvtXS ZSS ZIX = ZIS -ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx - -type family Tail l where -  Tail (_ : xs) = xs - -shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh -shCvtXS' ZSX = castWith (subst1 (unsafeCoerce Refl :: '[] :~: sh)) ZSS -shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) = -  castWith (subst1 (lem Refl)) $ -    n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerce Refl :: mjshT :~: MapJust (Tail sh))) -                                 idx) -  where -    lem :: forall sh1 sh' n. -           Just n : sh1 :~: MapJust sh' -        -> n : Tail sh' :~: sh' -    lem Refl = unsafeCoerce Refl -shCvtXS' (SUnknown _ :$% _) = error "impossible" - -ixCvtSX :: IIxS sh -> IIxX (MapJust sh) -ixCvtSX ZIS = ZIX -ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh - -shCvtSX :: ShS sh -> IShX (MapJust sh) -shCvtSX ZSS = ZSX -shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh - -shapeSizeS :: ShS sh -> Int -shapeSizeS ZSS = 1 -shapeSizeS (n :$$ sh) = fromSNat' n * shapeSizeS sh - - -sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh -sshape (Shaped arr) = shCvtXS' (mshape arr) - -sindex :: Elt a => Shaped sh a -> IIxS sh -> a -sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) - -shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh -shsTakeIx _ _ ZIS = ZSS -shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx - -sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a -sindexPartial sarr@(Shaped arr) idx = -  Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) -            (castWith (subst2 (lemCommMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) -            (ixCvtSX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a -sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) - --- | See the documentation of 'mlift'. -slift :: forall sh1 sh2 a. Elt a -      => ShS sh2 -      -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -      -> Shaped sh1 a -> Shaped sh2 a -slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) - --- | See the documentation of 'mlift'. -slift2 :: forall sh1 sh2 sh3 a. Elt a -       => ShS sh3 -       -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) -       -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a -slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) - -ssumOuter1P :: forall sh n a. (Storable a, NumElt a) -            => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) - -ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) -           => Shaped (n : sh) a -> Shaped sh a -ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive - -lemCommMapJustTakeLen :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) -lemCommMapJustTakeLen PNil _ = Refl -lemCommMapJustTakeLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl -lemCommMapJustTakeLen (_ `PCons` _) ZSS = error "TakeLen of empty" - -lemCommMapJustDropLen :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) -lemCommMapJustDropLen PNil _ = Refl -lemCommMapJustDropLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl -lemCommMapJustDropLen (_ `PCons` _) ZSS = error "DropLen of empty" - -lemCommMapJustIndex :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) -lemCommMapJustIndex SZ (_ :$$ _) = Refl -lemCommMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) -  | Refl <- lemCommMapJustIndex i sh -  , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) -  , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') -  = Refl -lemCommMapJustIndex _ ZSS = error "Index of empty" - -lemCommMapJustPermute :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) -lemCommMapJustPermute PNil _ = Refl -lemCommMapJustPermute (i `PCons` is) sh -  | Refl <- lemCommMapJustPermute is sh -  , Refl <- lemCommMapJustIndex i sh -  = Refl - -listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f -listsAppend ZS idx' = idx' -listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' - -listsTakeLen :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f -listsTakeLen PNil _ = ZS -listsTakeLen (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh -listsTakeLen (_ `PCons` _) ZS = error "Permutation longer than shape" - -listsDropLen :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f -listsDropLen PNil sh = sh -listsDropLen (_ `PCons` is) (_ ::$ sh) = listsDropLen is sh -listsDropLen (_ `PCons` _) ZS = error "Permutation longer than shape" - -listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f -listsPermute PNil _ = ZS -listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh) - -listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (Permute is shT) f -> ListS (Index i sh : Permute is shT) f -listsIndex _ _ SZ (n ::$ _) rest = n ::$ rest -listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) rest -  | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') -  = listsIndex p pT i sh rest -listsIndex _ _ _ ZS _ = error "Index into empty shape" - -shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) -shsTakeLen = coerce (listsTakeLen @SNat) - -shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) -shsPermute = coerce (listsPermute @SNat) - -shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (Permute is shT) -> ShS (Index i sh : Permute is shT) -shsIndex pis pshT = coerce (listsIndex @SNat pis pshT) - -applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f -applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLen perm sh)) (listsDropLen perm sh) - -applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i -applyPermIxS = coerce (applyPermS @(Const i)) - -applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) -applyPermShS = coerce (applyPermS @SNat) - -stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) -           => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a -stranspose perm sarr@(Shaped arr) -  | Refl <- lemRankMapJust (sshape sarr) -  , Refl <- lemCommMapJustTakeLen perm (sshape sarr) -  , Refl <- lemCommMapJustDropLen perm (sshape sarr) -  , Refl <- lemCommMapJustPermute perm (shsTakeLen perm (sshape sarr)) -  , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) -  = Shaped (mtranspose perm arr) - -sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a -sappend = coerce mappend - -sscalar :: Elt a => a -> Shaped '[] a -sscalar x = Shaped (mscalar x) - -sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) -sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) - -sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a -sfromVector sh v = sfromPrimitive (sfromVectorP sh v) - -stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a -stoVectorP = coerce mtoVectorP - -stoVector :: PrimElt a => Shaped sh a -> VS.Vector a -stoVector = coerce mtoVector - -sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l)) - -sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a -sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1 - -sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a -sfromList1Prim sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1Prim - -stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] -stoListOuter (Shaped arr) = coerce (mtoListOuter arr) - -stoList1 :: Elt a => Shaped '[n] a -> [a] -stoList1 = map sunScalar . stoListOuter - -sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a -sfromListPrim sn l -  | Refl <- lemAppNil @'[Just n] -  = let ssh = SUnknown () :!% ZKX -        xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) -    in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr - -sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a -sfromListPrimLinear sh l = -  let M_Primitive _ xarr = toPrimitive (mfromListPrim l) -  in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) - -sunScalar :: Elt a => Shaped '[] a -> a -sunScalar arr = sindex arr ZIS - -srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) -         => ShS sh -> ShS sh2 -         -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) -         -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) -srerankP sh sh2 f sarr@(Shaped arr) -  | Refl <- lemCommMapJustApp sh (Proxy @sh1) -  , Refl <- lemCommMapJustApp sh (Proxy @sh2) -  = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) -                     (shCvtSX sh2) -                     (\a -> let Shaped r = f (Shaped a) in r) -                     arr) - -srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) -        => ShS sh -> ShS sh2 -        -> (Shaped sh1 a -> Shaped sh2 b) -        -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b -srerank sh sh2 f (stoPrimitive -> arr) = -  sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr - -sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a -sreplicate sh (Shaped arr) -  | Refl <- lemCommMapJustApp sh (Proxy @sh') -  = Shaped (mreplicate (shCvtSX sh) arr) - -sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) -sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) - -sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a -sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) - -sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a -sslice i n@SNat arr = -  let _ :$$ sh = sshape arr -  in slift (n :$$ sh) (\_ -> X.slice i n) arr - -srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a -srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr - -sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a -sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) - -siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a -siota sn = Shaped (miota sn) - -sasXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) -sasXArrayPrimP (Shaped arr) = first shCvtXS' (masXArrayPrimP arr) - -sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) -sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr) - -sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) -sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) - -sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a -sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) - -stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a -stoRanked sarr@(Shaped arr) -  | Refl <- lemRankMapJust (sshape sarr) -  = mtoRanked arr - -sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a -sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) - -stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) -stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs new file mode 100644 index 0000000..c4fe066 --- /dev/null +++ b/src/Data/Array/Nested/Lemmas.hs @@ -0,0 +1,59 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module Data.Array.Nested.Lemmas where + +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits + +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types +import Data.Array.Nested.Shape + + +lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh +lemRankMapJust ZSS = Refl +lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl + +lemMapJustApp :: ShS sh1 -> Proxy sh2 +              -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 +lemMapJustApp ZSS _ = Refl +lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl + +lemMapJustTakeLen :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) +lemMapJustTakeLen PNil _ = Refl +lemMapJustTakeLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemMapJustTakeLen is sh = Refl +lemMapJustTakeLen (_ `PCons` _) ZSS = error "TakeLen of empty" + +lemMapJustDropLen :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) +lemMapJustDropLen PNil _ = Refl +lemMapJustDropLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemMapJustDropLen is sh = Refl +lemMapJustDropLen (_ `PCons` _) ZSS = error "DropLen of empty" + +lemMapJustIndex :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) +lemMapJustIndex SZ (_ :$$ _) = Refl +lemMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) +  | Refl <- lemMapJustIndex i sh +  , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) +  , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') +  = Refl +lemMapJustIndex _ ZSS = error "Index of empty" + +lemMapJustPermute :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) +lemMapJustPermute PNil _ = Refl +lemMapJustPermute (i `PCons` is) sh +  | Refl <- lemMapJustPermute is sh +  , Refl <- lemMapJustIndex i sh +  = Refl + +lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) +lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) +  where +    go :: ShS sh' -> StaticShX (MapJust sh') +    go ZSS = ZKX +    go (n :$$ sh) = SKnown n :!% go sh diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs new file mode 100644 index 0000000..84e16b3 --- /dev/null +++ b/src/Data/Array/Nested/Mixed.hs @@ -0,0 +1,741 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +module Data.Array.Nested.Mixed where + +import Control.DeepSeq (NFData) +import Control.Monad (forM_, when) +import Control.Monad.ST +import Data.Array.RankedS qualified as S +import Data.Coerce +import Data.Foldable (toList) +import Data.Int +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty(..)) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Data.Vector.Storable.Mutable qualified as VSM +import Foreign.C.Types (CInt) +import Foreign.Storable (Storable) +import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Mixed (XArray(..)) +import Data.Array.Mixed qualified as X +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Lemmas + + +-- Invariant in the API +-- ==================== +-- +-- In the underlying XArray, there is some shape for elements of an empty +-- array. For example, for this array: +-- +--   arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) +--   rshape arr == 0 :.: 0 :.: 0 :.: ZIR +-- +-- the two underlying XArrays have a shape, and those shapes might be anything. +-- The invariant is that these element shapes are unobservable in the API. +-- (This is possible because you ought to not be able to get to such an element +-- without indexing out of bounds.) +-- +-- Note, though, that the converse situation may arise: the outer array might +-- be nonempty but then the inner arrays might. This is fine, an invariant only +-- applies if the _outer_ array is empty. +-- +-- TODO: can we enforce that the elements of an empty (nested) array have +-- all-zero shape? +--   -> no, because mlift and also any kind of internals probing from outsiders + + +-- Primitive element types +-- ======================= +-- +-- There are a few primitive element types; arrays containing elements of such +-- type are a newtype over an XArray, which it itself a newtype over a Vector. +-- Unfortunately, the setup of the library requires us to list these primitive +-- element types multiple times; to aid in extending the list, all these lists +-- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. + + +-- | Wrapper type used as a tag to attach instances on. The instances on arrays +-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays +-- of scalars; this means that if @orthotope@ supports an element type @T@ that +-- this library does not (directly), it may just work if you use an array of +-- @'Primitive' T@ instead. +newtype Primitive a = Primitive a + +-- | Element types that are primitive; arrays of these types are just a newtype +-- wrapper over an array. +class Storable a => PrimElt a where +  fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a +  toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) + +  default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a +  fromPrimitive = coerce + +  default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) +  toPrimitive = coerce + +-- [PRIMITIVE ELEMENT TYPES LIST] +instance PrimElt Int +instance PrimElt Int64 +instance PrimElt Int32 +instance PrimElt CInt +instance PrimElt Float +instance PrimElt Double +instance PrimElt () + + +-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes +-- over product-typed elements using a data family so that the full array is +-- always in struct-of-arrays format. +-- +-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that +-- dimension permutations (e.g. 'mtranspose') are typically free. +-- +-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type +-- class. +type Mixed :: [Maybe Nat] -> Type -> Type +data family Mixed sh a +-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes +-- that you're not supposed to see. In particular, you might see (nonempty) +-- sizes of the elements of an empty array, which is information that should +-- ostensibly not exist; the full array is still empty. + +data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) +  deriving (Show, Eq, Generic) + +-- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays. +deriving instance (Ord a, Storable a) => Ord (Mixed '[] (Primitive a)) + +instance NFData a => NFData (Mixed sh (Primitive a)) + +-- [PRIMITIVE ELEMENT TYPES LIST] +newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show, Eq, Generic) +newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show, Eq, Generic) +newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show, Eq, Generic) +newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show, Eq, Generic) +newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show, Eq, Generic) +newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show, Eq, Generic) +newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show, Eq, Generic)  -- no content, orthotope optimises this (via Vector) +-- etc. + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving instance Ord (Mixed '[] Int) ; instance NFData (Mixed sh Int) +deriving instance Ord (Mixed '[] Int64) ; instance NFData (Mixed sh Int64) +deriving instance Ord (Mixed '[] Int32) ; instance NFData (Mixed sh Int32) +deriving instance Ord (Mixed '[] CInt) ; instance NFData (Mixed sh CInt) +deriving instance Ord (Mixed '[] Float) ; instance NFData (Mixed sh Float) +deriving instance Ord (Mixed '[] Double) ; instance NFData (Mixed sh Double) +deriving instance Ord (Mixed '[] ()) ; instance NFData (Mixed sh ()) + +data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic) +deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) +instance (NFData (Mixed sh a), NFData (Mixed sh b)) => NFData (Mixed sh (a, b)) +-- etc., larger tuples (perhaps use generics to allow arbitrary product types) + +data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic) +deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) +instance NFData (Mixed (sh1 ++ sh2) a) => NFData (Mixed sh1 (Mixed sh2 a)) + + +-- | Internal helper data family mirroring 'Mixed' that consists of mutable +-- vectors instead of 'XArray's. +type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type +data family MixedVecs s sh a + +newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) + +-- [PRIMITIVE ELEMENT TYPES LIST] +newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) +newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) +newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) +newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) +newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) +newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) +newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ())  -- no content, MVector optimises this +-- etc. + +data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) +-- etc. + +data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) + + +mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a +mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) + +mliftNumElt2 :: PrimElt a +             => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a) +             -> Mixed sh a -> Mixed sh a -> Mixed sh a +mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) +  | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2)) +  | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 + +instance (NumElt a, PrimElt a) => Num (Mixed sh a) where +  (+) = mliftNumElt2 numEltAdd +  (-) = mliftNumElt2 numEltSub +  (*) = mliftNumElt2 numEltMul +  negate = mliftNumElt1 numEltNeg +  abs = mliftNumElt1 numEltAbs +  signum = mliftNumElt1 numEltSignum +  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate" + +instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where +  fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" +  recip = mliftNumElt1 floatEltRecip +  (/) = mliftNumElt2 floatEltDiv + +instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where +  pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" +  exp = mliftNumElt1 floatEltExp +  log = mliftNumElt1 floatEltLog +  sqrt = mliftNumElt1 floatEltSqrt + +  (**) = mliftNumElt2 floatEltPow +  logBase = mliftNumElt2 floatEltLogbase + +  sin = mliftNumElt1 floatEltSin +  cos = mliftNumElt1 floatEltCos +  tan = mliftNumElt1 floatEltTan +  asin = mliftNumElt1 floatEltAsin +  acos = mliftNumElt1 floatEltAcos +  atan = mliftNumElt1 floatEltAtan +  sinh = mliftNumElt1 floatEltSinh +  cosh = mliftNumElt1 floatEltCosh +  tanh = mliftNumElt1 floatEltTanh +  asinh = mliftNumElt1 floatEltAsinh +  acosh = mliftNumElt1 floatEltAcosh +  atanh = mliftNumElt1 floatEltAtanh +  log1p = mliftNumElt1 floatEltLog1p +  expm1 = mliftNumElt1 floatEltExpm1 +  log1pexp = mliftNumElt1 floatEltLog1pexp +  log1mexp = mliftNumElt1 floatEltLog1mexp + + +-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or +-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' +-- a@; see the documentation for 'Primitive' for more details. +class Elt a where +  -- ====== PUBLIC METHODS ====== -- + +  mshape :: Mixed sh a -> IShX sh +  mindex :: Mixed sh a -> IIxX sh -> a +  mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a +  mscalar :: a -> Mixed '[] a + +  -- | All arrays in the list, even subarrays inside @a@, must have the same +  -- shape; if they do not, a runtime error will be thrown. See the +  -- documentation of 'mgenerate' for more information about this restriction. +  -- Furthermore, the length of the list must correspond with @n@: if @n@ is +  -- @Just m@ and @m@ does not equal the length of the list, a runtime error is +  -- thrown. +  -- +  -- Consider also 'mfromListPrim', which can avoid intermediate arrays. +  mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a + +  mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] + +  -- | Note: this library makes no particular guarantees about the shapes of +  -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the +  -- full 'XArray' and as such you can distinguish different empty arrays by +  -- the "shapes" of their elements. This information is meaningless, so you +  -- should not use it. +  mlift :: forall sh1 sh2. +           StaticShX sh2 +        -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) +        -> Mixed sh1 a -> Mixed sh2 a + +  -- | See the documentation for 'mlift'. +  mlift2 :: forall sh1 sh2 sh3. +            StaticShX sh3 +         -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) +         -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a + +  mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 +        => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a + +  mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) +             => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a + +  -- ====== PRIVATE METHODS ====== -- + +  -- | Tree giving the shape of every array component. +  type ShapeTree a + +  mshapeTree :: a -> ShapeTree a + +  mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool + +  mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool + +  mshowShapeTree :: Proxy a -> ShapeTree a -> String + +  -- | Given the shape of this array, an index and a value, write the value at +  -- that index in the vectors. +  mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + +  -- | Given the shape of this array, an index and a value, write the value at +  -- that index in the vectors. +  mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + +  -- | Given the shape of this array, finalise the vectors into 'XArray's. +  mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + + +-- | Element types for which we have evidence of the (static part of the) shape +-- in a type class constraint. Compare the instance contexts of the instances +-- of this class with those of 'Elt': some instances have an additional +-- "known-shape" constraint. +-- +-- This class is (currently) only required for 'mgenerate' / 'rgenerate' / +-- 'sgenerate'. +class Elt a => KnownElt a where +  -- | Create an empty array. The given shape must have size zero; this may or may not be checked. +  memptyArray :: IShX sh -> Mixed sh a + +  -- | Create uninitialised vectors for this array type, given the shape of +  -- this vector and an example for the contents. +  mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) + +  mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) + + +-- Arrays of scalars are basically just arrays of scalars. +instance Storable a => Elt (Primitive a) where +  mshape (M_Primitive sh _) = sh +  mindex (M_Primitive _ a) i = Primitive (X.index a i) +  mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i) +  mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) +  mfromListOuter l@(arr1 :| _) = +    let sh = SUnknown (length l) :$% mshape arr1 +    in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l))) +  mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) + +  mlift :: forall sh1 sh2. +           StaticShX sh2 +        -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) +        -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) +  mlift ssh2 f (M_Primitive _ a) +    | Refl <- lemAppNil @sh1 +    , Refl <- lemAppNil @sh2 +    , let result = f ZKX a +    = M_Primitive (X.shape ssh2 result) result + +  mlift2 :: forall sh1 sh2 sh3. +            StaticShX sh3 +         -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) +         -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) +  mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) +    | Refl <- lemAppNil @sh1 +    , Refl <- lemAppNil @sh2 +    , Refl <- lemAppNil @sh3 +    , let result = f ZKX a b +    = M_Primitive (X.shape ssh3 result) result + +  mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 +        => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) +  mcast ssh1 sh2 _ (M_Primitive sh1' arr) = +    let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' +    in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr) + +  mtranspose perm (M_Primitive sh arr) = +    M_Primitive (shxPermutePrefix perm sh) +                (X.transpose (ssxFromShape sh) perm arr) + +  type ShapeTree (Primitive a) = () +  mshapeTree _ = () +  mshapeTreeEq _ () () = True +  mshapeTreeEmpty _ () = False +  mshowShapeTree _ () = "()" +  mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x + +  -- TODO: this use of toVector is suboptimal +  mvecsWritePartial +    :: forall sh' sh s. +       IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () +  mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do +    let arrsh = X.shape (ssxFromShape sh') arr +        offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) +    VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) + +  mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving via Primitive Int instance Elt Int +deriving via Primitive Int64 instance Elt Int64 +deriving via Primitive Int32 instance Elt Int32 +deriving via Primitive CInt instance Elt CInt +deriving via Primitive Double instance Elt Double +deriving via Primitive Float instance Elt Float +deriving via Primitive () instance Elt () + +instance Storable a => KnownElt (Primitive a) where +  memptyArray sh = M_Primitive sh (X.empty sh) +  mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) +  mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving via Primitive Int instance KnownElt Int +deriving via Primitive Int64 instance KnownElt Int64 +deriving via Primitive Int32 instance KnownElt Int32 +deriving via Primitive CInt instance KnownElt CInt +deriving via Primitive Double instance KnownElt Double +deriving via Primitive Float instance KnownElt Float +deriving via Primitive () instance KnownElt () + +-- Arrays of pairs are pairs of arrays. +instance (Elt a, Elt b) => Elt (a, b) where +  mshape (M_Tup2 a _) = mshape a +  mindex (M_Tup2 a b) i = (mindex a i, mindex b i) +  mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) +  mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) +  mfromListOuter l = +    M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) +           (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) +  mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) +  mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) +  mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) + +  mcast ssh1 sh2 psh' (M_Tup2 a b) = +    M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b) + +  mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) + +  type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) +  mshapeTree (x, y) = (mshapeTree x, mshapeTree y) +  mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' +  mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 +  mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" +  mvecsWrite sh i (x, y) (MV_Tup2 a b) = do +    mvecsWrite sh i x a +    mvecsWrite sh i y b +  mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do +    mvecsWritePartial sh i x a +    mvecsWritePartial sh i y b +  mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b + +instance (KnownElt a, KnownElt b) => KnownElt (a, b) where +  memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) +  mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y +  mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) + +-- Arrays of arrays are just arrays, but with more dimensions. +instance Elt a => Elt (Mixed sh' a) where +  -- TODO: this is quadratic in the nesting depth because it repeatedly +  -- truncates the shape vector to one a little shorter. Fix with a +  -- moverlongShape method, a prefix of which is mshape. +  mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh +  mshape (M_Nest sh arr) +    = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr)) + +  mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a +  mindex (M_Nest _ arr) i = mindexPartial arr i + +  mindexPartial :: forall sh1 sh2. +                   Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) +  mindexPartial (M_Nest sh arr) i +    | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') +    = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + +  mscalar = M_Nest ZSX + +  mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) +  mfromListOuter l@(arr :| _) = +    M_Nest (SUnknown (length l) :$% mshape arr) +           (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) + +  mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) + +  mlift :: forall sh1 sh2. +           StaticShX sh2 +        -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) +        -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) +  mlift ssh2 f (M_Nest sh1 arr) = +    let result = mlift (ssxAppend ssh2 ssh') f' arr +        (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) +    in M_Nest sh2 result +    where +      ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr))) + +      f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b +      f' sshT +        | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) +        , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) +        = f (ssxAppend ssh' sshT) + +  mlift2 :: forall sh1 sh2 sh3. +            StaticShX sh3 +         -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) +         -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) +  mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = +    let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 +        (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) +    in M_Nest sh3 result +    where +      ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1))) + +      f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b +      f' sshT +        | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) +        , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) +        , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) +        = f (ssxAppend ssh' sshT) + +  mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 +        => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) +  mcast ssh1 sh2 _ (M_Nest sh1T arr) +    | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') +    , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') +    = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T +      in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr) + +  mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) +             => Perm is -> Mixed sh (Mixed sh' a) +             -> Mixed (PermutePrefix is sh) (Mixed sh' a) +  mtranspose perm (M_Nest sh arr) +    | let sh' = shxDropSh @sh @sh' (mshape arr) sh +    , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh') +    , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) +    , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') +    , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') +    , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') +    = M_Nest (shxPermutePrefix perm sh) +             (mtranspose perm arr) + +  type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) + +  mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) +  mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr))))) + +  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + +  mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + +  mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + +  mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs + +  mvecsWritePartial :: forall sh1 sh2 s. +                       IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) +                    -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) +                    -> ST s () +  mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) +    | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') +    = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs + +  mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs + +instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where +  memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh')))) + +  mvecsUnsafeNew sh example +    | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) +    | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh'))) +    where +      sh' = mshape example + +  mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) + + +-- | Create an array given a size and a function that computes the element at a +-- given index. +-- +-- __WARNING__: It is required that every @a@ returned by the argument to +-- 'mgenerate' has the same shape. For example, the following will throw a +-- runtime error: +-- +-- > foo :: Mixed [Nothing] (Mixed [Nothing] Double) +-- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) -> +-- >         mgenerate (i :.: ZIR) $ \(j :.: ZIR) -> +-- >           ... +-- +-- because the size of the inner 'mgenerate' is not always the same (it depends +-- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so +-- the entire hierarchy (after distributing out tuples) must be a rectangular +-- array. The type of 'mgenerate' allows this requirement to be broken very +-- easily, hence the runtime check. +mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgenerate sh f = case shxEnum sh of +  [] -> memptyArray sh +  firstidx : restidxs -> +    let firstelem = f (ixxZero' sh) +        shapetree = mshapeTree firstelem +    in if mshapeTreeEmpty (Proxy @a) shapetree +         then memptyArray sh +         else runST $ do +                vecs <- mvecsUnsafeNew sh firstelem +                mvecsWrite sh firstidx firstelem vecs +                -- TODO: This is likely fine if @a@ is big, but if @a@ is a +                -- scalar this array copying inefficient. Should improve this. +                forM_ restidxs $ \idx -> do +                  let val = f idx +                  when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ +                    error "Data.Array.Nested mgenerate: generated values do not have equal shapes" +                  mvecsWrite sh idx val vecs +                mvecsFreeze sh vecs + +msumOuter1P :: forall sh n a. (Storable a, NumElt a) +            => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) +msumOuter1P (M_Primitive (n :$% sh) arr) = +  let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX +  in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr) + +msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) +           => Mixed (n : sh) a -> Mixed sh a +msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive + +mappend :: forall n m sh a. Elt a +        => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a +mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 +  where +    sn :$% sh = mshape arr1 +    sm :$% _ = mshape arr2 +    ssh = ssxFromShape sh +    snm :: SMayNat () SNat (AddMaybe n m) +    snm = case (sn, sm) of +            (SUnknown{}, _) -> SUnknown () +            (SKnown{}, SUnknown{}) -> SUnknown () +            (SKnown n, SKnown m) -> SKnown (snatPlus n m) + +    f :: forall sh' b. Storable b +      => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b +    f ssh' = X.append (ssxAppend ssh ssh') + +mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) + +mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a +mfromVector sh v = fromPrimitive (mfromVectorP sh v) + +mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a +mtoVectorP (M_Primitive _ v) = X.toVector v + +mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a +mtoVector arr = mtoVectorP (toPrimitive arr) + +mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a +mfromList1 = mfromListOuter . fmap mscalar  -- TODO: optimise? + +mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a +mfromList1Prim l = +  let ssh = SUnknown () :!% ZKX +      xarr = X.fromList1 ssh l +  in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +mtoList1 :: Elt a => Mixed '[n] a -> [a] +mtoList1 = map munScalar . mtoListOuter + +mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a +mfromListPrim l = +  let ssh = SUnknown () :!% ZKX +      xarr = X.fromList1 ssh l +  in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a +mfromListPrimLinear sh l = +  let M_Primitive _ xarr = toPrimitive (mfromListPrim l) +  in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) + +munScalar :: Elt a => Mixed '[] a -> a +munScalar arr = mindex arr ZIX + +mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) +         => StaticShX sh -> IShX sh2 +         -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) +         -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) +mrerankP ssh sh2 f (M_Primitive sh arr) = +  let sh1 = shxDropSSX sh ssh +  in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2) +                 (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2) +                           (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) +                           arr) + +-- | See the caveats at @X.rerank@. +mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) +        => StaticShX sh -> IShX sh2 +        -> (Mixed sh1 a -> Mixed sh2 b) +        -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b +mrerank ssh sh2 f (toPrimitive -> arr) = +  fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr + +mreplicate :: forall sh sh' a. Elt a +           => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a +mreplicate sh arr = +  let ssh' = ssxFromShape (mshape arr) +  in mlift (ssxAppend (ssxFromShape sh) ssh') +           (\(sshT :: StaticShX shT) -> +              case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of +                Refl -> X.replicate sh (ssxAppend ssh' sshT)) +           arr + +mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) +mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) + +mreplicateScal :: forall sh a. PrimElt a +               => IShX sh -> a -> Mixed sh a +mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) + +mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a +mslice i n arr = +  let _ :$% sh = mshape arr +  in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr + +msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a +msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr + +mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a +mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr + +mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a +mreshape sh' arr = +  mlift (ssxFromShape sh') +        (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh') +        arr + +miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a +miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) + +masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) +masXArrayPrimP (M_Primitive sh arr) = (sh, arr) + +masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a) +masXArrayPrim = masXArrayPrimP . toPrimitive + +mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a) +mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr + +mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a +mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP + +mliftPrim :: PrimElt a +          => (a -> a) +          -> Mixed sh a -> Mixed sh a +mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) + +mliftPrim2 :: PrimElt a +           => (a -> a -> a) +           -> Mixed sh a -> Mixed sh a -> Mixed sh a +mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = +  fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs new file mode 100644 index 0000000..c2f9405 --- /dev/null +++ b/src/Data/Array/Nested/Ranked.hs @@ -0,0 +1,446 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Ranked where + +import Prelude hiding (mappend) + +import Control.DeepSeq (NFData) +import Control.Monad.ST +import Data.Array.RankedS qualified as S +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp) +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Mixed (XArray(..)) +import Data.Array.Mixed qualified as X +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types +import Data.Array.Nested.Mixed +import Data.Array.Nested.Shape + + +-- | A rank-typed array: the number of dimensions of the array (its /rank/) is +-- represented on the type level as a 'Nat'. +-- +-- Valid elements of a ranked arrays are described by the 'Elt' type class. +-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are +-- supported (and are represented as a single, flattened, struct-of-arrays +-- array internally). +-- +-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. +type Ranked :: Nat -> Type -> Type +newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) +deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) +deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) +deriving instance Ord (Mixed '[] a) => Ord (Ranked 0 a) +deriving instance NFData (Mixed (Replicate n Nothing) a) => NFData (Ranked n a) + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) +deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a)) + +newtype instance MixedVecs s sh (Ranked n   a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) + +-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; +-- these instances allow them to also be used as elements of arrays, thus +-- making them first-class in the API. +instance Elt a => Elt (Ranked n a) where +  mshape (M_Ranked arr) = mshape arr +  mindex (M_Ranked arr) i = Ranked (mindex arr i) + +  mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) +  mindexPartial (M_Ranked arr) i = +    coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ +        mindexPartial arr i + +  mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) + +  mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) +  mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) + +  mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] +  mtoListOuter (M_Ranked arr) = +    coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) + +  mlift :: forall sh1 sh2. +           StaticShX sh2 +        -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) +        -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) +  mlift ssh2 f (M_Ranked arr) = +    coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ +      mlift ssh2 f arr + +  mlift2 :: forall sh1 sh2 sh3. +            StaticShX sh3 +         -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) +         -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) +  mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = +    coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ +        mlift2 ssh3 f arr1 arr2 + +  mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr) + +  mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) + +  type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) + +  mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr) + +  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + +  mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + +  mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + +  mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () +  mvecsWrite sh idx (Ranked arr) vecs = +    mvecsWrite sh idx arr +      (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) +         vecs) + +  mvecsWritePartial :: forall sh sh' s. +                       IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) +                    -> MixedVecs s (sh ++ sh') (Ranked n a) +                    -> ST s () +  mvecsWritePartial sh idx arr vecs = +    mvecsWritePartial sh idx +      (coerce @(Mixed sh' (Ranked n a)) +              @(Mixed sh' (Mixed (Replicate n Nothing) a)) +         arr) +      (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) +              @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) +         vecs) + +  mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) +  mvecsFreeze sh vecs = +    coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) +           @(Mixed sh (Ranked n a)) +      <$> mvecsFreeze sh +            (coerce @(MixedVecs s sh (Ranked n a)) +                    @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) +                    vecs) + +instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where +  memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a) +  memptyArray i +    | Dict <- lemKnownReplicate (SNat @n) +    = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ +        memptyArray i + +  mvecsUnsafeNew idx (Ranked arr) +    | Dict <- lemKnownReplicate (SNat @n) +    = MV_Ranked <$> mvecsUnsafeNew idx arr + +  mvecsNewEmpty _ +    | Dict <- lemKnownReplicate (SNat @n) +    = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) + + +arithPromoteRanked :: forall n a. PrimElt a +                   => (forall sh. Mixed sh a -> Mixed sh a) +                   -> Ranked n a -> Ranked n a +arithPromoteRanked = coerce + +arithPromoteRanked2 :: forall n a. PrimElt a +                    => (forall sh. Mixed sh a -> Mixed sh a -> Mixed sh a) +                    -> Ranked n a -> Ranked n a -> Ranked n a +arithPromoteRanked2 = coerce + +instance (NumElt a, PrimElt a) => Num (Ranked n a) where +  (+) = arithPromoteRanked2 (+) +  (-) = arithPromoteRanked2 (-) +  (*) = arithPromoteRanked2 (*) +  negate = arithPromoteRanked negate +  abs = arithPromoteRanked abs +  signum = arithPromoteRanked signum +  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal" + +instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Ranked n a) where +  fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal" +  recip = arithPromoteRanked recip +  (/) = arithPromoteRanked2 (/) + +instance (FloatElt a, NumElt a, PrimElt a) => Floating (Ranked n a) where +  pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal" +  exp = arithPromoteRanked exp +  log = arithPromoteRanked log +  sqrt = arithPromoteRanked sqrt +  (**) = arithPromoteRanked2 (**) +  logBase = arithPromoteRanked2 logBase +  sin = arithPromoteRanked sin +  cos = arithPromoteRanked cos +  tan = arithPromoteRanked tan +  asin = arithPromoteRanked asin +  acos = arithPromoteRanked acos +  atan = arithPromoteRanked atan +  sinh = arithPromoteRanked sinh +  cosh = arithPromoteRanked cosh +  tanh = arithPromoteRanked tanh +  asinh = arithPromoteRanked asinh +  acosh = arithPromoteRanked acosh +  atanh = arithPromoteRanked atanh +  log1p = arithPromoteRanked GHC.Float.log1p +  expm1 = arithPromoteRanked GHC.Float.expm1 +  log1pexp = arithPromoteRanked GHC.Float.log1pexp +  log1mexp = arithPromoteRanked GHC.Float.log1mexp + + +rshape :: forall n a. Elt a => Ranked n a -> IShR n +rshape (Ranked arr) = shCvtXR' (mshape arr) + +rindex :: Elt a => Ranked n a -> IIxR n -> a +rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) + +rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a +rindexPartial (Ranked arr) idx = +  Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) +            (castWith (subst2 (lemReplicatePlusApp (ixrToSNat idx) (Proxy @m) (Proxy @Nothing))) arr) +            (ixCvtRX idx)) + +-- | __WARNING__: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details. +rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a +rgenerate sh f +  | sn@SNat <- shrToSNat sh +  , Dict <- lemKnownReplicate sn +  , Refl <- lemRankReplicate sn +  = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) + +-- | See the documentation of 'mlift'. +rlift :: forall n1 n2 a. Elt a +      => SNat n2 +      -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) +      -> Ranked n1 a -> Ranked n2 a +rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) + +-- | See the documentation of 'mlift2'. +rlift2 :: forall n1 n2 n3 a. Elt a +       => SNat n3 +       -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) +       -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a +rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) + +rsumOuter1P :: forall n a. +               (Storable a, NumElt a) +            => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) +rsumOuter1P (Ranked arr) +  | Refl <- lemReplicateSucc @(Nothing @Nat) @n +  = Ranked (msumOuter1P arr) + +rsumOuter1 :: forall n a. (NumElt a, PrimElt a) +           => Ranked (n + 1) a -> Ranked n a +rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive + +rtranspose :: forall n a. Elt a => [Int] -> Ranked n a -> Ranked n a +rtranspose perm arr +  | sn@SNat <- shrToSNat (rshape arr) +  , Dict <- lemKnownReplicate sn +  , length perm <= fromIntegral (natVal (Proxy @n)) +  = rlift sn +          (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm) +          arr +  | otherwise +  = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" + +rappend :: forall n a. Elt a +        => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a +rappend arr1 arr2 +  | sn@SNat <- shrToSNat (rshape arr1) +  , Dict <- lemKnownReplicate sn +  , Refl <- lemReplicateSucc @(Nothing @Nat) @n +  = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) +      arr1 arr2 + +rscalar :: Elt a => a -> Ranked 0 a +rscalar x = Ranked (mscalar x) + +rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVectorP sh v +  | Dict <- lemKnownReplicate (shrToSNat sh) +  = Ranked (mfromVectorP (shCvtRX sh) v) + +rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a +rfromVector sh v = rfromPrimitive (rfromVectorP sh v) + +rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a +rtoVectorP = coerce mtoVectorP + +rtoVector :: PrimElt a => Ranked n a -> VS.Vector a +rtoVector = coerce mtoVector + +rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a +rfromListOuter l +  | Refl <- lemReplicateSucc @(Nothing @Nat) @n +  = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) + +rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a +rfromList1 l = Ranked (mfromList1 l) + +rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a +rfromList1Prim l = Ranked (mfromList1Prim l) + +rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoListOuter (Ranked arr) +  | Refl <- lemReplicateSucc @(Nothing @Nat) @n +  = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) + +rtoList1 :: Elt a => Ranked 1 a -> [a] +rtoList1 = map runScalar . rtoListOuter + +rfromListPrim :: PrimElt a => [a] -> Ranked 1 a +rfromListPrim l = +  let ssh = SUnknown () :!% ZKX +      xarr = X.fromList1 ssh l +  in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a +rfromListPrimLinear sh l = +  let M_Primitive _ xarr = toPrimitive (mfromListPrim l) +  in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr) + +rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a +rfromOrthotope sn arr +  | Refl <- lemRankReplicate sn +  = let xarr = XArray arr +    in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) + +runScalar :: Elt a => Ranked 0 a -> a +runScalar arr = rindex arr ZIR + +rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) +         => SNat n -> IShR n2 +         -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) +         -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) +rrerankP sn sh2 f (Ranked arr) +  | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) +  , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) +  = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2) +                     (\a -> let Ranked r = f (Ranked a) in r) +                     arr) + +-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the +-- input array, then there is no way to deduce the full shape of the output +-- array (more precisely, the @n2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the @n2@ part of the output shape with zeros. +-- +-- For example, if: +-- +-- @ +-- arr :: Ranked 5 Int   -- of shape [3, 0, 4, 2, 21] +-- f :: Ranked 2 Int -> Ranked 3 Float +-- @ +-- +-- then: +-- +-- @ +-- rrerank _ _ _ f arr :: Ranked 5 Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the +-- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended +-- to return an array with shape all-0 here (it probably didn't), but there is +-- no better number to put here absent a subarray of the input to pass to @f@. +rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) +         => SNat n -> IShR n2 +         -> (Ranked n1 a -> Ranked n2 b) +         -> Ranked (n + n1) a -> Ranked (n + n2) b +rrerank ssh sh2 f (rtoPrimitive -> arr) = +  rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr + +rreplicate :: forall n m a. Elt a +           => IShR n -> Ranked m a -> Ranked (n + m) a +rreplicate sh (Ranked arr) +  | Refl <- lemReplicatePlusApp (shrToSNat sh) (Proxy @m) (Proxy @(Nothing @Nat)) +  = Ranked (mreplicate (shCvtRX sh) arr) + +rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) +rreplicateScalP sh x +  | Dict <- lemKnownReplicate (shrToSNat sh) +  = Ranked (mreplicateScalP (shCvtRX sh) x) + +rreplicateScal :: forall n a. PrimElt a +               => IShR n -> a -> Ranked n a +rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) + +rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a +rslice i n arr +  | Refl <- lemReplicateSucc @(Nothing @Nat) @n +  = rlift (shrToSNat (rshape arr)) +          (\_ -> X.sliceU i n) +          arr + +rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a +rrev1 arr = +  rlift (shrToSNat (rshape arr)) +        (\(_ :: StaticShX sh') -> +          case lemReplicateSucc @(Nothing @Nat) @n of +            Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) +        arr + +rreshape :: forall n n' a. Elt a +         => IShR n' -> Ranked n a -> Ranked n' a +rreshape sh' rarr@(Ranked arr) +  | Dict <- lemKnownReplicate (shrToSNat (rshape rarr)) +  , Dict <- lemKnownReplicate (shrToSNat sh') +  = Ranked (mreshape (shCvtRX sh') arr) + +riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a +riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota + +rasXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) +rasXArrayPrimP (Ranked arr) = first shCvtXR' (masXArrayPrimP arr) + +rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) +rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr) + +rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) +rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) + +rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a +rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr) + +rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a +rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) + +rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) +rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) + +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a +mtoRanked arr +  | Refl <- lemAppNil @sh +  , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat)) +  , Refl <- lemRankReplicate (shxRank (mshape arr)) +  = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr) +  where +    convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing) +    convSh ZSX = ZSX +    convSh (smn :$% (sh :: IShX sh'T)) +      | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T) +      = SUnknown (fromSMayNat' smn) :$% convSh sh diff --git a/src/Data/Array/Nested/Shape.hs b/src/Data/Array/Nested/Shape.hs new file mode 100644 index 0000000..774b4bd --- /dev/null +++ b/src/Data/Array/Nested/Shape.hs @@ -0,0 +1,467 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Shape where + +import Data.Array.Mixed.Types +import Data.Coerce (coerce) +import Data.Foldable qualified as Foldable +import Data.Functor.Const +import Data.Kind (Type, Constraint) +import Data.Monoid (Sum(..)) +import Data.Proxy +import Data.Type.Equality +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape + + +type role ListR nominal representational +type ListR :: Nat -> Type -> Type +data ListR n i where +  ZR :: ListR 0 i +  (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i +deriving instance Eq i => Eq (ListR n i) +deriving instance Ord i => Ord (ListR n i) +deriving instance Functor (ListR n) +deriving instance Foldable (ListR n) +infixr 3 ::: + +instance Show i => Show (ListR n i) where +  showsPrec _ = listrShow shows + +data UnconsListRRes i n1 = +  forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i +listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) +listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) +listrUncons ZR = Nothing + +listrShow :: forall sh i. (i -> ShowS) -> ListR sh i -> ShowS +listrShow f l = showString "[" . go "" l . showString "]" +  where +    go :: String -> ListR sh' i -> ShowS +    go _ ZR = id +    go prefix (x ::: xs) = showString prefix . f x . go "," xs + +listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i +listrAppend ZR sh = sh +listrAppend (x ::: xs) sh = x ::: listrAppend xs sh + +listrFromList :: [i] -> (forall n. ListR n i -> r) -> r +listrFromList [] k = k ZR +listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) + +listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i +listrIndex SZ (x ::: _) = x +listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs +listrIndex _ ZR = error "k + 1 <= 0" + +listrToSNat :: ListR n i -> SNat n +listrToSNat ZR = SNat +listrToSNat (_ ::: (l :: ListR n i)) | SNat <- listrToSNat l, Dict <- lemKnownNatSucc @n = SNat + +listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i +listrPermutePrefix = \perm sh -> +  listrFromList perm $ \sperm -> +    case (listrToSNat sperm, listrToSNat sh) of +      (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of +        LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post +        EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post +        GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" +                       ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" +  where +    listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) +    listrSplitAt SZ sh = (ZR, sh) +    listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) +    listrSplitAt SS{} ZR = error "m' + 1 <= 0" + +    applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i +    applyPermRFull _ ZR _ = ZR +    applyPermRFull sm@SNat (i ::: perm) l = +      TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> +        case cmpNat (SNat @(idx + 1)) sm of +          LTI -> listrIndex si l ::: applyPermRFull sm perm l +          EQI -> listrIndex si l ::: applyPermRFull sm perm l +          GTI -> error "listrPermutePrefix: Index in permutation out of range" + + +-- | An index into a rank-typed array. +type role IxR nominal representational +type IxR :: Nat -> Type -> Type +newtype IxR n i = IxR (ListR n i) +  deriving (Eq, Ord) +  deriving newtype (Functor, Foldable) + +pattern ZIR :: forall n i. () => n ~ 0 => IxR n i +pattern ZIR = IxR ZR + +pattern (:.:) +  :: forall {n1} {i}. +     forall n. (n + 1 ~ n1) +  => i -> IxR n i -> IxR n1 i +pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) +  where i :.: IxR sh = IxR (i ::: sh) +infixr 3 :.: + +{-# COMPLETE ZIR, (:.:) #-} + +type IIxR n = IxR n Int + +instance Show i => Show (IxR n i) where +  showsPrec _ (IxR l) = listrShow shows l + +ixrZero :: SNat n -> IIxR n +ixrZero SZ = ZIR +ixrZero (SS n) = 0 :.: ixrZero n + +ixCvtXR :: IIxX sh -> IIxR (Rank sh) +ixCvtXR ZIX = ZIR +ixCvtXR (n :.% idx) = n :.: ixCvtXR idx + +ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) +ixCvtRX ZIR = ZIX +ixCvtRX (n :.: (idx :: IxR m Int)) = +  castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m)) +    (n :.% ixCvtRX idx) + +ixrToSNat :: IxR n i -> SNat n +ixrToSNat (IxR sh) = listrToSNat sh + +ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i +ixrPermutePrefix = coerce (listrPermutePrefix @i) + + +type role ShR nominal representational +type ShR :: Nat -> Type -> Type +newtype ShR n i = ShR (ListR n i) +  deriving (Eq, Ord) +  deriving newtype (Functor, Foldable) + +pattern ZSR :: forall n i. () => n ~ 0 => ShR n i +pattern ZSR = ShR ZR + +pattern (:$:) +  :: forall {n1} {i}. +     forall n. (n + 1 ~ n1) +  => i -> ShR n i -> ShR n1 i +pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) +  where i :$: (ShR sh) = ShR (i ::: sh) +infixr 3 :$: + +{-# COMPLETE ZSR, (:$:) #-} + +type IShR n = ShR n Int + +instance Show i => Show (ShR n i) where +  showsPrec _ (ShR l) = listrShow shows l + +shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n +shCvtXR' ZSX = +  castWith (subst2 (unsafeCoerceRefl :: 0 :~: n)) +    ZSR +shCvtXR' (n :$% (idx :: IShX sh)) +  | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) = +  castWith (subst2 (lem1 @sh Refl)) +    (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx)) +  where +    lem1 :: forall sh' n' k. +            k : sh' :~: Replicate n' Nothing +         -> Rank sh' + 1 :~: n' +    lem1 Refl = unsafeCoerceRefl + +    lem2 :: k : sh :~: Replicate n Nothing +         -> sh :~: Replicate (Rank sh) Nothing +    lem2 Refl = unsafeCoerceRefl + +shCvtRX :: IShR n -> IShX (Replicate n Nothing) +shCvtRX ZSR = ZSX +shCvtRX (n :$: (idx :: ShR m Int)) = +  castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m)) +    (SUnknown n :$% shCvtRX idx) + +-- | The number of elements in an array described by this shape. +shrSize :: IShR n -> Int +shrSize ZSR = 1 +shrSize (n :$: sh) = n * shrSize sh + +shrToSNat :: ShR n i -> SNat n +shrToSNat (ShR sh) = listrToSNat sh + +shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i +shrPermutePrefix = coerce (listrPermutePrefix @i) + + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ListR n i) where +  type Item (ListR n i) = i +  fromList = go (SNat @n) +    where +      go :: SNat n' -> [i] -> ListR n' i +      go SZ [] = ZR +      go (SS n) (i : is) = i ::: go n is +      go _ _ = error "IsList(ListR): Mismatched list length" +  toList = Foldable.toList + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (IxR n i) where +  type Item (IxR n i) = i +  fromList = IxR . IsList.fromList +  toList = Foldable.toList + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ShR n i) where +  type Item (ShR n i) = i +  fromList = ShR . IsList.fromList +  toList = Foldable.toList + + +type role ListS nominal representational +type ListS :: [Nat] -> (Nat -> Type) -> Type +data ListS sh f where +  ZS :: ListS '[] f +  -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity +  (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f +deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) +infixr 3 ::$ + +instance (forall n. Show (f n)) => Show (ListS sh f) where +  showsPrec _ = listsShow shows + +data UnconsListSRes f sh1 = +  forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) +listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) +listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) +listsUncons ZS = Nothing + +listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g +listsFmap _ ZS = ZS +listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs + +listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m +listsFold _ ZS = mempty +listsFold f (x ::$ xs) = f x <> listsFold f xs + +listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS +listsShow f l = showString "[" . go "" l . showString "]" +  where +    go :: String -> ListS sh' f -> ShowS +    go _ ZS = id +    go prefix (x ::$ xs) = showString prefix . f x . go "," xs + +listsToList :: ListS sh (Const i) -> [i] +listsToList ZS = [] +listsToList (Const i ::$ is) = i : listsToList is + +listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f +listsAppend ZS idx' = idx' +listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' + +listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f +listsTakeLenPerm PNil _ = ZS +listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh +listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" + +listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f +listsDropLenPerm PNil sh = sh +listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh +listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" + +listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f +listsPermute PNil _ = ZS +listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = +  case listsIndex (Proxy @is') (Proxy @sh) i sh of +    (item, SNat) -> item ::$ listsPermute is sh + +-- TODO: remove this SNat when the KnownNat constaint in ListS is removed +listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) +listsIndex _ _ SZ (n ::$ _) = (n, SNat) +listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) +  | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') +  = listsIndex p pT i sh +listsIndex _ _ _ ZS = error "Index into empty shape" + +shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) +shsTakeLen = coerce (listsTakeLenPerm @SNat) + +shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) +shsPermute = coerce (listsPermute @SNat) + +shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) +shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) + +applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f +applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) + +applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i +applyPermIxS = coerce (applyPermS @(Const i)) + +applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) +applyPermShS = coerce (applyPermS @SNat) + + +-- | An index into a shape-typed array. +-- +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). Note that because the shape of a +-- shape-typed array is known statically, you can also retrieve the array shape +-- from a 'KnownShape' dictionary. +type role IxS nominal representational +type IxS :: [Nat] -> Type -> Type +newtype IxS sh i = IxS (ListS sh (Const i)) +  deriving (Eq, Ord) + +pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i +pattern ZIS = IxS ZS + +pattern (:.$) +  :: forall {sh1} {i}. +     forall n sh. (KnownNat n, n : sh ~ sh1) +  => i -> IxS sh i -> IxS sh1 i +pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) +  where i :.$ IxS shl = IxS (Const i ::$ shl) +infixr 3 :.$ + +{-# COMPLETE ZIS, (:.$) #-} + +type IIxS sh = IxS sh Int + +instance Show i => Show (IxS sh i) where +  showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l + +instance Functor (IxS sh) where +  fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) + +instance Foldable (IxS sh) where +  foldMap f (IxS l) = listsFold (f . getConst) l + +ixsZero :: ShS sh -> IIxS sh +ixsZero ZSS = ZIS +ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh + +ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh +ixCvtXS ZSS ZIX = ZIS +ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx + +ixCvtSX :: IIxS sh -> IIxX (MapJust sh) +ixCvtSX ZIS = ZIX +ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh + + +-- | The shape of a shape-typed array given as a list of 'SNat' values. +type role ShS nominal +type ShS :: [Nat] -> Type +newtype ShS sh = ShS (ListS sh SNat) +  deriving (Eq, Ord) + +pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh +pattern ZSS = ShS ZS + +pattern (:$$) +  :: forall {sh1}. +     forall n sh. (KnownNat n, n : sh ~ sh1) +  => SNat n -> ShS sh -> ShS sh1 +pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) +  where i :$$ ShS shl = ShS (i ::$ shl) + +infixr 3 :$$ + +{-# COMPLETE ZSS, (:$$) #-} + +instance Show (ShS sh) where +  showsPrec _ (ShS l) = listsShow (shows . fromSNat) l + +shsLength :: ShS sh -> Int +shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l) + +shsToList :: ShS sh -> [Int] +shsToList ZSS = [] +shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh + +shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh +shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS +shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) = +  castWith (subst1 (lem Refl)) $ +    n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) +                                 idx) +  where +    lem :: forall sh1 sh' n. +           Just n : sh1 :~: MapJust sh' +        -> n : Tail sh' :~: sh' +    lem Refl = unsafeCoerceRefl +shCvtXS' (SUnknown _ :$% _) = error "impossible" + +shCvtSX :: ShS sh -> IShX (MapJust sh) +shCvtSX ZSS = ZSX +shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh + +shsSize :: ShS sh -> Int +shsSize ZSS = 1 +shsSize (n :$$ sh) = fromSNat' n * shsSize sh + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShS :: [Nat] -> Constraint +class KnownShS sh where knownShS :: ShS sh +instance KnownShS '[] where knownShS = ZSS +instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS + + +-- | Untyped: length is checked at runtime. +instance KnownShS sh => IsList (ListS sh (Const i)) where +  type Item (ListS sh (Const i)) = i +  fromList topl = go (knownShS @sh) topl +    where +      go :: ShS sh' -> [i] -> ListS sh' (Const i) +      go ZSS [] = ZS +      go (_ :$$ sh) (i : is) = Const i ::$ go sh is +      go _ _ = error $ "IsList(ListS): Mismatched list length (type says " +                         ++ show (shsLength (knownShS @sh)) ++ ", list has length " +                         ++ show (length topl) ++ ")" +  toList = listsToList + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShS sh => IsList (IxS sh i) where +  type Item (IxS sh i) = i +  fromList = IxS . IsList.fromList +  toList = Foldable.toList + +-- | Untyped: length and values are checked at runtime. +instance KnownShS sh => IsList (ShS sh) where +  type Item (ShS sh) = Int +  fromList topl = ShS (go (knownShS @sh) topl) +    where +      go :: ShS sh' -> [Int] -> ListS sh' SNat +      go ZSS [] = ZS +      go (sn :$$ sh) (i : is) +        | i == fromSNat' sn = sn ::$ go sh is +        | otherwise = error $ "IsList(ShS): Value does not match typing (type says " +                                ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" +      go _ _ = error $ "IsList(ShS): Mismatched list length (type says " +                         ++ show (shsLength (knownShS @sh)) ++ ", list has length " +                         ++ show (length topl) ++ ")" +  toList = shsToList diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs new file mode 100644 index 0000000..934433e --- /dev/null +++ b/src/Data/Array/Nested/Shaped.hs @@ -0,0 +1,379 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Shaped where + +import Prelude hiding (mappend) + +import Control.DeepSeq (NFData) +import Control.Monad.ST +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp) +import GHC.TypeLits + +import Data.Array.Mixed (XArray) +import Data.Array.Mixed qualified as X +import Data.Array.Mixed.Internal.Arith +import Data.Array.Mixed.Lemmas +import Data.Array.Mixed.Permutation +import Data.Array.Mixed.Shape +import Data.Array.Mixed.Types +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Shape + + +-- | A shape-typed array: the full shape of the array (the sizes of its +-- dimensions) is represented on the type level as a list of 'Nat's. Note that +-- these are "GHC.TypeLits" naturals, because we do not need induction over +-- them and we want very large arrays to be possible. +-- +-- Like for 'Ranked', the valid elements are described by the 'Elt' type class, +-- and 'Shaped' itself is again an instance of 'Elt' as well. +-- +-- 'Shaped' is a newtype around a 'Mixed' of 'Just's. +type Shaped :: [Nat] -> Type -> Type +newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) +deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) +deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) +deriving instance Ord (Mixed '[] a) => Ord (Shaped '[] a) +deriving instance NFData (Mixed (MapJust sh) a) => NFData (Shaped sh a) + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) +deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a)) + +newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh'        ) a)) + +instance Elt a => Elt (Shaped sh a) where +  mshape (M_Shaped arr) = mshape arr +  mindex (M_Shaped arr) i = Shaped (mindex arr i) + +  mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) +  mindexPartial (M_Shaped arr) i = +    coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ +      mindexPartial arr i + +  mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) + +  mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) +  mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) + +  mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] +  mtoListOuter (M_Shaped arr) +    = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) + +  mlift :: forall sh1 sh2. +           StaticShX sh2 +        -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) +        -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) +  mlift ssh2 f (M_Shaped arr) = +    coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ +      mlift ssh2 f arr + +  mlift2 :: forall sh1 sh2 sh3. +            StaticShX sh3 +         -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) +         -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) +  mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = +    coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ +      mlift2 ssh3 f arr1 arr2 + +  mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr) + +  mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) + +  type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) + +  mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr) + +  mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + +  mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + +  mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + +  mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () +  mvecsWrite sh idx (Shaped arr) vecs = +    mvecsWrite sh idx arr +      (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) +         vecs) + +  mvecsWritePartial :: forall sh1 sh2 s. +                       IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) +                    -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) +                    -> ST s () +  mvecsWritePartial sh idx arr vecs = +    mvecsWritePartial sh idx +      (coerce @(Mixed sh2 (Shaped sh a)) +              @(Mixed sh2 (Mixed (MapJust sh) a)) +         arr) +      (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) +              @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) +         vecs) + +  mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) +  mvecsFreeze sh vecs = +    coerce @(Mixed sh' (Mixed (MapJust sh) a)) +           @(Mixed sh' (Shaped sh a)) +      <$> mvecsFreeze sh +            (coerce @(MixedVecs s sh' (Shaped sh a)) +                    @(MixedVecs s sh' (Mixed (MapJust sh) a)) +                    vecs) + +instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where +  memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) +  memptyArray i +    | Dict <- lemKnownMapJust (Proxy @sh) +    = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ +        memptyArray i + +  mvecsUnsafeNew idx (Shaped arr) +    | Dict <- lemKnownMapJust (Proxy @sh) +    = MV_Shaped <$> mvecsUnsafeNew idx arr + +  mvecsNewEmpty _ +    | Dict <- lemKnownMapJust (Proxy @sh) +    = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) + + +arithPromoteShaped :: forall sh a. PrimElt a +                   => (forall shx. Mixed shx a -> Mixed shx a) +                   -> Shaped sh a -> Shaped sh a +arithPromoteShaped = coerce + +arithPromoteShaped2 :: forall sh a. PrimElt a +                    => (forall shx. Mixed shx a -> Mixed shx a -> Mixed shx a) +                    -> Shaped sh a -> Shaped sh a -> Shaped sh a +arithPromoteShaped2 = coerce + +instance (NumElt a, PrimElt a) => Num (Shaped sh a) where +  (+) = arithPromoteShaped2 (+) +  (-) = arithPromoteShaped2 (-) +  (*) = arithPromoteShaped2 (*) +  negate = arithPromoteShaped negate +  abs = arithPromoteShaped abs +  signum = arithPromoteShaped signum +  fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal" + +instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Shaped sh a) where +  fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" +  recip = arithPromoteShaped recip +  (/) = arithPromoteShaped2 (/) + +instance (FloatElt a, NumElt a, PrimElt a) => Floating (Shaped sh a) where +  pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" +  exp = arithPromoteShaped exp +  log = arithPromoteShaped log +  sqrt = arithPromoteShaped sqrt +  (**) = arithPromoteShaped2 (**) +  logBase = arithPromoteShaped2 logBase +  sin = arithPromoteShaped sin +  cos = arithPromoteShaped cos +  tan = arithPromoteShaped tan +  asin = arithPromoteShaped asin +  acos = arithPromoteShaped acos +  atan = arithPromoteShaped atan +  sinh = arithPromoteShaped sinh +  cosh = arithPromoteShaped cosh +  tanh = arithPromoteShaped tanh +  asinh = arithPromoteShaped asinh +  acosh = arithPromoteShaped acosh +  atanh = arithPromoteShaped atanh +  log1p = arithPromoteShaped GHC.Float.log1p +  expm1 = arithPromoteShaped GHC.Float.expm1 +  log1pexp = arithPromoteShaped GHC.Float.log1pexp +  log1mexp = arithPromoteShaped GHC.Float.log1mexp + + +sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh +sshape (Shaped arr) = shCvtXS' (mshape arr) + +sindex :: Elt a => Shaped sh a -> IIxS sh -> a +sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) + +shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh +shsTakeIx _ _ ZIS = ZSS +shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx + +sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a +sindexPartial sarr@(Shaped arr) idx = +  Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) +            (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) +            (ixCvtSX idx)) + +-- | __WARNING__: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details. +sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a +sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh)) + +-- | See the documentation of 'mlift'. +slift :: forall sh1 sh2 a. Elt a +      => ShS sh2 +      -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) +      -> Shaped sh1 a -> Shaped sh2 a +slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr) + +-- | See the documentation of 'mlift'. +slift2 :: forall sh1 sh2 sh3 a. Elt a +       => ShS sh3 +       -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) +       -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a +slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2) + +ssumOuter1P :: forall sh n a. (Storable a, NumElt a) +            => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) +ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) + +ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) +           => Shaped (n : sh) a -> Shaped sh a +ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive + +stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) +           => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a +stranspose perm sarr@(Shaped arr) +  | Refl <- lemRankMapJust (sshape sarr) +  , Refl <- lemMapJustTakeLen perm (sshape sarr) +  , Refl <- lemMapJustDropLen perm (sshape sarr) +  , Refl <- lemMapJustPermute perm (shsTakeLen perm (sshape sarr)) +  , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) +  = Shaped (mtranspose perm arr) + +sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a +sappend = coerce mappend + +sscalar :: Elt a => a -> Shaped '[] a +sscalar x = Shaped (mscalar x) + +sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) +sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v) + +sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a +sfromVector sh v = sfromPrimitive (sfromVectorP sh v) + +stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a +stoVectorP = coerce mtoVectorP + +stoVector :: PrimElt a => Shaped sh a -> VS.Vector a +stoVector = coerce mtoVector + +sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l)) + +sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a +sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1 + +sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a +sfromList1Prim sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1Prim + +stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] +stoListOuter (Shaped arr) = coerce (mtoListOuter arr) + +stoList1 :: Elt a => Shaped '[n] a -> [a] +stoList1 = map sunScalar . stoListOuter + +sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a +sfromListPrim sn l +  | Refl <- lemAppNil @'[Just n] +  = let ssh = SUnknown () :!% ZKX +        xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) +    in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr + +sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a +sfromListPrimLinear sh l = +  let M_Primitive _ xarr = toPrimitive (mfromListPrim l) +  in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr) + +sunScalar :: Elt a => Shaped '[] a -> a +sunScalar arr = sindex arr ZIS + +srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) +         => ShS sh -> ShS sh2 +         -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) +         -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) +srerankP sh sh2 f sarr@(Shaped arr) +  | Refl <- lemMapJustApp sh (Proxy @sh1) +  , Refl <- lemMapJustApp sh (Proxy @sh2) +  = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh)))) +                     (shCvtSX sh2) +                     (\a -> let Shaped r = f (Shaped a) in r) +                     arr) + +srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) +        => ShS sh -> ShS sh2 +        -> (Shaped sh1 a -> Shaped sh2 b) +        -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b +srerank sh sh2 f (stoPrimitive -> arr) = +  sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr + +sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a +sreplicate sh (Shaped arr) +  | Refl <- lemMapJustApp sh (Proxy @sh') +  = Shaped (mreplicate (shCvtSX sh) arr) + +sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) +sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x) + +sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a +sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) + +sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a +sslice i n@SNat arr = +  let _ :$$ sh = sshape arr +  in slift (n :$$ sh) (\_ -> X.slice i n) arr + +srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a +srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr + +sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a +sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr) + +siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a +siota sn = Shaped (miota sn) + +sasXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) +sasXArrayPrimP (Shaped arr) = first shCvtXS' (masXArrayPrimP arr) + +sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) +sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr) + +sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr) + +sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr) + +sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a +sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) + +stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) +stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) + +mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') +              => Mixed sh a -> ShS sh' -> Shaped sh' a +mcastToShaped arr targetsh +  | Refl <- lemAppNil @sh +  , Refl <- lemAppNil @(MapJust sh') +  , Refl <- lemRankMapJust targetsh +  = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr) | 
