diff options
Diffstat (limited to 'src/Data/Array')
-rw-r--r-- | src/Data/Array/Mixed.hs | 266 | ||||
-rw-r--r-- | src/Data/Array/Nested.hs | 19 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 381 |
3 files changed, 505 insertions, 161 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 0351beb..ce18431 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} @@ -7,12 +8,16 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed where @@ -21,14 +26,18 @@ import qualified Data.Array.Ranked as ORB import Data.Coerce import Data.Kind import Data.Proxy +import Data.Type.Bool import Data.Type.Equality import qualified Data.Vector.Storable as VS import Foreign.Storable (Storable) +import GHC.TypeError import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) -import Data.INat +-- | Evidence for the constraint @c a@. +data Dict c a where + Dict :: c a => Dict c a -- | The 'SNat' pattern synonym is complete, but it doesn't have a -- @COMPLETE@ pragma. This copy of it does. @@ -39,6 +48,28 @@ pattern GHC_SNat = SNat fromSNat' :: SNat n -> Int fromSNat' = fromIntegral . fromSNat +pattern SZ :: () => (n ~ 0) => SNat n +pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) + where SZ = SNat + +pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 +pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) + where SS = snatSucc + +{-# COMPLETE SZ, SS #-} + +snatSucc :: SNat n -> SNat (n + 1) +snatSucc SNat = SNat + +data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) +snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) +snatPred snp1 = + withKnownNat snp1 $ + case cmpNat (Proxy @1) (Proxy @np1) of + LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + GTI -> Nothing + -- | Type-level list append. type family l1 ++ l2 where @@ -51,6 +82,11 @@ lemAppNil = unsafeCoerce Refl lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl +type family Replicate n a where + Replicate 0 a = '[] + Replicate n a = a : Replicate (n - 1) a + + type IxX :: [Maybe Nat] -> Type -> Type data IxX sh i where ZIX :: IxX '[] i @@ -103,11 +139,11 @@ instance KnownShapeX sh => KnownShapeX (Nothing : sh) where knownShapeX = () :!$? knownShapeX type family Rank sh where - Rank '[] = Z - Rank (_ : sh) = S (Rank sh) + Rank '[] = 0 + Rank (_ : sh) = 1 + Rank sh type XArray :: [Maybe Nat] -> Type -> Type -newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a) +newtype XArray sh a = XArray (S.Array (Rank sh) a) deriving (Show) zeroIxX :: StaticShX sh -> IIxX sh @@ -157,6 +193,15 @@ ssxToShape' ZKSX = Just ZSX ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh ssxToShape' (_ :!$? _) = Nothing +lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc = unsafeCoerce Refl + +ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) +ssxReplicate SZ = ZKSX +ssxReplicate (SS (n :: SNat n')) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n' + = () :!$? ssxReplicate n + fromLinearIdx :: IShX sh -> Int -> IIxX sh fromLinearIdx = \sh i -> case go sh i of (idx, 0) -> idx @@ -211,23 +256,28 @@ ssxIotaFrom _ ZKSX = [] ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh +staticShapeFrom :: IShX sh -> StaticShX sh +staticShapeFrom ZSX = ZKSX +staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh +staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh + lemRankApp :: StaticShX sh1 -> StaticShX sh2 - -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2) + -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 - -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1)) + -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this -lemKnownINatRank :: IShX sh -> Dict KnownINat (Rank sh) -lemKnownINatRank ZSX = Dict -lemKnownINatRank (_ :$@ sh) | Dict <- lemKnownINatRank sh = Dict -lemKnownINatRank (_ :$? sh) | Dict <- lemKnownINatRank sh = Dict +lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRank ZSX = Dict +lemKnownNatRank (_ :$@ sh) | Dict <- lemKnownNatRank sh = Dict +lemKnownNatRank (_ :$? sh) | Dict <- lemKnownNatRank sh = Dict -lemKnownINatRankSSX :: StaticShX sh -> Dict KnownINat (Rank sh) -lemKnownINatRankSSX ZKSX = Dict -lemKnownINatRankSSX (_ :!$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict -lemKnownINatRankSSX (_ :!$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict +lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRankSSX ZKSX = Dict +lemKnownNatRankSSX (_ :!$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict +lemKnownNatRankSSX (_ :!$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh lemKnownShapeX ZKSX = Dict @@ -254,8 +304,7 @@ shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v - | Dict <- lemKnownINatRank sh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) + | Dict <- lemKnownNatRank sh = XArray (S.fromVector (shapeLshape sh) v) toVector :: Storable a => XArray sh a -> VS.Vector a @@ -269,15 +318,14 @@ unScalar (XArray a) = S.unScalar a constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a constant sh x - | Dict <- lemKnownINatRank sh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) + | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) x) generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) -- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) --- generateM sh f | Dict <- lemKnownINatRank sh = +-- generateM sh f | Dict <- lemKnownNatRank sh = -- XArray . S.fromVector (shapeLshape sh) -- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) @@ -300,8 +348,7 @@ type family AddMaybe n m where append :: forall n m sh a. (KnownShapeX sh, Storable a) => XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a append (XArray a) (XArray b) - | Dict <- lemKnownINatRankSSX (knownShapeX @sh) - , Dict <- knownNatFromINat (Proxy @(Rank sh)) + | Dict <- lemKnownNatRankSSX (knownShapeX @sh) = XArray (S.append a b) rerank :: forall sh sh1 sh2 a b. @@ -310,21 +357,18 @@ rerank :: forall sh sh1 sh2 a b. -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f (XArray arr) - | Dict <- lemKnownINatRankSSX ssh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - , Dict <- lemKnownINatRankSSX ssh2 - , Dict <- knownNatFromINat (Proxy @(Rank sh2)) + | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (S.rerank @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2)) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough + = XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) (\a -> unXArray (f (XArray a))) arr) where unXArray (XArray a) = a -rerankTop :: forall sh sh1 sh2 a b. +rerankTop :: forall sh1 sh2 sh a b. (Storable a, Storable b) => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh -> (XArray sh1 a -> XArray sh2 b) @@ -337,26 +381,135 @@ rerank2 :: forall sh sh1 sh2 a b c. -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) - | Dict <- lemKnownINatRankSSX ssh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - , Dict <- lemKnownINatRankSSX ssh2 - , Dict <- knownNatFromINat (Proxy @(Rank sh2)) + | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (S.rerank2 @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2)) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough + = XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) (\a b -> unXArray (f (XArray a) (XArray b))) arr1 arr2) where unXArray (XArray a) = a +type family Elem x l where + Elem x '[] = 'False + Elem x (x : _) = 'True + Elem x (_ : ys) = Elem x ys + +type family AllElem' as bs where + AllElem' '[] bs = 'True + AllElem' (a : as) bs = Elem a bs && AllElem' as bs + +type AllElem as bs = Assert (AllElem' as bs) + (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) + +type family Count i n where + Count n n = '[] + Count i n = i : Count (i + 1) n + +type Permutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) + +type family Index i sh where + Index 0 (n : sh) = n + Index i (_ : sh) = Index (i - 1) sh + +type family Permute is sh where + Permute '[] sh = '[] + Permute (i : is) sh = Index i sh : Permute is sh + +type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh + +data HList f list where + HNil :: HList f '[] + HCons :: f a -> HList f l -> HList f (a : l) +infixr 5 `HCons` + +foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m +foldHList _ HNil = mempty +foldHList f (x `HCons` l) = f x <> foldHList f l + +class KnownNatList l where makeNatList :: HList SNat l +instance KnownNatList '[] where makeNatList = HNil +instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList + +type family TakeLen ref l where + TakeLen '[] l = '[] + TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs + +type family DropLen ref l where + DropLen '[] l = l + DropLen (_ : ref) (_ : xs) = DropLen ref xs + +lemRankPermute :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is +lemRankPermute _ HNil = Refl +lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl + +lemRankDropLen :: forall is sh. (Rank is <= Rank sh) + => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is +lemRankDropLen ZKSX HNil = Refl +lemRankDropLen (_ :!$@ sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!$? sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!$@ _) HNil = Refl +lemRankDropLen (_ :!$? _) HNil = Refl +lemRankDropLen ZKSX (_ `HCons` _) = error "1 <= 0" + +lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l +lemIndexSucc _ _ _ = unsafeCoerce Refl + +ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLen HNil _ = ZKSX +ssxTakeLen (_ `HCons` is) (n :!$@ sh) = n :!$@ ssxTakeLen is sh +ssxTakeLen (_ `HCons` is) (n :!$? sh) = n :!$? ssxTakeLen is sh +ssxTakeLen (_ `HCons` _) ZKSX = error "Permutation longer than shape" + +ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh) +ssxDropLen HNil sh = sh +ssxDropLen (_ `HCons` is) (_ :!$@ sh) = ssxDropLen is sh +ssxDropLen (_ `HCons` is) (_ :!$? sh) = ssxDropLen is sh +ssxDropLen (_ `HCons` _) ZKSX = error "Permutation longer than shape" + +ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh) +ssxPermute HNil _ = ZKSX +ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh) + +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) +ssxIndex _ _ SZ (n :!$@ _) rest = n :!$@ rest +ssxIndex _ _ SZ (n :!$? _) rest = n :!$? rest +ssxIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :!$@ (sh :: StaticShX sh')) rest + | Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @sh') + = ssxIndex p pT i sh rest +ssxIndex p pT (SS (i :: SNat i')) (() :!$? (sh :: StaticShX sh')) rest + | Refl <- lemIndexSucc (Proxy @i') (Proxy @Nothing) (Proxy @sh') + = ssxIndex p pT i sh rest +ssxIndex _ _ _ ZKSX _ = error "Index into empty shape" + -- | The list argument gives indices into the original dimension list. -transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a +transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh) + => HList SNat is + -> XArray sh a + -> XArray (PermutePrefix is sh) a transpose perm (XArray arr) - | Dict <- lemKnownINatRankSSX (knownShapeX @sh) - , Dict <- knownNatFromINat (Proxy @(Rank sh)) + | Dict <- lemKnownNatRankSSX (knownShapeX @sh) + , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm (knownShapeX @sh))) (ssxDropLen perm (knownShapeX @sh)) + , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm + , Refl <- lemRankDropLen (knownShapeX @sh) perm + = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int] + in XArray (S.transpose perm' arr) + +-- | The list argument gives indices into the original dimension list. +-- +-- The permutation (the list) must have length <= @n@. If it is longer, this +-- function throws. +transposeUntyped :: forall n sh a. + SNat n -> StaticShX sh -> [Int] + -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a +transposeUntyped sn ssh perm (XArray arr) + | length perm <= fromSNat' sn + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) = XArray (S.transpose perm arr) + | otherwise + = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" transpose2 :: forall sh1 sh2 a. StaticShX sh1 -> StaticShX sh2 @@ -364,10 +517,8 @@ transpose2 :: forall sh1 sh2 a. transpose2 ssh1 ssh2 (XArray arr) | Refl <- lemRankApp ssh1 ssh2 , Refl <- lemRankApp ssh2 ssh1 - , Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh2) - , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh2))) - , Dict <- lemKnownINatRankSSX (ssxAppend ssh2 ssh1) - , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh1))) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) @@ -390,13 +541,12 @@ sumOuter ssh ssh' fromList1 :: forall n sh a. Storable a => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a fromList1 ssh l - | Dict <- lemKnownINatRankSSX ssh - , Dict <- knownNatFromINat (Proxy @(Rank (n : sh))) + | Dict <- lemKnownNatRankSSX ssh = case ssh of m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) -> error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ "does not match the type (" ++ show (natVal m) ++ ")" - _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l))) + _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a] toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr)) @@ -404,13 +554,29 @@ toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr)) -- | Throws if the given shape is not, in fact, empty. empty :: forall sh a. Storable a => IShX sh -> XArray sh a empty sh - | Dict <- lemKnownINatRank sh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) + | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) (error "Data.Array.Mixed.empty: shape was not empty")) -slice :: [(Int, Int)] -> XArray sh a -> XArray sh a -slice ivs (XArray arr) = XArray (S.slice ivs arr) +slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a +slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr) + +sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a +sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr) rev1 :: XArray (n : sh) a -> XArray (n : sh) a rev1 (XArray arr) = XArray (S.rev [0] arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a +reshape ssh1 sh2 (XArray arr) + | Dict <- lemKnownNatRankSSX ssh1 + , Dict <- lemKnownNatRank sh2 + = XArray (S.reshape (shapeLshape sh2) arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +reshapePartial ssh1 ssh' sh2 (XArray arr) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') + , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') + = XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index ec5f0b5..4b455da 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -9,9 +9,11 @@ module Data.Array.Nested ( rshape, rindex, rindexPartial, rgenerate, rsumOuter1, rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar, rconstant, rfromList, rfromList1, rtoList, rtoList1, - rslice, rrev1, + rslice, rrev1, rreshape, -- ** Lifting orthotope operations to 'Ranked' arrays rlift, + -- ** Conversions + rasXArrayPrim, rfromXArrayPrim, -- * Shaped arrays Shaped, @@ -21,33 +23,36 @@ module Data.Array.Nested ( sshape, sindex, sindexPartial, sgenerate, ssumOuter1, stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar, sconstant, sfromList, sfromList1, stoList, stoList1, - sslice, srev1, + sslice, srev1, sreshape, -- ** Lifting orthotope operations to 'Shaped' arrays slift, + -- ** Conversions + sasXArrayPrim, sfromXArrayPrim, -- * Mixed arrays Mixed, IxX(..), IIxX, KnownShapeX(..), StaticShX(..), mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar, - mconstant, mfromList, mtoList, mslice, mrev1, + mconstant, mfromList, mtoList, mslice, mrev1, mreshape, + -- ** Conversions + masXArrayPrim, mfromXArrayPrim, -- * Array elements Elt(mshape, mindex, mindexPartial, mscalar, mfromList1, mtoList1, mlift, mlift2), PrimElt, Primitive(..), - -- * Inductive natural numbers - module Data.INat, - -- * Further utilities / re-exports type (++), Storable, + HList, + Permutation, + makeNatList, ) where import Prelude hiding (mappend) import Data.Array.Mixed import Data.Array.Nested.Internal -import Data.INat import Foreign.Storable diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 350eb6f..7bd6565 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -20,6 +20,7 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-| @@ -27,9 +28,42 @@ TODO: * We should be more consistent in whether functions take a 'StaticShX' argument or a 'KnownShapeX' constraint. -* Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point - being that we need to do induction over the former, but the latter need to be - able to get large. +* Mikolaj wants these: + + About your wishlist of operations: these are already there + + OR.index + OR.append + OR.transpose + + These can be easily lifted from the definition for XArray (5min work): + + OR.scalar + OR.unScalar + OR.constant + + These should not be hard: + + OR.fromList + ORB.toList . OR.unravel + OR.ravel . ORB.fromList + OR.slice + OR.rev + OR.reshape + + though it's a bit unfortunate that we end up needing toList. Looking in + horde-ad I see that you seem to need them to do certain operations in Haskell + that orthotope doesn't support? + + For this one we'll need to see to what extent you really need it, and what API + you'd need precisely: + + OR.rerank + + and for these we have an API design question: + + OR.toVector + OR.fromVector -} @@ -52,9 +86,8 @@ import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Storable (Storable) import GHC.TypeLits -import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat) +import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat, Dict(..), HList(..), pattern SZ, pattern SS, Replicate) import qualified Data.Array.Mixed as X -import Data.INat -- Invariant in the API @@ -90,35 +123,60 @@ import Data.INat -- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. -type family Replicate n a where - Replicate Z a = '[] - Replicate (S n) a = a : Replicate n a - type family MapJust l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs -lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) -lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n)) + +-- Stupid things that the type checker should be able to figure out in-line, but can't + +subst1 :: forall f a b. a :~: b -> f a :~: f b +subst1 Refl = Refl + +subst2 :: forall f c a b. a :~: b -> f a c :~: f b c +subst2 Refl = Refl + +lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l +lemAppLeft _ Refl = Refl + +knownNatSucc :: KnownNat n => Dict KnownNat (n + 1) +knownNatSucc = Dict + + +lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) +lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) where - go :: SINat m -> StaticShX (Replicate m Nothing) + go :: SNat m -> StaticShX (Replicate m Nothing) go SZ = ZKSX - go (SS n) = () :!$? go n + go (SS (n :: SNat nm1)) | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 = () :!$? go n -lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate _ = go (inatSing @n) +lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = go (natSing @n) where - go :: SINat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m + go :: forall m. SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m go SZ = Refl - go (SS n) | Refl <- go n = Refl + go (SS (n :: SNat nm1)) + | Refl <- X.lemReplicateSucc @(Nothing @Nat) @nm1 + , Refl <- go n + = Refl -lemReplicatePlusApp :: forall n m a. KnownINat n => Proxy n -> Proxy m -> Proxy a - -> Replicate (n +! m) a :~: Replicate n a ++ Replicate m a -lemReplicatePlusApp _ _ _ = go (inatSing @n) +lemRankMapJust :: forall sh. KnownShape sh => Proxy sh -> X.Rank (MapJust sh) :~: X.Rank sh +lemRankMapJust _ = go (knownShape @sh) where - go :: SINat n' -> Replicate (n' +! m) a :~: Replicate n' a ++ Replicate m a + go :: forall sh'. ShS sh' -> X.Rank (MapJust sh') :~: X.Rank sh' + go ZSS = Refl + go (_ :$$ sh') | Refl <- go sh' = Refl + +lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a + -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +lemReplicatePlusApp _ _ _ = go (natSing @n) + where + go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a go SZ = Refl - go (SS n) | Refl <- go n = Refl + go (SS (n :: SNat n'm1)) + | Refl <- X.lemReplicateSucc @a @n'm1 + , Refl <- go n + = sym (X.lemReplicateSucc @a @(n'm1 + m)) shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') shAppSplit _ ZKSX idx = (ZSX, idx) @@ -494,10 +552,12 @@ mgenerate sh f = case X.enumShape sh of mvecsWrite sh idx val vecs mvecsFreeze sh vecs -mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a -mtranspose perm = - mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh') - (X.transpose perm)) +mtranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShapeX sh, Elt a) => HList SNat is -> Mixed sh a -> Mixed (X.PermutePrefix is sh) a +mtranspose perm + | Dict <- X.lemKnownShapeX (X.ssxAppend (X.ssxPermute perm (X.ssxTakeLen perm (knownShapeX @sh))) (X.ssxDropLen perm (knownShapeX @sh))) + = mlift $ \(Proxy @sh') -> + X.rerankTop (knownShapeX @sh) (knownShapeX @(X.PermutePrefix is sh)) (knownShapeX @sh') + (X.transpose perm) mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a) => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a @@ -534,12 +594,32 @@ mconstant :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a) => IShX sh -> a -> Mixed sh a mconstant sh x = fromPrimitive (mconstantP sh x) -mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a -mslice ivs = mlift $ \_ -> X.slice ivs +mslice :: (KnownShapeX sh, Elt a) => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a +mslice i n = withKnownNat n $ mlift $ \_ -> X.slice i n + +msliceU :: (KnownShapeX sh, Elt a) => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a +msliceU i n = mlift $ \_ -> X.sliceU i n mrev1 :: (KnownShapeX (n : sh), Elt a) => Mixed (n : sh) a -> Mixed (n : sh) a mrev1 = mlift $ \_ -> X.rev1 +mreshape :: forall sh sh' a. (KnownShapeX sh, KnownShapeX sh', Elt a) + => IShX sh' -> Mixed sh a -> Mixed sh' a +mreshape sh' = mlift $ \(_ :: Proxy shIn) -> + X.reshapePartial (knownShapeX @sh) (knownShapeX @shIn) sh' + +masXArrayPrimP :: Mixed sh (Primitive a) -> XArray sh a +masXArrayPrimP (M_Primitive arr) = arr + +masXArrayPrim :: PrimElt a => Mixed sh a -> XArray sh a +masXArrayPrim = masXArrayPrimP . toPrimitive + +mfromXArrayPrimP :: XArray sh a -> Mixed sh (Primitive a) +mfromXArrayPrimP = M_Primitive + +mfromXArrayPrim :: PrimElt a => XArray sh a -> Mixed sh a +mfromXArrayPrim = fromPrimitive . mfromXArrayPrimP + mliftPrim :: (KnownShapeX sh, Storable a) => (a -> a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) @@ -570,18 +650,15 @@ deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed s -- | A rank-typed array: the number of dimensions of the array (its /rank/) is --- represented on the type level as a 'INat'. +-- represented on the type level as a 'Nat'. -- -- Valid elements of a ranked arrays are described by the 'Elt' type class. -- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are -- supported (and are represented as a single, flattened, struct-of-arrays -- array internally). -- --- Note that this 'INat' is not a "GHC.TypeLits" natural, because we want a --- type-level natural that supports induction. --- -- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. -type Ranked :: INat -> Type -> Type +type Ranked :: Nat -> Type -> Type newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) @@ -611,7 +688,7 @@ newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixe -- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; -- these instances allow them to also be used as elements of arrays, thus -- making them first-class in the API. -instance (Elt a, KnownINat n) => Elt (Ranked n a) where +instance (Elt a, KnownNat n) => Elt (Ranked n a) where mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) @@ -732,13 +809,10 @@ lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) go ZSS = ZKSX go (n :$$ sh) = n :!$@ go sh -lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2 +lemCommMapJustApp :: forall sh1 sh2. ShS sh1 -> Proxy sh2 -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemMapJustPlusApp _ _ = go (knownShape @sh1) - where - go :: ShS sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2 - go ZSS = Refl - go (_ :$$ sh) | Refl <- go sh = Refl +lemCommMapJustApp ZSS _ = Refl +lemCommMapJustApp (_ :$$ sh) p | Refl <- lemCommMapJustApp sh p = Refl instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr @@ -843,37 +917,37 @@ rewriteMixed Refl x = x -- ====== API OF RANKED ARRAYS ====== -- -arithPromoteRanked :: forall n a. KnownINat n +arithPromoteRanked :: forall n a. KnownNat n => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a) -> Ranked n a -> Ranked n a arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce -arithPromoteRanked2 :: forall n a. KnownINat n +arithPromoteRanked2 :: forall n a. KnownNat n => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a) -> Ranked n a -> Ranked n a -> Ranked n a arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce -instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where +instance (KnownNat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where (+) = arithPromoteRanked2 (+) (-) = arithPromoteRanked2 (-) (*) = arithPromoteRanked2 (*) negate = arithPromoteRanked negate abs = arithPromoteRanked abs signum = arithPromoteRanked signum - fromInteger n = case inatSing @n of + fromInteger n = case natSing @n of SZ -> Ranked (M_Primitive (X.scalar (fromInteger n))) - SS _ -> error "Data.Array.Nested.fromIntegral(Ranked): \ - \Rank non-zero, use explicit mconstant" + _ -> error "Data.Array.Nested.fromIntegral(Ranked): \ + \Rank non-zero, use explicit mconstant" -- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types) -deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int) -deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double) +deriving via Ranked n (Primitive Int) instance KnownNat n => Num (Ranked n Int) +deriving via Ranked n (Primitive Double) instance KnownNat n => Num (Ranked n Double) type role ListR nominal representational -type ListR :: INat -> Type -> Type +type ListR :: Nat -> Type -> Type data ListR n i where - ZR :: ListR Z i - (:::) :: forall n {i}. i -> ListR n i -> ListR (S n) i + ZR :: ListR 0 i + (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i deriving instance Show i => Show (ListR n i) deriving instance Eq i => Eq (ListR n i) deriving instance Ord i => Ord (ListR n i) @@ -887,23 +961,23 @@ listRToList :: ListR n i -> [i] listRToList ZR = [] listRToList (i ::: is) = i : listRToList is -knownListR :: ListR n i -> Dict KnownINat n +knownListR :: ListR n i -> Dict KnownNat n knownListR ZR = Dict -knownListR (_ ::: l) | Dict <- knownListR l = Dict +knownListR (_ ::: (l :: ListR m i)) | Dict <- knownListR l = knownNatSucc @m -- | An index into a rank-typed array. type role IxR nominal representational -type IxR :: INat -> Type -> Type +type IxR :: Nat -> Type -> Type newtype IxR n i = IxR (ListR n i) deriving (Show, Eq, Ord) deriving newtype (Functor, Foldable) -pattern ZIR :: forall n i. () => n ~ Z => IxR n i +pattern ZIR :: forall n i. () => n ~ 0 => IxR n i pattern ZIR = IxR ZR pattern (:.:) :: forall {n1} {i}. - forall n. (S n ~ n1) + forall n. (n + 1 ~ n1) => i -> IxR n i -> IxR n1 i pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i)) where i :.: IxR sh = IxR (i ::: sh) @@ -911,30 +985,30 @@ pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i)) infixr 3 :.: data UnconsIxRRes i n1 = - forall n. ((S n) ~ n1) => UnconsIxRRes (IxR n i) i + forall n. (n + 1 ~ n1) => UnconsIxRRes (IxR n i) i unconsIxR :: IxR n1 i -> Maybe (UnconsIxRRes i n1) unconsIxR (IxR (i ::: sh')) = Just (UnconsIxRRes (IxR sh') i) unconsIxR (IxR ZR) = Nothing type IIxR n = IxR n Int -knownIxR :: IxR n i -> Dict KnownINat n +knownIxR :: IxR n i -> Dict KnownNat n knownIxR (IxR sh) = knownListR sh type role ShR nominal representational -type ShR :: INat -> Type -> Type +type ShR :: Nat -> Type -> Type newtype ShR n i = ShR (ListR n i) deriving (Show, Eq, Ord) deriving newtype (Functor, Foldable) type IShR n = ShR n Int -pattern ZSR :: forall n i. () => n ~ Z => ShR n i +pattern ZSR :: forall n i. () => n ~ 0 => ShR n i pattern ZSR = ShR ZR pattern (:$:) :: forall {n1} {i}. - forall n. (S n ~ n1) + forall n. (n + 1 ~ n1) => i -> ShR n i -> ShR n1 i pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i)) where i :$: (ShR sh) = ShR (i ::: sh) @@ -942,15 +1016,15 @@ pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i)) infixr 3 :$: data UnconsShRRes i n1 = - forall n. S n ~ n1 => UnconsShRRes (ShR n i) i + forall n. n + 1 ~ n1 => UnconsShRRes (ShR n i) i unconsShR :: ShR n1 i -> Maybe (UnconsShRRes i n1) unconsShR (ShR (i ::: sh')) = Just (UnconsShRRes (ShR sh') i) unconsShR (ShR ZR) = Nothing -knownShR :: ShR n i -> Dict KnownINat n +knownShR :: ShR n i -> Dict KnownNat n knownShR (ShR sh) = knownListR sh -zeroIxR :: SINat n -> IIxR n +zeroIxR :: SNat n -> IIxR n zeroIxR SZ = ZIR zeroIxR (SS n) = 0 :.: zeroIxR n @@ -966,18 +1040,18 @@ shCvtXR (n :$? idx) = n :$: shCvtXR idx ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) ixCvtRX ZIR = ZIX -ixCvtRX (n :.: idx) = n :.? ixCvtRX idx +ixCvtRX (n :.: (idx :: IxR m Int)) = castWith (subst2 @IxX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :.? ixCvtRX idx) shCvtRX :: IShR n -> IShX (Replicate n Nothing) shCvtRX ZSR = ZSX -shCvtRX (n :$: idx) = n :$? shCvtRX idx +shCvtRX (n :$: (idx :: ShR m Int)) = castWith (subst2 @ShX @Int (X.lemReplicateSucc @(Nothing @Nat) @m)) (n :$? shCvtRX idx) shapeSizeR :: IShR n -> Int shapeSizeR ZSR = 1 shapeSizeR (n :$: sh) = n * shapeSizeR sh -rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IShR n +rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IShR n rshape (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) , Refl <- lemRankReplicate (Proxy @n) @@ -986,7 +1060,7 @@ rshape (Ranked arr) rindex :: Elt a => Ranked n a -> IIxR n -> a rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) -rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IIxR n -> Ranked m a +rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IIxR n -> Ranked m a rindexPartial (Ranked arr) idx = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) @@ -1002,47 +1076,54 @@ rgenerate sh f = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) -- | See the documentation of 'mlift'. -rlift :: forall n1 n2 a. (KnownINat n2, Elt a) - => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) +rlift :: forall n1 n2 a. (KnownNat n2, Elt a) + => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) -> Ranked n1 a -> Ranked n2 a rlift f (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n2) = Ranked (mlift f arr) rsumOuter1P :: forall n a. - (Storable a, Num a, KnownINat n) - => Ranked (S n) (Primitive a) -> Ranked n (Primitive a) + (Storable a, Num a, KnownNat n) + => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) rsumOuter1P (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) + , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n = Ranked . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a)) . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing)) - . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a) + . coerce @(Mixed (Replicate (n + 1) Nothing) (Primitive a)) @(XArray (Replicate (n + 1) Nothing) a) $ arr -rsumOuter1 :: forall n a. - (Storable a, Num a, PrimElt a, KnownINat n) - => Ranked (S n) a -> Ranked n a +rsumOuter1 :: forall n a. (Storable a, Num a, PrimElt a, KnownNat n) + => Ranked (n + 1) a -> Ranked n a rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive -rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a -rtranspose perm (Ranked arr) +rtranspose :: forall n a. (KnownNat n, Elt a) => [Int] -> Ranked n a -> Ranked n a +rtranspose perm | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mtranspose perm arr) - -rappend :: forall n a. (KnownINat n, Elt a) - => Ranked (S n) a -> Ranked (S n) a -> Ranked (S n) a -rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend + , length perm <= fromIntegral (natVal (Proxy @n)) + = rlift $ \(Proxy @sh') -> + X.transposeUntyped (natSing @n) (knownShapeX @sh') perm + | otherwise + = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" + +rappend :: forall n a. (KnownNat n, Elt a) + => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a +rappend + | Dict <- lemKnownReplicate (Proxy @n) + , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) -rscalar :: Elt a => a -> Ranked I0 a +rscalar :: Elt a => a -> Ranked 0 a rscalar x = Ranked (mscalar x) -rfromVectorP :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVectorP :: forall n a. (KnownNat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a) rfromVectorP sh v | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mfromVectorP (shCvtRX sh) v) -rfromVector :: forall n a. (KnownINat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a +rfromVector :: forall n a. (KnownNat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a rfromVector sh v = coerce fromPrimitive (rfromVectorP sh v) rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a @@ -1051,37 +1132,63 @@ rtoVectorP = coerce mtoVectorP rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a rtoVector = coerce mtoVector -rfromList1 :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a +rfromList1 :: forall n a. (KnownNat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (n + 1) a rfromList1 l | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mfromList1 (coerce l)) + , Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + = Ranked (mfromList1 @a @Nothing @(Replicate n Nothing) (coerce l)) -rfromList :: Elt a => NonEmpty a -> Ranked I1 a +rfromList :: Elt a => NonEmpty a -> Ranked 1 a rfromList = Ranked . mfromList1 . fmap mscalar -rtoList :: Elt a => Ranked (S n) a -> [Ranked n a] -rtoList (Ranked arr) = coerce (mtoList1 arr) +rtoList :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoList (Ranked arr) + | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + = coerce (mtoList1 @a @Nothing @(Replicate n Nothing) arr) -rtoList1 :: Elt a => Ranked I1 a -> [a] +rtoList1 :: Elt a => Ranked 1 a -> [a] rtoList1 = map runScalar . rtoList -runScalar :: Elt a => Ranked I0 a -> a +runScalar :: Elt a => Ranked 0 a -> a runScalar arr = rindex arr ZIR -rconstantP :: forall n a. (KnownINat n, Storable a) => IShR n -> a -> Ranked n (Primitive a) +rconstantP :: forall n a. (KnownNat n, Storable a) => IShR n -> a -> Ranked n (Primitive a) rconstantP sh x | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mconstantP (shCvtRX sh) x) -rconstant :: forall n a. (KnownINat n, Storable a, PrimElt a) +rconstant :: forall n a. (KnownNat n, Storable a, PrimElt a) => IShR n -> a -> Ranked n a rconstant sh x = coerce fromPrimitive (rconstantP sh x) -rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a -rslice ivs = rlift $ \_ -> X.slice ivs +rslice :: forall n a. (KnownNat n, Elt a) => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a +rslice i n + | Refl <- X.lemReplicateSucc @(Nothing @Nat) @n + = rlift $ \_ -> X.sliceU i n + +rrev1 :: forall n a. (KnownNat n, Elt a) => Ranked (n + 1) a -> Ranked (n + 1) a +rrev1 = rlift $ \(Proxy @sh') -> + case X.lemReplicateSucc @(Nothing @Nat) @n of + Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh') -rrev1 :: (KnownINat n, Elt a) => Ranked (S n) a -> Ranked (S n) a -rrev1 = rlift $ \_ -> X.rev1 +rreshape :: forall n n' a. (KnownNat n, KnownNat n', Elt a) + => IShR n' -> Ranked n a -> Ranked n' a +rreshape sh' (Ranked arr) + | Dict <- lemKnownReplicate (Proxy @n) + , Dict <- lemKnownReplicate (Proxy @n') + = Ranked (mreshape (shCvtRX sh') arr) + +rasXArrayPrimP :: Ranked n (Primitive a) -> XArray (Replicate n Nothing) a +rasXArrayPrimP (Ranked arr) = masXArrayPrimP arr + +rasXArrayPrim :: PrimElt a => Ranked n a -> XArray (Replicate n Nothing) a +rasXArrayPrim (Ranked arr) = masXArrayPrim arr + +rfromXArrayPrimP :: XArray (Replicate n Nothing) a -> Ranked n (Primitive a) +rfromXArrayPrimP = Ranked . mfromXArrayPrimP + +rfromXArrayPrim :: PrimElt a => XArray (Replicate n Nothing) a -> Ranked n a +rfromXArrayPrim = Ranked . mfromXArrayPrim -- ====== API OF SHAPED ARRAYS ====== -- @@ -1200,7 +1307,7 @@ sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a sindexPartial (Shaped arr) idx = Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) - (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) + (rewriteMixed (lemCommMapJustApp (knownShape @sh1) (Proxy @sh2)) arr) (ixCvtSX idx)) -- | __WARNING__: All values returned from the function must have equal shape. @@ -1212,7 +1319,7 @@ sgenerate f -- | See the documentation of 'mlift'. slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) - => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) + => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) -> Shaped sh1 a -> Shaped sh2 a slift f (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh2) @@ -1234,9 +1341,56 @@ ssumOuter1 :: forall sh n a. => Shaped (n : sh) a -> Shaped sh a ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive -stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a +lemCommMapJustTakeLen :: HList SNat is -> ShS sh -> X.TakeLen is (MapJust sh) :~: MapJust (X.TakeLen is sh) +lemCommMapJustTakeLen HNil _ = Refl +lemCommMapJustTakeLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustTakeLen is sh = Refl +lemCommMapJustTakeLen (_ `HCons` _) ZSS = error "TakeLen of empty" + +lemCommMapJustDropLen :: HList SNat is -> ShS sh -> X.DropLen is (MapJust sh) :~: MapJust (X.DropLen is sh) +lemCommMapJustDropLen HNil _ = Refl +lemCommMapJustDropLen (_ `HCons` is) (_ :$$ sh) | Refl <- lemCommMapJustDropLen is sh = Refl +lemCommMapJustDropLen (_ `HCons` _) ZSS = error "DropLen of empty" + +lemCommMapJustIndex :: SNat i -> ShS sh -> X.Index i (MapJust sh) :~: Just (X.Index i sh) +lemCommMapJustIndex SZ (_ :$$ _) = Refl +lemCommMapJustIndex (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) + | Refl <- lemCommMapJustIndex i sh + , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) + , Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = Refl +lemCommMapJustIndex _ ZSS = error "Index of empty" + +lemCommMapJustPermute :: HList SNat is -> ShS sh -> X.Permute is (MapJust sh) :~: MapJust (X.Permute is sh) +lemCommMapJustPermute HNil _ = Refl +lemCommMapJustPermute (i `HCons` is) sh + | Refl <- lemCommMapJustPermute is sh + , Refl <- lemCommMapJustIndex i sh + = Refl + +shTakeLen :: HList SNat is -> ShS sh -> ShS (X.TakeLen is sh) +shTakeLen HNil _ = ZSS +shTakeLen (_ `HCons` is) (n :$$ sh) = n :$$ shTakeLen is sh +shTakeLen (_ `HCons` _) ZSS = error "Permutation longer than shape" + +shPermute :: HList SNat is -> ShS sh -> ShS (X.Permute is sh) +shPermute HNil _ = ZSS +shPermute (i `HCons` (is :: HList SNat is')) (sh :: ShS sh) = shIndex (Proxy @is') (Proxy @sh) i sh (shPermute is sh) + +shIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> ShS (X.Permute is shT) -> ShS (X.Index i sh : X.Permute is shT) +shIndex _ _ SZ (n :$$ _) rest = n :$$ rest +shIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) rest + | Refl <- X.lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = shIndex p pT i sh rest +shIndex _ _ _ ZSS _ = error "Index into empty shape" + +stranspose :: forall is sh a. (X.Permutation is, X.Rank is <= X.Rank sh, KnownShape sh, Elt a) => HList SNat is -> Shaped sh a -> Shaped (X.PermutePrefix is sh) a stranspose perm (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) + , Refl <- lemRankMapJust (Proxy @sh) + , Refl <- lemCommMapJustTakeLen perm (knownShape @sh) + , Refl <- lemCommMapJustDropLen perm (knownShape @sh) + , Refl <- lemCommMapJustPermute perm (shTakeLen perm (knownShape @sh)) + , Refl <- lemCommMapJustApp (shPermute perm (shTakeLen perm (knownShape @sh))) (Proxy @(X.DropLen is sh)) = Shaped (mtranspose perm arr) sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a) @@ -1287,8 +1441,27 @@ sconstant :: forall sh a. (KnownShape sh, Storable a, PrimElt a) => a -> Shaped sh a sconstant x = coerce fromPrimitive (sconstantP @sh x) -sslice :: (KnownShape sh, Elt a) => [(Int, Int)] -> Shaped sh a -> Shaped sh a -sslice ivs = slift $ \_ -> X.slice ivs +sslice :: (KnownShape sh, Elt a) => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a +sslice i n = withKnownNat n $ slift $ \_ -> X.slice i n srev1 :: (KnownNat n, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (n : sh) a srev1 = slift $ \_ -> X.rev1 + +sreshape :: forall sh sh' a. (KnownShape sh, KnownShape sh', Elt a) + => ShS sh' -> Shaped sh a -> Shaped sh' a +sreshape sh' (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + , Dict <- lemKnownMapJust (Proxy @sh') + = Shaped (mreshape (shCvtSX sh') arr) + +sasXArrayPrimP :: Shaped sh (Primitive a) -> XArray (MapJust sh) a +sasXArrayPrimP (Shaped arr) = masXArrayPrimP arr + +sasXArrayPrim :: PrimElt a => Shaped sh a -> XArray (MapJust sh) a +sasXArrayPrim (Shaped arr) = masXArrayPrim arr + +sfromXArrayPrimP :: XArray (MapJust sh) a -> Shaped sh (Primitive a) +sfromXArrayPrimP = Shaped . mfromXArrayPrimP + +sfromXArrayPrim :: PrimElt a => XArray (MapJust sh) a -> Shaped sh a +sfromXArrayPrim = Shaped . mfromXArrayPrim |