aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs93
1 files changed, 39 insertions, 54 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index d782e9f..672b832 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -13,6 +13,7 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Data.Array.Mixed where
@@ -27,8 +28,10 @@ import Foreign.Storable (Storable)
import GHC.TypeLits
import Unsafe.Coerce (unsafeCoerce)
-import Data.INat
+-- | Evidence for the constraint @c a@.
+data Dict c a where
+ Dict :: c a => Dict c a
-- | The 'SNat' pattern synonym is complete, but it doesn't have a
-- @COMPLETE@ pragma. This copy of it does.
@@ -103,11 +106,11 @@ instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
knownShapeX = () :!$? knownShapeX
type family Rank sh where
- Rank '[] = Z
- Rank (_ : sh) = S (Rank sh)
+ Rank '[] = 0
+ Rank (_ : sh) = 1 + Rank sh
type XArray :: [Maybe Nat] -> Type -> Type
-newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a)
+newtype XArray sh a = XArray (S.Array (Rank sh) a)
deriving (Show)
zeroIxX :: StaticShX sh -> IIxX sh
@@ -217,22 +220,22 @@ staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh
staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh
lemRankApp :: StaticShX sh1 -> StaticShX sh2
- -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2)
+ -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2
lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this
lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
- -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1))
+ -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1)
lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
-lemKnownINatRank :: IShX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRank ZSX = Dict
-lemKnownINatRank (_ :$@ sh) | Dict <- lemKnownINatRank sh = Dict
-lemKnownINatRank (_ :$? sh) | Dict <- lemKnownINatRank sh = Dict
+lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRank ZSX = Dict
+lemKnownNatRank (_ :$@ sh) | Dict <- lemKnownNatRank sh = Dict
+lemKnownNatRank (_ :$? sh) | Dict <- lemKnownNatRank sh = Dict
-lemKnownINatRankSSX :: StaticShX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRankSSX ZKSX = Dict
-lemKnownINatRankSSX (_ :!$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
-lemKnownINatRankSSX (_ :!$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
+lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRankSSX ZKSX = Dict
+lemKnownNatRankSSX (_ :!$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
+lemKnownNatRankSSX (_ :!$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh
lemKnownShapeX ZKSX = Dict
@@ -259,8 +262,7 @@ shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)
fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
fromVector sh v
- | Dict <- lemKnownINatRank sh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRank sh
= XArray (S.fromVector (shapeLshape sh) v)
toVector :: Storable a => XArray sh a -> VS.Vector a
@@ -274,15 +276,14 @@ unScalar (XArray a) = S.unScalar a
constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
constant sh x
- | Dict <- lemKnownINatRank sh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRank sh
= XArray (S.constant (shapeLshape sh) x)
generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a
generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)
-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a)
--- generateM sh f | Dict <- lemKnownINatRank sh =
+-- generateM sh f | Dict <- lemKnownNatRank sh =
-- XArray . S.fromVector (shapeLshape sh)
-- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh)
@@ -305,8 +306,7 @@ type family AddMaybe n m where
append :: forall n m sh a. (KnownShapeX sh, Storable a)
=> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a
append (XArray a) (XArray b)
- | Dict <- lemKnownINatRankSSX (knownShapeX @sh)
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
= XArray (S.append a b)
rerank :: forall sh sh1 sh2 a b.
@@ -315,15 +315,12 @@ rerank :: forall sh sh1 sh2 a b.
-> (XArray sh1 a -> XArray sh2 b)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
rerank ssh ssh1 ssh2 f (XArray arr)
- | Dict <- lemKnownINatRankSSX ssh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
- , Dict <- lemKnownINatRankSSX ssh2
- , Dict <- knownNatFromINat (Proxy @(Rank sh2))
+ | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- lemKnownNatRankSSX ssh2
, Refl <- lemRankApp ssh ssh1
, Refl <- lemRankApp ssh ssh2
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
- , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (S.rerank @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2))
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough
+ = XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2)
(\a -> unXArray (f (XArray a)))
arr)
where
@@ -342,15 +339,12 @@ rerank2 :: forall sh sh1 sh2 a b c.
-> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
- | Dict <- lemKnownINatRankSSX ssh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
- , Dict <- lemKnownINatRankSSX ssh2
- , Dict <- knownNatFromINat (Proxy @(Rank sh2))
+ | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- lemKnownNatRankSSX ssh2
, Refl <- lemRankApp ssh ssh1
, Refl <- lemRankApp ssh ssh2
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
- , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
- = XArray (S.rerank2 @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2))
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
+ = XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2)
(\a b -> unXArray (f (XArray a) (XArray b)))
arr1 arr2)
where
@@ -359,8 +353,7 @@ rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
-- | The list argument gives indices into the original dimension list.
transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
transpose perm (XArray arr)
- | Dict <- lemKnownINatRankSSX (knownShapeX @sh)
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
= XArray (S.transpose perm arr)
transpose2 :: forall sh1 sh2 a.
@@ -369,10 +362,8 @@ transpose2 :: forall sh1 sh2 a.
transpose2 ssh1 ssh2 (XArray arr)
| Refl <- lemRankApp ssh1 ssh2
, Refl <- lemRankApp ssh2 ssh1
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh2)
- , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh2)))
- , Dict <- lemKnownINatRankSSX (ssxAppend ssh2 ssh1)
- , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh1)))
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2)
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
, Refl <- lemRankAppComm ssh1 ssh2
, let n1 = ssxLength ssh1
= XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
@@ -395,13 +386,12 @@ sumOuter ssh ssh'
fromList1 :: forall n sh a. Storable a
=> StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
fromList1 ssh l
- | Dict <- lemKnownINatRankSSX ssh
- , Dict <- knownNatFromINat (Proxy @(Rank (n : sh)))
+ | Dict <- lemKnownNatRankSSX ssh
= case ssh of
m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) ->
error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++
"does not match the type (" ++ show (natVal m) ++ ")"
- _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l)))
+ _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l)))
toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a]
toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr))
@@ -409,8 +399,7 @@ toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr))
-- | Throws if the given shape is not, in fact, empty.
empty :: forall sh a. Storable a => IShX sh -> XArray sh a
empty sh
- | Dict <- lemKnownINatRank sh
- , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ | Dict <- lemKnownNatRank sh
= XArray (S.constant (shapeLshape sh)
(error "Data.Array.Mixed.empty: shape was not empty"))
@@ -423,17 +412,13 @@ rev1 (XArray arr) = XArray (S.rev [0] arr)
-- | Throws if the given array and the target shape do not have the same number of elements.
reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a
reshape ssh1 sh2 (XArray arr)
- | Dict <- lemKnownINatRankSSX ssh1
- , Dict <- knownNatFromINat (Proxy @(Rank sh1))
- , Dict <- lemKnownINatRank sh2
- , Dict <- knownNatFromINat (Proxy @(Rank sh2))
+ | Dict <- lemKnownNatRankSSX ssh1
+ , Dict <- lemKnownNatRank sh2
= XArray (S.reshape (shapeLshape sh2) arr)
-- | Throws if the given array and the target shape do not have the same number of elements.
reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a
reshapePartial ssh1 ssh' sh2 (XArray arr)
- | Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh')
- , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh')))
- , Dict <- lemKnownINatRankSSX (ssxAppend (staticShapeFrom sh2) ssh')
- , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh')))
+ | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh')
+ , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh')
= XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr)