aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Mixed.hs266
-rw-r--r--src/Data/Array/Nested.hs19
-rw-r--r--src/Data/Array/Nested/Internal.hs381
3 files changed, 505 insertions, 161 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 0351beb..ce18431 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
@@ -7,12 +8,16 @@
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed where
@@ -21,14 +26,18 @@ import qualified Data.Array.Ranked as ORB
import Data.Coerce
import Data.Kind
import Data.Proxy
+import Data.Type.Bool
import Data.Type.Equality
import qualified Data.Vector.Storable as VS
import Foreign.Storable (Storable)
+import GHC.TypeError
import GHC.TypeLits
import Unsafe.Coerce (unsafeCoerce)
-import Data.INat
+-- | Evidence for the constraint @c a@.
+data Dict c a where
+ Dict :: c a => Dict c a
-- | The 'SNat' pattern synonym is complete, but it doesn't have a
-- @COMPLETE@ pragma. This copy of it does.
@@ -39,6 +48,28 @@ pattern GHC_SNat = SNat
fromSNat' :: SNat n -> Int
fromSNat' = fromIntegral . fromSNat
+pattern SZ :: () => (n ~ 0) => SNat n
+pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl)
+ where SZ = SNat
+
+pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1
+pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl))
+ where SS = snatSucc
+
+{-# COMPLETE SZ, SS #-}
+
+snatSucc :: SNat n -> SNat (n + 1)
+snatSucc SNat = SNat
+
+data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1)
+snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1)
+snatPred snp1 =
+ withKnownNat snp1 $
+ case cmpNat (Proxy @1) (Proxy @np1) of
+ LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl)
+ GTI -> Nothing
+
-- | Type-level list append.
type family l1 ++ l2 where
@@ -51,6 +82,11 @@ lemAppNil = unsafeCoerce Refl
lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerce Refl
+type family Replicate n a where
+ Replicate 0 a = '[]
+ Replicate n a = a : Replicate (n - 1) a
+
+
type IxX :: [Maybe Nat] -> Type -> Type
data IxX sh i where
ZIX :: IxX '[] i
@@ -103,11 +139,11 @@ instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
knownShapeX = () :!$? knownShapeX
type family Rank sh where
- Rank '[] = Z
- Rank (_ : sh) = S (Rank sh)
+ Rank '[] = 0
+ Rank (_ : sh) = 1 + Rank sh
type XArray :: [Maybe Nat] -> Type -> Type
-newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a)
+newtype XArray sh a = XArray (S.Array (Rank sh) a)
deriving (Show)
zeroIxX :: StaticShX sh -> IIxX sh
@@ -157,6 +193,15 @@ ssxToShape' ZKSX = Just ZSX
ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh
ssxToShape' (_ :!$? _) = Nothing
+lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a
+lemReplicateSucc = unsafeCoerce Refl
+
+ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing)
+ssxReplicate SZ = ZKSX
+ssxReplicate (SS (n :: SNat n'))
+ | Refl <- lemReplicateSucc @(Nothing @Nat) @n'
+ = () :!$? ssxReplicate n
+
fromLinearIdx :: IShX sh -> Int -> IIxX sh
fromLinearIdx = \sh i -> case go sh i of
(idx, 0) -> idx
@@ -211,23 +256,28 @@ ssxIotaFrom _ ZKSX = []
ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh
ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh
+staticShapeFrom :: IShX sh -> StaticShX sh
+staticShapeFrom ZSX = ZKSX
+staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh
+staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh
+
lemRankApp :: StaticShX sh1 -> StaticShX sh2
- -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2)
+ -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this
lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
- -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1))
+ -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
-lemKnownINatRank :: IShX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRank ZSX = Dict
-lemKnownINatRank (_ :$@ sh) | Dict <- lemKnownINatRank sh = Dict
-lemKnownINatRank (_ :$? sh) | Dict <- lemKnownINatRank sh = Dict
+lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRank ZSX = Dict
+lemKnownNatRank (_ :$@ sh) | Dict <- lemKnownNatRank sh = Dict
+lemKnownNatRank (_ :$? sh) | Dict <- lemKnownNatRank sh = Dict
-lemKnownINatRankSSX :: StaticShX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRankSSX ZKSX = Dict
-lemKnownINatRankSSX (_ :!$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
-lemKnownINatRankSSX (_ :!$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
+lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRankSSX ZKSX = Dict
+lemKnownNatRankSSX (_ :!$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
+lemKnownNatRankSSX (_ :!$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh
lemKnownShapeX ZKSX = Dict
@@ -254,8 +304,7 @@ shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)
fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
fromVector sh v
- | Dict <- lemKnownINatRank sh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRank sh
= XArray (S.fromVector (shapeLshape sh) v)
toVector :: Storable a => XArray sh a -> VS.Vector a
@@ -269,15 +318,14 @@ unScalar (XArray a) = S.unScalar a
constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
constant sh x
- | Dict <- lemKnownINatRank sh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRank sh
= XArray (S.constant (shapeLshape sh) x)
generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a
generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)
-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a)
--- generateM sh f | Dict <- lemKnownINatRank sh =
+-- generateM sh f | Dict <- lemKnownNatRank sh =
-- XArray . S.fromVector (shapeLshape sh)
-- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh)
@@ -300,8 +348,7 @@ type family AddMaybe n m where
append :: forall n m sh a. (KnownShapeX sh, Storable a)
=> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
append (XArray a) (XArray b)
- | Dict <- lemKnownINatRankSSX (knownShapeX @sh)
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
= XArray (S.append a b)
rerank :: forall sh sh1 sh2 a b.
@@ -310,21 +357,18 @@ rerank :: forall sh sh1 sh2 a b.
-> (XArray sh1 a -> XArray sh2 b)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
rerank ssh ssh1 ssh2 f (XArray arr)
- | Dict <- lemKnownINatRankSSX ssh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
- , Dict <- lemKnownINatRankSSX ssh2
- , Dict <- knownNatFromINat (Proxy @(Rank sh2))
+ | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- lemKnownNatRankSSX ssh2
, Refl <- lemRankApp ssh ssh1
, Refl <- lemRankApp ssh ssh2
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
- , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (S.rerank @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2))
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough
+ = XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2)
(\a -> unXArray (f (XArray a)))
arr)
where
unXArray (XArray a) = a
-rerankTop :: forall sh sh1 sh2 a b.
+rerankTop :: forall sh1 sh2 sh a b.
(Storable a, Storable b)
=> StaticShX sh1 -> StaticShX sh2 -> StaticShX sh
-> (XArray sh1 a -> XArray sh2 b)
@@ -337,26 +381,135 @@ rerank2 :: forall sh sh1 sh2 a b c.
-> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
- | Dict <- lemKnownINatRankSSX ssh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
- , Dict <- lemKnownINatRankSSX ssh2
- , Dict <- knownNatFromINat (Proxy @(Rank sh2))
+ | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- lemKnownNatRankSSX ssh2
, Refl <- lemRankApp ssh ssh1
, Refl <- lemRankApp ssh ssh2
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
- , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (S.rerank2 @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2))
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough
+ = XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2)
(\a b -> unXArray (f (XArray a) (XArray b)))
arr1 arr2)
where
unXArray (XArray a) = a
+type family Elem x l where
+ Elem x '[] = 'False
+ Elem x (x : _) = 'True
+ Elem x (_ : ys) = Elem x ys
+
+type family AllElem' as bs where
+ AllElem' '[] bs = 'True
+ AllElem' (a : as) bs = Elem a bs && AllElem' as bs
+
+type AllElem as bs = Assert (AllElem' as bs)
+ (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs))
+
+type family Count i n where
+ Count n n = '[]
+ Count i n = i : Count (i + 1) n
+
+type Permutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as)
+
+type family Index i sh where
+ Index 0 (n : sh) = n
+ Index i (_ : sh) = Index (i - 1) sh
+
+type family Permute is sh where
+ Permute '[] sh = '[]
+ Permute (i : is) sh = Index i sh : Permute is sh
+
+type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh
+
+data HList f list where
+ HNil :: HList f '[]
+ HCons :: f a -> HList f l -> HList f (a : l)
+infixr 5 `HCons`
+
+foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m
+foldHList _ HNil = mempty
+foldHList f (x `HCons` l) = f x <> foldHList f l
+
+class KnownNatList l where makeNatList :: HList SNat l
+instance KnownNatList '[] where makeNatList = HNil
+instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList
+
+type family TakeLen ref l where
+ TakeLen '[] l = '[]
+ TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs
+
+type family DropLen ref l where
+ DropLen '[] l = l
+ DropLen (_ : ref) (_ : xs) = DropLen ref xs
+
+lemRankPermute :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is
+lemRankPermute _ HNil = Refl
+lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl
+
+lemRankDropLen :: forall is sh. (Rank is <= Rank sh)
+ => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is
+lemRankDropLen ZKSX HNil = Refl
+lemRankDropLen (_ :!$@ sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!$? sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl
+lemRankDropLen (_ :!$@ _) HNil = Refl
+lemRankDropLen (_ :!$? _) HNil = Refl
+lemRankDropLen ZKSX (_ `HCons` _) = error "1 <= 0"
+
+lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l
+lemIndexSucc _ _ _ = unsafeCoerce Refl
+
+ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh)
+ssxTakeLen HNil _ = ZKSX
+ssxTakeLen (_ `HCons` is) (n :!$@ sh) = n :!$@ ssxTakeLen is sh
+ssxTakeLen (_ `HCons` is) (n :!$? sh) = n :!$? ssxTakeLen is sh
+ssxTakeLen (_ `HCons` _) ZKSX = error "Permutation longer than shape"
+
+ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh)
+ssxDropLen HNil sh = sh
+ssxDropLen (_ `HCons` is) (_ :!$@ sh) = ssxDropLen is sh
+ssxDropLen (_ `HCons` is) (_ :!$? sh) = ssxDropLen is sh
+ssxDropLen (_ `HCons` _) ZKSX = error "Permutation longer than shape"
+
+ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh)
+ssxPermute HNil _ = ZKSX
+ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh)
+
+ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT)
+ssxIndex _ _ SZ (n :!$@ _) rest = n :!$@ rest
+ssxIndex _ _ SZ (n :!$? _) rest = n :!$? rest
+ssxIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :!$@ (sh :: StaticShX sh')) rest
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @sh')
+ = ssxIndex p pT i sh rest
+ssxIndex p pT (SS (i :: SNat i')) (() :!$? (sh :: StaticShX sh')) rest
+ | Refl <- lemIndexSucc (Proxy @i') (Proxy @Nothing) (Proxy @sh')
+ = ssxIndex p pT i sh rest
+ssxIndex _ _ _ ZKSX _ = error "Index into empty shape"
+
-- | The list argument gives indices into the original dimension list.
-transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
+transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh)
+ => HList SNat is
+ -> XArray sh a
+ -> XArray (PermutePrefix is sh) a
transpose perm (XArray arr)
- | Dict <- lemKnownINatRankSSX (knownShapeX @sh)
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
+ , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm (knownShapeX @sh))) (ssxDropLen perm (knownShapeX @sh))
+ , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm
+ , Refl <- lemRankDropLen (knownShapeX @sh) perm
+ = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int]
+ in XArray (S.transpose perm' arr)
+
+-- | The list argument gives indices into the original dimension list.
+--
+-- The permutation (the list) must have length <= @n@. If it is longer, this
+-- function throws.
+transposeUntyped :: forall n sh a.
+ SNat n -> StaticShX sh -> [Int]
+ -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a
+transposeUntyped sn ssh perm (XArray arr)
+ | length perm <= fromSNat' sn
+ , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh)
= XArray (S.transpose perm arr)
+ | otherwise
+ = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type"
transpose2 :: forall sh1 sh2 a.
StaticShX sh1 -> StaticShX sh2
@@ -364,10 +517,8 @@ transpose2 :: forall sh1 sh2 a.
transpose2 ssh1 ssh2 (XArray arr)
| Refl <- lemRankApp ssh1 ssh2
, Refl <- lemRankApp ssh2 ssh1
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh2)
- , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh2)))
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh2 ssh1)
- , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh1)))
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2)
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
, Refl <- lemRankAppComm ssh1 ssh2
, let n1 = ssxLength ssh1
= XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
@@ -390,13 +541,12 @@ sumOuter ssh ssh'
fromList1 :: forall n sh a. Storable a
=> StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
fromList1 ssh l
- | Dict <- lemKnownINatRankSSX ssh
- , Dict <- knownNatFromINat (Proxy @(Rank (n : sh)))
+ | Dict <- lemKnownNatRankSSX ssh
= case ssh of
m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) ->
error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++
"does not match the type (" ++ show (natVal m) ++ ")"
- _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l)))
+ _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a]
toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr))
@@ -404,13 +554,29 @@ toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr))
-- | Throws if the given shape is not, in fact, empty.
empty :: forall sh a. Storable a => IShX sh -> XArray sh a
empty sh
- | Dict <- lemKnownINatRank sh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRank sh
= XArray (S.constant (shapeLshape sh)
(error "Data.Array.Mixed.empty: shape was not empty"))
-slice :: [(Int, Int)] -> XArray sh a -> XArray sh a
-slice ivs (XArray arr) = XArray (S.slice ivs arr)
+slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a
+slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr)
+
+sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a
+sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr)
rev1 :: XArray (n : sh) a -> XArray (n : sh) a
rev1 (XArray arr) = XArray (S.rev [0] arr)
+
+-- | Throws if the given array and the target shape do not have the same number of elements.
+reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a
+reshape ssh1 sh2 (XArray arr)
+ | Dict <- lemKnownNatRankSSX ssh1
+ , Dict <- lemKnownNatRank sh2
+ = XArray (S.reshape (shapeLshape sh2) arr)
+
+-- | Throws if the given array and the target shape do not have the same number of elements.
+reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
+reshapePartial ssh1 ssh' sh2 (XArray arr)
+ | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh')
+ , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh')
+ = XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr)
diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs
index ec5f0b5..4b455da 100644
--- a/src/Data/Array/Nested.hs
+++ b/src/Data/Array/Nested.hs
@@ -9,9 +9,11 @@ module Data.Array.Nested (
rshape, rindex, rindexPartial, rgenerate, rsumOuter1,
rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar,
rconstant, rfromList, rfromList1, rtoList, rtoList1,
- rslice, rrev1,
+ rslice, rrev1, rreshape,
-- ** Lifting orthotope operations to 'Ranked' arrays
rlift,
+ -- ** Conversions
+ rasXArrayPrim, rfromXArrayPrim,
-- * Shaped arrays
Shaped,
@@ -21,33 +23,36 @@ module Data.Array.Nested (
sshape, sindex, sindexPartial, sgenerate, ssumOuter1,
stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar,
sconstant, sfromList, sfromList1, stoList, stoList1,
- sslice, srev1,
+ sslice, srev1, sreshape,
-- ** Lifting orthotope operations to 'Shaped' arrays
slift,
+ -- ** Conversions
+ sasXArrayPrim, sfromXArrayPrim,
-- * Mixed arrays
Mixed,
IxX(..), IIxX,
KnownShapeX(..), StaticShX(..),
mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar,
- mconstant, mfromList, mtoList, mslice, mrev1,
+ mconstant, mfromList, mtoList, mslice, mrev1, mreshape,
+ -- ** Conversions
+ masXArrayPrim, mfromXArrayPrim,
-- * Array elements
Elt(mshape, mindex, mindexPartial, mscalar, mfromList1, mtoList1, mlift, mlift2),
PrimElt,
Primitive(..),
- -- * Inductive natural numbers
- module Data.INat,
-
-- * Further utilities / re-exports
type (++),
Storable,
+ HList,
+ Permutation,
+ makeNatList,
) where
import Prelude hiding (mappend)
import Data.Array.Mixed
import Data.Array.Nested.Internal
-import Data.INat
import Foreign.Storable
diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs
index 350eb6f..7bd6565 100644
--- a/src/Data/Array/Nested/Internal.hs
+++ b/src/Data/Array/Nested/Internal.hs
@@ -20,6 +20,7 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-|
@@ -27,9 +28,42 @@ TODO:
* We should be more consistent in whether functions take a 'StaticShX'
argument or a 'KnownShapeX' constraint.
-* Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point
- being that we need to do induction over the former, but the latter need to be
- able to get large.
+* Mikolaj wants these:
+
+ About your wishlist of operations: these are already there
+
+ OR.index
+ OR.append
+ OR.transpose
+
+ These can be easily lifted from the definition for XArray (5min work):
+
+ OR.scalar
+ OR.unScalar
+ OR.constant
+
+ These should not be hard:
+
+ OR.fromList
+ ORB.toList . OR.unravel
+ OR.ravel . ORB.fromList
+ OR.slice
+ OR.rev
+ OR.reshape
+
+ though it's a bit unfortunate that we end up needing toList. Looking in
+ horde-ad I see that you seem to need them to do certain operations in Haskell
+ that orthotope doesn't support?
+
+ For this one we'll need to see to what extent you really need it, and what API
+ you'd need precisely:
+
+ OR.rerank
+
+ and for these we have an API design question:
+
+ OR.toVector
+ OR.fromVector
-}
@@ -52,9 +86,8 @@ import qualified Data.Vector.Storable.Mutable as VSM
import Foreign.Storable (Storable)
import GHC.TypeLits
-import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat)
+import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..), pattern SZ, pattern SS, Replicate)
import qualified Data.Array.Mixed as X
-import Data.INat
-- Invariant in the API
@@ -90,35 +123,60 @@ import Data.INat
-- have been marked with [PRIMITIVE ELEMENT TYPES LIST].
-type family Replicate n a where
- Replicate Z a = '[]
- Replicate (S n) a = a : Replicate n a
-
type family MapJust l where
MapJust '[] = '[]
MapJust (x : xs) = Just x : MapJust xs
-lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
-lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n))
+
+-- Stupid things that the type checker should be able to figure out in-line, but can't
+
+subst1 :: forall f a b. a :~: b -> f a :~: f b
+subst1 Refl = Refl
+
+subst2 :: forall f c a b. a :~: b -> f a c :~: f b c
+subst2 Refl = Refl
+
+lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l
+lemAppLeft _ Refl = Refl
+
+knownNatSucc :: KnownNat n => Dict KnownNat (n + 1)
+knownNatSucc = Dict
+
+
+lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
+lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n))
where
- go :: SINat m -> StaticShX (Replicate m Nothing)
+ go :: SNat m -> StaticShX (Replicate m Nothing)
go SZ = ZKSX
- go (SS n) = () :!$? go n
+ go (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n
-lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
-lemRankReplicate _ = go (inatSing @n)
+lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
+lemRankReplicate _ = go (natSing @n)
where
- go :: SINat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
+ go :: forall m. SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
go SZ = Refl
- go (SS n) | Refl <- go n = Refl
+ go (SS (n :: SNat nm1))
+ | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1
+ , Refl <- go n
+ = Refl
-lemReplicatePlusApp :: forall n m a. KnownINat n => Proxy n -> Proxy m -> Proxy a
- -> Replicate (n +! m) a :~: Replicate n a ++ Replicate m a
-lemReplicatePlusApp _ _ _ = go (inatSing @n)
+lemRankMapJust :: forall sh. KnownShape sh => Proxy sh -> X.Rank (MapJust sh) :~: X.Rank sh
+lemRankMapJust _ = go (knownShape @sh)
where
- go :: SINat n' -> Replicate (n' +! m) a :~: Replicate n' a ++ Replicate m a
+ go :: forall sh'. ShS sh' -> X.Rank (MapJust sh') :~: X.Rank sh'
+ go ZSS = Refl
+ go (_ :$$ sh') | Refl <- go sh' = Refl
+
+lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a
+ -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
+lemReplicatePlusApp _ _ _ = go (natSing @n)
+ where
+ go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
go SZ = Refl
- go (SS n) | Refl <- go n = Refl
+ go (SS (n :: SNat n'm1))
+ | Refl <- X.lemReplicateSucc @a @n'm1
+ , Refl <- go n
+ = sym (X.lemReplicateSucc @a @(n'm1 + m))
shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh')
shAppSplit _ ZKSX idx = (ZSX, idx)
@@ -494,10 +552,12 @@ mgenerate sh f = case X.enumShape sh of
mvecsWrite sh idx val vecs
mvecsFreeze sh vecs
-mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a
-mtranspose perm =
- mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh')
- (X.transpose perm))
+mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a
+mtranspose perm
+ | Dict <- X.lemKnownShapeX (X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm (knownShapeX @sh))) (X.ssxDropLen perm (knownShapeX @sh)))
+ = mlift $ \(Proxy @sh') ->
+ X.rerankTop (knownShapeX @sh) (knownShapeX @(X.PermutePrefix is sh)) (knownShapeX @sh')
+ (X.transpose perm)
mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a)
=> Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a
@@ -534,12 +594,32 @@ mconstant :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a)
=> IShX sh -> a -> Mixed sh a
mconstant sh x = fromPrimitive (mconstantP sh x)
-mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a
-mslice ivs = mlift $ \_ -> X.slice ivs
+mslice :: (KnownShapeX sh, Elt a) => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a
+mslice i n = withKnownNat n $ mlift $ \_ -> X.slice i n
+
+msliceU :: (KnownShapeX sh, Elt a) => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a
+msliceU i n = mlift $ \_ -> X.sliceU i n
mrev1 :: (KnownShapeX (n : sh), Elt a) => Mixed (n : sh) a -> Mixed (n : sh) a
mrev1 = mlift $ \_ -> X.rev1
+mreshape :: forall sh sh' a. (KnownShapeX sh, KnownShapeX sh', Elt a)
+ => IShX sh' -> Mixed sh a -> Mixed sh' a
+mreshape sh' = mlift $ \(_ :: Proxy shIn) ->
+ X.reshapePartial (knownShapeX @sh) (knownShapeX @shIn) sh'
+
+masXArrayPrimP :: Mixed sh (Primitive a) -> XArray sh a
+masXArrayPrimP (M_Primitive arr) = arr
+
+masXArrayPrim :: PrimElt a => Mixed sh a -> XArray sh a
+masXArrayPrim = masXArrayPrimP . toPrimitive
+
+mfromXArrayPrimP :: XArray sh a -> Mixed sh (Primitive a)
+mfromXArrayPrimP = M_Primitive
+
+mfromXArrayPrim :: PrimElt a => XArray sh a -> Mixed sh a
+mfromXArrayPrim = fromPrimitive . mfromXArrayPrimP
+
mliftPrim :: (KnownShapeX sh, Storable a)
=> (a -> a)
-> Mixed sh (Primitive a) -> Mixed sh (Primitive a)
@@ -570,18 +650,15 @@ deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed s
-- | A rank-typed array: the number of dimensions of the array (its /rank/) is
--- represented on the type level as a 'INat'.
+-- represented on the type level as a 'Nat'.
--
-- Valid elements of a ranked arrays are described by the 'Elt' type class.
-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are
-- supported (and are represented as a single, flattened, struct-of-arrays
-- array internally).
--
--- Note that this 'INat' is not a "GHC.TypeLits" natural, because we want a
--- type-level natural that supports induction.
---
-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's.
-type Ranked :: INat -> Type -> Type
+type Ranked :: Nat -> Type -> Type
newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a)
@@ -611,7 +688,7 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe
-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
-- these instances allow them to also be used as elements of arrays, thus
-- making them first-class in the API.
-instance (Elt a, KnownINat n) => Elt (Ranked n a) where
+instance (Elt a, KnownNat n) => Elt (Ranked n a) where
mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)
@@ -732,13 +809,10 @@ lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh))
go ZSS = ZKSX
go (n :$$ sh) = n :!$@ go sh
-lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2
+lemCommMapJustApp :: forall sh1 sh2. ShS sh1 -> Proxy sh2
-> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
-lemMapJustPlusApp _ _ = go (knownShape @sh1)
- where
- go :: ShS sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2
- go ZSS = Refl
- go (_ :$$ sh) | Refl <- go sh = Refl
+lemCommMapJustApp ZSS _ = Refl
+lemCommMapJustApp (_ :$$ sh) p | Refl <- lemCommMapJustApp sh p = Refl
instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where
mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr
@@ -843,37 +917,37 @@ rewriteMixed Refl x = x
-- ====== API OF RANKED ARRAYS ====== --
-arithPromoteRanked :: forall n a. KnownINat n
+arithPromoteRanked :: forall n a. KnownNat n
=> (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a)
-> Ranked n a -> Ranked n a
arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce
-arithPromoteRanked2 :: forall n a. KnownINat n
+arithPromoteRanked2 :: forall n a. KnownNat n
=> (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a)
-> Ranked n a -> Ranked n a -> Ranked n a
arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce
-instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where
+instance (KnownNat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where
(+) = arithPromoteRanked2 (+)
(-) = arithPromoteRanked2 (-)
(*) = arithPromoteRanked2 (*)
negate = arithPromoteRanked negate
abs = arithPromoteRanked abs
signum = arithPromoteRanked signum
- fromInteger n = case inatSing @n of
+ fromInteger n = case natSing @n of
SZ -> Ranked (M_Primitive (X.scalar (fromInteger n)))
- SS _ -> error "Data.Array.Nested.fromIntegral(Ranked): \
- \Rank non-zero, use explicit mconstant"
+ _ -> error "Data.Array.Nested.fromIntegral(Ranked): \
+ \Rank non-zero, use explicit mconstant"
-- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types)
-deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int)
-deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double)
+deriving via Ranked n (Primitive Int) instance KnownNat n => Num (Ranked n Int)
+deriving via Ranked n (Primitive Double) instance KnownNat n => Num (Ranked n Double)
type role ListR nominal representational
-type ListR :: INat -> Type -> Type
+type ListR :: Nat -> Type -> Type
data ListR n i where
- ZR :: ListR Z i
- (:::) :: forall n {i}. i -> ListR n i -> ListR (S n) i
+ ZR :: ListR 0 i
+ (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i
deriving instance Show i => Show (ListR n i)
deriving instance Eq i => Eq (ListR n i)
deriving instance Ord i => Ord (ListR n i)
@@ -887,23 +961,23 @@ listRToList :: ListR n i -> [i]
listRToList ZR = []
listRToList (i ::: is) = i : listRToList is
-knownListR :: ListR n i -> Dict KnownINat n
+knownListR :: ListR n i -> Dict KnownNat n
knownListR ZR = Dict
-knownListR (_ ::: l) | Dict <- knownListR l = Dict
+knownListR (_ ::: (l :: ListR m i)) | Dict <- knownListR l = knownNatSucc @m
-- | An index into a rank-typed array.
type role IxR nominal representational
-type IxR :: INat -> Type -> Type
+type IxR :: Nat -> Type -> Type
newtype IxR n i = IxR (ListR n i)
deriving (Show, Eq, Ord)
deriving newtype (Functor, Foldable)
-pattern ZIR :: forall n i. () => n ~ Z => IxR n i
+pattern ZIR :: forall n i. () => n ~ 0 => IxR n i
pattern ZIR = IxR ZR
pattern (:.:)
:: forall {n1} {i}.
- forall n. (S n ~ n1)
+ forall n. (n + 1 ~ n1)
=> i -> IxR n i -> IxR n1 i
pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i))
where i :.: IxR sh = IxR (i ::: sh)
@@ -911,30 +985,30 @@ pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i))
infixr 3 :.:
data UnconsIxRRes i n1 =
- forall n. ((S n) ~ n1) => UnconsIxRRes (IxR n i) i
+ forall n. (n + 1 ~ n1) => UnconsIxRRes (IxR n i) i
unconsIxR :: IxR n1 i -> Maybe (UnconsIxRRes i n1)
unconsIxR (IxR (i ::: sh')) = Just (UnconsIxRRes (IxR sh') i)
unconsIxR (IxR ZR) = Nothing
type IIxR n = IxR n Int
-knownIxR :: IxR n i -> Dict KnownINat n
+knownIxR :: IxR n i -> Dict KnownNat n
knownIxR (IxR sh) = knownListR sh
type role ShR nominal representational
-type ShR :: INat -> Type -> Type
+type ShR :: Nat -> Type -> Type
newtype ShR n i = ShR (ListR n i)
deriving (Show, Eq, Ord)
deriving newtype (Functor, Foldable)
type IShR n = ShR n Int
-pattern ZSR :: forall n i. () => n ~ Z => ShR n i
+pattern ZSR :: forall n i. () => n ~ 0 => ShR n i
pattern ZSR = ShR ZR
pattern (:$:)
:: forall {n1} {i}.
- forall n. (S n ~ n1)
+ forall n. (n + 1 ~ n1)
=> i -> ShR n i -> ShR n1 i
pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i))
where i :$: (ShR sh) = ShR (i ::: sh)
@@ -942,15 +1016,15 @@ pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i))
infixr 3 :$:
data UnconsShRRes i n1 =
- forall n. S n ~ n1 => UnconsShRRes (ShR n i) i
+ forall n. n + 1 ~ n1 => UnconsShRRes (ShR n i) i
unconsShR :: ShR n1 i -> Maybe (UnconsShRRes i n1)
unconsShR (ShR (i ::: sh')) = Just (UnconsShRRes (ShR sh') i)
unconsShR (ShR ZR) = Nothing
-knownShR :: ShR n i -> Dict KnownINat n
+knownShR :: ShR n i -> Dict KnownNat n
knownShR (ShR sh) = knownListR sh
-zeroIxR :: SINat n -> IIxR n
+zeroIxR :: SNat n -> IIxR n
zeroIxR SZ = ZIR
zeroIxR (SS n) = 0 :.: zeroIxR n
@@ -966,18 +1040,18 @@ shCvtXR (n :$? idx) = n :$: shCvtXR idx
ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing)
ixCvtRX ZIR = ZIX
-ixCvtRX (n :.: idx) = n :.? ixCvtRX idx
+ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx)
shCvtRX :: IShR n -> IShX (Replicate n Nothing)
shCvtRX ZSR = ZSX
-shCvtRX (n :$: idx) = n :$? shCvtRX idx
+shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx)
shapeSizeR :: IShR n -> Int
shapeSizeR ZSR = 1
shapeSizeR (n :$: sh) = n * shapeSizeR sh
-rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IShR n
+rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IShR n
rshape (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
, Refl <- lemRankReplicate (Proxy @n)
@@ -986,7 +1060,7 @@ rshape (Ranked arr)
rindex :: Elt a => Ranked n a -> IIxR n -> a
rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
-rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IIxR n -> Ranked m a
+rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IIxR n -> Ranked m a
rindexPartial (Ranked arr) idx =
Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
(rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
@@ -1002,47 +1076,54 @@ rgenerate sh f
= Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR))
-- | See the documentation of 'mlift'.
-rlift :: forall n1 n2 a. (KnownINat n2, Elt a)
- => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
+rlift :: forall n1 n2 a. (KnownNat n2, Elt a)
+ => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
-> Ranked n1 a -> Ranked n2 a
rlift f (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n2)
= Ranked (mlift f arr)
rsumOuter1P :: forall n a.
- (Storable a, Num a, KnownINat n)
- => Ranked (S n) (Primitive a) -> Ranked n (Primitive a)
+ (Storable a, Num a, KnownNat n)
+ => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a)
rsumOuter1P (Ranked arr)
| Dict <- lemKnownReplicate (Proxy @n)
+ , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
= Ranked
. coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a))
. X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing))
- . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a)
+ . coerce @(Mixed (Replicate (n + 1) Nothing) (Primitive a)) @(XArray (Replicate (n + 1) Nothing) a)
$ arr
-rsumOuter1 :: forall n a.
- (Storable a, Num a, PrimElt a, KnownINat n)
- => Ranked (S n) a -> Ranked n a
+rsumOuter1 :: forall n a. (Storable a, Num a, PrimElt a, KnownNat n)
+ => Ranked (n + 1) a -> Ranked n a
rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive
-rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
-rtranspose perm (Ranked arr)
+rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a
+rtranspose perm
| Dict <- lemKnownReplicate (Proxy @n)
- = Ranked (mtranspose perm arr)
-
-rappend :: forall n a. (KnownINat n, Elt a)
- => Ranked (S n) a -> Ranked (S n) a -> Ranked (S n) a
-rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend
+ , length perm <= fromIntegral (natVal (Proxy @n))
+ = rlift $ \(Proxy @sh') ->
+ X.transposeUntyped (natSing @n) (knownShapeX @sh') perm
+ | otherwise
+ = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array"
+
+rappend :: forall n a. (KnownNat n, Elt a)
+ => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a
+rappend
+ | Dict <- lemKnownReplicate (Proxy @n)
+ , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mappend @Nothing @Nothing @(Replicate n Nothing))
-rscalar :: Elt a => a -> Ranked I0 a
+rscalar :: Elt a => a -> Ranked 0 a
rscalar x = Ranked (mscalar x)
-rfromVectorP :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)
+rfromVectorP :: forall n a. (KnownNat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a)
rfromVectorP sh v
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mfromVectorP (shCvtRX sh) v)
-rfromVector :: forall n a. (KnownINat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a
+rfromVector :: forall n a. (KnownNat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a
rfromVector sh v = coerce fromPrimitive (rfromVectorP sh v)
rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a
@@ -1051,37 +1132,63 @@ rtoVectorP = coerce mtoVectorP
rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a
rtoVector = coerce mtoVector
-rfromList1 :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a
+rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (n + 1) a
rfromList1 l
| Dict <- lemKnownReplicate (Proxy @n)
- = Ranked (mfromList1 (coerce l))
+ , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ = Ranked (mfromList1 @a @Nothing @(Replicate n Nothing) (coerce l))
-rfromList :: Elt a => NonEmpty a -> Ranked I1 a
+rfromList :: Elt a => NonEmpty a -> Ranked 1 a
rfromList = Ranked . mfromList1 . fmap mscalar
-rtoList :: Elt a => Ranked (S n) a -> [Ranked n a]
-rtoList (Ranked arr) = coerce (mtoList1 arr)
+rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a]
+rtoList (Ranked arr)
+ | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ = coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr)
-rtoList1 :: Elt a => Ranked I1 a -> [a]
+rtoList1 :: Elt a => Ranked 1 a -> [a]
rtoList1 = map runScalar . rtoList
-runScalar :: Elt a => Ranked I0 a -> a
+runScalar :: Elt a => Ranked 0 a -> a
runScalar arr = rindex arr ZIR
-rconstantP :: forall n a. (KnownINat n, Storable a) => IShR n -> a -> Ranked n (Primitive a)
+rconstantP :: forall n a. (KnownNat n, Storable a) => IShR n -> a -> Ranked n (Primitive a)
rconstantP sh x
| Dict <- lemKnownReplicate (Proxy @n)
= Ranked (mconstantP (shCvtRX sh) x)
-rconstant :: forall n a. (KnownINat n, Storable a, PrimElt a)
+rconstant :: forall n a. (KnownNat n, Storable a, PrimElt a)
=> IShR n -> a -> Ranked n a
rconstant sh x = coerce fromPrimitive (rconstantP sh x)
-rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a
-rslice ivs = rlift $ \_ -> X.slice ivs
+rslice :: forall n a. (KnownNat n, Elt a) => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a
+rslice i n
+ | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n
+ = rlift $ \_ -> X.sliceU i n
+
+rrev1 :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a
+rrev1 = rlift $ \(Proxy @sh') ->
+ case X.lemReplicateSucc @(Nothing @Nat) @n of
+ Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')
-rrev1 :: (KnownINat n, Elt a) => Ranked (S n) a -> Ranked (S n) a
-rrev1 = rlift $ \_ -> X.rev1
+rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a)
+ => IShR n' -> Ranked n a -> Ranked n' a
+rreshape sh' (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ , Dict <- lemKnownReplicate (Proxy @n')
+ = Ranked (mreshape (shCvtRX sh') arr)
+
+rasXArrayPrimP :: Ranked n (Primitive a) -> XArray (Replicate n Nothing) a
+rasXArrayPrimP (Ranked arr) = masXArrayPrimP arr
+
+rasXArrayPrim :: PrimElt a => Ranked n a -> XArray (Replicate n Nothing) a
+rasXArrayPrim (Ranked arr) = masXArrayPrim arr
+
+rfromXArrayPrimP :: XArray (Replicate n Nothing) a -> Ranked n (Primitive a)
+rfromXArrayPrimP = Ranked . mfromXArrayPrimP
+
+rfromXArrayPrim :: PrimElt a => XArray (Replicate n Nothing) a -> Ranked n a
+rfromXArrayPrim = Ranked . mfromXArrayPrim
-- ====== API OF SHAPED ARRAYS ====== --
@@ -1200,7 +1307,7 @@ sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a
sindexPartial (Shaped arr) idx =
Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
- (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr)
+ (rewriteMixed (lemCommMapJustApp (knownShape @sh1) (Proxy @sh2)) arr)
(ixCvtSX idx))
-- | __WARNING__: All values returned from the function must have equal shape.
@@ -1212,7 +1319,7 @@ sgenerate f
-- | See the documentation of 'mlift'.
slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a)
- => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
+ => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
-> Shaped sh1 a -> Shaped sh2 a
slift f (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh2)
@@ -1234,9 +1341,56 @@ ssumOuter1 :: forall sh n a.
=> Shaped (n : sh) a -> Shaped sh a
ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive
-stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a
+lemCommMapJustTakeLen :: HList SNat is -> ShS sh -> X.TakeLen is (MapJust sh) :~: MapJust (X.TakeLen is sh)
+lemCommMapJustTakeLen HNil _ = Refl
+lemCommMapJustTakeLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl
+lemCommMapJustTakeLen (_ `HCons` _) ZSS = error "TakeLen of empty"
+
+lemCommMapJustDropLen :: HList SNat is -> ShS sh -> X.DropLen is (MapJust sh) :~: MapJust (X.DropLen is sh)
+lemCommMapJustDropLen HNil _ = Refl
+lemCommMapJustDropLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl
+lemCommMapJustDropLen (_ `HCons` _) ZSS = error "DropLen of empty"
+
+lemCommMapJustIndex :: SNat i -> ShS sh -> X.Index i (MapJust sh) :~: Just (X.Index i sh)
+lemCommMapJustIndex SZ (_ :$$ _) = Refl
+lemCommMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh'))
+ | Refl <- lemCommMapJustIndex i sh
+ , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh'))
+ , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = Refl
+lemCommMapJustIndex _ ZSS = error "Index of empty"
+
+lemCommMapJustPermute :: HList SNat is -> ShS sh -> X.Permute is (MapJust sh) :~: MapJust (X.Permute is sh)
+lemCommMapJustPermute HNil _ = Refl
+lemCommMapJustPermute (i `HCons` is) sh
+ | Refl <- lemCommMapJustPermute is sh
+ , Refl <- lemCommMapJustIndex i sh
+ = Refl
+
+shTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh)
+shTakeLen HNil _ = ZSS
+shTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shTakeLen is sh
+shTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape"
+
+shPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh)
+shPermute HNil _ = ZSS
+shPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shIndex (Proxy @is') (Proxy @sh) i sh (shPermute is sh)
+
+shIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT)
+shIndex _ _ SZ (n :$$ _) rest = n :$$ rest
+shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest
+ | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh')
+ = shIndex p pT i sh rest
+shIndex _ _ _ ZSS _ = error "Index into empty shape"
+
+stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a
stranspose perm (Shaped arr)
| Dict <- lemKnownMapJust (Proxy @sh)
+ , Refl <- lemRankMapJust (Proxy @sh)
+ , Refl <- lemCommMapJustTakeLen perm (knownShape @sh)
+ , Refl <- lemCommMapJustDropLen perm (knownShape @sh)
+ , Refl <- lemCommMapJustPermute perm (shTakeLen perm (knownShape @sh))
+ , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (knownShape @sh))) (Proxy @(X.DropLen is sh))
= Shaped (mtranspose perm arr)
sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a)
@@ -1287,8 +1441,27 @@ sconstant :: forall sh a. (KnownShape sh, Storable a, PrimElt a)
=> a -> Shaped sh a
sconstant x = coerce fromPrimitive (sconstantP @sh x)
-sslice :: (KnownShape sh, Elt a) => [(Int, Int)] -> Shaped sh a -> Shaped sh a
-sslice ivs = slift $ \_ -> X.slice ivs
+sslice :: (KnownShape sh, Elt a) => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a
+sslice i n = withKnownNat n $ slift $ \_ -> X.slice i n
srev1 :: (KnownNat n, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (n : sh) a
srev1 = slift $ \_ -> X.rev1
+
+sreshape :: forall sh sh' a. (KnownShape sh, KnownShape sh', Elt a)
+ => ShS sh' -> Shaped sh a -> Shaped sh' a
+sreshape sh' (Shaped arr)
+ | Dict <- lemKnownMapJust (Proxy @sh)
+ , Dict <- lemKnownMapJust (Proxy @sh')
+ = Shaped (mreshape (shCvtSX sh') arr)
+
+sasXArrayPrimP :: Shaped sh (Primitive a) -> XArray (MapJust sh) a
+sasXArrayPrimP (Shaped arr) = masXArrayPrimP arr
+
+sasXArrayPrim :: PrimElt a => Shaped sh a -> XArray (MapJust sh) a
+sasXArrayPrim (Shaped arr) = masXArrayPrim arr
+
+sfromXArrayPrimP :: XArray (MapJust sh) a -> Shaped sh (Primitive a)
+sfromXArrayPrimP = Shaped . mfromXArrayPrimP
+
+sfromXArrayPrim :: PrimElt a => XArray (MapJust sh) a -> Shaped sh a
+sfromXArrayPrim = Shaped . mfromXArrayPrim