aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-04-21 15:45:47 +0200
committerTom Smeding <tom@tomsmeding.com>2024-04-21 15:45:47 +0200
commit3a82a91be0f1b18f071cdb35526b2b2d0b8e093f (patch)
tree6a65376a17ef4051446952541bab57149c66fca7 /src/Data/Array/Mixed.hs
parent181c02b60445310105be126fd6cc2ee9f5dc2c8a (diff)
Make index types useful for horde-ad by parameterising Int
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs68
1 files changed, 35 insertions, 33 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
index 164f832..8d20583 100644
--- a/src/Data/Array/Mixed.hs
+++ b/src/Data/Array/Mixed.hs
@@ -45,15 +45,17 @@ lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
lemAppAssoc _ _ _ = unsafeCoerce Refl
-type IxX :: [Maybe Nat] -> Type
-data IxX sh where
- IZX :: IxX '[]
- (::@) :: Int -> IxX sh -> IxX (Just n : sh)
- (::?) :: Int -> IxX sh -> IxX (Nothing : sh)
-deriving instance Show (IxX sh)
-deriving instance Eq (IxX sh)
-infixr 5 ::@
-infixr 5 ::?
+type IxX :: Type -> [Maybe Nat] -> Type
+data IxX i sh where
+ IZX :: IxX i '[]
+ (::@) :: forall n sh i. i -> IxX i sh -> IxX i (Just n : sh)
+ (::?) :: forall sh i. i -> IxX i sh -> IxX i (Nothing : sh)
+deriving instance Show i => Show (IxX i sh)
+deriving instance Eq i => Eq (IxX i sh)
+infixr 3 ::@
+infixr 3 ::?
+
+type IIxX = IxX Int
-- | The part of a shape that is statically known.
type StaticShapeX :: [Maybe Nat] -> Type
@@ -62,8 +64,8 @@ data StaticShapeX sh where
(:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh)
(:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh)
deriving instance Show (StaticShapeX sh)
-infixr 5 :$@
-infixr 5 :$?
+infixr 3 :$@
+infixr 3 :$?
-- | Evidence for the static part of a shape.
type KnownShapeX :: [Maybe Nat] -> Constraint
@@ -84,22 +86,22 @@ type XArray :: [Maybe Nat] -> Type -> Type
newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a)
deriving (Show)
-zeroIxX :: StaticShapeX sh -> IxX sh
+zeroIxX :: StaticShapeX sh -> IIxX sh
zeroIxX SZX = IZX
zeroIxX (_ :$@ ssh) = 0 ::@ zeroIxX ssh
zeroIxX (_ :$? ssh) = 0 ::? zeroIxX ssh
-zeroIxX' :: IxX sh -> IxX sh
+zeroIxX' :: IIxX sh -> IIxX sh
zeroIxX' IZX = IZX
zeroIxX' (_ ::@ sh) = 0 ::@ zeroIxX' sh
zeroIxX' (_ ::? sh) = 0 ::? zeroIxX' sh
-ixAppend :: IxX sh -> IxX sh' -> IxX (sh ++ sh')
+ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh')
ixAppend IZX idx' = idx'
ixAppend (i ::@ idx) idx' = i ::@ ixAppend idx idx'
ixAppend (i ::? idx) idx' = i ::? ixAppend idx idx'
-ixDrop :: IxX (sh ++ sh') -> IxX sh -> IxX sh'
+ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh'
ixDrop sh IZX = sh
ixDrop (_ ::@ sh) (_ ::@ idx) = ixDrop sh idx
ixDrop (_ ::? sh) (_ ::? idx) = ixDrop sh idx
@@ -109,25 +111,25 @@ ssxAppend SZX sh' = sh'
ssxAppend (n :$@ sh) sh' = n :$@ ssxAppend sh sh'
ssxAppend (() :$? sh) sh' = () :$? ssxAppend sh sh'
-shapeSize :: IxX sh -> Int
+shapeSize :: IIxX sh -> Int
shapeSize IZX = 1
shapeSize (n ::@ sh) = n * shapeSize sh
shapeSize (n ::? sh) = n * shapeSize sh
-- | This may fail if @sh@ has @Nothing@s in it.
-ssxToShape' :: StaticShapeX sh -> Maybe (IxX sh)
+ssxToShape' :: StaticShapeX sh -> Maybe (IIxX sh)
ssxToShape' SZX = Just IZX
ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) ::@) <$> ssxToShape' sh
ssxToShape' (_ :$? _) = Nothing
-fromLinearIdx :: IxX sh -> Int -> IxX sh
+fromLinearIdx :: IIxX sh -> Int -> IIxX sh
fromLinearIdx = \sh i -> case go sh i of
(idx, 0) -> idx
_ -> error $ "fromLinearIdx: out of range (" ++ show i ++
" in array of shape " ++ show sh ++ ")"
where
-- returns (index in subarray, remaining index in enclosing array)
- go :: IxX sh -> Int -> (IxX sh, Int)
+ go :: IIxX sh -> Int -> (IIxX sh, Int)
go IZX i = (IZX, i)
go (n ::@ sh) i =
let (idx, i') = go sh i
@@ -138,11 +140,11 @@ fromLinearIdx = \sh i -> case go sh i of
(upi, locali) = i' `quotRem` n
in (locali ::? idx, upi)
-toLinearIdx :: IxX sh -> IxX sh -> Int
+toLinearIdx :: IIxX sh -> IIxX sh -> Int
toLinearIdx = \sh i -> fst (go sh i)
where
-- returns (index in subarray, size of subarray)
- go :: IxX sh -> IxX sh -> (Int, Int)
+ go :: IIxX sh -> IIxX sh -> (Int, Int)
go IZX IZX = (0, 1)
go (n ::@ sh) (i ::@ ix) =
let (lidx, sz) = go sh ix
@@ -151,15 +153,15 @@ toLinearIdx = \sh i -> fst (go sh i)
let (lidx, sz) = go sh ix
in (sz * i + lidx, n * sz)
-enumShape :: IxX sh -> [IxX sh]
+enumShape :: IIxX sh -> [IIxX sh]
enumShape = \sh -> go sh id []
where
- go :: IxX sh -> (IxX sh -> a) -> [a] -> [a]
+ go :: IIxX sh -> (IIxX sh -> a) -> [a] -> [a]
go IZX f = (f IZX :)
go (n ::@ sh) f = foldr (.) id [go sh (f . (i ::@)) | i <- [0 .. n-1]]
go (n ::? sh) f = foldr (.) id [go sh (f . (i ::?)) | i <- [0 .. n-1]]
-shapeLshape :: IxX sh -> S.ShapeL
+shapeLshape :: IIxX sh -> S.ShapeL
shapeLshape IZX = []
shapeLshape (n ::@ sh) = n : shapeLshape sh
shapeLshape (n ::? sh) = n : shapeLshape sh
@@ -182,7 +184,7 @@ lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2
-> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1))
lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
-lemKnownINatRank :: IxX sh -> Dict KnownINat (Rank sh)
+lemKnownINatRank :: IIxX sh -> Dict KnownINat (Rank sh)
lemKnownINatRank IZX = Dict
lemKnownINatRank (_ ::@ sh) | Dict <- lemKnownINatRank sh = Dict
lemKnownINatRank (_ ::? sh) | Dict <- lemKnownINatRank sh = Dict
@@ -206,16 +208,16 @@ lemAppKnownShapeX (() :$? ssh) ssh'
| Dict <- lemAppKnownShapeX ssh ssh'
= Dict
-shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh
+shape :: forall sh a. KnownShapeX sh => XArray sh a -> IIxX sh
shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr)
where
- go :: StaticShapeX sh' -> [Int] -> IxX sh'
+ go :: StaticShapeX sh' -> [Int] -> IIxX sh'
go SZX [] = IZX
go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) ::@ go ssh l
go (() :$? ssh) (n : l) = n ::? go ssh l
go _ _ = error "Invalid shapeL"
-fromVector :: forall sh a. Storable a => IxX sh -> VS.Vector a -> XArray sh a
+fromVector :: forall sh a. Storable a => IIxX sh -> VS.Vector a -> XArray sh a
fromVector sh v
| Dict <- lemKnownINatRank sh
, Dict <- knownNatFromINat (Proxy @(Rank sh))
@@ -230,26 +232,26 @@ scalar = XArray . S.scalar
unScalar :: Storable a => XArray '[] a -> a
unScalar (XArray a) = S.unScalar a
-constant :: forall sh a. Storable a => IxX sh -> a -> XArray sh a
+constant :: forall sh a. Storable a => IIxX sh -> a -> XArray sh a
constant sh x
| Dict <- lemKnownINatRank sh
, Dict <- knownNatFromINat (Proxy @(Rank sh))
= XArray (S.constant (shapeLshape sh) x)
-generate :: Storable a => IxX sh -> (IxX sh -> a) -> XArray sh a
+generate :: Storable a => IIxX sh -> (IIxX sh -> a) -> XArray sh a
generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh)
--- generateM :: (Monad m, Storable a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a)
+-- generateM :: (Monad m, Storable a) => IIxX sh -> (IIxX sh -> m a) -> m (XArray sh a)
-- generateM sh f | Dict <- lemKnownINatRank sh =
-- XArray . S.fromVector (shapeLshape sh)
-- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh)
-indexPartial :: Storable a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a
+indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a
indexPartial (XArray arr) IZX = XArray arr
indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (S.index arr i)) idx
indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (S.index arr i)) idx
-index :: forall sh a. Storable a => XArray sh a -> IxX sh -> a
+index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a
index xarr i
| Refl <- lemAppNil @sh
= let XArray arr' = indexPartial xarr i :: XArray '[] a