aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-05-30 22:20:57 +0200
committerTom Smeding <tom@tomsmeding.com>2024-05-30 22:20:57 +0200
commitf0752d67cd188f438280e1f0c692dc1f5f14a190 (patch)
tree2dd05c13aef3b3c6384bfa091b14633bc86e65a4
parent19eab026f4f4c6a2d38ceb1fffa6062ba2637a46 (diff)
Refactor Nested (modules, function names)
-rw-r--r--bench/Main.hs3
-rw-r--r--ox-arrays.cabal8
-rw-r--r--src/Data/Array/Mixed.hs9
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs11
-rw-r--r--src/Data/Array/Mixed/Lemmas.hs86
-rw-r--r--src/Data/Array/Mixed/Permutation.hs17
-rw-r--r--src/Data/Array/Mixed/Shape.hs17
-rw-r--r--src/Data/Array/Mixed/Types.hs48
-rw-r--r--src/Data/Array/Nested.hs8
-rw-r--r--src/Data/Array/Nested/Convert.hs28
-rw-r--r--src/Data/Array/Nested/Internal.hs2054
-rw-r--r--src/Data/Array/Nested/Lemmas.hs59
-rw-r--r--src/Data/Array/Nested/Mixed.hs741
-rw-r--r--src/Data/Array/Nested/Ranked.hs446
-rw-r--r--src/Data/Array/Nested/Shape.hs467
-rw-r--r--src/Data/Array/Nested/Shaped.hs379
-rw-r--r--test/Gen.hs1
-rw-r--r--test/Tests/C.hs20
18 files changed, 2287 insertions, 2115 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index 41eb3b3..08fde04 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -6,7 +6,8 @@ import qualified Numeric.LinearAlgebra as LA
import Test.Tasty.Bench
import Data.Array.Nested
-import Data.Array.Nested.Internal (mliftPrim, mliftPrim2, arithPromoteRanked, arithPromoteRanked2)
+import Data.Array.Nested.Mixed (mliftPrim, mliftPrim2)
+import Data.Array.Nested.Ranked (arithPromoteRanked, arithPromoteRanked2)
main :: IO ()
diff --git a/ox-arrays.cabal b/ox-arrays.cabal
index 2356e72..f290ca2 100644
--- a/ox-arrays.cabal
+++ b/ox-arrays.cabal
@@ -19,7 +19,12 @@ library
Data.Array.Mixed.Shape
Data.Array.Mixed.Types
Data.Array.Nested
- Data.Array.Nested.Internal
+ Data.Array.Nested.Convert
+ Data.Array.Nested.Mixed
+ Data.Array.Nested.Lemmas
+ Data.Array.Nested.Ranked
+ Data.Array.Nested.Shape
+ Data.Array.Nested.Shaped
build-depends:
base >=4.18 && <4.20,
deepseq,
@@ -42,7 +47,6 @@ test-suite test
other-modules:
Gen
Tests.C
- Tests.Mixed
Util
build-depends:
ox-arrays,
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index d5a8b78..4a338a2 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -2,10 +2,11 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
@@ -14,14 +15,14 @@
module Data.Array.Mixed where
import Control.DeepSeq (NFData(..))
-import qualified Data.Array.RankedS as S
-import qualified Data.Array.Ranked as ORB
+import Data.Array.Ranked qualified as ORB
+import Data.Array.RankedS qualified as S
import Data.Coerce
import Data.Kind
import Data.Proxy
import Data.Type.Equality
import Data.Type.Ord
-import qualified Data.Vector.Storable as VS
+import Data.Vector.Storable qualified as VS
import Foreign.Storable (Storable)
import GHC.Generics (Generic)
import GHC.TypeLits
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index cf6820b..bb3ee4a 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
@@ -9,14 +10,14 @@
module Data.Array.Mixed.Internal.Arith where
import Control.Monad (forM, guard)
-import qualified Data.Array.Internal as OI
-import qualified Data.Array.Internal.RankedG as RG
-import qualified Data.Array.Internal.RankedS as RS
+import Data.Array.Internal qualified as OI
+import Data.Array.Internal.RankedG qualified as RG
+import Data.Array.Internal.RankedS qualified as RS
import Data.Bits
import Data.Int
import Data.List (sort)
-import qualified Data.Vector.Storable as VS
-import qualified Data.Vector.Storable.Mutable as VSM
+import Data.Vector.Storable qualified as VS
+import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.C.Types
import Foreign.Ptr
import Foreign.Storable (Storable)
diff --git a/src/Data/Array/Mixed/Lemmas.hs b/src/Data/Array/Mixed/Lemmas.hs
index d0e8d24..ff2e45c 100644
--- a/src/Data/Array/Mixed/Lemmas.hs
+++ b/src/Data/Array/Mixed/Lemmas.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
@@ -11,10 +12,45 @@ import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits
+import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
+-- * Reasoning helpers
+
+subst1 :: forall f a b. a :~: b -> f a :~: f b
+subst1 Refl = Refl
+
+subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
+subst2 Refl = Refl
+
+
+-- * Lemmas
+
+-- ** Nat
+
+lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True
+lemLeqSuccSucc _ _ = unsafeCoerceRefl
+
+lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
+lemLeqPlus _ _ _ = Refl
+
+
+-- ** Append
+
+lemAppNil :: l ++ '[] :~: l
+lemAppNil = unsafeCoerceRefl
+
+lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
+lemAppAssoc _ _ _ = unsafeCoerceRefl
+
+lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
+lemAppLeft _ Refl = Refl
+
+
+-- ** Rank
+
lemRankApp :: forall sh1 sh2.
StaticShX sh1 -> StaticShX sh2
-> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
@@ -38,10 +74,58 @@ lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
-> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
lemRankAppComm _ _ = unsafeCoerceRefl -- TODO improve this
-lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh)
+lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate SZ = Refl
+lemRankReplicate (SS (n :: SNat nm1))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
+ , Refl <- lemRankReplicate n
+ = Refl
+
+
+-- ** Various type families
+
+lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
+ -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
+lemReplicatePlusApp sn _ _ = go sn
+ where
+ go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
+ go SZ = Refl
+ go (SS (n :: SNat n'm1))
+ | Refl <- lemReplicateSucc @a @n'm1
+ , Refl <- go n
+ = sym (lemReplicateSucc @a @(n'm1 + m))
+
+lemDropLenApp :: Rank l1 <= Rank l2
+ => Proxy l1 -> Proxy l2 -> Proxy rest
+ -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
+lemDropLenApp _ _ _ = unsafeCoerceRefl
+
+lemTakeLenApp :: Rank l1 <= Rank l2
+ => Proxy l1 -> Proxy l2 -> Proxy rest
+ -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest)
+lemTakeLenApp _ _ _ = unsafeCoerceRefl
+
+
+-- ** KnownNat
+
+lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1)
+lemKnownNatSucc = Dict
+
+lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh)
lemKnownNatRank ZSX = Dict
lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict
lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
lemKnownNatRankSSX ZKX = Dict
lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
+
+
+-- ** Known shapes
+
+lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
+lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn)
+
+lemKnownShX :: StaticShX sh -> Dict KnownShX sh
+lemKnownShX ZKX = Dict
+lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict
+lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs
index 1df0ec7..83a5ee4 100644
--- a/src/Data/Array/Mixed/Permutation.hs
+++ b/src/Data/Array/Mixed/Permutation.hs
@@ -1,6 +1,7 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
@@ -25,7 +26,7 @@ import Data.Type.Equality
import Data.Type.Ord
import GHC.TypeError
import GHC.TypeLits
-import qualified GHC.TypeNats as TN
+import GHC.TypeNats qualified as TN
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
@@ -112,14 +113,14 @@ listxDropLen (_ `PCons` _) ZX = error "IsPermutation longer than shape"
listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f
listxPermute PNil _ = ZX
listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) =
- listxIndex (Proxy @is') (Proxy @sh) i sh (listxPermute is sh)
+ listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh
-listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> ListX (Permute is shT) f -> ListX (Index i sh : Permute is shT) f
-listxIndex _ _ SZ (n ::% _) rest = n ::% rest
-listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) rest
+listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh)
+listxIndex _ _ SZ (n ::% _) = n
+listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f))
| Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listxIndex p pT i sh rest
-listxIndex _ _ _ ZX _ = error "Index into empty shape"
+ = listxIndex p pT i sh
+listxIndex _ _ _ ZX = error "Index into empty shape"
listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f
listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh)
@@ -136,7 +137,7 @@ ssxDropLen = coerce (listxDropLen @(SMayNat () SNat))
ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh)
ssxPermute = coerce (listxPermute @(SMayNat () SNat))
-ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT)
+ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh)
ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2)
ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh)
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs
index 363b772..a13a176 100644
--- a/src/Data/Array/Mixed/Shape.hs
+++ b/src/Data/Array/Mixed/Shape.hs
@@ -2,6 +2,7 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
@@ -22,7 +23,9 @@
module Data.Array.Mixed.Shape where
import Control.DeepSeq (NFData(..))
-import qualified Data.Foldable as Foldable
+import Data.Bifunctor (first)
+import Data.Coerce
+import Data.Foldable qualified as Foldable
import Data.Functor.Const
import Data.Kind (Type, Constraint)
import Data.Monoid (Sum(..))
@@ -30,12 +33,10 @@ import Data.Proxy
import Data.Type.Equality
import GHC.Generics (Generic)
import GHC.IsList (IsList)
-import qualified GHC.IsList as IsList
+import GHC.IsList qualified as IsList
import GHC.TypeLits
import Data.Array.Mixed.Types
-import Data.Coerce
-import Data.Bifunctor (first)
-- | The length of a type-level list. If the argument is a shape, then the
@@ -307,6 +308,10 @@ shxEnum = \sh -> go sh id []
go ZSX f = (f ZIX :)
go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]]
+shxRank :: ShX sh f -> SNat (Rank sh)
+shxRank ZSX = SNat
+shxRank (_ :$% sh) | SNat <- shxRank sh = SNat
+
-- * Static mixed shapes
@@ -377,6 +382,10 @@ ssxFromShape :: IShX sh -> StaticShX sh
ssxFromShape ZSX = ZKX
ssxFromShape (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShape sh
+ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)
+ssxFromSNat SZ = ZKX
+ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n
+
-- | Evidence for the static part of a shape. This pops up only when you are
-- polymorphic in the element type of an array.
diff --git a/src/Data/Array/Mixed/Types.hs b/src/Data/Array/Mixed/Types.hs
index d77513f..52201df 100644
--- a/src/Data/Array/Mixed/Types.hs
+++ b/src/Data/Array/Mixed/Types.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
@@ -22,10 +23,10 @@ module Data.Array.Mixed.Types (
-- * Type-level lists
type (++),
- lemAppNil,
- lemAppAssoc,
Replicate,
lemReplicateSucc,
+ MapJust,
+ Tail,
-- * Unsafe
unsafeCoerceRefl,
@@ -34,8 +35,8 @@ module Data.Array.Mixed.Types (
import Data.Type.Equality
import Data.Proxy
import GHC.TypeLits
-import qualified GHC.TypeNats as TN
-import qualified Unsafe.Coerce
+import GHC.TypeNats qualified as TN
+import Unsafe.Coerce qualified
-- | Evidence for the constraint @c a@.
@@ -76,6 +77,26 @@ snatMul :: SNat n -> SNat m -> SNat (n * m)
snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce
+-- | Type-level list append.
+type family l1 ++ l2 where
+ '[] ++ l2 = l2
+ (x : xs) ++ l2 = x : xs ++ l2
+
+type family Replicate n a where
+ Replicate 0 a = '[]
+ Replicate n a = a : Replicate (n - 1) a
+
+lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
+lemReplicateSucc = unsafeCoerceRefl
+
+type family MapJust l where
+ MapJust '[] = '[]
+ MapJust (x : xs) = Just x : MapJust xs
+
+type family Tail l where
+ Tail (_ : xs) = xs
+
+
-- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to
-- only typecheck for actual type equalities. One cannot, e.g. accidentally
-- write this:
@@ -89,22 +110,3 @@ snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsa
-- but would have resulted in interesting memory errors at runtime.
unsafeCoerceRefl :: a :~: b
unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl
-
-
--- | Type-level list append.
-type family l1 ++ l2 where
- '[] ++ l2 = l2
- (x : xs) ++ l2 = x : xs ++ l2
-
-lemAppNil :: l ++ '[] :~: l
-lemAppNil = unsafeCoerceRefl
-
-lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
-lemAppAssoc _ _ _ = unsafeCoerceRefl
-
-type family Replicate n a where
- Replicate 0 a = '[]
- Replicate n a = a : Replicate (n - 1) a
-
-lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
-lemReplicateSucc = unsafeCoerceRefl
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index 1a4e094..c982b4d 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -62,8 +62,6 @@ module Data.Array.Nested (
Perm(..),
IsPermutation,
KnownNatList(..),
- listSToList,
- shSToList,
NumElt, FloatElt,
) where
@@ -74,6 +72,10 @@ import Data.Array.Mixed.Internal.Arith
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
-import Data.Array.Nested.Internal
+import Data.Array.Nested.Convert
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Ranked
+import Data.Array.Nested.Shape
+import Data.Array.Nested.Shaped
import Foreign.Storable
import GHC.TypeLits
diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs
new file mode 100644
index 0000000..cb22c32
--- /dev/null
+++ b/src/Data/Array/Nested/Convert.hs
@@ -0,0 +1,28 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE TypeOperators #-}
+module Data.Array.Nested.Convert where
+
+import Data.Type.Equality
+
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Shape
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Ranked
+import Data.Array.Nested.Shape
+import Data.Array.Nested.Shaped
+
+
+stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
+stoRanked sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ = mtoRanked arr
+
+rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
+rcastToShaped (Ranked arr) targetsh
+ | Refl <- lemRankReplicate (shxRank (shCvtSX targetsh))
+ , Refl <- lemRankMapJust targetsh
+ = mcastToShaped arr targetsh
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
deleted file mode 100644
index 0870789..0000000
--- a/src/Data/Array/Nested/Internal.hs
+++ /dev/null
@@ -1,2054 +0,0 @@
-{-# LANGUAGE ConstraintKinds #-}
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DefaultSignatures #-}
-{-# LANGUAGE DeriveFoldable #-}
-{-# LANGUAGE DeriveFunctor #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE DerivingVia #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-{-# LANGUAGE InstanceSigs #-}
-{-# LANGUAGE PatternSynonyms #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE RoleAnnotations #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE StrictData #-}
-{-# LANGUAGE TypeAbstractions #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE ViewPatterns #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-
-module Data.Array.Nested.Internal where
-
-import Prelude hiding (mappend)
-
-import Control.DeepSeq (NFData)
-import Control.Monad (forM_, when)
-import Control.Monad.ST
-import qualified Data.Array.RankedS as S
-import Data.Bifunctor (first)
-import Data.Coerce (coerce, Coercible)
-import Data.Foldable as Foldable (toList)
-import Data.Functor.Const
-import Data.Int
-import Data.Kind
-import Data.List.NonEmpty (NonEmpty(..))
-import Data.Monoid (Sum(..))
-import Data.Proxy
-import Data.Type.Equality
-import qualified Data.Vector.Storable as VS
-import qualified Data.Vector.Storable.Mutable as VSM
-import Foreign.C.Types (CInt(..))
-import Foreign.Storable (Storable)
-import qualified GHC.Float (log1p, expm1, log1pexp, log1mexp)
-import GHC.Generics (Generic)
-import GHC.IsList (IsList)
-import qualified GHC.IsList as IsList
-import GHC.TypeLits
-import qualified GHC.TypeNats as TypeNats
-import Unsafe.Coerce
-
-import Data.Array.Mixed
-import qualified Data.Array.Mixed as X
-import Data.Array.Mixed.Lemmas
-import Data.Array.Mixed.Permutation
-import Data.Array.Mixed.Shape
-import Data.Array.Mixed.Internal.Arith
-import Data.Array.Mixed.Types
-
-
--- Invariant in the API
--- ====================
---
--- In the underlying XArray, there is some shape for elements of an empty
--- array. For example, for this array:
---
--- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float)
--- rshape arr == 0 :.: 0 :.: 0 :.: ZIR
---
--- the two underlying XArrays have a shape, and those shapes might be anything.
--- The invariant is that these element shapes are unobservable in the API.
--- (This is possible because you ought to not be able to get to such an element
--- without indexing out of bounds.)
---
--- Note, though, that the converse situation may arise: the outer array might
--- be nonempty but then the inner arrays might. This is fine, an invariant only
--- applies if the _outer_ array is empty.
---
--- TODO: can we enforce that the elements of an empty (nested) array have
--- all-zero shape?
--- -> no, because mlift and also any kind of internals probing from outsiders
-
-
--- Primitive element types
--- =======================
---
--- There are a few primitive element types; arrays containing elements of such
--- type are a newtype over an XArray, which it itself a newtype over a Vector.
--- Unfortunately, the setup of the library requires us to list these primitive
--- element types multiple times; to aid in extending the list, all these lists
--- have been marked with [PRIMITIVE ELEMENT TYPES LIST].
-
-
-type family MapJust l where
- MapJust '[] = '[]
- MapJust (x : xs) = Just x : MapJust xs
-
-
--- Stupid things that the type checker should be able to figure out in-line, but can't
-
-subst1 :: forall f a b. a :~: b -> f a :~: f b
-subst1 Refl = Refl
-
-subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
-subst2 Refl = Refl
-
-lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
-lemAppLeft _ Refl = Refl
-
-knownNatSucc :: KnownNat n => Dict KnownNat (n + 1)
-knownNatSucc = Dict
-
-
-lemKnownShX :: StaticShX sh -> Dict KnownShX sh
-lemKnownShX ZKX = Dict
-lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict
-lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict
-
-ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing)
-ssxFromSNat SZ = ZKX
-ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n
-
-lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing)
-lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn)
-
-lemRankReplicate :: SNat n -> Rank (Replicate n (Nothing @Nat)) :~: n
-lemRankReplicate SZ = Refl
-lemRankReplicate (SS (n :: SNat nm1))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1
- , Refl <- lemRankReplicate n
- = Refl
-
-lemRankMapJust :: forall sh. ShS sh -> Rank (MapJust sh) :~: Rank sh
-lemRankMapJust ZSS = Refl
-lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl
-
-lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a
- -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
-lemReplicatePlusApp sn _ _ = go sn
- where
- go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
- go SZ = Refl
- go (SS (n :: SNat n'm1))
- | Refl <- lemReplicateSucc @a @n'm1
- , Refl <- go n
- = sym (lemReplicateSucc @a @(n'm1 + m))
-
-lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True
-lemLeqPlus _ _ _ = Refl
-
-lemLeqSuccSucc :: (k + 1 <= n) => Proxy k -> Proxy n -> (k <=? n - 1) :~: True
-lemLeqSuccSucc _ _ = unsafeCoerce Refl
-
-lemDropLenApp :: Rank l1 <= Rank l2
- => Proxy l1 -> Proxy l2 -> Proxy rest
- -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest)
-lemDropLenApp _ _ _ = unsafeCoerce Refl
-
-lemTakeLenApp :: Rank l1 <= Rank l2
- => Proxy l1 -> Proxy l2 -> Proxy rest
- -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest)
-lemTakeLenApp _ _ _ = unsafeCoerce Refl
-
-srankSh :: ShX sh f -> SNat (Rank sh)
-srankSh ZSX = SNat
-srankSh (_ :$% sh) | SNat <- srankSh sh = SNat
-
-
--- === NEW INDEX TYPES === --
-
-type role ListR nominal representational
-type ListR :: Nat -> Type -> Type
-data ListR n i where
- ZR :: ListR 0 i
- (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
-deriving instance Eq i => Eq (ListR n i)
-deriving instance Ord i => Ord (ListR n i)
-deriving instance Functor (ListR n)
-deriving instance Foldable (ListR n)
-infixr 3 :::
-
-instance Show i => Show (ListR n i) where
- showsPrec _ = showListR shows
-
-data UnconsListRRes i n1 =
- forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
-unconsListR :: ListR n1 i -> Maybe (UnconsListRRes i n1)
-unconsListR (i ::: sh') = Just (UnconsListRRes sh' i)
-unconsListR ZR = Nothing
-
-showListR :: forall sh i. (i -> ShowS) -> ListR sh i -> ShowS
-showListR f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListR sh' i -> ShowS
- go _ ZR = id
- go prefix (x ::: xs) = showString prefix . f x . go "," xs
-
-listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
-listrAppend ZR sh = sh
-listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
-
-listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
-listrFromList [] k = k ZR
-listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
-
-listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
-listrIndex SZ (x ::: _) = x
-listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
-listrIndex _ ZR = error "k + 1 <= 0"
-
-
--- | An index into a rank-typed array.
-type role IxR nominal representational
-type IxR :: Nat -> Type -> Type
-newtype IxR n i = IxR (ListR n i)
- deriving (Eq, Ord)
- deriving newtype (Functor, Foldable)
-
-pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
-pattern ZIR = IxR ZR
-
-pattern (:.:)
- :: forall {n1} {i}.
- forall n. (n + 1 ~ n1)
- => i -> IxR n i -> IxR n1 i
-pattern i :.: sh <- IxR (unconsListR -> Just (UnconsListRRes (IxR -> sh) i))
- where i :.: IxR sh = IxR (i ::: sh)
-infixr 3 :.:
-
-{-# COMPLETE ZIR, (:.:) #-}
-
-type IIxR n = IxR n Int
-
-instance Show i => Show (IxR n i) where
- showsPrec _ (IxR l) = showListR shows l
-
-
-type role ShR nominal representational
-type ShR :: Nat -> Type -> Type
-newtype ShR n i = ShR (ListR n i)
- deriving (Eq, Ord)
- deriving newtype (Functor, Foldable)
-
-pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
-pattern ZSR = ShR ZR
-
-pattern (:$:)
- :: forall {n1} {i}.
- forall n. (n + 1 ~ n1)
- => i -> ShR n i -> ShR n1 i
-pattern i :$: sh <- ShR (unconsListR -> Just (UnconsListRRes (ShR -> sh) i))
- where i :$: (ShR sh) = ShR (i ::: sh)
-infixr 3 :$:
-
-{-# COMPLETE ZSR, (:$:) #-}
-
-type IShR n = ShR n Int
-
-instance Show i => Show (ShR n i) where
- showsPrec _ (ShR l) = showListR shows l
-
-
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ListR n i) where
- type Item (ListR n i) = i
- fromList = go (SNat @n)
- where
- go :: SNat n' -> [i] -> ListR n' i
- go SZ [] = ZR
- go (SS n) (i : is) = i ::: go n is
- go _ _ = error "IsList(ListR): Mismatched list length"
- toList = Foldable.toList
-
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (IxR n i) where
- type Item (IxR n i) = i
- fromList = IxR . IsList.fromList
- toList = Foldable.toList
-
--- | Untyped: length is checked at runtime.
-instance KnownNat n => IsList (ShR n i) where
- type Item (ShR n i) = i
- fromList = ShR . IsList.fromList
- toList = Foldable.toList
-
-
-type role ListS nominal representational
-type ListS :: [Nat] -> (Nat -> Type) -> Type
-data ListS sh f where
- ZS :: ListS '[] f
- (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
-deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
-deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
-infixr 3 ::$
-
-instance (forall n. Show (f n)) => Show (ListS sh f) where
- showsPrec _ = showListS shows
-
-data UnconsListSRes f sh1 =
- forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
-unconsListS :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
-unconsListS (x ::$ sh') = Just (UnconsListSRes sh' x)
-unconsListS ZS = Nothing
-
-fmapListS :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
-fmapListS _ ZS = ZS
-fmapListS f (x ::$ xs) = f x ::$ fmapListS f xs
-
-foldListS :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
-foldListS _ ZS = mempty
-foldListS f (x ::$ xs) = f x <> foldListS f xs
-
-showListS :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
-showListS f l = showString "[" . go "" l . showString "]"
- where
- go :: String -> ListS sh' f -> ShowS
- go _ ZS = id
- go prefix (x ::$ xs) = showString prefix . f x . go "," xs
-
-listSToList :: ListS sh (Const i) -> [i]
-listSToList ZS = []
-listSToList (Const i ::$ is) = i : listSToList is
-
-
--- | An index into a shape-typed array.
---
--- For convenience, this contains regular 'Int's instead of bounded integers
--- (traditionally called \"@Fin@\"). Note that because the shape of a
--- shape-typed array is known statically, you can also retrieve the array shape
--- from a 'KnownShape' dictionary.
-type role IxS nominal representational
-type IxS :: [Nat] -> Type -> Type
-newtype IxS sh i = IxS (ListS sh (Const i))
- deriving (Eq, Ord)
-
-pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
-pattern ZIS = IxS ZS
-
-pattern (:.$)
- :: forall {sh1} {i}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
- => i -> IxS sh i -> IxS sh1 i
-pattern i :.$ shl <- IxS (unconsListS -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
- where i :.$ IxS shl = IxS (Const i ::$ shl)
-infixr 3 :.$
-
-{-# COMPLETE ZIS, (:.$) #-}
-
-type IIxS sh = IxS sh Int
-
-instance Show i => Show (IxS sh i) where
- showsPrec _ (IxS l) = showListS (\(Const i) -> shows i) l
-
-instance Functor (IxS sh) where
- fmap f (IxS l) = IxS (fmapListS (Const . f . getConst) l)
-
-instance Foldable (IxS sh) where
- foldMap f (IxS l) = foldListS (f . getConst) l
-
-
--- | The shape of a shape-typed array given as a list of 'SNat' values.
-type role ShS nominal
-type ShS :: [Nat] -> Type
-newtype ShS sh = ShS (ListS sh SNat)
- deriving (Eq, Ord)
-
-pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
-pattern ZSS = ShS ZS
-
-pattern (:$$)
- :: forall {sh1}.
- forall n sh. (KnownNat n, n : sh ~ sh1)
- => SNat n -> ShS sh -> ShS sh1
-pattern i :$$ shl <- ShS (unconsListS -> Just (UnconsListSRes (ShS -> shl) i))
- where i :$$ ShS shl = ShS (i ::$ shl)
-
-infixr 3 :$$
-
-{-# COMPLETE ZSS, (:$$) #-}
-
-instance Show (ShS sh) where
- showsPrec _ (ShS l) = showListS (shows . fromSNat) l
-
-lengthShS :: ShS sh -> Int
-lengthShS (ShS l) = getSum (foldListS (\_ -> Sum 1) l)
-
-shSToList :: ShS sh -> [Int]
-shSToList ZSS = []
-shSToList (sn :$$ sh) = fromSNat' sn : shSToList sh
-
-
--- | Untyped: length is checked at runtime.
-instance KnownShS sh => IsList (ListS sh (Const i)) where
- type Item (ListS sh (Const i)) = i
- fromList topl = go (knownShS @sh) topl
- where
- go :: ShS sh' -> [i] -> ListS sh' (Const i)
- go ZSS [] = ZS
- go (_ :$$ sh) (i : is) = Const i ::$ go sh is
- go _ _ = error $ "IsList(ListS): Mismatched list length (type says "
- ++ show (lengthShS (knownShS @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
- toList = listSToList
-
--- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
-instance KnownShS sh => IsList (IxS sh i) where
- type Item (IxS sh i) = i
- fromList = IxS . IsList.fromList
- toList = Foldable.toList
-
--- | Untyped: length and values are checked at runtime.
-instance KnownShS sh => IsList (ShS sh) where
- type Item (ShS sh) = Int
- fromList topl = ShS (go (knownShS @sh) topl)
- where
- go :: ShS sh' -> [Int] -> ListS sh' SNat
- go ZSS [] = ZS
- go (sn :$$ sh) (i : is)
- | i == fromSNat' sn = sn ::$ go sh is
- | otherwise = error $ "IsList(ShS): Value does not match typing (type says "
- ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
- go _ _ = error $ "IsList(ShS): Mismatched list length (type says "
- ++ show (lengthShS (knownShS @sh)) ++ ", list has length "
- ++ show (length topl) ++ ")"
- toList = shSToList
-
-
--- | Wrapper type used as a tag to attach instances on. The instances on arrays
--- of @'Primitive' a@ are more polymorphic than the direct instances for arrays
--- of scalars; this means that if @orthotope@ supports an element type @T@ that
--- this library does not (directly), it may just work if you use an array of
--- @'Primitive' T@ instead.
-newtype Primitive a = Primitive a
-
--- | Element types that are primitive; arrays of these types are just a newtype
--- wrapper over an array.
-class Storable a => PrimElt a where
- fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a
- toPrimitive :: Mixed sh a -> Mixed sh (Primitive a)
-
- default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a
- fromPrimitive = coerce
-
- default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a)
- toPrimitive = coerce
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-instance PrimElt Int
-instance PrimElt Int64
-instance PrimElt Int32
-instance PrimElt CInt
-instance PrimElt Float
-instance PrimElt Double
-instance PrimElt ()
-
-
--- | Mixed arrays: some dimensions are size-typed, some are not. Distributes
--- over product-typed elements using a data family so that the full array is
--- always in struct-of-arrays format.
---
--- Built on top of 'XArray' which is built on top of @orthotope@, meaning that
--- dimension permutations (e.g. 'mtranspose') are typically free.
---
--- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
--- class.
-type Mixed :: [Maybe Nat] -> Type -> Type
-data family Mixed sh a
--- NOTE: When opening up the Mixed abstraction, you might see dimension sizes
--- that you're not supposed to see. In particular, you might see (nonempty)
--- sizes of the elements of an empty array, which is information that should
--- ostensibly not exist; the full array is still empty.
-
-data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a)
- deriving (Show, Eq, Generic)
-
--- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays.
-deriving instance (Ord a, Storable a) => Ord (Mixed '[] (Primitive a))
-
-instance NFData a => NFData (Mixed sh (Primitive a))
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show, Eq, Generic)
-newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show, Eq, Generic)
-newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show, Eq, Generic)
-newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show, Eq, Generic)
-newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show, Eq, Generic)
-newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show, Eq, Generic)
-newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show, Eq, Generic) -- no content, orthotope optimises this (via Vector)
--- etc.
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-deriving instance Ord (Mixed '[] Int) ; instance NFData (Mixed sh Int)
-deriving instance Ord (Mixed '[] Int64) ; instance NFData (Mixed sh Int64)
-deriving instance Ord (Mixed '[] Int32) ; instance NFData (Mixed sh Int32)
-deriving instance Ord (Mixed '[] CInt) ; instance NFData (Mixed sh CInt)
-deriving instance Ord (Mixed '[] Float) ; instance NFData (Mixed sh Float)
-deriving instance Ord (Mixed '[] Double) ; instance NFData (Mixed sh Double)
-deriving instance Ord (Mixed '[] ()) ; instance NFData (Mixed sh ())
-
-data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic)
-deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b))
-instance (NFData (Mixed sh a), NFData (Mixed sh b)) => NFData (Mixed sh (a, b))
--- etc., larger tuples (perhaps use generics to allow arbitrary product types)
-
-data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic)
-deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a))
-instance NFData (Mixed (sh1 ++ sh2) a) => NFData (Mixed sh1 (Mixed sh2 a))
-
-
--- | Internal helper data family mirroring 'Mixed' that consists of mutable
--- vectors instead of 'XArray's.
-type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
-data family MixedVecs s sh a
-
-newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a)
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int)
-newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64)
-newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32)
-newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt)
-newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double)
-newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float)
-newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this
--- etc.
-
-data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b)
--- etc.
-
-data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a)
-
-
--- | Tree giving the shape of every array component.
-type family ShapeTree a where
- ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
- ShapeTree (Mixed sh a) = (IShX sh, ShapeTree a)
- ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
- ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
-
- -- to avoid having to list all of the primitive types:
- ShapeTree _ = ()
-
-
--- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
--- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
--- a@; see the documentation for 'Primitive' for more details.
-class Elt a where
- -- ====== PUBLIC METHODS ====== --
-
- mshape :: Mixed sh a -> IShX sh
- mindex :: Mixed sh a -> IIxX sh -> a
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
- mscalar :: a -> Mixed '[] a
-
- -- | All arrays in the list, even subarrays inside @a@, must have the same
- -- shape; if they do not, a runtime error will be thrown. See the
- -- documentation of 'mgenerate' for more information about this restriction.
- -- Furthermore, the length of the list must correspond with @n@: if @n@ is
- -- @Just m@ and @m@ does not equal the length of the list, a runtime error is
- -- thrown.
- --
- -- Consider also 'mfromListPrim', which can avoid intermediate arrays.
- mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
-
- mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
-
- -- | Note: this library makes no particular guarantees about the shapes of
- -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the
- -- full 'XArray' and as such you can distinguish different empty arrays by
- -- the "shapes" of their elements. This information is meaningless, so you
- -- should not use it.
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 a -> Mixed sh2 a
-
- -- | See the documentation for 'mlift'.
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
-
- mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
-
- mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
- => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a
-
- -- ====== PRIVATE METHODS ====== --
-
- mshapeTree :: a -> ShapeTree a
-
- mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
-
- mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool
-
- mshowShapeTree :: Proxy a -> ShapeTree a -> String
-
- -- | Given the shape of this array, an index and a value, write the value at
- -- that index in the vectors.
- mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
-
- -- | Given the shape of this array, an index and a value, write the value at
- -- that index in the vectors.
- mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-
- -- | Given the shape of this array, finalise the vectors into 'XArray's.
- mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
-
-
--- | Element types for which we have evidence of the (static part of the) shape
--- in a type class constraint. Compare the instance contexts of the instances
--- of this class with those of 'Elt': some instances have an additional
--- "known-shape" constraint.
---
--- This class is (currently) only required for 'mgenerate' / 'rgenerate' /
--- 'sgenerate'.
-class Elt a => KnownElt a where
- -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
- memptyArray :: IShX sh -> Mixed sh a
-
- -- | Create uninitialised vectors for this array type, given the shape of
- -- this vector and an example for the contents.
- mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
-
- mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
-
-
--- Arrays of scalars are basically just arrays of scalars.
-instance Storable a => Elt (Primitive a) where
- mshape (M_Primitive sh _) = sh
- mindex (M_Primitive _ a) i = Primitive (X.index a i)
- mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
- mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
- mfromListOuter l@(arr1 :| _) =
- let sh = SUnknown (length l) :$% mshape arr1
- in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l)))
- mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
- -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
- mlift ssh2 f (M_Primitive _ a)
- | Refl <- lemAppNil @sh1
- , Refl <- lemAppNil @sh2
- , let result = f ZKX a
- = M_Primitive (X.shape ssh2 result) result
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
- -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a)
- mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b)
- | Refl <- lemAppNil @sh1
- , Refl <- lemAppNil @sh2
- , Refl <- lemAppNil @sh3
- , let result = f ZKX a b
- = M_Primitive (X.shape ssh3 result) result
-
- mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
- mcast ssh1 sh2 _ (M_Primitive sh1' arr) =
- let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
- in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)
-
- mtranspose perm (M_Primitive sh arr) =
- M_Primitive (shxPermutePrefix perm sh)
- (X.transpose (ssxFromShape sh) perm arr)
-
- mshapeTree _ = ()
- mshapeTreeEq _ () () = True
- mshapeTreeEmpty _ () = False
- mshowShapeTree _ () = "()"
- mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
-
- -- TODO: this use of toVector is suboptimal
- mvecsWritePartial
- :: forall sh' sh s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
- mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
- let arrsh = X.shape (ssxFromShape sh') arr
- offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
- VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
-
- mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-deriving via Primitive Int instance Elt Int
-deriving via Primitive Int64 instance Elt Int64
-deriving via Primitive Int32 instance Elt Int32
-deriving via Primitive CInt instance Elt CInt
-deriving via Primitive Double instance Elt Double
-deriving via Primitive Float instance Elt Float
-deriving via Primitive () instance Elt ()
-
-instance Storable a => KnownElt (Primitive a) where
- memptyArray sh = M_Primitive sh (X.empty sh)
- mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
- mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
-
--- [PRIMITIVE ELEMENT TYPES LIST]
-deriving via Primitive Int instance KnownElt Int
-deriving via Primitive Int64 instance KnownElt Int64
-deriving via Primitive Int32 instance KnownElt Int32
-deriving via Primitive CInt instance KnownElt CInt
-deriving via Primitive Double instance KnownElt Double
-deriving via Primitive Float instance KnownElt Float
-deriving via Primitive () instance KnownElt ()
-
--- Arrays of pairs are pairs of arrays.
-instance (Elt a, Elt b) => Elt (a, b) where
- mshape (M_Tup2 a _) = mshape a
- mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
- mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
- mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
- mfromListOuter l =
- M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l))
- (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l))
- mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
- mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
- mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
-
- mcast ssh1 sh2 psh' (M_Tup2 a b) =
- M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b)
-
- mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b)
-
- mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
- mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
- mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
- mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
- mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
- mvecsWrite sh i x a
- mvecsWrite sh i y b
- mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
- mvecsWritePartial sh i x a
- mvecsWritePartial sh i y b
- mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
-
-instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
- memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
- mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
- mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
-
--- Arrays of arrays are just arrays, but with more dimensions.
-instance Elt a => Elt (Mixed sh' a) where
- -- TODO: this is quadratic in the nesting depth because it repeatedly
- -- truncates the shape vector to one a little shorter. Fix with a
- -- moverlongShape method, a prefix of which is mshape.
- mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
- mshape (M_Nest sh arr)
- = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr))
-
- mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
- mindex (M_Nest _ arr) i = mindexPartial arr i
-
- mindexPartial :: forall sh1 sh2.
- Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- mindexPartial (M_Nest sh arr) i
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
-
- mscalar = M_Nest ZSX
-
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
- mfromListOuter l@(arr :| _) =
- M_Nest (SUnknown (length l) :$% mshape arr)
- (mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
-
- mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
- -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
- mlift ssh2 f (M_Nest sh1 arr) =
- let result = mlift (ssxAppend ssh2 ssh') f' arr
- (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
- in M_Nest sh2 result
- where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr)))
-
- f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
- f' sshT
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
- , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
- = f (ssxAppend ssh' sshT)
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
- -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
- mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) =
- let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2
- (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
- in M_Nest sh3 result
- where
- ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
-
- f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
- f' sshT
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
- , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
- , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
- = f (ssxAppend ssh' sshT)
-
- mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2
- => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
- mcast ssh1 sh2 _ (M_Nest sh1T arr)
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
- , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
- = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
- in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr)
-
- mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
- => Perm is -> Mixed sh (Mixed sh' a)
- -> Mixed (PermutePrefix is sh) (Mixed sh' a)
- mtranspose perm (M_Nest sh arr)
- | let sh' = shxDropSh @sh @sh' (mshape arr) sh
- , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh')
- , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
- , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
- , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
- , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
- = M_Nest (shxPermutePrefix perm sh)
- (mtranspose perm arr)
-
- mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
- mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr)))))
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
-
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
- -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
- -> ST s ()
- mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
- | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
-
- mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
-
-instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
- memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
-
- mvecsUnsafeNew sh example
- | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
- | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh')))
- where
- sh' = mshape example
-
- mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
-
-
--- | Create an array given a size and a function that computes the element at a
--- given index.
---
--- __WARNING__: It is required that every @a@ returned by the argument to
--- 'mgenerate' has the same shape. For example, the following will throw a
--- runtime error:
---
--- > foo :: Mixed [Nothing] (Mixed [Nothing] Double)
--- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) ->
--- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) ->
--- > ...
---
--- because the size of the inner 'mgenerate' is not always the same (it depends
--- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so
--- the entire hierarchy (after distributing out tuples) must be a rectangular
--- array. The type of 'mgenerate' allows this requirement to be broken very
--- easily, hence the runtime check.
-mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
-mgenerate sh f = case shxEnum sh of
- [] -> memptyArray sh
- firstidx : restidxs ->
- let firstelem = f (ixxZero' sh)
- shapetree = mshapeTree firstelem
- in if mshapeTreeEmpty (Proxy @a) shapetree
- then memptyArray sh
- else runST $ do
- vecs <- mvecsUnsafeNew sh firstelem
- mvecsWrite sh firstidx firstelem vecs
- -- TODO: This is likely fine if @a@ is big, but if @a@ is a
- -- scalar this array copying inefficient. Should improve this.
- forM_ restidxs $ \idx -> do
- let val = f idx
- when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
- error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
- mvecsWrite sh idx val vecs
- mvecsFreeze sh vecs
-
-msumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
-msumOuter1P (M_Primitive (n :$% sh) arr) =
- let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
- in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr)
-
-msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Mixed (n : sh) a -> Mixed sh a
-msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
-
-mappend :: forall n m sh a. Elt a
- => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
-mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
- where
- sn :$% sh = mshape arr1
- sm :$% _ = mshape arr2
- ssh = ssxFromShape sh
- snm :: SMayNat () SNat (AddMaybe n m)
- snm = case (sn, sm) of
- (SUnknown{}, _) -> SUnknown ()
- (SKnown{}, SUnknown{}) -> SUnknown ()
- (SKnown n, SKnown m) -> SKnown (snatPlus n m)
-
- f :: forall sh' b. Storable b
- => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b
- f ssh' = X.append (ssxAppend ssh ssh')
-
-mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
-mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
-
-mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a
-mfromVector sh v = fromPrimitive (mfromVectorP sh v)
-
-mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a
-mtoVectorP (M_Primitive _ v) = X.toVector v
-
-mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
-mtoVector arr = mtoVectorP (toPrimitive arr)
-
-mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
-mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
-
-mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
-mfromList1Prim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-
-mtoList1 :: Elt a => Mixed '[n] a -> [a]
-mtoList1 = map munScalar . mtoListOuter
-
-mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
-mfromListPrim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-
-mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a
-mfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
-
-munScalar :: Elt a => Mixed '[] a -> a
-munScalar arr = mindex arr ZIX
-
-mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
- -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
-mrerankP ssh sh2 f (M_Primitive sh arr) =
- let sh1 = shxDropSSX sh ssh
- in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2)
- (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2)
- (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
- arr)
-
--- | See the caveats at @X.rerank@.
-mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => StaticShX sh -> IShX sh2
- -> (Mixed sh1 a -> Mixed sh2 b)
- -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b
-mrerank ssh sh2 f (toPrimitive -> arr) =
- fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
-
-mreplicate :: forall sh sh' a. Elt a
- => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
-mreplicate sh arr =
- let ssh' = ssxFromShape (mshape arr)
- in mlift (ssxAppend (ssxFromShape sh) ssh')
- (\(sshT :: StaticShX shT) ->
- case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
- Refl -> X.replicate sh (ssxAppend ssh' sshT))
- arr
-
-mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
-mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
-
-mreplicateScal :: forall sh a. PrimElt a
- => IShX sh -> a -> Mixed sh a
-mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
-
-mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
-mslice i n arr =
- let _ :$% sh = mshape arr
- in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr
-
-msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
-msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr
-
-mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
-mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr
-
-mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a
-mreshape sh' arr =
- mlift (ssxFromShape sh')
- (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh')
- arr
-
-miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a
-miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
-
-masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
-masXArrayPrimP (M_Primitive sh arr) = (sh, arr)
-
-masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a)
-masXArrayPrim = masXArrayPrimP . toPrimitive
-
-mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a)
-mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr
-
-mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a
-mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
-
-mliftPrim :: PrimElt a
- => (a -> a)
- -> Mixed sh a -> Mixed sh a
-mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr))
-
-mliftPrim2 :: PrimElt a
- => (a -> a -> a)
- -> Mixed sh a -> Mixed sh a -> Mixed sh a
-mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
- fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))
-
-mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a
-mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (srankSh sh) arr))
-
-mliftNumElt2 :: PrimElt a
- => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a)
- -> Mixed sh a -> Mixed sh a -> Mixed sh a
-mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2))
- | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (srankSh sh1) arr1 arr2))
- | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2
-
-instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
- (+) = mliftNumElt2 numEltAdd
- (-) = mliftNumElt2 numEltSub
- (*) = mliftNumElt2 numEltMul
- negate = mliftNumElt1 numEltNeg
- abs = mliftNumElt1 numEltAbs
- signum = mliftNumElt1 numEltSignum
- fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate"
-
-instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where
- fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
- recip = mliftNumElt1 floatEltRecip
- (/) = mliftNumElt2 floatEltDiv
-
-instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
- exp = mliftNumElt1 floatEltExp
- log = mliftNumElt1 floatEltLog
- sqrt = mliftNumElt1 floatEltSqrt
-
- (**) = mliftNumElt2 floatEltPow
- logBase = mliftNumElt2 floatEltLogbase
-
- sin = mliftNumElt1 floatEltSin
- cos = mliftNumElt1 floatEltCos
- tan = mliftNumElt1 floatEltTan
- asin = mliftNumElt1 floatEltAsin
- acos = mliftNumElt1 floatEltAcos
- atan = mliftNumElt1 floatEltAtan
- sinh = mliftNumElt1 floatEltSinh
- cosh = mliftNumElt1 floatEltCosh
- tanh = mliftNumElt1 floatEltTanh
- asinh = mliftNumElt1 floatEltAsinh
- acosh = mliftNumElt1 floatEltAcosh
- atanh = mliftNumElt1 floatEltAtanh
- log1p = mliftNumElt1 floatEltLog1p
- expm1 = mliftNumElt1 floatEltExpm1
- log1pexp = mliftNumElt1 floatEltLog1pexp
- log1mexp = mliftNumElt1 floatEltLog1mexp
-
-mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
-mtoRanked arr
- | Refl <- lemAppNil @sh
- , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat))
- , Refl <- lemRankReplicate (srankSh (mshape arr))
- = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr)
- where
- convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing)
- convSh ZSX = ZSX
- convSh (smn :$% (sh :: IShX sh'T))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T)
- = SUnknown (fromSMayNat' smn) :$% convSh sh
-
-mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
- => Mixed sh a -> ShS sh' -> Shaped sh' a
-mcastToShaped arr targetsh
- | Refl <- lemAppNil @sh
- , Refl <- lemAppNil @(MapJust sh')
- , Refl <- lemRankMapJust targetsh
- = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr)
-
-
--- | A rank-typed array: the number of dimensions of the array (its /rank/) is
--- represented on the type level as a 'Nat'.
---
--- Valid elements of a ranked arrays are described by the 'Elt' type class.
--- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
--- supported (and are represented as a single, flattened, struct-of-arrays
--- array internally).
---
--- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
-type Ranked :: Nat -> Type -> Type
-newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
-deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
-deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a)
-deriving instance Ord (Mixed '[] a) => Ord (Ranked 0 a)
-deriving instance NFData (Mixed (Replicate n Nothing) a) => NFData (Ranked n a)
-
--- | A shape-typed array: the full shape of the array (the sizes of its
--- dimensions) is represented on the type level as a list of 'Nat's. Note that
--- these are "GHC.TypeLits" naturals, because we do not need induction over
--- them and we want very large arrays to be possible.
---
--- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
--- and 'Shaped' itself is again an instance of 'Elt' as well.
---
--- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
-type Shaped :: [Nat] -> Type -> Type
-newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
-deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a)
-deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
-deriving instance Ord (Mixed '[] a) => Ord (Shaped '[] a)
-deriving instance NFData (Mixed (MapJust sh) a) => NFData (Shaped sh a)
-
--- just unwrap the newtype and defer to the general instance for nested arrays
-newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
-deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a))
-newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a))
-deriving instance Show (Mixed sh (Mixed (MapJust sh' ) a)) => Show (Mixed sh (Shaped sh' a))
-
-newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
-newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a))
-
-
--- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
--- these instances allow them to also be used as elements of arrays, thus
--- making them first-class in the API.
-instance Elt a => Elt (Ranked n a) where
- mshape (M_Ranked arr) = mshape arr
- mindex (M_Ranked arr) i = Ranked (mindex arr i)
-
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
- mindexPartial (M_Ranked arr) i =
- coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
- mindexPartial arr i
-
- mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
-
- mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
- mfromListOuter l = M_Ranked (mfromListOuter (coerce l))
-
- mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
- mtoListOuter (M_Ranked arr) =
- coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
- mlift ssh2 f (M_Ranked arr) =
- coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
- mlift ssh2 f arr
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
- mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) =
- coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
- mlift2 ssh3 f arr1 arr2
-
- mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr)
-
- mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)
-
- mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr)
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = shapeSizeR sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
- mvecsWrite sh idx (Ranked arr) vecs =
- mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
-
- mvecsWritePartial :: forall sh sh' s.
- IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
- -> MixedVecs s (sh ++ sh') (Ranked n a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
- (coerce @(Mixed sh' (Ranked n a))
- @(Mixed sh' (Mixed (Replicate n Nothing) a))
- arr)
- (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
- @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
- vecs)
-
- mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
- mvecsFreeze sh vecs =
- coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
- @(Mixed sh (Ranked n a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh (Ranked n a))
- @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
-
-instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
- memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a)
- memptyArray i
- | Dict <- lemKnownReplicate (SNat @n)
- = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
- memptyArray i
-
- mvecsUnsafeNew idx (Ranked arr)
- | Dict <- lemKnownReplicate (SNat @n)
- = MV_Ranked <$> mvecsUnsafeNew idx arr
-
- mvecsNewEmpty _
- | Dict <- lemKnownReplicate (SNat @n)
- = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
-
--- sshapeKnown :: ShS sh -> Dict KnownShape sh
--- sshapeKnown ZSS = Dict
--- sshapeKnown (SNat :$$ sh) | Dict <- sshapeKnown sh = Dict
-
-lemCommMapJustApp :: forall sh1 sh2. ShS sh1 -> Proxy sh2
- -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
-lemCommMapJustApp ZSS _ = Refl
-lemCommMapJustApp (_ :$$ sh) p | Refl <- lemCommMapJustApp sh p = Refl
-
-instance Elt a => Elt (Shaped sh a) where
- mshape (M_Shaped arr) = mshape arr
- mindex (M_Shaped arr) i = Shaped (mindex arr i)
-
- mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- mindexPartial (M_Shaped arr) i =
- coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mindexPartial arr i
-
- mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
-
- mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
- mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
-
- mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
- mtoListOuter (M_Shaped arr)
- = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
-
- mlift :: forall sh1 sh2.
- StaticShX sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
- mlift ssh2 f (M_Shaped arr) =
- coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mlift ssh2 f arr
-
- mlift2 :: forall sh1 sh2 sh3.
- StaticShX sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
- -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
- mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) =
- coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
- mlift2 ssh3 f arr1 arr2
-
- mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr)
-
- mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
-
- mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr)
-
- mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
-
- mshapeTreeEmpty _ (sh, t) = shapeSizeS sh == 0 && mshapeTreeEmpty (Proxy @a) t
-
- mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
-
- mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
- mvecsWrite sh idx (Shaped arr) vecs =
- mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
-
- mvecsWritePartial :: forall sh1 sh2 s.
- IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
- -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs =
- mvecsWritePartial sh idx
- (coerce @(Mixed sh2 (Shaped sh a))
- @(Mixed sh2 (Mixed (MapJust sh) a))
- arr)
- (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
- @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
- vecs)
-
- mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
- mvecsFreeze sh vecs =
- coerce @(Mixed sh' (Mixed (MapJust sh) a))
- @(Mixed sh' (Shaped sh a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh' (Shaped sh a))
- @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
-
--- | Evidence for the static part of a shape. This pops up only when you are
--- polymorphic in the element type of an array.
-type KnownShS :: [Nat] -> Constraint
-class KnownShS sh where knownShS :: ShS sh
-instance KnownShS '[] where knownShS = ZSS
-instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
-
-lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh)
-lemKnownMapJust _ = lemKnownShX (go (knownShS @sh))
- where
- go :: ShS sh' -> StaticShX (MapJust sh')
- go ZSS = ZKX
- go (n :$$ sh) = SKnown n :!% go sh
-
-instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
- memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
- memptyArray i
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
- memptyArray i
-
- mvecsUnsafeNew idx (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = MV_Shaped <$> mvecsUnsafeNew idx arr
-
- mvecsNewEmpty _
- | Dict <- lemKnownMapJust (Proxy @sh)
- = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
-
-
--- ====== API OF RANKED ARRAYS ====== --
-
-arithPromoteRanked :: forall n a. PrimElt a
- => (forall sh. Mixed sh a -> Mixed sh a)
- -> Ranked n a -> Ranked n a
-arithPromoteRanked = coerce
-
-arithPromoteRanked2 :: forall n a. PrimElt a
- => (forall sh. Mixed sh a -> Mixed sh a -> Mixed sh a)
- -> Ranked n a -> Ranked n a -> Ranked n a
-arithPromoteRanked2 = coerce
-
-instance (NumElt a, PrimElt a) => Num (Ranked n a) where
- (+) = arithPromoteRanked2 (+)
- (-) = arithPromoteRanked2 (-)
- (*) = arithPromoteRanked2 (*)
- negate = arithPromoteRanked negate
- abs = arithPromoteRanked abs
- signum = arithPromoteRanked signum
- fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal"
-
-instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Ranked n a) where
- fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal"
- recip = arithPromoteRanked recip
- (/) = arithPromoteRanked2 (/)
-
-instance (FloatElt a, NumElt a, PrimElt a) => Floating (Ranked n a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal"
- exp = arithPromoteRanked exp
- log = arithPromoteRanked log
- sqrt = arithPromoteRanked sqrt
- (**) = arithPromoteRanked2 (**)
- logBase = arithPromoteRanked2 logBase
- sin = arithPromoteRanked sin
- cos = arithPromoteRanked cos
- tan = arithPromoteRanked tan
- asin = arithPromoteRanked asin
- acos = arithPromoteRanked acos
- atan = arithPromoteRanked atan
- sinh = arithPromoteRanked sinh
- cosh = arithPromoteRanked cosh
- tanh = arithPromoteRanked tanh
- asinh = arithPromoteRanked asinh
- acosh = arithPromoteRanked acosh
- atanh = arithPromoteRanked atanh
- log1p = arithPromoteRanked GHC.Float.log1p
- expm1 = arithPromoteRanked GHC.Float.expm1
- log1pexp = arithPromoteRanked GHC.Float.log1pexp
- log1mexp = arithPromoteRanked GHC.Float.log1mexp
-
-zeroIxR :: SNat n -> IIxR n
-zeroIxR SZ = ZIR
-zeroIxR (SS n) = 0 :.: zeroIxR n
-
-ixCvtXR :: IIxX sh -> IIxR (Rank sh)
-ixCvtXR ZIX = ZIR
-ixCvtXR (n :.% idx) = n :.: ixCvtXR idx
-
-shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n
-shCvtXR' ZSX =
- castWith (subst2 (unsafeCoerce Refl :: 0 :~: n))
- ZSR
-shCvtXR' (n :$% (idx :: IShX sh))
- | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
- castWith (subst2 (lem1 @sh Refl))
- (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
- where
- lem1 :: forall sh' n' k.
- k : sh' :~: Replicate n' Nothing
- -> Rank sh' + 1 :~: n'
- lem1 Refl = unsafeCoerce Refl
-
- lem2 :: k : sh :~: Replicate n Nothing
- -> sh :~: Replicate (Rank sh) Nothing
- lem2 Refl = unsafeCoerce Refl
-
-ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
-ixCvtRX ZIR = ZIX
-ixCvtRX (n :.: (idx :: IxR m Int)) =
- castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
- (n :.% ixCvtRX idx)
-
-shCvtRX :: IShR n -> IShX (Replicate n Nothing)
-shCvtRX ZSR = ZSX
-shCvtRX (n :$: (idx :: ShR m Int)) =
- castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
- (SUnknown n :$% shCvtRX idx)
-
-shapeSizeR :: IShR n -> Int
-shapeSizeR ZSR = 1
-shapeSizeR (n :$: sh) = n * shapeSizeR sh
-
-
-rshape :: forall n a. Elt a => Ranked n a -> IShR n
-rshape (Ranked arr) = shCvtXR' (mshape arr)
-
-rindex :: Elt a => Ranked n a -> IIxR n -> a
-rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
-
-snatFromListR :: ListR n i -> SNat n
-snatFromListR ZR = SNat
-snatFromListR (_ ::: (l :: ListR n i)) | SNat <- snatFromListR l, Dict <- knownNatSucc @n = SNat
-
-snatFromIxR :: IxR n i -> SNat n
-snatFromIxR (IxR sh) = snatFromListR sh
-
-snatFromShR :: ShR n i -> SNat n
-snatFromShR (ShR sh) = snatFromListR sh
-
-rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a
-rindexPartial (Ranked arr) idx =
- Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
- (castWith (subst2 (lemReplicatePlusApp (snatFromIxR idx) (Proxy @m) (Proxy @Nothing))) arr)
- (ixCvtRX idx))
-
--- | __WARNING__: All values returned from the function must have equal shape.
--- See the documentation of 'mgenerate' for more details.
-rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a
-rgenerate sh f
- | sn@SNat <- snatFromShR sh
- , Dict <- lemKnownReplicate sn
- , Refl <- lemRankReplicate sn
- = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
-
--- | See the documentation of 'mlift'.
-rlift :: forall n1 n2 a. Elt a
- => SNat n2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
- -> Ranked n1 a -> Ranked n2 a
-rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr)
-
--- | See the documentation of 'mlift2'.
-rlift2 :: forall n1 n2 n3 a. Elt a
- => SNat n3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b)
- -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
-rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
-
-rsumOuter1P :: forall n a.
- (Storable a, NumElt a)
- => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
-rsumOuter1P (Ranked arr)
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = Ranked (msumOuter1P arr)
-
-rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
- => Ranked (n + 1) a -> Ranked n a
-rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
-
-applyPermR :: forall i n. [Int] -> ListR n i -> ListR n i
-applyPermR = \perm sh ->
- listrFromList perm $ \sperm ->
- case (snatFromListR sperm, snatFromListR sh) of
- (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
- LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
- GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
- ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
- where
- listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
- listrSplitAt SZ sh = (ZR, sh)
- listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
- listrSplitAt SS{} ZR = error "m' + 1 <= 0"
-
- applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
- applyPermRFull _ ZR _ = ZR
- applyPermRFull sm@SNat (i ::: perm) l =
- TypeNats.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
- case cmpNat (SNat @(idx + 1)) sm of
- LTI -> listrIndex si l ::: applyPermRFull sm perm l
- EQI -> listrIndex si l ::: applyPermRFull sm perm l
- GTI -> error "applyPermR: Index in permutation out of range"
-
-applyPermIxR :: forall n i. [Int] -> IxR n i -> IxR n i
-applyPermIxR = coerce (applyPermR @i)
-
-applyPermShR :: forall n i. [Int] -> ShR n i -> ShR n i
-applyPermShR = coerce (applyPermR @i)
-
-rtranspose :: forall n a. Elt a => [Int] -> Ranked n a -> Ranked n a
-rtranspose perm arr
- | sn@SNat <- snatFromShR (rshape arr)
- , Dict <- lemKnownReplicate sn
- , length perm <= fromIntegral (natVal (Proxy @n))
- = rlift sn
- (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm)
- arr
- | otherwise
- = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array"
-
-rappend :: forall n a. Elt a
- => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
-rappend arr1 arr2
- | sn@SNat <- snatFromShR (rshape arr1)
- , Dict <- lemKnownReplicate sn
- , Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
- arr1 arr2
-
-rscalar :: Elt a => a -> Ranked 0 a
-rscalar x = Ranked (mscalar x)
-
-rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
-rfromVectorP sh v
- | Dict <- lemKnownReplicate (snatFromShR sh)
- = Ranked (mfromVectorP (shCvtRX sh) v)
-
-rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a
-rfromVector sh v = rfromPrimitive (rfromVectorP sh v)
-
-rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
-rtoVectorP = coerce mtoVectorP
-
-rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
-rtoVector = coerce mtoVector
-
-rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
-rfromListOuter l
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
-
-rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
-rfromList1 l = Ranked (mfromList1 l)
-
-rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a
-rfromList1Prim l = Ranked (mfromList1Prim l)
-
-rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
-rtoListOuter (Ranked arr)
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
-
-rtoList1 :: Elt a => Ranked 1 a -> [a]
-rtoList1 = map runScalar . rtoListOuter
-
-rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
-rfromListPrim l =
- let ssh = SUnknown () :!% ZKX
- xarr = X.fromList1 ssh l
- in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
-
-rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
-rfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr)
-
-rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a
-rfromOrthotope sn arr
- | Refl <- lemRankReplicate sn
- = let xarr = XArray arr
- in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr))
-
-runScalar :: Elt a => Ranked 0 a -> a
-runScalar arr = rindex arr ZIR
-
-rrerankP :: forall n1 n2 n a b. (Storable a, Storable b)
- => SNat n -> IShR n2
- -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
- -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b)
-rrerankP sn sh2 f (Ranked arr)
- | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat))
- , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat))
- = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2)
- (\a -> let Ranked r = f (Ranked a) in r)
- arr)
-
--- | If there is a zero-sized dimension in the @n@-prefix of the shape of the
--- input array, then there is no way to deduce the full shape of the output
--- array (more precisely, the @n2@ part): that could only come from calling
--- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in
--- this case; we choose to fill the @n2@ part of the output shape with zeros.
---
--- For example, if:
---
--- @
--- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21]
--- f :: Ranked 2 Int -> Ranked 3 Float
--- @
---
--- then:
---
--- @
--- rrerank _ _ _ f arr :: Ranked 5 Float
--- @
---
--- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the
--- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended
--- to return an array with shape all-0 here (it probably didn't), but there is
--- no better number to put here absent a subarray of the input to pass to @f@.
-rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
- => SNat n -> IShR n2
- -> (Ranked n1 a -> Ranked n2 b)
- -> Ranked (n + n1) a -> Ranked (n + n2) b
-rrerank ssh sh2 f (rtoPrimitive -> arr) =
- rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr
-
-rreplicate :: forall n m a. Elt a
- => IShR n -> Ranked m a -> Ranked (n + m) a
-rreplicate sh (Ranked arr)
- | Refl <- lemReplicatePlusApp (snatFromShR sh) (Proxy @m) (Proxy @(Nothing @Nat))
- = Ranked (mreplicate (shCvtRX sh) arr)
-
-rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
-rreplicateScalP sh x
- | Dict <- lemKnownReplicate (snatFromShR sh)
- = Ranked (mreplicateScalP (shCvtRX sh) x)
-
-rreplicateScal :: forall n a. PrimElt a
- => IShR n -> a -> Ranked n a
-rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
-
-rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
-rslice i n arr
- | Refl <- lemReplicateSucc @(Nothing @Nat) @n
- = rlift (snatFromShR (rshape arr))
- (\_ -> X.sliceU i n)
- arr
-
-rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
-rrev1 arr =
- rlift (snatFromShR (rshape arr))
- (\(_ :: StaticShX sh') ->
- case lemReplicateSucc @(Nothing @Nat) @n of
- Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
- arr
-
-rreshape :: forall n n' a. Elt a
- => IShR n' -> Ranked n a -> Ranked n' a
-rreshape sh' rarr@(Ranked arr)
- | Dict <- lemKnownReplicate (snatFromShR (rshape rarr))
- , Dict <- lemKnownReplicate (snatFromShR sh')
- = Ranked (mreshape (shCvtRX sh') arr)
-
-riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a
-riota n = TypeNats.withSomeSNat (fromIntegral n) $ mtoRanked . miota
-
-rasXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
-rasXArrayPrimP (Ranked arr) = first shCvtXR' (masXArrayPrimP arr)
-
-rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a)
-rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr)
-
-rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
-rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
-
-rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a
-rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
-
-rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a
-rcastToShaped (Ranked arr) targetsh
- | Refl <- lemRankReplicate (srankSh (shCvtSX targetsh))
- , Refl <- lemRankMapJust targetsh
- = mcastToShaped arr targetsh
-
-rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a
-rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr)
-
-rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a)
-rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)
-
-
--- ====== API OF SHAPED ARRAYS ====== --
-
-arithPromoteShaped :: forall sh a. PrimElt a
- => (forall shx. Mixed shx a -> Mixed shx a)
- -> Shaped sh a -> Shaped sh a
-arithPromoteShaped = coerce
-
-arithPromoteShaped2 :: forall sh a. PrimElt a
- => (forall shx. Mixed shx a -> Mixed shx a -> Mixed shx a)
- -> Shaped sh a -> Shaped sh a -> Shaped sh a
-arithPromoteShaped2 = coerce
-
-instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
- (+) = arithPromoteShaped2 (+)
- (-) = arithPromoteShaped2 (-)
- (*) = arithPromoteShaped2 (*)
- negate = arithPromoteShaped negate
- abs = arithPromoteShaped abs
- signum = arithPromoteShaped signum
- fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal"
-
-instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Shaped sh a) where
- fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
- recip = arithPromoteShaped recip
- (/) = arithPromoteShaped2 (/)
-
-instance (FloatElt a, NumElt a, PrimElt a) => Floating (Shaped sh a) where
- pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
- exp = arithPromoteShaped exp
- log = arithPromoteShaped log
- sqrt = arithPromoteShaped sqrt
- (**) = arithPromoteShaped2 (**)
- logBase = arithPromoteShaped2 logBase
- sin = arithPromoteShaped sin
- cos = arithPromoteShaped cos
- tan = arithPromoteShaped tan
- asin = arithPromoteShaped asin
- acos = arithPromoteShaped acos
- atan = arithPromoteShaped atan
- sinh = arithPromoteShaped sinh
- cosh = arithPromoteShaped cosh
- tanh = arithPromoteShaped tanh
- asinh = arithPromoteShaped asinh
- acosh = arithPromoteShaped acosh
- atanh = arithPromoteShaped atanh
- log1p = arithPromoteShaped GHC.Float.log1p
- expm1 = arithPromoteShaped GHC.Float.expm1
- log1pexp = arithPromoteShaped GHC.Float.log1pexp
- log1mexp = arithPromoteShaped GHC.Float.log1mexp
-
-zeroIxS :: ShS sh -> IIxS sh
-zeroIxS ZSS = ZIS
-zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh
-
-ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh
-ixCvtXS ZSS ZIX = ZIS
-ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx
-
-type family Tail l where
- Tail (_ : xs) = xs
-
-shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh
-shCvtXS' ZSX = castWith (subst1 (unsafeCoerce Refl :: '[] :~: sh)) ZSS
-shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) =
- castWith (subst1 (lem Refl)) $
- n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerce Refl :: mjshT :~: MapJust (Tail sh)))
- idx)
- where
- lem :: forall sh1 sh' n.
- Just n : sh1 :~: MapJust sh'
- -> n : Tail sh' :~: sh'
- lem Refl = unsafeCoerce Refl
-shCvtXS' (SUnknown _ :$% _) = error "impossible"
-
-ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
-ixCvtSX ZIS = ZIX
-ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh
-
-shCvtSX :: ShS sh -> IShX (MapJust sh)
-shCvtSX ZSS = ZSX
-shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh
-
-shapeSizeS :: ShS sh -> Int
-shapeSizeS ZSS = 1
-shapeSizeS (n :$$ sh) = fromSNat' n * shapeSizeS sh
-
-
-sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
-sshape (Shaped arr) = shCvtXS' (mshape arr)
-
-sindex :: Elt a => Shaped sh a -> IIxS sh -> a
-sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
-
-shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh
-shsTakeIx _ _ ZIS = ZSS
-shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx
-
-sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
-sindexPartial sarr@(Shaped arr) idx =
- Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
- (castWith (subst2 (lemCommMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr)
- (ixCvtSX idx))
-
--- | __WARNING__: All values returned from the function must have equal shape.
--- See the documentation of 'mgenerate' for more details.
-sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
-sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh))
-
--- | See the documentation of 'mlift'.
-slift :: forall sh1 sh2 a. Elt a
- => ShS sh2
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
- -> Shaped sh1 a -> Shaped sh2 a
-slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr)
-
--- | See the documentation of 'mlift'.
-slift2 :: forall sh1 sh2 sh3 a. Elt a
- => ShS sh3
- -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
- -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
-slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2)
-
-ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
- => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
-ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)
-
-ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
- => Shaped (n : sh) a -> Shaped sh a
-ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
-
-lemCommMapJustTakeLen :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)
-lemCommMapJustTakeLen PNil _ = Refl
-lemCommMapJustTakeLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl
-lemCommMapJustTakeLen (_ `PCons` _) ZSS = error "TakeLen of empty"
-
-lemCommMapJustDropLen :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh)
-lemCommMapJustDropLen PNil _ = Refl
-lemCommMapJustDropLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl
-lemCommMapJustDropLen (_ `PCons` _) ZSS = error "DropLen of empty"
-
-lemCommMapJustIndex :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)
-lemCommMapJustIndex SZ (_ :$$ _) = Refl
-lemCommMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
- | Refl <- lemCommMapJustIndex i sh
- , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
- , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = Refl
-lemCommMapJustIndex _ ZSS = error "Index of empty"
-
-lemCommMapJustPermute :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh)
-lemCommMapJustPermute PNil _ = Refl
-lemCommMapJustPermute (i `PCons` is) sh
- | Refl <- lemCommMapJustPermute is sh
- , Refl <- lemCommMapJustIndex i sh
- = Refl
-
-listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
-listsAppend ZS idx' = idx'
-listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
-
-listsTakeLen :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
-listsTakeLen PNil _ = ZS
-listsTakeLen (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLen is sh
-listsTakeLen (_ `PCons` _) ZS = error "Permutation longer than shape"
-
-listsDropLen :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
-listsDropLen PNil sh = sh
-listsDropLen (_ `PCons` is) (_ ::$ sh) = listsDropLen is sh
-listsDropLen (_ `PCons` _) ZS = error "Permutation longer than shape"
-
-listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
-listsPermute PNil _ = ZS
-listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = listsIndex (Proxy @is') (Proxy @sh) i sh (listsPermute is sh)
-
-listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> ListS (Permute is shT) f -> ListS (Index i sh : Permute is shT) f
-listsIndex _ _ SZ (n ::$ _) rest = n ::$ rest
-listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) rest
- | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
- = listsIndex p pT i sh rest
-listsIndex _ _ _ ZS _ = error "Index into empty shape"
-
-shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
-shsTakeLen = coerce (listsTakeLen @SNat)
-
-shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
-shsPermute = coerce (listsPermute @SNat)
-
-shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (Permute is shT) -> ShS (Index i sh : Permute is shT)
-shsIndex pis pshT = coerce (listsIndex @SNat pis pshT)
-
-applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
-applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLen perm sh)) (listsDropLen perm sh)
-
-applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
-applyPermIxS = coerce (applyPermS @(Const i))
-
-applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
-applyPermShS = coerce (applyPermS @SNat)
-
-stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a)
- => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a
-stranspose perm sarr@(Shaped arr)
- | Refl <- lemRankMapJust (sshape sarr)
- , Refl <- lemCommMapJustTakeLen perm (sshape sarr)
- , Refl <- lemCommMapJustDropLen perm (sshape sarr)
- , Refl <- lemCommMapJustPermute perm (shsTakeLen perm (sshape sarr))
- , Refl <- lemCommMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh))
- = Shaped (mtranspose perm arr)
-
-sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
-sappend = coerce mappend
-
-sscalar :: Elt a => a -> Shaped '[] a
-sscalar x = Shaped (mscalar x)
-
-sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
-sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v)
-
-sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a
-sfromVector sh v = sfromPrimitive (sfromVectorP sh v)
-
-stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
-stoVectorP = coerce mtoVectorP
-
-stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
-stoVector = coerce mtoVector
-
-sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
-sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l))
-
-sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
-sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1
-
-sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a
-sfromList1Prim sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1Prim
-
-stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
-stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
-
-stoList1 :: Elt a => Shaped '[n] a -> [a]
-stoList1 = map sunScalar . stoListOuter
-
-sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
-sfromListPrim sn l
- | Refl <- lemAppNil @'[Just n]
- = let ssh = SUnknown () :!% ZKX
- xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
- in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
-
-sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a
-sfromListPrimLinear sh l =
- let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
- in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr)
-
-sunScalar :: Elt a => Shaped '[] a -> a
-sunScalar arr = sindex arr ZIS
-
-srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
- -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b)
-srerankP sh sh2 f sarr@(Shaped arr)
- | Refl <- lemCommMapJustApp sh (Proxy @sh1)
- , Refl <- lemCommMapJustApp sh (Proxy @sh2)
- = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh))))
- (shCvtSX sh2)
- (\a -> let Shaped r = f (Shaped a) in r)
- arr)
-
-srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
- => ShS sh -> ShS sh2
- -> (Shaped sh1 a -> Shaped sh2 b)
- -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b
-srerank sh sh2 f (stoPrimitive -> arr) =
- sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
-
-sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
-sreplicate sh (Shaped arr)
- | Refl <- lemCommMapJustApp sh (Proxy @sh')
- = Shaped (mreplicate (shCvtSX sh) arr)
-
-sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
-sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x)
-
-sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
-sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)
-
-sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
-sslice i n@SNat arr =
- let _ :$$ sh = sshape arr
- in slift (n :$$ sh) (\_ -> X.slice i n) arr
-
-srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a
-srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr
-
-sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a
-sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr)
-
-siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
-siota sn = Shaped (miota sn)
-
-sasXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
-sasXArrayPrimP (Shaped arr) = first shCvtXS' (masXArrayPrimP arr)
-
-sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a)
-sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr)
-
-sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a)
-sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr)
-
-sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
-sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr)
-
-stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a
-stoRanked sarr@(Shaped arr)
- | Refl <- lemRankMapJust (sshape sarr)
- = mtoRanked arr
-
-sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a
-sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr)
-
-stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a)
-stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)
diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs
new file mode 100644
index 0000000..c4fe066
--- /dev/null
+++ b/src/Data/Array/Nested/Lemmas.hs
@@ -0,0 +1,59 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+module Data.Array.Nested.Lemmas where
+
+import Data.Proxy
+import Data.Type.Equality
+import GHC.TypeLits
+
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
+import Data.Array.Nested.Shape
+
+
+lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh
+lemRankMapJust ZSS = Refl
+lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl
+
+lemMapJustApp :: ShS sh1 -> Proxy sh2
+ -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
+lemMapJustApp ZSS _ = Refl
+lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl
+
+lemMapJustTakeLen :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh)
+lemMapJustTakeLen PNil _ = Refl
+lemMapJustTakeLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemMapJustTakeLen is sh = Refl
+lemMapJustTakeLen (_ `PCons` _) ZSS = error "TakeLen of empty"
+
+lemMapJustDropLen :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh)
+lemMapJustDropLen PNil _ = Refl
+lemMapJustDropLen (_ `PCons` is) (_ :$$ sh) | Refl <- lemMapJustDropLen is sh = Refl
+lemMapJustDropLen (_ `PCons` _) ZSS = error "DropLen of empty"
+
+lemMapJustIndex :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh)
+lemMapJustIndex SZ (_ :$$ _) = Refl
+lemMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
+ | Refl <- lemMapJustIndex i sh
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
+ , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = Refl
+lemMapJustIndex _ ZSS = error "Index of empty"
+
+lemMapJustPermute :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh)
+lemMapJustPermute PNil _ = Refl
+lemMapJustPermute (i `PCons` is) sh
+ | Refl <- lemMapJustPermute is sh
+ , Refl <- lemMapJustIndex i sh
+ = Refl
+
+lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh)
+lemKnownMapJust _ = lemKnownShX (go (knownShS @sh))
+ where
+ go :: ShS sh' -> StaticShX (MapJust sh')
+ go ZSS = ZKX
+ go (n :$$ sh) = SKnown n :!% go sh
diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs
new file mode 100644
index 0000000..84e16b3
--- /dev/null
+++ b/src/Data/Array/Nested/Mixed.hs
@@ -0,0 +1,741 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DefaultSignatures #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DerivingVia #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+module Data.Array.Nested.Mixed where
+
+import Control.DeepSeq (NFData)
+import Control.Monad (forM_, when)
+import Control.Monad.ST
+import Data.Array.RankedS qualified as S
+import Data.Coerce
+import Data.Foldable (toList)
+import Data.Int
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty(..))
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Data.Vector.Storable.Mutable qualified as VSM
+import Foreign.C.Types (CInt)
+import Foreign.Storable (Storable)
+import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp)
+import GHC.Generics (Generic)
+import GHC.TypeLits
+
+import Data.Array.Mixed (XArray(..))
+import Data.Array.Mixed qualified as X
+import Data.Array.Mixed.Internal.Arith
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Lemmas
+
+
+-- Invariant in the API
+-- ====================
+--
+-- In the underlying XArray, there is some shape for elements of an empty
+-- array. For example, for this array:
+--
+-- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float)
+-- rshape arr == 0 :.: 0 :.: 0 :.: ZIR
+--
+-- the two underlying XArrays have a shape, and those shapes might be anything.
+-- The invariant is that these element shapes are unobservable in the API.
+-- (This is possible because you ought to not be able to get to such an element
+-- without indexing out of bounds.)
+--
+-- Note, though, that the converse situation may arise: the outer array might
+-- be nonempty but then the inner arrays might. This is fine, an invariant only
+-- applies if the _outer_ array is empty.
+--
+-- TODO: can we enforce that the elements of an empty (nested) array have
+-- all-zero shape?
+-- -> no, because mlift and also any kind of internals probing from outsiders
+
+
+-- Primitive element types
+-- =======================
+--
+-- There are a few primitive element types; arrays containing elements of such
+-- type are a newtype over an XArray, which it itself a newtype over a Vector.
+-- Unfortunately, the setup of the library requires us to list these primitive
+-- element types multiple times; to aid in extending the list, all these lists
+-- have been marked with [PRIMITIVE ELEMENT TYPES LIST].
+
+
+-- | Wrapper type used as a tag to attach instances on. The instances on arrays
+-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays
+-- of scalars; this means that if @orthotope@ supports an element type @T@ that
+-- this library does not (directly), it may just work if you use an array of
+-- @'Primitive' T@ instead.
+newtype Primitive a = Primitive a
+
+-- | Element types that are primitive; arrays of these types are just a newtype
+-- wrapper over an array.
+class Storable a => PrimElt a where
+ fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a
+ toPrimitive :: Mixed sh a -> Mixed sh (Primitive a)
+
+ default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a
+ fromPrimitive = coerce
+
+ default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a)
+ toPrimitive = coerce
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+instance PrimElt Int
+instance PrimElt Int64
+instance PrimElt Int32
+instance PrimElt CInt
+instance PrimElt Float
+instance PrimElt Double
+instance PrimElt ()
+
+
+-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes
+-- over product-typed elements using a data family so that the full array is
+-- always in struct-of-arrays format.
+--
+-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that
+-- dimension permutations (e.g. 'mtranspose') are typically free.
+--
+-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type
+-- class.
+type Mixed :: [Maybe Nat] -> Type -> Type
+data family Mixed sh a
+-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes
+-- that you're not supposed to see. In particular, you might see (nonempty)
+-- sizes of the elements of an empty array, which is information that should
+-- ostensibly not exist; the full array is still empty.
+
+data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a)
+ deriving (Show, Eq, Generic)
+
+-- | Only on scalars, because lexicographical ordering is strange on multi-dimensional arrays.
+deriving instance (Ord a, Storable a) => Ord (Mixed '[] (Primitive a))
+
+instance NFData a => NFData (Mixed sh (Primitive a))
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Show, Eq, Generic)
+newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Show, Eq, Generic) -- no content, orthotope optimises this (via Vector)
+-- etc.
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+deriving instance Ord (Mixed '[] Int) ; instance NFData (Mixed sh Int)
+deriving instance Ord (Mixed '[] Int64) ; instance NFData (Mixed sh Int64)
+deriving instance Ord (Mixed '[] Int32) ; instance NFData (Mixed sh Int32)
+deriving instance Ord (Mixed '[] CInt) ; instance NFData (Mixed sh CInt)
+deriving instance Ord (Mixed '[] Float) ; instance NFData (Mixed sh Float)
+deriving instance Ord (Mixed '[] Double) ; instance NFData (Mixed sh Double)
+deriving instance Ord (Mixed '[] ()) ; instance NFData (Mixed sh ())
+
+data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic)
+deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b))
+instance (NFData (Mixed sh a), NFData (Mixed sh b)) => NFData (Mixed sh (a, b))
+-- etc., larger tuples (perhaps use generics to allow arbitrary product types)
+
+data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic)
+deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a))
+instance NFData (Mixed (sh1 ++ sh2) a) => NFData (Mixed sh1 (Mixed sh2 a))
+
+
+-- | Internal helper data family mirroring 'Mixed' that consists of mutable
+-- vectors instead of 'XArray's.
+type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
+data family MixedVecs s sh a
+
+newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a)
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int)
+newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64)
+newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32)
+newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt)
+newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double)
+newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float)
+newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this
+-- etc.
+
+data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b)
+-- etc.
+
+data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a)
+
+
+mliftNumElt1 :: PrimElt a => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a) -> Mixed sh a -> Mixed sh a
+mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr))
+
+mliftNumElt2 :: PrimElt a
+ => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) a -> S.Array (Rank sh) a)
+ -> Mixed sh a -> Mixed sh a -> Mixed sh a
+mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2))
+ | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2))
+ | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2
+
+instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
+ (+) = mliftNumElt2 numEltAdd
+ (-) = mliftNumElt2 numEltSub
+ (*) = mliftNumElt2 numEltMul
+ negate = mliftNumElt1 numEltNeg
+ abs = mliftNumElt1 numEltAbs
+ signum = mliftNumElt1 numEltSignum
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit mreplicate"
+
+instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Mixed sh a) where
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
+ recip = mliftNumElt1 floatEltRecip
+ (/) = mliftNumElt2 floatEltDiv
+
+instance (FloatElt a, NumElt a, PrimElt a) => Floating (Mixed sh a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
+ exp = mliftNumElt1 floatEltExp
+ log = mliftNumElt1 floatEltLog
+ sqrt = mliftNumElt1 floatEltSqrt
+
+ (**) = mliftNumElt2 floatEltPow
+ logBase = mliftNumElt2 floatEltLogbase
+
+ sin = mliftNumElt1 floatEltSin
+ cos = mliftNumElt1 floatEltCos
+ tan = mliftNumElt1 floatEltTan
+ asin = mliftNumElt1 floatEltAsin
+ acos = mliftNumElt1 floatEltAcos
+ atan = mliftNumElt1 floatEltAtan
+ sinh = mliftNumElt1 floatEltSinh
+ cosh = mliftNumElt1 floatEltCosh
+ tanh = mliftNumElt1 floatEltTanh
+ asinh = mliftNumElt1 floatEltAsinh
+ acosh = mliftNumElt1 floatEltAcosh
+ atanh = mliftNumElt1 floatEltAtanh
+ log1p = mliftNumElt1 floatEltLog1p
+ expm1 = mliftNumElt1 floatEltExpm1
+ log1pexp = mliftNumElt1 floatEltLog1pexp
+ log1mexp = mliftNumElt1 floatEltLog1mexp
+
+
+-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
+-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive'
+-- a@; see the documentation for 'Primitive' for more details.
+class Elt a where
+ -- ====== PUBLIC METHODS ====== --
+
+ mshape :: Mixed sh a -> IShX sh
+ mindex :: Mixed sh a -> IIxX sh -> a
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a
+ mscalar :: a -> Mixed '[] a
+
+ -- | All arrays in the list, even subarrays inside @a@, must have the same
+ -- shape; if they do not, a runtime error will be thrown. See the
+ -- documentation of 'mgenerate' for more information about this restriction.
+ -- Furthermore, the length of the list must correspond with @n@: if @n@ is
+ -- @Just m@ and @m@ does not equal the length of the list, a runtime error is
+ -- thrown.
+ --
+ -- Consider also 'mfromListPrim', which can avoid intermediate arrays.
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a
+
+ mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a]
+
+ -- | Note: this library makes no particular guarantees about the shapes of
+ -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the
+ -- full 'XArray' and as such you can distinguish different empty arrays by
+ -- the "shapes" of their elements. This information is meaningless, so you
+ -- should not use it.
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 a -> Mixed sh2 a
+
+ -- | See the documentation for 'mlift'.
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a
+
+ mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a
+
+ mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
+ => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a
+
+ -- ====== PRIVATE METHODS ====== --
+
+ -- | Tree giving the shape of every array component.
+ type ShapeTree a
+
+ mshapeTree :: a -> ShapeTree a
+
+ mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool
+
+ mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool
+
+ mshowShapeTree :: Proxy a -> ShapeTree a -> String
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s ()
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+
+ -- | Given the shape of this array, finalise the vectors into 'XArray's.
+ mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+
+
+-- | Element types for which we have evidence of the (static part of the) shape
+-- in a type class constraint. Compare the instance contexts of the instances
+-- of this class with those of 'Elt': some instances have an additional
+-- "known-shape" constraint.
+--
+-- This class is (currently) only required for 'mgenerate' / 'rgenerate' /
+-- 'sgenerate'.
+class Elt a => KnownElt a where
+ -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
+ memptyArray :: IShX sh -> Mixed sh a
+
+ -- | Create uninitialised vectors for this array type, given the shape of
+ -- this vector and an example for the contents.
+ mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a)
+
+ mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a)
+
+
+-- Arrays of scalars are basically just arrays of scalars.
+instance Storable a => Elt (Primitive a) where
+ mshape (M_Primitive sh _) = sh
+ mindex (M_Primitive _ a) i = Primitive (X.index a i)
+ mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx sh i) (X.indexPartial a i)
+ mscalar (Primitive x) = M_Primitive ZSX (X.scalar x)
+ mfromListOuter l@(arr1 :| _) =
+ let sh = SUnknown (length l) :$% mshape arr1
+ in M_Primitive sh (X.fromListOuter (ssxFromShape sh) (map (\(M_Primitive _ a) -> a) (toList l)))
+ mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
+ -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
+ mlift ssh2 f (M_Primitive _ a)
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ , let result = f ZKX a
+ = M_Primitive (X.shape ssh2 result) result
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a)
+ -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a)
+ mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b)
+ | Refl <- lemAppNil @sh1
+ , Refl <- lemAppNil @sh2
+ , Refl <- lemAppNil @sh3
+ , let result = f ZKX a b
+ = M_Primitive (X.shape ssh3 result) result
+
+ mcast :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> IShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a)
+ mcast ssh1 sh2 _ (M_Primitive sh1' arr) =
+ let (_, sh') = shxSplitApp (Proxy @sh') ssh1 sh1'
+ in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShape sh') arr)
+
+ mtranspose perm (M_Primitive sh arr) =
+ M_Primitive (shxPermutePrefix perm sh)
+ (X.transpose (ssxFromShape sh) perm arr)
+
+ type ShapeTree (Primitive a) = ()
+ mshapeTree _ = ()
+ mshapeTreeEq _ () () = True
+ mshapeTreeEmpty _ () = False
+ mshowShapeTree _ () = "()"
+ mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x
+
+ -- TODO: this use of toVector is suboptimal
+ mvecsWritePartial
+ :: forall sh' sh s.
+ IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do
+ let arrsh = X.shape (ssxFromShape sh') arr
+ offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh))
+ VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr)
+
+ mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+deriving via Primitive Int instance Elt Int
+deriving via Primitive Int64 instance Elt Int64
+deriving via Primitive Int32 instance Elt Int32
+deriving via Primitive CInt instance Elt CInt
+deriving via Primitive Double instance Elt Double
+deriving via Primitive Float instance Elt Float
+deriving via Primitive () instance Elt ()
+
+instance Storable a => KnownElt (Primitive a) where
+ memptyArray sh = M_Primitive sh (X.empty sh)
+ mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh)
+ mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0
+
+-- [PRIMITIVE ELEMENT TYPES LIST]
+deriving via Primitive Int instance KnownElt Int
+deriving via Primitive Int64 instance KnownElt Int64
+deriving via Primitive Int32 instance KnownElt Int32
+deriving via Primitive CInt instance KnownElt CInt
+deriving via Primitive Double instance KnownElt Double
+deriving via Primitive Float instance KnownElt Float
+deriving via Primitive () instance KnownElt ()
+
+-- Arrays of pairs are pairs of arrays.
+instance (Elt a, Elt b) => Elt (a, b) where
+ mshape (M_Tup2 a _) = mshape a
+ mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
+ mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
+ mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y)
+ mfromListOuter l =
+ M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l))
+ (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l))
+ mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b)
+ mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b)
+ mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y)
+
+ mcast ssh1 sh2 psh' (M_Tup2 a b) =
+ M_Tup2 (mcast ssh1 sh2 psh' a) (mcast ssh1 sh2 psh' b)
+
+ mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b)
+
+ type ShapeTree (a, b) = (ShapeTree a, ShapeTree b)
+ mshapeTree (x, y) = (mshapeTree x, mshapeTree y)
+ mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2'
+ mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2
+ mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")"
+ mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
+ mvecsWrite sh i x a
+ mvecsWrite sh i y b
+ mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
+ mvecsWritePartial sh i x a
+ mvecsWritePartial sh i y b
+ mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
+
+instance (KnownElt a, KnownElt b) => KnownElt (a, b) where
+ memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
+ mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b)
+
+-- Arrays of arrays are just arrays, but with more dimensions.
+instance Elt a => Elt (Mixed sh' a) where
+ -- TODO: this is quadratic in the nesting depth because it repeatedly
+ -- truncates the shape vector to one a little shorter. Fix with a
+ -- moverlongShape method, a prefix of which is mshape.
+ mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh
+ mshape (M_Nest sh arr)
+ = fst (shxSplitApp (Proxy @sh') (ssxFromShape sh) (mshape arr))
+
+ mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a
+ mindex (M_Nest _ arr) i = mindexPartial arr i
+
+ mindexPartial :: forall sh1 sh2.
+ Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
+ mindexPartial (M_Nest sh arr) i
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = M_Nest (shxDropIx sh i) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+
+ mscalar = M_Nest ZSX
+
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a)
+ mfromListOuter l@(arr :| _) =
+ M_Nest (SUnknown (length l) :$% mshape arr)
+ (mfromListOuter ((\(M_Nest _ a) -> a) <$> l))
+
+ mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b)
+ -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
+ mlift ssh2 f (M_Nest sh1 arr) =
+ let result = mlift (ssxAppend ssh2 ssh') f' arr
+ (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result)
+ in M_Nest sh2 result
+ where
+ ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr)))
+
+ f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b
+ f' sshT
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b)
+ -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a)
+ mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) =
+ let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2
+ (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result)
+ in M_Nest sh3 result
+ where
+ ssh' = ssxFromShape (snd (shxSplitApp (Proxy @sh') (ssxFromShape sh1) (mshape arr1)))
+
+ f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b
+ f' sshT
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT)
+ , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT)
+ = f (ssxAppend ssh' sshT)
+
+ mcast :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2
+ => StaticShX sh1 -> IShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a)
+ mcast ssh1 sh2 _ (M_Nest sh1T arr)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh')
+ , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh')
+ = let (_, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T
+ in M_Nest (shxAppend sh2 shT) (mcast ssh1 sh2 (Proxy @(shT ++ sh')) arr)
+
+ mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh)
+ => Perm is -> Mixed sh (Mixed sh' a)
+ -> Mixed (PermutePrefix is sh) (Mixed sh' a)
+ mtranspose perm (M_Nest sh arr)
+ | let sh' = shxDropSh @sh @sh' (mshape arr) sh
+ , Refl <- lemRankApp (ssxFromShape sh) (ssxFromShape sh')
+ , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh'))
+ , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh')
+ , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
+ , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh')
+ = M_Nest (shxPermutePrefix perm sh)
+ (mtranspose perm arr)
+
+ type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a)
+
+ mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a)
+ mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShape (mshape arr)))))
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs
+
+ mvecsWritePartial :: forall sh1 sh2 s.
+ IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a)
+ -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
+ -> ST s ()
+ mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs)
+ | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs
+
+ mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs
+
+instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where
+ memptyArray sh = M_Nest sh (memptyArray (shxAppend sh (shxCompleteZeros (knownShX @sh'))))
+
+ mvecsUnsafeNew sh example
+ | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a))
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShape sh')))
+ where
+ sh' = mshape example
+
+ mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a)
+
+
+-- | Create an array given a size and a function that computes the element at a
+-- given index.
+--
+-- __WARNING__: It is required that every @a@ returned by the argument to
+-- 'mgenerate' has the same shape. For example, the following will throw a
+-- runtime error:
+--
+-- > foo :: Mixed [Nothing] (Mixed [Nothing] Double)
+-- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) ->
+-- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) ->
+-- > ...
+--
+-- because the size of the inner 'mgenerate' is not always the same (it depends
+-- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so
+-- the entire hierarchy (after distributing out tuples) must be a rectangular
+-- array. The type of 'mgenerate' allows this requirement to be broken very
+-- easily, hence the runtime check.
+mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a
+mgenerate sh f = case shxEnum sh of
+ [] -> memptyArray sh
+ firstidx : restidxs ->
+ let firstelem = f (ixxZero' sh)
+ shapetree = mshapeTree firstelem
+ in if mshapeTreeEmpty (Proxy @a) shapetree
+ then memptyArray sh
+ else runST $ do
+ vecs <- mvecsUnsafeNew sh firstelem
+ mvecsWrite sh firstidx firstelem vecs
+ -- TODO: This is likely fine if @a@ is big, but if @a@ is a
+ -- scalar this array copying inefficient. Should improve this.
+ forM_ restidxs $ \idx -> do
+ let val = f idx
+ when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $
+ error "Data.Array.Nested mgenerate: generated values do not have equal shapes"
+ mvecsWrite sh idx val vecs
+ mvecsFreeze sh vecs
+
+msumOuter1P :: forall sh n a. (Storable a, NumElt a)
+ => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a)
+msumOuter1P (M_Primitive (n :$% sh) arr) =
+ let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX
+ in M_Primitive sh (X.sumOuter nssh (ssxFromShape sh) arr)
+
+msumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
+ => Mixed (n : sh) a -> Mixed sh a
+msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive
+
+mappend :: forall n m sh a. Elt a
+ => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a
+mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2
+ where
+ sn :$% sh = mshape arr1
+ sm :$% _ = mshape arr2
+ ssh = ssxFromShape sh
+ snm :: SMayNat () SNat (AddMaybe n m)
+ snm = case (sn, sm) of
+ (SUnknown{}, _) -> SUnknown ()
+ (SKnown{}, SUnknown{}) -> SUnknown ()
+ (SKnown n, SKnown m) -> SKnown (snatPlus n m)
+
+ f :: forall sh' b. Storable b
+ => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b
+ f ssh' = X.append (ssxAppend ssh ssh')
+
+mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a)
+mfromVectorP sh v = M_Primitive sh (X.fromVector sh v)
+
+mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a
+mfromVector sh v = fromPrimitive (mfromVectorP sh v)
+
+mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a
+mtoVectorP (M_Primitive _ v) = X.toVector v
+
+mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a
+mtoVector arr = mtoVectorP (toPrimitive arr)
+
+mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a
+mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise?
+
+mfromList1Prim :: PrimElt a => [a] -> Mixed '[Nothing] a
+mfromList1Prim l =
+ let ssh = SUnknown () :!% ZKX
+ xarr = X.fromList1 ssh l
+ in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+mtoList1 :: Elt a => Mixed '[n] a -> [a]
+mtoList1 = map munScalar . mtoListOuter
+
+mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a
+mfromListPrim l =
+ let ssh = SUnknown () :!% ZKX
+ xarr = X.fromList1 ssh l
+ in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a
+mfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr)
+
+munScalar :: Elt a => Mixed '[] a -> a
+munScalar arr = mindex arr ZIX
+
+mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => StaticShX sh -> IShX sh2
+ -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b))
+ -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b)
+mrerankP ssh sh2 f (M_Primitive sh arr) =
+ let sh1 = shxDropSSX sh ssh
+ in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) sh ssh) sh2)
+ (X.rerank ssh (ssxFromShape sh1) (ssxFromShape sh2)
+ (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r)
+ arr)
+
+-- | See the caveats at @X.rerank@.
+mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => StaticShX sh -> IShX sh2
+ -> (Mixed sh1 a -> Mixed sh2 b)
+ -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b
+mrerank ssh sh2 f (toPrimitive -> arr) =
+ fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr
+
+mreplicate :: forall sh sh' a. Elt a
+ => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a
+mreplicate sh arr =
+ let ssh' = ssxFromShape (mshape arr)
+ in mlift (ssxAppend (ssxFromShape sh) ssh')
+ (\(sshT :: StaticShX shT) ->
+ case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of
+ Refl -> X.replicate sh (ssxAppend ssh' sshT))
+ arr
+
+mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a)
+mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x)
+
+mreplicateScal :: forall sh a. PrimElt a
+ => IShX sh -> a -> Mixed sh a
+mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x)
+
+mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
+mslice i n arr =
+ let _ :$% sh = mshape arr
+ in mlift (SKnown n :!% ssxFromShape sh) (\_ -> X.slice i n) arr
+
+msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
+msliceU i n arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.sliceU i n) arr
+
+mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a
+mrev1 arr = mlift (ssxFromShape (mshape arr)) (\_ -> X.rev1) arr
+
+mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a
+mreshape sh' arr =
+ mlift (ssxFromShape sh')
+ (\sshIn -> X.reshapePartial (ssxFromShape (mshape arr)) sshIn sh')
+ arr
+
+miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a
+miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
+
+masXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a)
+masXArrayPrimP (M_Primitive sh arr) = (sh, arr)
+
+masXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a)
+masXArrayPrim = masXArrayPrimP . toPrimitive
+
+mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a)
+mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr
+
+mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a
+mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP
+
+mliftPrim :: PrimElt a
+ => (a -> a)
+ -> Mixed sh a -> Mixed sh a
+mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr))
+
+mliftPrim2 :: PrimElt a
+ => (a -> a -> a)
+ -> Mixed sh a -> Mixed sh a -> Mixed sh a
+mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) =
+ fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2))
diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs
new file mode 100644
index 0000000..c2f9405
--- /dev/null
+++ b/src/Data/Array/Nested/Ranked.hs
@@ -0,0 +1,446 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Ranked where
+
+import Prelude hiding (mappend)
+
+import Control.DeepSeq (NFData)
+import Control.Monad.ST
+import Data.Array.RankedS qualified as S
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Foreign.Storable (Storable)
+import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp)
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+
+import Data.Array.Mixed (XArray(..))
+import Data.Array.Mixed qualified as X
+import Data.Array.Mixed.Internal.Arith
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Shape
+
+
+-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
+-- represented on the type level as a 'Nat'.
+--
+-- Valid elements of a ranked arrays are described by the 'Elt' type class.
+-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
+-- supported (and are represented as a single, flattened, struct-of-arrays
+-- array internally).
+--
+-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
+type Ranked :: Nat -> Type -> Type
+newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
+deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
+deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a)
+deriving instance Ord (Mixed '[] a) => Ord (Ranked 0 a)
+deriving instance NFData (Mixed (Replicate n Nothing) a) => NFData (Ranked n a)
+
+-- just unwrap the newtype and defer to the general instance for nested arrays
+newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
+deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a))
+
+newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
+
+-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
+-- these instances allow them to also be used as elements of arrays, thus
+-- making them first-class in the API.
+instance Elt a => Elt (Ranked n a) where
+ mshape (M_Ranked arr) = mshape arr
+ mindex (M_Ranked arr) i = Ranked (mindex arr i)
+
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a)
+ mindexPartial (M_Ranked arr) i =
+ coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
+ mindexPartial arr i
+
+ mscalar (Ranked x) = M_Ranked (M_Nest ZSX x)
+
+ mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a)
+ mfromListOuter l = M_Ranked (mfromListOuter (coerce l))
+
+ mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)]
+ mtoListOuter (M_Ranked arr) =
+ coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
+ mlift ssh2 f (M_Ranked arr) =
+ coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
+ mlift ssh2 f arr
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a)
+ mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) =
+ coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $
+ mlift2 ssh3 f arr1 arr2
+
+ mcast ssh1 sh2 psh' (M_Ranked arr) = M_Ranked (mcast ssh1 sh2 psh' arr)
+
+ mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr)
+
+ type ShapeTree (Ranked n a) = (IShR n, ShapeTree a)
+
+ mshapeTree (Ranked arr) = first shCvtXR' (mshapeTree arr)
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
+ mvecsWrite sh idx (Ranked arr) vecs =
+ mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh sh' s.
+ IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a)
+ -> MixedVecs s (sh ++ sh') (Ranked n a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs =
+ mvecsWritePartial sh idx
+ (coerce @(Mixed sh' (Ranked n a))
+ @(Mixed sh' (Mixed (Replicate n Nothing) a))
+ arr)
+ (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
+ @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
+ vecs)
+
+ mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
+ mvecsFreeze sh vecs =
+ coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
+ @(Mixed sh (Ranked n a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh (Ranked n a))
+ @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
+ vecs)
+
+instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where
+ memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a)
+ memptyArray i
+ | Dict <- lemKnownReplicate (SNat @n)
+ = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
+ memptyArray i
+
+ mvecsUnsafeNew idx (Ranked arr)
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsUnsafeNew idx arr
+
+ mvecsNewEmpty _
+ | Dict <- lemKnownReplicate (SNat @n)
+ = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a))
+
+
+arithPromoteRanked :: forall n a. PrimElt a
+ => (forall sh. Mixed sh a -> Mixed sh a)
+ -> Ranked n a -> Ranked n a
+arithPromoteRanked = coerce
+
+arithPromoteRanked2 :: forall n a. PrimElt a
+ => (forall sh. Mixed sh a -> Mixed sh a -> Mixed sh a)
+ -> Ranked n a -> Ranked n a -> Ranked n a
+arithPromoteRanked2 = coerce
+
+instance (NumElt a, PrimElt a) => Num (Ranked n a) where
+ (+) = arithPromoteRanked2 (+)
+ (-) = arithPromoteRanked2 (-)
+ (*) = arithPromoteRanked2 (*)
+ negate = arithPromoteRanked negate
+ abs = arithPromoteRanked abs
+ signum = arithPromoteRanked signum
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit rreplicateScal"
+
+instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Ranked n a) where
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit rreplicateScal"
+ recip = arithPromoteRanked recip
+ (/) = arithPromoteRanked2 (/)
+
+instance (FloatElt a, NumElt a, PrimElt a) => Floating (Ranked n a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit rreplicateScal"
+ exp = arithPromoteRanked exp
+ log = arithPromoteRanked log
+ sqrt = arithPromoteRanked sqrt
+ (**) = arithPromoteRanked2 (**)
+ logBase = arithPromoteRanked2 logBase
+ sin = arithPromoteRanked sin
+ cos = arithPromoteRanked cos
+ tan = arithPromoteRanked tan
+ asin = arithPromoteRanked asin
+ acos = arithPromoteRanked acos
+ atan = arithPromoteRanked atan
+ sinh = arithPromoteRanked sinh
+ cosh = arithPromoteRanked cosh
+ tanh = arithPromoteRanked tanh
+ asinh = arithPromoteRanked asinh
+ acosh = arithPromoteRanked acosh
+ atanh = arithPromoteRanked atanh
+ log1p = arithPromoteRanked GHC.Float.log1p
+ expm1 = arithPromoteRanked GHC.Float.expm1
+ log1pexp = arithPromoteRanked GHC.Float.log1pexp
+ log1mexp = arithPromoteRanked GHC.Float.log1mexp
+
+
+rshape :: forall n a. Elt a => Ranked n a -> IShR n
+rshape (Ranked arr) = shCvtXR' (mshape arr)
+
+rindex :: Elt a => Ranked n a -> IIxR n -> a
+rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
+
+rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a
+rindexPartial (Ranked arr) idx =
+ Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
+ (castWith (subst2 (lemReplicatePlusApp (ixrToSNat idx) (Proxy @m) (Proxy @Nothing))) arr)
+ (ixCvtRX idx))
+
+-- | __WARNING__: All values returned from the function must have equal shape.
+-- See the documentation of 'mgenerate' for more details.
+rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a
+rgenerate sh f
+ | sn@SNat <- shrToSNat sh
+ , Dict <- lemKnownReplicate sn
+ , Refl <- lemRankReplicate sn
+ = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
+
+-- | See the documentation of 'mlift'.
+rlift :: forall n1 n2 a. Elt a
+ => SNat n2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
+ -> Ranked n1 a -> Ranked n2 a
+rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr)
+
+-- | See the documentation of 'mlift2'.
+rlift2 :: forall n1 n2 n3 a. Elt a
+ => SNat n3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b)
+ -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a
+rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2)
+
+rsumOuter1P :: forall n a.
+ (Storable a, NumElt a)
+ => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
+rsumOuter1P (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = Ranked (msumOuter1P arr)
+
+rsumOuter1 :: forall n a. (NumElt a, PrimElt a)
+ => Ranked (n + 1) a -> Ranked n a
+rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive
+
+rtranspose :: forall n a. Elt a => [Int] -> Ranked n a -> Ranked n a
+rtranspose perm arr
+ | sn@SNat <- shrToSNat (rshape arr)
+ , Dict <- lemKnownReplicate sn
+ , length perm <= fromIntegral (natVal (Proxy @n))
+ = rlift sn
+ (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm)
+ arr
+ | otherwise
+ = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array"
+
+rappend :: forall n a. Elt a
+ => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
+rappend arr1 arr2
+ | sn@SNat <- shrToSNat (rshape arr1)
+ , Dict <- lemKnownReplicate sn
+ , Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
+ arr1 arr2
+
+rscalar :: Elt a => a -> Ranked 0 a
+rscalar x = Ranked (mscalar x)
+
+rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVectorP sh v
+ | Dict <- lemKnownReplicate (shrToSNat sh)
+ = Ranked (mfromVectorP (shCvtRX sh) v)
+
+rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a
+rfromVector sh v = rfromPrimitive (rfromVectorP sh v)
+
+rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
+rtoVectorP = coerce mtoVectorP
+
+rtoVector :: PrimElt a => Ranked n a -> VS.Vector a
+rtoVector = coerce mtoVector
+
+rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a
+rfromListOuter l
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a)))
+
+rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a
+rfromList1 l = Ranked (mfromList1 l)
+
+rfromList1Prim :: PrimElt a => [a] -> Ranked 1 a
+rfromList1Prim l = Ranked (mfromList1Prim l)
+
+rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
+rtoListOuter (Ranked arr)
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr)
+
+rtoList1 :: Elt a => Ranked 1 a -> [a]
+rtoList1 = map runScalar . rtoListOuter
+
+rfromListPrim :: PrimElt a => [a] -> Ranked 1 a
+rfromListPrim l =
+ let ssh = SUnknown () :!% ZKX
+ xarr = X.fromList1 ssh l
+ in Ranked $ fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr
+
+rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a
+rfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in Ranked $ fromPrimitive $ M_Primitive (shCvtRX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtRX sh) xarr)
+
+rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a
+rfromOrthotope sn arr
+ | Refl <- lemRankReplicate sn
+ = let xarr = XArray arr
+ in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr))
+
+runScalar :: Elt a => Ranked 0 a -> a
+runScalar arr = rindex arr ZIR
+
+rrerankP :: forall n1 n2 n a b. (Storable a, Storable b)
+ => SNat n -> IShR n2
+ -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b))
+ -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b)
+rrerankP sn sh2 f (Ranked arr)
+ | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat))
+ , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat))
+ = Ranked (mrerankP (ssxFromSNat sn) (shCvtRX sh2)
+ (\a -> let Ranked r = f (Ranked a) in r)
+ arr)
+
+-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the
+-- input array, then there is no way to deduce the full shape of the output
+-- array (more precisely, the @n2@ part): that could only come from calling
+-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in
+-- this case; we choose to fill the @n2@ part of the output shape with zeros.
+--
+-- For example, if:
+--
+-- @
+-- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21]
+-- f :: Ranked 2 Int -> Ranked 3 Float
+-- @
+--
+-- then:
+--
+-- @
+-- rrerank _ _ _ f arr :: Ranked 5 Float
+-- @
+--
+-- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the
+-- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended
+-- to return an array with shape all-0 here (it probably didn't), but there is
+-- no better number to put here absent a subarray of the input to pass to @f@.
+rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b)
+ => SNat n -> IShR n2
+ -> (Ranked n1 a -> Ranked n2 b)
+ -> Ranked (n + n1) a -> Ranked (n + n2) b
+rrerank ssh sh2 f (rtoPrimitive -> arr) =
+ rfromPrimitive $ rrerankP ssh sh2 (rtoPrimitive . f . rfromPrimitive) arr
+
+rreplicate :: forall n m a. Elt a
+ => IShR n -> Ranked m a -> Ranked (n + m) a
+rreplicate sh (Ranked arr)
+ | Refl <- lemReplicatePlusApp (shrToSNat sh) (Proxy @m) (Proxy @(Nothing @Nat))
+ = Ranked (mreplicate (shCvtRX sh) arr)
+
+rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a)
+rreplicateScalP sh x
+ | Dict <- lemKnownReplicate (shrToSNat sh)
+ = Ranked (mreplicateScalP (shCvtRX sh) x)
+
+rreplicateScal :: forall n a. PrimElt a
+ => IShR n -> a -> Ranked n a
+rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x)
+
+rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
+rslice i n arr
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n
+ = rlift (shrToSNat (rshape arr))
+ (\_ -> X.sliceU i n)
+ arr
+
+rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a
+rrev1 arr =
+ rlift (shrToSNat (rshape arr))
+ (\(_ :: StaticShX sh') ->
+ case lemReplicateSucc @(Nothing @Nat) @n of
+ Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh'))
+ arr
+
+rreshape :: forall n n' a. Elt a
+ => IShR n' -> Ranked n a -> Ranked n' a
+rreshape sh' rarr@(Ranked arr)
+ | Dict <- lemKnownReplicate (shrToSNat (rshape rarr))
+ , Dict <- lemKnownReplicate (shrToSNat sh')
+ = Ranked (mreshape (shCvtRX sh') arr)
+
+riota :: (Enum a, PrimElt a, Elt a) => Int -> Ranked 1 a
+riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota
+
+rasXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a)
+rasXArrayPrimP (Ranked arr) = first shCvtXR' (masXArrayPrimP arr)
+
+rasXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a)
+rasXArrayPrim (Ranked arr) = first shCvtXR' (masXArrayPrim arr)
+
+rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
+rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
+
+rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a
+rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShape (X.shape (ssxFromSNat sn) arr)) arr)
+
+rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a
+rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr)
+
+rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a)
+rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr)
+
+mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a
+mtoRanked arr
+ | Refl <- lemAppNil @sh
+ , Refl <- lemAppNil @(Replicate (Rank sh) (Nothing @Nat))
+ , Refl <- lemRankReplicate (shxRank (mshape arr))
+ = Ranked (mcast (ssxFromShape (mshape arr)) (convSh (mshape arr)) (Proxy @'[]) arr)
+ where
+ convSh :: IShX sh' -> IShX (Replicate (Rank sh') Nothing)
+ convSh ZSX = ZSX
+ convSh (smn :$% (sh :: IShX sh'T))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @(Rank sh'T)
+ = SUnknown (fromSMayNat' smn) :$% convSh sh
diff --git a/src/Data/Array/Nested/Shape.hs b/src/Data/Array/Nested/Shape.hs
new file mode 100644
index 0000000..774b4bd
--- /dev/null
+++ b/src/Data/Array/Nested/Shape.hs
@@ -0,0 +1,467 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RoleAnnotations #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Shape where
+
+import Data.Array.Mixed.Types
+import Data.Coerce (coerce)
+import Data.Foldable qualified as Foldable
+import Data.Functor.Const
+import Data.Kind (Type, Constraint)
+import Data.Monoid (Sum(..))
+import Data.Proxy
+import Data.Type.Equality
+import GHC.IsList (IsList)
+import GHC.IsList qualified as IsList
+import GHC.TypeLits
+import GHC.TypeNats qualified as TN
+
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Shape
+
+
+type role ListR nominal representational
+type ListR :: Nat -> Type -> Type
+data ListR n i where
+ ZR :: ListR 0 i
+ (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
+deriving instance Eq i => Eq (ListR n i)
+deriving instance Ord i => Ord (ListR n i)
+deriving instance Functor (ListR n)
+deriving instance Foldable (ListR n)
+infixr 3 :::
+
+instance Show i => Show (ListR n i) where
+ showsPrec _ = listrShow shows
+
+data UnconsListRRes i n1 =
+ forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i
+listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
+listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
+listrUncons ZR = Nothing
+
+listrShow :: forall sh i. (i -> ShowS) -> ListR sh i -> ShowS
+listrShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListR sh' i -> ShowS
+ go _ ZR = id
+ go prefix (x ::: xs) = showString prefix . f x . go "," xs
+
+listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i
+listrAppend ZR sh = sh
+listrAppend (x ::: xs) sh = x ::: listrAppend xs sh
+
+listrFromList :: [i] -> (forall n. ListR n i -> r) -> r
+listrFromList [] k = k ZR
+listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l)
+
+listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i
+listrIndex SZ (x ::: _) = x
+listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs
+listrIndex _ ZR = error "k + 1 <= 0"
+
+listrToSNat :: ListR n i -> SNat n
+listrToSNat ZR = SNat
+listrToSNat (_ ::: (l :: ListR n i)) | SNat <- listrToSNat l, Dict <- lemKnownNatSucc @n = SNat
+
+listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i
+listrPermutePrefix = \perm sh ->
+ listrFromList perm $ \sperm ->
+ case (listrToSNat sperm, listrToSNat sh) of
+ (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of
+ LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post
+ GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")"
+ ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")"
+ where
+ listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i)
+ listrSplitAt SZ sh = (ZR, sh)
+ listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh)
+ listrSplitAt SS{} ZR = error "m' + 1 <= 0"
+
+ applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i
+ applyPermRFull _ ZR _ = ZR
+ applyPermRFull sm@SNat (i ::: perm) l =
+ TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) ->
+ case cmpNat (SNat @(idx + 1)) sm of
+ LTI -> listrIndex si l ::: applyPermRFull sm perm l
+ EQI -> listrIndex si l ::: applyPermRFull sm perm l
+ GTI -> error "listrPermutePrefix: Index in permutation out of range"
+
+
+-- | An index into a rank-typed array.
+type role IxR nominal representational
+type IxR :: Nat -> Type -> Type
+newtype IxR n i = IxR (ListR n i)
+ deriving (Eq, Ord)
+ deriving newtype (Functor, Foldable)
+
+pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
+pattern ZIR = IxR ZR
+
+pattern (:.:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> IxR n i -> IxR n1 i
+pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i))
+ where i :.: IxR sh = IxR (i ::: sh)
+infixr 3 :.:
+
+{-# COMPLETE ZIR, (:.:) #-}
+
+type IIxR n = IxR n Int
+
+instance Show i => Show (IxR n i) where
+ showsPrec _ (IxR l) = listrShow shows l
+
+ixrZero :: SNat n -> IIxR n
+ixrZero SZ = ZIR
+ixrZero (SS n) = 0 :.: ixrZero n
+
+ixCvtXR :: IIxX sh -> IIxR (Rank sh)
+ixCvtXR ZIX = ZIR
+ixCvtXR (n :.% idx) = n :.: ixCvtXR idx
+
+ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
+ixCvtRX ZIR = ZIX
+ixCvtRX (n :.: (idx :: IxR m Int)) =
+ castWith (subst2 @IxX @Int (lemReplicateSucc @(Nothing @Nat) @m))
+ (n :.% ixCvtRX idx)
+
+ixrToSNat :: IxR n i -> SNat n
+ixrToSNat (IxR sh) = listrToSNat sh
+
+ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i
+ixrPermutePrefix = coerce (listrPermutePrefix @i)
+
+
+type role ShR nominal representational
+type ShR :: Nat -> Type -> Type
+newtype ShR n i = ShR (ListR n i)
+ deriving (Eq, Ord)
+ deriving newtype (Functor, Foldable)
+
+pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
+pattern ZSR = ShR ZR
+
+pattern (:$:)
+ :: forall {n1} {i}.
+ forall n. (n + 1 ~ n1)
+ => i -> ShR n i -> ShR n1 i
+pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
+ where i :$: (ShR sh) = ShR (i ::: sh)
+infixr 3 :$:
+
+{-# COMPLETE ZSR, (:$:) #-}
+
+type IShR n = ShR n Int
+
+instance Show i => Show (ShR n i) where
+ showsPrec _ (ShR l) = listrShow shows l
+
+shCvtXR' :: forall n. IShX (Replicate n Nothing) -> IShR n
+shCvtXR' ZSX =
+ castWith (subst2 (unsafeCoerceRefl :: 0 :~: n))
+ ZSR
+shCvtXR' (n :$% (idx :: IShX sh))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @(n - 1) =
+ castWith (subst2 (lem1 @sh Refl))
+ (fromSMayNat' n :$: shCvtXR' (castWith (subst2 (lem2 Refl)) idx))
+ where
+ lem1 :: forall sh' n' k.
+ k : sh' :~: Replicate n' Nothing
+ -> Rank sh' + 1 :~: n'
+ lem1 Refl = unsafeCoerceRefl
+
+ lem2 :: k : sh :~: Replicate n Nothing
+ -> sh :~: Replicate (Rank sh) Nothing
+ lem2 Refl = unsafeCoerceRefl
+
+shCvtRX :: IShR n -> IShX (Replicate n Nothing)
+shCvtRX ZSR = ZSX
+shCvtRX (n :$: (idx :: ShR m Int)) =
+ castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
+ (SUnknown n :$% shCvtRX idx)
+
+-- | The number of elements in an array described by this shape.
+shrSize :: IShR n -> Int
+shrSize ZSR = 1
+shrSize (n :$: sh) = n * shrSize sh
+
+shrToSNat :: ShR n i -> SNat n
+shrToSNat (ShR sh) = listrToSNat sh
+
+shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i
+shrPermutePrefix = coerce (listrPermutePrefix @i)
+
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (ListR n i) where
+ type Item (ListR n i) = i
+ fromList = go (SNat @n)
+ where
+ go :: SNat n' -> [i] -> ListR n' i
+ go SZ [] = ZR
+ go (SS n) (i : is) = i ::: go n is
+ go _ _ = error "IsList(ListR): Mismatched list length"
+ toList = Foldable.toList
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (IxR n i) where
+ type Item (IxR n i) = i
+ fromList = IxR . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length is checked at runtime.
+instance KnownNat n => IsList (ShR n i) where
+ type Item (ShR n i) = i
+ fromList = ShR . IsList.fromList
+ toList = Foldable.toList
+
+
+type role ListS nominal representational
+type ListS :: [Nat] -> (Nat -> Type) -> Type
+data ListS sh f where
+ ZS :: ListS '[] f
+ -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity
+ (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f
+deriving instance (forall n. Eq (f n)) => Eq (ListS sh f)
+deriving instance (forall n. Ord (f n)) => Ord (ListS sh f)
+infixr 3 ::$
+
+instance (forall n. Show (f n)) => Show (ListS sh f) where
+ showsPrec _ = listsShow shows
+
+data UnconsListSRes f sh1 =
+ forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n)
+listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
+listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
+listsUncons ZS = Nothing
+
+listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
+listsFmap _ ZS = ZS
+listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs
+
+listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m
+listsFold _ ZS = mempty
+listsFold f (x ::$ xs) = f x <> listsFold f xs
+
+listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS
+listsShow f l = showString "[" . go "" l . showString "]"
+ where
+ go :: String -> ListS sh' f -> ShowS
+ go _ ZS = id
+ go prefix (x ::$ xs) = showString prefix . f x . go "," xs
+
+listsToList :: ListS sh (Const i) -> [i]
+listsToList ZS = []
+listsToList (Const i ::$ is) = i : listsToList is
+
+listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f
+listsAppend ZS idx' = idx'
+listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx'
+
+listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f
+listsTakeLenPerm PNil _ = ZS
+listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh
+listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
+
+listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f
+listsDropLenPerm PNil sh = sh
+listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh
+listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape"
+
+listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f
+listsPermute PNil _ = ZS
+listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) =
+ case listsIndex (Proxy @is') (Proxy @sh) i sh of
+ (item, SNat) -> item ::$ listsPermute is sh
+
+-- TODO: remove this SNat when the KnownNat constaint in ListS is removed
+listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh))
+listsIndex _ _ SZ (n ::$ _) = (n, SNat)
+listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f))
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = listsIndex p pT i sh
+listsIndex _ _ _ ZS = error "Index into empty shape"
+
+shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh)
+shsTakeLen = coerce (listsTakeLenPerm @SNat)
+
+shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh)
+shsPermute = coerce (listsPermute @SNat)
+
+shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh)
+shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh)))
+
+applyPermS :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f
+applyPermS perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh)
+
+applyPermIxS :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i
+applyPermIxS = coerce (applyPermS @(Const i))
+
+applyPermShS :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh)
+applyPermShS = coerce (applyPermS @SNat)
+
+
+-- | An index into a shape-typed array.
+--
+-- For convenience, this contains regular 'Int's instead of bounded integers
+-- (traditionally called \"@Fin@\"). Note that because the shape of a
+-- shape-typed array is known statically, you can also retrieve the array shape
+-- from a 'KnownShape' dictionary.
+type role IxS nominal representational
+type IxS :: [Nat] -> Type -> Type
+newtype IxS sh i = IxS (ListS sh (Const i))
+ deriving (Eq, Ord)
+
+pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i
+pattern ZIS = IxS ZS
+
+pattern (:.$)
+ :: forall {sh1} {i}.
+ forall n sh. (KnownNat n, n : sh ~ sh1)
+ => i -> IxS sh i -> IxS sh1 i
+pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i)))
+ where i :.$ IxS shl = IxS (Const i ::$ shl)
+infixr 3 :.$
+
+{-# COMPLETE ZIS, (:.$) #-}
+
+type IIxS sh = IxS sh Int
+
+instance Show i => Show (IxS sh i) where
+ showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l
+
+instance Functor (IxS sh) where
+ fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l)
+
+instance Foldable (IxS sh) where
+ foldMap f (IxS l) = listsFold (f . getConst) l
+
+ixsZero :: ShS sh -> IIxS sh
+ixsZero ZSS = ZIS
+ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh
+
+ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh
+ixCvtXS ZSS ZIX = ZIS
+ixCvtXS (_ :$$ sh) (n :.% idx) = n :.$ ixCvtXS sh idx
+
+ixCvtSX :: IIxS sh -> IIxX (MapJust sh)
+ixCvtSX ZIS = ZIX
+ixCvtSX (n :.$ sh) = n :.% ixCvtSX sh
+
+
+-- | The shape of a shape-typed array given as a list of 'SNat' values.
+type role ShS nominal
+type ShS :: [Nat] -> Type
+newtype ShS sh = ShS (ListS sh SNat)
+ deriving (Eq, Ord)
+
+pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh
+pattern ZSS = ShS ZS
+
+pattern (:$$)
+ :: forall {sh1}.
+ forall n sh. (KnownNat n, n : sh ~ sh1)
+ => SNat n -> ShS sh -> ShS sh1
+pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i))
+ where i :$$ ShS shl = ShS (i ::$ shl)
+
+infixr 3 :$$
+
+{-# COMPLETE ZSS, (:$$) #-}
+
+instance Show (ShS sh) where
+ showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
+
+shsLength :: ShS sh -> Int
+shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l)
+
+shsToList :: ShS sh -> [Int]
+shsToList ZSS = []
+shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh
+
+shCvtXS' :: forall sh. IShX (MapJust sh) -> ShS sh
+shCvtXS' ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS
+shCvtXS' (SKnown n@SNat :$% (idx :: IShX mjshT)) =
+ castWith (subst1 (lem Refl)) $
+ n :$$ shCvtXS' @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh)))
+ idx)
+ where
+ lem :: forall sh1 sh' n.
+ Just n : sh1 :~: MapJust sh'
+ -> n : Tail sh' :~: sh'
+ lem Refl = unsafeCoerceRefl
+shCvtXS' (SUnknown _ :$% _) = error "impossible"
+
+shCvtSX :: ShS sh -> IShX (MapJust sh)
+shCvtSX ZSS = ZSX
+shCvtSX (n :$$ sh) = SKnown n :$% shCvtSX sh
+
+shsSize :: ShS sh -> Int
+shsSize ZSS = 1
+shsSize (n :$$ sh) = fromSNat' n * shsSize sh
+
+-- | Evidence for the static part of a shape. This pops up only when you are
+-- polymorphic in the element type of an array.
+type KnownShS :: [Nat] -> Constraint
+class KnownShS sh where knownShS :: ShS sh
+instance KnownShS '[] where knownShS = ZSS
+instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS
+
+
+-- | Untyped: length is checked at runtime.
+instance KnownShS sh => IsList (ListS sh (Const i)) where
+ type Item (ListS sh (Const i)) = i
+ fromList topl = go (knownShS @sh) topl
+ where
+ go :: ShS sh' -> [i] -> ListS sh' (Const i)
+ go ZSS [] = ZS
+ go (_ :$$ sh) (i : is) = Const i ::$ go sh is
+ go _ _ = error $ "IsList(ListS): Mismatched list length (type says "
+ ++ show (shsLength (knownShS @sh)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = listsToList
+
+-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__.
+instance KnownShS sh => IsList (IxS sh i) where
+ type Item (IxS sh i) = i
+ fromList = IxS . IsList.fromList
+ toList = Foldable.toList
+
+-- | Untyped: length and values are checked at runtime.
+instance KnownShS sh => IsList (ShS sh) where
+ type Item (ShS sh) = Int
+ fromList topl = ShS (go (knownShS @sh) topl)
+ where
+ go :: ShS sh' -> [Int] -> ListS sh' SNat
+ go ZSS [] = ZS
+ go (sn :$$ sh) (i : is)
+ | i == fromSNat' sn = sn ::$ go sh is
+ | otherwise = error $ "IsList(ShS): Value does not match typing (type says "
+ ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")"
+ go _ _ = error $ "IsList(ShS): Mismatched list length (type says "
+ ++ show (shsLength (knownShS @sh)) ++ ", list has length "
+ ++ show (length topl) ++ ")"
+ toList = shsToList
diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs
new file mode 100644
index 0000000..934433e
--- /dev/null
+++ b/src/Data/Array/Nested/Shaped.hs
@@ -0,0 +1,379 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Nested.Shaped where
+
+import Prelude hiding (mappend)
+
+import Control.DeepSeq (NFData)
+import Control.Monad.ST
+import Data.Bifunctor (first)
+import Data.Coerce (coerce)
+import Data.Kind (Type)
+import Data.List.NonEmpty (NonEmpty)
+import Data.Proxy
+import Data.Type.Equality
+import Data.Vector.Storable qualified as VS
+import Foreign.Storable (Storable)
+import GHC.Float qualified (log1p, expm1, log1pexp, log1mexp)
+import GHC.TypeLits
+
+import Data.Array.Mixed (XArray)
+import Data.Array.Mixed qualified as X
+import Data.Array.Mixed.Internal.Arith
+import Data.Array.Mixed.Lemmas
+import Data.Array.Mixed.Permutation
+import Data.Array.Mixed.Shape
+import Data.Array.Mixed.Types
+import Data.Array.Nested.Lemmas
+import Data.Array.Nested.Mixed
+import Data.Array.Nested.Shape
+
+
+-- | A shape-typed array: the full shape of the array (the sizes of its
+-- dimensions) is represented on the type level as a list of 'Nat's. Note that
+-- these are "GHC.TypeLits" naturals, because we do not need induction over
+-- them and we want very large arrays to be possible.
+--
+-- Like for 'Ranked', the valid elements are described by the 'Elt' type class,
+-- and 'Shaped' itself is again an instance of 'Elt' as well.
+--
+-- 'Shaped' is a newtype around a 'Mixed' of 'Just's.
+type Shaped :: [Nat] -> Type -> Type
+newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
+deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a)
+deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a)
+deriving instance Ord (Mixed '[] a) => Ord (Shaped '[] a)
+deriving instance NFData (Mixed (MapJust sh) a) => NFData (Shaped sh a)
+
+-- just unwrap the newtype and defer to the general instance for nested arrays
+newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a))
+deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a))
+
+newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a))
+
+instance Elt a => Elt (Shaped sh a) where
+ mshape (M_Shaped arr) = mshape arr
+ mindex (M_Shaped arr) i = Shaped (mindex arr i)
+
+ mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ mindexPartial (M_Shaped arr) i =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mindexPartial arr i
+
+ mscalar (Shaped x) = M_Shaped (M_Nest ZSX x)
+
+ mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a)
+ mfromListOuter l = M_Shaped (mfromListOuter (coerce l))
+
+ mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)]
+ mtoListOuter (M_Shaped arr)
+ = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr)
+
+ mlift :: forall sh1 sh2.
+ StaticShX sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
+ mlift ssh2 f (M_Shaped arr) =
+ coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
+ mlift ssh2 f arr
+
+ mlift2 :: forall sh1 sh2 sh3.
+ StaticShX sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b)
+ -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a)
+ mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) =
+ coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $
+ mlift2 ssh3 f arr1 arr2
+
+ mcast ssh1 sh2 psh' (M_Shaped arr) = M_Shaped (mcast ssh1 sh2 psh' arr)
+
+ mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr)
+
+ type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a)
+
+ mshapeTree (Shaped arr) = first shCvtXS' (mshapeTree arr)
+
+ mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2
+
+ mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t
+
+ mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")"
+
+ mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
+ mvecsWrite sh idx (Shaped arr) vecs =
+ mvecsWrite sh idx arr
+ (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsWritePartial :: forall sh1 sh2 s.
+ IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a)
+ -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
+ -> ST s ()
+ mvecsWritePartial sh idx arr vecs =
+ mvecsWritePartial sh idx
+ (coerce @(Mixed sh2 (Shaped sh a))
+ @(Mixed sh2 (Mixed (MapJust sh) a))
+ arr)
+ (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
+ @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
+ vecs)
+
+ mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
+ mvecsFreeze sh vecs =
+ coerce @(Mixed sh' (Mixed (MapJust sh) a))
+ @(Mixed sh' (Shaped sh a))
+ <$> mvecsFreeze sh
+ (coerce @(MixedVecs s sh' (Shaped sh a))
+ @(MixedVecs s sh' (Mixed (MapJust sh) a))
+ vecs)
+
+instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where
+ memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a)
+ memptyArray i
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
+ memptyArray i
+
+ mvecsUnsafeNew idx (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsUnsafeNew idx arr
+
+ mvecsNewEmpty _
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a))
+
+
+arithPromoteShaped :: forall sh a. PrimElt a
+ => (forall shx. Mixed shx a -> Mixed shx a)
+ -> Shaped sh a -> Shaped sh a
+arithPromoteShaped = coerce
+
+arithPromoteShaped2 :: forall sh a. PrimElt a
+ => (forall shx. Mixed shx a -> Mixed shx a -> Mixed shx a)
+ -> Shaped sh a -> Shaped sh a -> Shaped sh a
+arithPromoteShaped2 = coerce
+
+instance (NumElt a, PrimElt a) => Num (Shaped sh a) where
+ (+) = arithPromoteShaped2 (+)
+ (-) = arithPromoteShaped2 (-)
+ (*) = arithPromoteShaped2 (*)
+ negate = arithPromoteShaped negate
+ abs = arithPromoteShaped abs
+ signum = arithPromoteShaped signum
+ fromInteger _ = error "Data.Array.Nested.fromIntegral: No singletons available, use explicit sreplicateScal"
+
+instance (FloatElt a, NumElt a, PrimElt a) => Fractional (Shaped sh a) where
+ fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal"
+ recip = arithPromoteShaped recip
+ (/) = arithPromoteShaped2 (/)
+
+instance (FloatElt a, NumElt a, PrimElt a) => Floating (Shaped sh a) where
+ pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal"
+ exp = arithPromoteShaped exp
+ log = arithPromoteShaped log
+ sqrt = arithPromoteShaped sqrt
+ (**) = arithPromoteShaped2 (**)
+ logBase = arithPromoteShaped2 logBase
+ sin = arithPromoteShaped sin
+ cos = arithPromoteShaped cos
+ tan = arithPromoteShaped tan
+ asin = arithPromoteShaped asin
+ acos = arithPromoteShaped acos
+ atan = arithPromoteShaped atan
+ sinh = arithPromoteShaped sinh
+ cosh = arithPromoteShaped cosh
+ tanh = arithPromoteShaped tanh
+ asinh = arithPromoteShaped asinh
+ acosh = arithPromoteShaped acosh
+ atanh = arithPromoteShaped atanh
+ log1p = arithPromoteShaped GHC.Float.log1p
+ expm1 = arithPromoteShaped GHC.Float.expm1
+ log1pexp = arithPromoteShaped GHC.Float.log1pexp
+ log1mexp = arithPromoteShaped GHC.Float.log1mexp
+
+
+sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh
+sshape (Shaped arr) = shCvtXS' (mshape arr)
+
+sindex :: Elt a => Shaped sh a -> IIxS sh -> a
+sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
+
+shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh
+shsTakeIx _ _ ZIS = ZSS
+shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx
+
+sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
+sindexPartial sarr@(Shaped arr) idx =
+ Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
+ (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr)
+ (ixCvtSX idx))
+
+-- | __WARNING__: All values returned from the function must have equal shape.
+-- See the documentation of 'mgenerate' for more details.
+sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a
+sgenerate sh f = Shaped (mgenerate (shCvtSX sh) (f . ixCvtXS sh))
+
+-- | See the documentation of 'mlift'.
+slift :: forall sh1 sh2 a. Elt a
+ => ShS sh2
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
+ -> Shaped sh1 a -> Shaped sh2 a
+slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShape (shCvtSX sh2)) f arr)
+
+-- | See the documentation of 'mlift'.
+slift2 :: forall sh1 sh2 sh3 a. Elt a
+ => ShS sh3
+ -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b)
+ -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a
+slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShape (shCvtSX sh3)) f arr1 arr2)
+
+ssumOuter1P :: forall sh n a. (Storable a, NumElt a)
+ => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a)
+ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr)
+
+ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a)
+ => Shaped (n : sh) a -> Shaped sh a
+ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive
+
+stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a)
+ => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a
+stranspose perm sarr@(Shaped arr)
+ | Refl <- lemRankMapJust (sshape sarr)
+ , Refl <- lemMapJustTakeLen perm (sshape sarr)
+ , Refl <- lemMapJustDropLen perm (sshape sarr)
+ , Refl <- lemMapJustPermute perm (shsTakeLen perm (sshape sarr))
+ , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh))
+ = Shaped (mtranspose perm arr)
+
+sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a
+sappend = coerce mappend
+
+sscalar :: Elt a => a -> Shaped '[] a
+sscalar x = Shaped (mscalar x)
+
+sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a)
+sfromVectorP sh v = Shaped (mfromVectorP (shCvtSX sh) v)
+
+sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a
+sfromVector sh v = sfromPrimitive (sfromVectorP sh v)
+
+stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a
+stoVectorP = coerce mtoVectorP
+
+stoVector :: PrimElt a => Shaped sh a -> VS.Vector a
+stoVector = coerce mtoVector
+
+sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a
+sfromListOuter sn l = Shaped (mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy $ mfromListOuter (coerce l))
+
+sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a
+sfromList1 sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1
+
+sfromList1Prim :: (PrimElt a, Elt a) => SNat n -> [a] -> Shaped '[n] a
+sfromList1Prim sn = Shaped . mcast (SUnknown () :!% ZKX) (SKnown sn :$% ZSX) Proxy . mfromList1Prim
+
+stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a]
+stoListOuter (Shaped arr) = coerce (mtoListOuter arr)
+
+stoList1 :: Elt a => Shaped '[n] a -> [a]
+stoList1 = map sunScalar . stoListOuter
+
+sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a
+sfromListPrim sn l
+ | Refl <- lemAppNil @'[Just n]
+ = let ssh = SUnknown () :!% ZKX
+ xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l)
+ in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr
+
+sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a
+sfromListPrimLinear sh l =
+ let M_Primitive _ xarr = toPrimitive (mfromListPrim l)
+ in Shaped $ fromPrimitive $ M_Primitive (shCvtSX sh) (X.reshape (SUnknown () :!% ZKX) (shCvtSX sh) xarr)
+
+sunScalar :: Elt a => Shaped '[] a -> a
+sunScalar arr = sindex arr ZIS
+
+srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b)
+ => ShS sh -> ShS sh2
+ -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b))
+ -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b)
+srerankP sh sh2 f sarr@(Shaped arr)
+ | Refl <- lemMapJustApp sh (Proxy @sh1)
+ , Refl <- lemMapJustApp sh (Proxy @sh2)
+ = Shaped (mrerankP (ssxFromShape (shxTakeSSX (Proxy @(MapJust sh1)) (shCvtSX (sshape sarr)) (ssxFromShape (shCvtSX sh))))
+ (shCvtSX sh2)
+ (\a -> let Shaped r = f (Shaped a) in r)
+ arr)
+
+srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b)
+ => ShS sh -> ShS sh2
+ -> (Shaped sh1 a -> Shaped sh2 b)
+ -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b
+srerank sh sh2 f (stoPrimitive -> arr) =
+ sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr
+
+sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a
+sreplicate sh (Shaped arr)
+ | Refl <- lemMapJustApp sh (Proxy @sh')
+ = Shaped (mreplicate (shCvtSX sh) arr)
+
+sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a)
+sreplicateScalP sh x = Shaped (mreplicateScalP (shCvtSX sh) x)
+
+sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a
+sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x)
+
+sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
+sslice i n@SNat arr =
+ let _ :$$ sh = sshape arr
+ in slift (n :$$ sh) (\_ -> X.slice i n) arr
+
+srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a
+srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr
+
+sreshape :: Elt a => ShS sh' -> Shaped sh a -> Shaped sh' a
+sreshape sh' (Shaped arr) = Shaped (mreshape (shCvtSX sh') arr)
+
+siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a
+siota sn = Shaped (miota sn)
+
+sasXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a)
+sasXArrayPrimP (Shaped arr) = first shCvtXS' (masXArrayPrimP arr)
+
+sasXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a)
+sasXArrayPrim (Shaped arr) = first shCvtXS' (masXArrayPrim arr)
+
+sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a)
+sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShape (shCvtSX sh)) arr)
+
+sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a
+sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShape (shCvtSX sh)) arr)
+
+sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a
+sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr)
+
+stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a)
+stoPrimitive (Shaped arr) = Shaped (toPrimitive arr)
+
+mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh')
+ => Mixed sh a -> ShS sh' -> Shaped sh' a
+mcastToShaped arr targetsh
+ | Refl <- lemAppNil @sh
+ , Refl <- lemAppNil @(MapJust sh')
+ , Refl <- lemRankMapJust targetsh
+ = Shaped (mcast (ssxFromShape (mshape arr)) (shCvtSX targetsh) (Proxy @'[]) arr)
diff --git a/test/Gen.hs b/test/Gen.hs
index 9652963..3e879b9 100644
--- a/test/Gen.hs
+++ b/test/Gen.hs
@@ -18,7 +18,6 @@ import Foreign
import GHC.TypeLits
import qualified GHC.TypeNats as TN
-import Data.Array.Mixed
import Data.Array.Mixed.Types
import Data.Array.Nested
diff --git a/test/Tests/C.hs b/test/Tests/C.hs
index 1041b2a..53955dc 100644
--- a/test/Tests/C.hs
+++ b/test/Tests/C.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
@@ -9,20 +10,21 @@
module Tests.C where
import Control.Monad
-import qualified Data.Array.RankedS as OR
+import Data.Array.RankedS qualified as OR
import Data.Foldable (toList)
import Data.Type.Equality
import Foreign
import GHC.TypeLits
-import qualified Data.Array.Mixed as X
+import Data.Array.Mixed qualified as X
+import Data.Array.Mixed.Lemmas
import Data.Array.Nested
-import qualified Data.Array.Nested.Internal as I
+import Data.Array.Nested.Mixed
import Hedgehog
import Hedgehog.Internal.Property (forAllT)
-import qualified Hedgehog.Gen as Gen
-import qualified Hedgehog.Range as Range
+import Hedgehog.Gen qualified as Gen
+import Hedgehog.Range qualified as Range
import Test.Tasty
import Test.Tasty.Hedgehog
@@ -46,8 +48,8 @@ tests = testGroup "C"
(\w -> fromIntegral w / fromIntegral (maxBound :: Word64))
let rarr = rfromOrthotope inrank arr
-- annotateShow rarr
- Refl <- return $ I.lemRankReplicate outrank
- let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
+ Refl <- return $ lemRankReplicate outrank
+ let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
let rhs = orSumOuter1 outrank arr
-- annotateShow lhs
-- annotateShow rhs
@@ -66,8 +68,8 @@ tests = testGroup "C"
-- traceM ("sh: " ++ show sh ++ " -> " ++ show (product sh))
let arr = OR.fromList @Double @(n + 1) (toList sh) []
let rarr = rfromOrthotope inrank arr
- Refl <- return $ I.lemRankReplicate outrank
- let Ranked (I.M_Double (I.M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
+ Refl <- return $ lemRankReplicate outrank
+ let Ranked (M_Double (M_Primitive _ (X.XArray lhs))) = rsumOuter1 rarr
OR.toList lhs === []
]
]