From 3a82a91be0f1b18f071cdb35526b2b2d0b8e093f Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Sun, 21 Apr 2024 15:45:47 +0200 Subject: Make index types useful for horde-ad by parameterising Int --- src/Data/Array/Mixed.hs | 68 +++++++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 33 deletions(-) (limited to 'src/Data/Array/Mixed.hs') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 164f832..8d20583 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -45,15 +45,17 @@ lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl -type IxX :: [Maybe Nat] -> Type -data IxX sh where - IZX :: IxX '[] - (::@) :: Int -> IxX sh -> IxX (Just n : sh) - (::?) :: Int -> IxX sh -> IxX (Nothing : sh) -deriving instance Show (IxX sh) -deriving instance Eq (IxX sh) -infixr 5 ::@ -infixr 5 ::? +type IxX :: Type -> [Maybe Nat] -> Type +data IxX i sh where + IZX :: IxX i '[] + (::@) :: forall n sh i. i -> IxX i sh -> IxX i (Just n : sh) + (::?) :: forall sh i. i -> IxX i sh -> IxX i (Nothing : sh) +deriving instance Show i => Show (IxX i sh) +deriving instance Eq i => Eq (IxX i sh) +infixr 3 ::@ +infixr 3 ::? + +type IIxX = IxX Int -- | The part of a shape that is statically known. type StaticShapeX :: [Maybe Nat] -> Type @@ -62,8 +64,8 @@ data StaticShapeX sh where (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) deriving instance Show (StaticShapeX sh) -infixr 5 :$@ -infixr 5 :$? +infixr 3 :$@ +infixr 3 :$? -- | Evidence for the static part of a shape. type KnownShapeX :: [Maybe Nat] -> Constraint @@ -84,22 +86,22 @@ type XArray :: [Maybe Nat] -> Type -> Type newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a) deriving (Show) -zeroIxX :: StaticShapeX sh -> IxX sh +zeroIxX :: StaticShapeX sh -> IIxX sh zeroIxX SZX = IZX zeroIxX (_ :$@ ssh) = 0 ::@ zeroIxX ssh zeroIxX (_ :$? ssh) = 0 ::? zeroIxX ssh -zeroIxX' :: IxX sh -> IxX sh +zeroIxX' :: IIxX sh -> IIxX sh zeroIxX' IZX = IZX zeroIxX' (_ ::@ sh) = 0 ::@ zeroIxX' sh zeroIxX' (_ ::? sh) = 0 ::? zeroIxX' sh -ixAppend :: IxX sh -> IxX sh' -> IxX (sh ++ sh') +ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh') ixAppend IZX idx' = idx' ixAppend (i ::@ idx) idx' = i ::@ ixAppend idx idx' ixAppend (i ::? idx) idx' = i ::? ixAppend idx idx' -ixDrop :: IxX (sh ++ sh') -> IxX sh -> IxX sh' +ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh' ixDrop sh IZX = sh ixDrop (_ ::@ sh) (_ ::@ idx) = ixDrop sh idx ixDrop (_ ::? sh) (_ ::? idx) = ixDrop sh idx @@ -109,25 +111,25 @@ ssxAppend SZX sh' = sh' ssxAppend (n :$@ sh) sh' = n :$@ ssxAppend sh sh' ssxAppend (() :$? sh) sh' = () :$? ssxAppend sh sh' -shapeSize :: IxX sh -> Int +shapeSize :: IIxX sh -> Int shapeSize IZX = 1 shapeSize (n ::@ sh) = n * shapeSize sh shapeSize (n ::? sh) = n * shapeSize sh -- | This may fail if @sh@ has @Nothing@s in it. -ssxToShape' :: StaticShapeX sh -> Maybe (IxX sh) +ssxToShape' :: StaticShapeX sh -> Maybe (IIxX sh) ssxToShape' SZX = Just IZX ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) ::@) <$> ssxToShape' sh ssxToShape' (_ :$? _) = Nothing -fromLinearIdx :: IxX sh -> Int -> IxX sh +fromLinearIdx :: IIxX sh -> Int -> IIxX sh fromLinearIdx = \sh i -> case go sh i of (idx, 0) -> idx _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ " in array of shape " ++ show sh ++ ")" where -- returns (index in subarray, remaining index in enclosing array) - go :: IxX sh -> Int -> (IxX sh, Int) + go :: IIxX sh -> Int -> (IIxX sh, Int) go IZX i = (IZX, i) go (n ::@ sh) i = let (idx, i') = go sh i @@ -138,11 +140,11 @@ fromLinearIdx = \sh i -> case go sh i of (upi, locali) = i' `quotRem` n in (locali ::? idx, upi) -toLinearIdx :: IxX sh -> IxX sh -> Int +toLinearIdx :: IIxX sh -> IIxX sh -> Int toLinearIdx = \sh i -> fst (go sh i) where -- returns (index in subarray, size of subarray) - go :: IxX sh -> IxX sh -> (Int, Int) + go :: IIxX sh -> IIxX sh -> (Int, Int) go IZX IZX = (0, 1) go (n ::@ sh) (i ::@ ix) = let (lidx, sz) = go sh ix @@ -151,15 +153,15 @@ toLinearIdx = \sh i -> fst (go sh i) let (lidx, sz) = go sh ix in (sz * i + lidx, n * sz) -enumShape :: IxX sh -> [IxX sh] +enumShape :: IIxX sh -> [IIxX sh] enumShape = \sh -> go sh id [] where - go :: IxX sh -> (IxX sh -> a) -> [a] -> [a] + go :: IIxX sh -> (IIxX sh -> a) -> [a] -> [a] go IZX f = (f IZX :) go (n ::@ sh) f = foldr (.) id [go sh (f . (i ::@)) | i <- [0 .. n-1]] go (n ::? sh) f = foldr (.) id [go sh (f . (i ::?)) | i <- [0 .. n-1]] -shapeLshape :: IxX sh -> S.ShapeL +shapeLshape :: IIxX sh -> S.ShapeL shapeLshape IZX = [] shapeLshape (n ::@ sh) = n : shapeLshape sh shapeLshape (n ::? sh) = n : shapeLshape sh @@ -182,7 +184,7 @@ lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2 -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1)) lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this -lemKnownINatRank :: IxX sh -> Dict KnownINat (Rank sh) +lemKnownINatRank :: IIxX sh -> Dict KnownINat (Rank sh) lemKnownINatRank IZX = Dict lemKnownINatRank (_ ::@ sh) | Dict <- lemKnownINatRank sh = Dict lemKnownINatRank (_ ::? sh) | Dict <- lemKnownINatRank sh = Dict @@ -206,16 +208,16 @@ lemAppKnownShapeX (() :$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict -shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh +shape :: forall sh a. KnownShapeX sh => XArray sh a -> IIxX sh shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) where - go :: StaticShapeX sh' -> [Int] -> IxX sh' + go :: StaticShapeX sh' -> [Int] -> IIxX sh' go SZX [] = IZX go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) ::@ go ssh l go (() :$? ssh) (n : l) = n ::? go ssh l go _ _ = error "Invalid shapeL" -fromVector :: forall sh a. Storable a => IxX sh -> VS.Vector a -> XArray sh a +fromVector :: forall sh a. Storable a => IIxX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownINatRank sh , Dict <- knownNatFromINat (Proxy @(Rank sh)) @@ -230,26 +232,26 @@ scalar = XArray . S.scalar unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a -constant :: forall sh a. Storable a => IxX sh -> a -> XArray sh a +constant :: forall sh a. Storable a => IIxX sh -> a -> XArray sh a constant sh x | Dict <- lemKnownINatRank sh , Dict <- knownNatFromINat (Proxy @(Rank sh)) = XArray (S.constant (shapeLshape sh) x) -generate :: Storable a => IxX sh -> (IxX sh -> a) -> XArray sh a +generate :: Storable a => IIxX sh -> (IIxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) --- generateM :: (Monad m, Storable a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a) +-- generateM :: (Monad m, Storable a) => IIxX sh -> (IIxX sh -> m a) -> m (XArray sh a) -- generateM sh f | Dict <- lemKnownINatRank sh = -- XArray . S.fromVector (shapeLshape sh) -- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) -indexPartial :: Storable a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a +indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a indexPartial (XArray arr) IZX = XArray arr indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (S.index arr i)) idx indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (S.index arr i)) idx -index :: forall sh a. Storable a => XArray sh a -> IxX sh -> a +index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a index xarr i | Refl <- lemAppNil @sh = let XArray arr' = indexPartial xarr i :: XArray '[] a -- cgit v1.2.3-70-g09d2