aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs108
1 files changed, 54 insertions, 54 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 8d20583..c19fbe5 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -47,20 +47,20 @@ lemAppAssoc _ _ _ = unsafeCoerce Refl
type IxX :: Type -> [Maybe Nat] -> Type
data IxX i sh where
- IZX :: IxX i '[]
- (::@) :: forall n sh i. i -> IxX i sh -> IxX i (Just n : sh)
- (::?) :: forall sh i. i -> IxX i sh -> IxX i (Nothing : sh)
+ ZIX :: IxX i '[]
+ (:.@) :: forall n sh i. i -> IxX i sh -> IxX i (Just n : sh)
+ (:.?) :: forall sh i. i -> IxX i sh -> IxX i (Nothing : sh)
deriving instance Show i => Show (IxX i sh)
deriving instance Eq i => Eq (IxX i sh)
-infixr 3 ::@
-infixr 3 ::?
+infixr 3 :.@
+infixr 3 :.?
type IIxX = IxX Int
-- | The part of a shape that is statically known.
type StaticShapeX :: [Maybe Nat] -> Type
data StaticShapeX sh where
- SZX :: StaticShapeX '[]
+ ZSX :: StaticShapeX '[]
(:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh)
(:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh)
deriving instance Show (StaticShapeX sh)
@@ -72,7 +72,7 @@ type KnownShapeX :: [Maybe Nat] -> Constraint
class KnownShapeX sh where
knownShapeX :: StaticShapeX sh
instance KnownShapeX '[] where
- knownShapeX = SZX
+ knownShapeX = ZSX
instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
knownShapeX = natSing :$@ knownShapeX
instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
@@ -87,39 +87,39 @@ newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a)
deriving (Show)
zeroIxX :: StaticShapeX sh -> IIxX sh
-zeroIxX SZX = IZX
-zeroIxX (_ :$@ ssh) = 0 ::@ zeroIxX ssh
-zeroIxX (_ :$? ssh) = 0 ::? zeroIxX ssh
+zeroIxX ZSX = ZIX
+zeroIxX (_ :$@ ssh) = 0 :.@ zeroIxX ssh
+zeroIxX (_ :$? ssh) = 0 :.? zeroIxX ssh
zeroIxX' :: IIxX sh -> IIxX sh
-zeroIxX' IZX = IZX
-zeroIxX' (_ ::@ sh) = 0 ::@ zeroIxX' sh
-zeroIxX' (_ ::? sh) = 0 ::? zeroIxX' sh
+zeroIxX' ZIX = ZIX
+zeroIxX' (_ :.@ sh) = 0 :.@ zeroIxX' sh
+zeroIxX' (_ :.? sh) = 0 :.? zeroIxX' sh
ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh')
-ixAppend IZX idx' = idx'
-ixAppend (i ::@ idx) idx' = i ::@ ixAppend idx idx'
-ixAppend (i ::? idx) idx' = i ::? ixAppend idx idx'
+ixAppend ZIX idx' = idx'
+ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx'
+ixAppend (i :.? idx) idx' = i :.? ixAppend idx idx'
ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh'
-ixDrop sh IZX = sh
-ixDrop (_ ::@ sh) (_ ::@ idx) = ixDrop sh idx
-ixDrop (_ ::? sh) (_ ::? idx) = ixDrop sh idx
+ixDrop sh ZIX = sh
+ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx
+ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx
ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh')
-ssxAppend SZX sh' = sh'
+ssxAppend ZSX sh' = sh'
ssxAppend (n :$@ sh) sh' = n :$@ ssxAppend sh sh'
ssxAppend (() :$? sh) sh' = () :$? ssxAppend sh sh'
shapeSize :: IIxX sh -> Int
-shapeSize IZX = 1
-shapeSize (n ::@ sh) = n * shapeSize sh
-shapeSize (n ::? sh) = n * shapeSize sh
+shapeSize ZIX = 1
+shapeSize (n :.@ sh) = n * shapeSize sh
+shapeSize (n :.? sh) = n * shapeSize sh
-- | This may fail if @sh@ has @Nothing@s in it.
ssxToShape' :: StaticShapeX sh -> Maybe (IIxX sh)
-ssxToShape' SZX = Just IZX
-ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) ::@) <$> ssxToShape' sh
+ssxToShape' ZSX = Just ZIX
+ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) :.@) <$> ssxToShape' sh
ssxToShape' (_ :$? _) = Nothing
fromLinearIdx :: IIxX sh -> Int -> IIxX sh
@@ -130,26 +130,26 @@ fromLinearIdx = \sh i -> case go sh i of
where
-- returns (index in subarray, remaining index in enclosing array)
go :: IIxX sh -> Int -> (IIxX sh, Int)
- go IZX i = (IZX, i)
- go (n ::@ sh) i =
+ go ZIX i = (ZIX, i)
+ go (n :.@ sh) i =
let (idx, i') = go sh i
(upi, locali) = i' `quotRem` n
- in (locali ::@ idx, upi)
- go (n ::? sh) i =
+ in (locali :.@ idx, upi)
+ go (n :.? sh) i =
let (idx, i') = go sh i
(upi, locali) = i' `quotRem` n
- in (locali ::? idx, upi)
+ in (locali :.? idx, upi)
toLinearIdx :: IIxX sh -> IIxX sh -> Int
toLinearIdx = \sh i -> fst (go sh i)
where
-- returns (index in subarray, size of subarray)
go :: IIxX sh -> IIxX sh -> (Int, Int)
- go IZX IZX = (0, 1)
- go (n ::@ sh) (i ::@ ix) =
+ go ZIX ZIX = (0, 1)
+ go (n :.@ sh) (i :.@ ix) =
let (lidx, sz) = go sh ix
in (sz * i + lidx, n * sz)
- go (n ::? sh) (i ::? ix) =
+ go (n :.? sh) (i :.? ix) =
let (lidx, sz) = go sh ix
in (sz * i + lidx, n * sz)
@@ -157,22 +157,22 @@ enumShape :: IIxX sh -> [IIxX sh]
enumShape = \sh -> go sh id []
where
go :: IIxX sh -> (IIxX sh -> a) -> [a] -> [a]
- go IZX f = (f IZX :)
- go (n ::@ sh) f = foldr (.) id [go sh (f . (i ::@)) | i <- [0 .. n-1]]
- go (n ::? sh) f = foldr (.) id [go sh (f . (i ::?)) | i <- [0 .. n-1]]
+ go ZIX f = (f ZIX :)
+ go (n :.@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. n-1]]
+ go (n :.? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]]
shapeLshape :: IIxX sh -> S.ShapeL
-shapeLshape IZX = []
-shapeLshape (n ::@ sh) = n : shapeLshape sh
-shapeLshape (n ::? sh) = n : shapeLshape sh
+shapeLshape ZIX = []
+shapeLshape (n :.@ sh) = n : shapeLshape sh
+shapeLshape (n :.? sh) = n : shapeLshape sh
ssxLength :: StaticShapeX sh -> Int
-ssxLength SZX = 0
+ssxLength ZSX = 0
ssxLength (_ :$@ ssh) = 1 + ssxLength ssh
ssxLength (_ :$? ssh) = 1 + ssxLength ssh
ssxIotaFrom :: Int -> StaticShapeX sh -> [Int]
-ssxIotaFrom _ SZX = []
+ssxIotaFrom _ ZSX = []
ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh
ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh
@@ -185,22 +185,22 @@ lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2
lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
lemKnownINatRank :: IIxX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRank IZX = Dict
-lemKnownINatRank (_ ::@ sh) | Dict <- lemKnownINatRank sh = Dict
-lemKnownINatRank (_ ::? sh) | Dict <- lemKnownINatRank sh = Dict
+lemKnownINatRank ZIX = Dict
+lemKnownINatRank (_ :.@ sh) | Dict <- lemKnownINatRank sh = Dict
+lemKnownINatRank (_ :.? sh) | Dict <- lemKnownINatRank sh = Dict
lemKnownINatRankSSX :: StaticShapeX sh -> Dict KnownINat (Rank sh)
-lemKnownINatRankSSX SZX = Dict
+lemKnownINatRankSSX ZSX = Dict
lemKnownINatRankSSX (_ :$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
lemKnownINatRankSSX (_ :$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict
lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh
-lemKnownShapeX SZX = Dict
+lemKnownShapeX ZSX = Dict
lemKnownShapeX (GHC_SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict
lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict
lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
-lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh'
+lemAppKnownShapeX ZSX ssh' = lemKnownShapeX ssh'
lemAppKnownShapeX (GHC_SNat :$@ ssh) ssh'
| Dict <- lemAppKnownShapeX ssh ssh'
= Dict
@@ -212,9 +212,9 @@ shape :: forall sh a. KnownShapeX sh => XArray sh a -> IIxX sh
shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)
where
go :: StaticShapeX sh' -> [Int] -> IIxX sh'
- go SZX [] = IZX
- go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) ::@ go ssh l
- go (() :$? ssh) (n : l) = n ::? go ssh l
+ go ZSX [] = ZIX
+ go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) :.@ go ssh l
+ go (() :$? ssh) (n : l) = n :.? go ssh l
go _ _ = error "Invalid shapeL"
fromVector :: forall sh a. Storable a => IIxX sh -> VS.Vector a -> XArray sh a
@@ -247,9 +247,9 @@ generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh
-- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh)
indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a
-indexPartial (XArray arr) IZX = XArray arr
-indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (S.index arr i)) idx
-indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (S.index arr i)) idx
+indexPartial (XArray arr) ZIX = XArray arr
+indexPartial (XArray arr) (i :.@ idx) = indexPartial (XArray (S.index arr i)) idx
+indexPartial (XArray arr) (i :.? idx) = indexPartial (XArray (S.index arr i)) idx
index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a
index xarr i
@@ -344,7 +344,7 @@ sumInner :: forall sh sh' a. (Storable a, Num a)
=> StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a
sumInner ssh ssh'
| Refl <- lemAppNil @sh
- = rerank ssh ssh' SZX (scalar . sumFull)
+ = rerank ssh ssh' ZSX (scalar . sumFull)
sumOuter :: forall sh sh' a. (Storable a, Num a)
=> StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a