aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs246
1 files changed, 143 insertions, 103 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 47027cb..94b7cdf 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -1,6 +1,10 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
@@ -32,6 +36,9 @@ pattern GHC_SNat :: () => KnownNat n => SNat n
pattern GHC_SNat = SNat
{-# COMPLETE GHC_SNat #-}
+fromSNat' :: SNat n -> Int
+fromSNat' = fromIntegral . fromSNat
+
-- | Type-level list append.
type family l1 ++ l2 where
@@ -44,9 +51,6 @@ lemAppNil = unsafeCoerce Refl
lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerce Refl
--- TODO: ListX? But if so, why is StaticShapeX not defined as a newtype
--- over IxX (so that we can make IxX and StaticShapeX a newtype over ListX)?
-
type IxX :: [Maybe Nat] -> Type -> Type
data IxX sh i where
ZIX :: IxX '[] i
@@ -55,31 +59,48 @@ data IxX sh i where
deriving instance Show i => Show (IxX sh i)
deriving instance Eq i => Eq (IxX sh i)
deriving instance Ord i => Ord (IxX sh i)
+deriving instance Functor (IxX sh)
+deriving instance Foldable (IxX sh)
infixr 3 :.@
infixr 3 :.?
type IIxX sh = IxX sh Int
--- | The part of a shape that is statically known.
-type StaticShapeX :: [Maybe Nat] -> Type
-data StaticShapeX sh where
- ZSX :: StaticShapeX '[]
- (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh)
- (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh)
-deriving instance Show (StaticShapeX sh)
+type ShX :: [Maybe Nat] -> Type -> Type
+data ShX sh i where
+ ZSX :: ShX '[] i
+ (:$@) :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i
+ (:$?) :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i
+deriving instance Show i => Show (ShX sh i)
+deriving instance Eq i => Eq (ShX sh i)
+deriving instance Ord i => Ord (ShX sh i)
+deriving instance Functor (ShX sh)
+deriving instance Foldable (ShX sh)
infixr 3 :$@
infixr 3 :$?
+type IShX sh = ShX sh Int
+
+-- | The part of a shape that is statically known.
+type StaticShX :: [Maybe Nat] -> Type
+data StaticShX sh where
+ ZKSX :: StaticShX '[]
+ (:!$@) :: SNat n -> StaticShX sh -> StaticShX (Just n : sh)
+ (:!$?) :: () -> StaticShX sh -> StaticShX (Nothing : sh)
+deriving instance Show (StaticShX sh)
+infixr 3 :!$@
+infixr 3 :!$?
+
-- | Evidence for the static part of a shape.
type KnownShapeX :: [Maybe Nat] -> Constraint
class KnownShapeX sh where
- knownShapeX :: StaticShapeX sh
+ knownShapeX :: StaticShX sh
instance KnownShapeX '[] where
- knownShapeX = ZSX
+ knownShapeX = ZKSX
instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
- knownShapeX = natSing :$@ knownShapeX
+ knownShapeX = natSing :!$@ knownShapeX
instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
- knownShapeX = () :$? knownShapeX
+ knownShapeX = () :!$? knownShapeX
type family Rank sh where
Rank '[] = Z
@@ -89,138 +110,149 @@ type XArray :: [Maybe Nat] -> Type -> Type
newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a)
deriving (Show)
-zeroIxX :: StaticShapeX sh -> IIxX sh
-zeroIxX ZSX = ZIX
-zeroIxX (_ :$@ ssh) = 0 :.@ zeroIxX ssh
-zeroIxX (_ :$? ssh) = 0 :.? zeroIxX ssh
+zeroIxX :: StaticShX sh -> IIxX sh
+zeroIxX ZKSX = ZIX
+zeroIxX (_ :!$@ ssh) = 0 :.@ zeroIxX ssh
+zeroIxX (_ :!$? ssh) = 0 :.? zeroIxX ssh
+
+zeroIxX' :: IShX sh -> IIxX sh
+zeroIxX' ZSX = ZIX
+zeroIxX' (_ :$@ sh) = 0 :.@ zeroIxX' sh
+zeroIxX' (_ :$? sh) = 0 :.? zeroIxX' sh
-zeroIxX' :: IIxX sh -> IIxX sh
-zeroIxX' ZIX = ZIX
-zeroIxX' (_ :.@ sh) = 0 :.@ zeroIxX' sh
-zeroIxX' (_ :.? sh) = 0 :.? zeroIxX' sh
+-- This is a weird operation, so it has a long name
+completeShXzeros :: StaticShX sh -> IShX sh
+completeShXzeros ZKSX = ZSX
+completeShXzeros (n :!$@ ssh) = n :$@ completeShXzeros ssh
+completeShXzeros (_ :!$? ssh) = 0 :$? completeShXzeros ssh
ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh')
ixAppend ZIX idx' = idx'
ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx'
ixAppend (i :.? idx) idx' = i :.? ixAppend idx idx'
+shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh')
+shAppend ZSX sh' = sh'
+shAppend (n :$@ sh) sh' = n :$@ shAppend sh sh'
+shAppend (n :$? sh) sh' = n :$? shAppend sh sh'
+
ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh'
ixDrop sh ZIX = sh
ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx
ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx
-ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh')
-ssxAppend ZSX sh' = sh'
-ssxAppend (n :$@ sh) sh' = n :$@ ssxAppend sh sh'
-ssxAppend (() :$? sh) sh' = () :$? ssxAppend sh sh'
+ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
+ssxAppend ZKSX sh' = sh'
+ssxAppend (n :!$@ sh) sh' = n :!$@ ssxAppend sh sh'
+ssxAppend (() :!$? sh) sh' = () :!$? ssxAppend sh sh'
-shapeSize :: IIxX sh -> Int
-shapeSize ZIX = 1
-shapeSize (n :.@ sh) = n * shapeSize sh
-shapeSize (n :.? sh) = n * shapeSize sh
+shapeSize :: IShX sh -> Int
+shapeSize ZSX = 1
+shapeSize (n :$@ sh) = fromSNat' n * shapeSize sh
+shapeSize (n :$? sh) = n * shapeSize sh
-- | This may fail if @sh@ has @Nothing@s in it.
-ssxToShape' :: StaticShapeX sh -> Maybe (IIxX sh)
-ssxToShape' ZSX = Just ZIX
-ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) :.@) <$> ssxToShape' sh
-ssxToShape' (_ :$? _) = Nothing
+ssxToShape' :: StaticShX sh -> Maybe (IShX sh)
+ssxToShape' ZKSX = Just ZSX
+ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh
+ssxToShape' (_ :!$? _) = Nothing
-fromLinearIdx :: IIxX sh -> Int -> IIxX sh
+fromLinearIdx :: IShX sh -> Int -> IIxX sh
fromLinearIdx = \sh i -> case go sh i of
(idx, 0) -> idx
_ -> error $ "fromLinearIdx: out of range (" ++ show i ++
" in array of shape " ++ show sh ++ ")"
where
-- returns (index in subarray, remaining index in enclosing array)
- go :: IIxX sh -> Int -> (IIxX sh, Int)
- go ZIX i = (ZIX, i)
- go (n :.@ sh) i =
+ go :: IShX sh -> Int -> (IIxX sh, Int)
+ go ZSX i = (ZIX, i)
+ go (n :$@ sh) i =
let (idx, i') = go sh i
- (upi, locali) = i' `quotRem` n
+ (upi, locali) = i' `quotRem` fromSNat' n
in (locali :.@ idx, upi)
- go (n :.? sh) i =
+ go (n :$? sh) i =
let (idx, i') = go sh i
(upi, locali) = i' `quotRem` n
in (locali :.? idx, upi)
-toLinearIdx :: IIxX sh -> IIxX sh -> Int
+toLinearIdx :: IShX sh -> IIxX sh -> Int
toLinearIdx = \sh i -> fst (go sh i)
where
-- returns (index in subarray, size of subarray)
- go :: IIxX sh -> IIxX sh -> (Int, Int)
- go ZIX ZIX = (0, 1)
- go (n :.@ sh) (i :.@ ix) =
+ go :: IShX sh -> IIxX sh -> (Int, Int)
+ go ZSX ZIX = (0, 1)
+ go (n :$@ sh) (i :.@ ix) =
let (lidx, sz) = go sh ix
- in (sz * i + lidx, n * sz)
- go (n :.? sh) (i :.? ix) =
+ in (sz * i + lidx, fromSNat' n * sz)
+ go (n :$? sh) (i :.? ix) =
let (lidx, sz) = go sh ix
in (sz * i + lidx, n * sz)
-enumShape :: IIxX sh -> [IIxX sh]
+enumShape :: IShX sh -> [IIxX sh]
enumShape = \sh -> go sh id []
where
- go :: IIxX sh -> (IIxX sh -> a) -> [a] -> [a]
- go ZIX f = (f ZIX :)
- go (n :.@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. n-1]]
- go (n :.? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]]
-
-shapeLshape :: IIxX sh -> S.ShapeL
-shapeLshape ZIX = []
-shapeLshape (n :.@ sh) = n : shapeLshape sh
-shapeLshape (n :.? sh) = n : shapeLshape sh
-
-ssxLength :: StaticShapeX sh -> Int
-ssxLength ZSX = 0
-ssxLength (_ :$@ ssh) = 1 + ssxLength ssh
-ssxLength (_ :$? ssh) = 1 + ssxLength ssh
-
-ssxIotaFrom :: Int -> StaticShapeX sh -> [Int]
-ssxIotaFrom _ ZSX = []
-ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh
-ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh
-
-lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2
+ go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a]
+ go ZSX f = (f ZIX :)
+ go (n :$@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. fromSNat' n - 1]]
+ go (n :$? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]]
+
+shapeLshape :: IShX sh -> S.ShapeL
+shapeLshape ZSX = []
+shapeLshape (n :$@ sh) = fromSNat' n : shapeLshape sh
+shapeLshape (n :$? sh) = n : shapeLshape sh
+
+ssxLength :: StaticShX sh -> Int
+ssxLength ZKSX = 0
+ssxLength (_ :!$@ ssh) = 1 + ssxLength ssh
+ssxLength (_ :!$? ssh) = 1 + ssxLength ssh
+
+ssxIotaFrom :: Int -> StaticShX sh -> [Int]
+ssxIotaFrom _ ZKSX = []
+ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh
+ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh
+
+lemRankApp :: StaticShX sh1 -> StaticShX sh2
-> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2)
lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this
-lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2
+lemRankAppComm :: StaticShX sh1 -> StaticShX sh2
-> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1))
lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
-lemKnownINatRank :: IIxX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRank ZIX = Dict
-lemKnownINatRank (_ :.@ sh) | Dict <- lemKnownINatRank sh = Dict
-lemKnownINatRank (_ :.? sh) | Dict <- lemKnownINatRank sh = Dict
+lemKnownINatRank :: IShX sh -> Dict KnownINat (Rank sh)
+lemKnownINatRank ZSX = Dict
+lemKnownINatRank (_ :$@ sh) | Dict <- lemKnownINatRank sh = Dict
+lemKnownINatRank (_ :$? sh) | Dict <- lemKnownINatRank sh = Dict
-lemKnownINatRankSSX :: StaticShapeX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRankSSX ZSX = Dict
-lemKnownINatRankSSX (_ :$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
-lemKnownINatRankSSX (_ :$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
+lemKnownINatRankSSX :: StaticShX sh -> Dict KnownINat (Rank sh)
+lemKnownINatRankSSX ZKSX = Dict
+lemKnownINatRankSSX (_ :!$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
+lemKnownINatRankSSX (_ :!$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
-lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh
-lemKnownShapeX ZSX = Dict
-lemKnownShapeX (GHC_SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict
-lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict
+lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh
+lemKnownShapeX ZKSX = Dict
+lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict
+lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict
-lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
-lemAppKnownShapeX ZSX ssh' = lemKnownShapeX ssh'
-lemAppKnownShapeX (GHC_SNat :$@ ssh) ssh'
+lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
+lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh'
+lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh'
| Dict <- lemAppKnownShapeX ssh ssh'
= Dict
-lemAppKnownShapeX (() :$? ssh) ssh'
+lemAppKnownShapeX (() :!$? ssh) ssh'
| Dict <- lemAppKnownShapeX ssh ssh'
= Dict
-shape :: forall sh a. KnownShapeX sh => XArray sh a -> IIxX sh
+shape :: forall sh a. KnownShapeX sh => XArray sh a -> IShX sh
shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)
where
- go :: StaticShapeX sh' -> [Int] -> IIxX sh'
- go ZSX [] = ZIX
- go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) :.@ go ssh l
- go (() :$? ssh) (n : l) = n :.? go ssh l
+ go :: StaticShX sh' -> [Int] -> IShX sh'
+ go ZKSX [] = ZSX
+ go (n :!$@ ssh) (_ : l) = n :$@ go ssh l
+ go (() :!$? ssh) (n : l) = n :$? go ssh l
go _ _ = error "Invalid shapeL"
-fromVector :: forall sh a. Storable a => IIxX sh -> VS.Vector a -> XArray sh a
+fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a
fromVector sh v
| Dict <- lemKnownINatRank sh
, Dict <- knownNatFromINat (Proxy @(Rank sh))
@@ -235,16 +267,16 @@ scalar = XArray . S.scalar
unScalar :: Storable a => XArray '[] a -> a
unScalar (XArray a) = S.unScalar a
-constant :: forall sh a. Storable a => IIxX sh -> a -> XArray sh a
+constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a
constant sh x
| Dict <- lemKnownINatRank sh
, Dict <- knownNatFromINat (Proxy @(Rank sh))
= XArray (S.constant (shapeLshape sh) x)
-generate :: Storable a => IIxX sh -> (IIxX sh -> a) -> XArray sh a
+generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a
generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)
--- generateM :: (Monad m, Storable a) => IIxX sh -> (IIxX sh -> m a) -> m (XArray sh a)
+-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a)
-- generateM sh f | Dict <- lemKnownINatRank sh =
-- XArray . S.fromVector (shapeLshape sh)
-- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh)
@@ -274,7 +306,7 @@ append (XArray a) (XArray b)
rerank :: forall sh sh1 sh2 a b.
(Storable a, Storable b)
- => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
+ => StaticShX sh -> StaticShX sh1 -> StaticShX sh2
-> (XArray sh1 a -> XArray sh2 b)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
rerank ssh ssh1 ssh2 f (XArray arr)
@@ -294,14 +326,14 @@ rerank ssh ssh1 ssh2 f (XArray arr)
rerankTop :: forall sh sh1 sh2 a b.
(Storable a, Storable b)
- => StaticShapeX sh1 -> StaticShapeX sh2 -> StaticShapeX sh
+ => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh
-> (XArray sh1 a -> XArray sh2 b)
-> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b
rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh
rerank2 :: forall sh sh1 sh2 a b c.
(Storable a, Storable b, Storable c)
- => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
+ => StaticShX sh -> StaticShX sh1 -> StaticShX sh2
-> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
-> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
@@ -327,7 +359,7 @@ transpose perm (XArray arr)
= XArray (S.transpose perm arr)
transpose2 :: forall sh1 sh2 a.
- StaticShapeX sh1 -> StaticShapeX sh2
+ StaticShX sh1 -> StaticShX sh2
-> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a
transpose2 ssh1 ssh2 (XArray arr)
| Refl <- lemRankApp ssh1 ssh2
@@ -344,24 +376,24 @@ sumFull :: (Storable a, Num a) => XArray sh a -> a
sumFull (XArray arr) = S.sumA arr
sumInner :: forall sh sh' a. (Storable a, Num a)
- => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a
+ => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a
sumInner ssh ssh'
| Refl <- lemAppNil @sh
- = rerank ssh ssh' ZSX (scalar . sumFull)
+ = rerank ssh ssh' ZKSX (scalar . sumFull)
sumOuter :: forall sh sh' a. (Storable a, Num a)
- => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a
+ => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a
sumOuter ssh ssh'
| Refl <- lemAppNil @sh
= sumInner ssh' ssh . transpose2 ssh ssh'
fromList :: forall n sh a. Storable a
- => StaticShapeX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
+ => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a
fromList ssh l
| Dict <- lemKnownINatRankSSX ssh
, Dict <- knownNatFromINat (Proxy @(Rank (n : sh)))
= case ssh of
- m@GHC_SNat :$@ _ | natVal m /= fromIntegral (length l) ->
+ m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) ->
error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++
"does not match the type (" ++ show (natVal m) ++ ")"
_ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l)))
@@ -369,5 +401,13 @@ fromList ssh l
toList :: Storable a => XArray (n : sh) a -> [XArray sh a]
toList (XArray arr) = coerce (ORB.toList (S.unravel arr))
+-- | Throws if the given shape is not, in fact, empty.
+empty :: forall sh a. Storable a => IShX sh -> XArray sh a
+empty sh
+ | Dict <- lemKnownINatRank sh
+ , Dict <- knownNatFromINat (Proxy @(Rank sh))
+ = XArray (S.constant (shapeLshape sh)
+ (error "Data.Array.Mixed.empty: shape was not empty"))
+
slice :: [(Int, Int)] -> XArray sh a -> XArray sh a
slice ivs (XArray arr) = XArray (S.slice ivs arr)