From e4e23a33f77d250af1e9b6614cf249128ba1510a Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 13 May 2024 22:47:42 +0200 Subject: Shape/index hygiene --- src/Data/Array/Mixed.hs | 246 ++++++++++++++++++++++++++++-------------------- 1 file changed, 143 insertions(+), 103 deletions(-) (limited to 'src/Data/Array/Mixed.hs') diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 47027cb..94b7cdf 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -1,6 +1,10 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -32,6 +36,9 @@ pattern GHC_SNat :: () => KnownNat n => SNat n pattern GHC_SNat = SNat {-# COMPLETE GHC_SNat #-} +fromSNat' :: SNat n -> Int +fromSNat' = fromIntegral . fromSNat + -- | Type-level list append. type family l1 ++ l2 where @@ -44,9 +51,6 @@ lemAppNil = unsafeCoerce Refl lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl --- TODO: ListX? But if so, why is StaticShapeX not defined as a newtype --- over IxX (so that we can make IxX and StaticShapeX a newtype over ListX)? - type IxX :: [Maybe Nat] -> Type -> Type data IxX sh i where ZIX :: IxX '[] i @@ -55,31 +59,48 @@ data IxX sh i where deriving instance Show i => Show (IxX sh i) deriving instance Eq i => Eq (IxX sh i) deriving instance Ord i => Ord (IxX sh i) +deriving instance Functor (IxX sh) +deriving instance Foldable (IxX sh) infixr 3 :.@ infixr 3 :.? type IIxX sh = IxX sh Int --- | The part of a shape that is statically known. -type StaticShapeX :: [Maybe Nat] -> Type -data StaticShapeX sh where - ZSX :: StaticShapeX '[] - (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) - (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) -deriving instance Show (StaticShapeX sh) +type ShX :: [Maybe Nat] -> Type -> Type +data ShX sh i where + ZSX :: ShX '[] i + (:$@) :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i + (:$?) :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i +deriving instance Show i => Show (ShX sh i) +deriving instance Eq i => Eq (ShX sh i) +deriving instance Ord i => Ord (ShX sh i) +deriving instance Functor (ShX sh) +deriving instance Foldable (ShX sh) infixr 3 :$@ infixr 3 :$? +type IShX sh = ShX sh Int + +-- | The part of a shape that is statically known. +type StaticShX :: [Maybe Nat] -> Type +data StaticShX sh where + ZKSX :: StaticShX '[] + (:!$@) :: SNat n -> StaticShX sh -> StaticShX (Just n : sh) + (:!$?) :: () -> StaticShX sh -> StaticShX (Nothing : sh) +deriving instance Show (StaticShX sh) +infixr 3 :!$@ +infixr 3 :!$? + -- | Evidence for the static part of a shape. type KnownShapeX :: [Maybe Nat] -> Constraint class KnownShapeX sh where - knownShapeX :: StaticShapeX sh + knownShapeX :: StaticShX sh instance KnownShapeX '[] where - knownShapeX = ZSX + knownShapeX = ZKSX instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where - knownShapeX = natSing :$@ knownShapeX + knownShapeX = natSing :!$@ knownShapeX instance KnownShapeX sh => KnownShapeX (Nothing : sh) where - knownShapeX = () :$? knownShapeX + knownShapeX = () :!$? knownShapeX type family Rank sh where Rank '[] = Z @@ -89,138 +110,149 @@ type XArray :: [Maybe Nat] -> Type -> Type newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a) deriving (Show) -zeroIxX :: StaticShapeX sh -> IIxX sh -zeroIxX ZSX = ZIX -zeroIxX (_ :$@ ssh) = 0 :.@ zeroIxX ssh -zeroIxX (_ :$? ssh) = 0 :.? zeroIxX ssh +zeroIxX :: StaticShX sh -> IIxX sh +zeroIxX ZKSX = ZIX +zeroIxX (_ :!$@ ssh) = 0 :.@ zeroIxX ssh +zeroIxX (_ :!$? ssh) = 0 :.? zeroIxX ssh + +zeroIxX' :: IShX sh -> IIxX sh +zeroIxX' ZSX = ZIX +zeroIxX' (_ :$@ sh) = 0 :.@ zeroIxX' sh +zeroIxX' (_ :$? sh) = 0 :.? zeroIxX' sh -zeroIxX' :: IIxX sh -> IIxX sh -zeroIxX' ZIX = ZIX -zeroIxX' (_ :.@ sh) = 0 :.@ zeroIxX' sh -zeroIxX' (_ :.? sh) = 0 :.? zeroIxX' sh +-- This is a weird operation, so it has a long name +completeShXzeros :: StaticShX sh -> IShX sh +completeShXzeros ZKSX = ZSX +completeShXzeros (n :!$@ ssh) = n :$@ completeShXzeros ssh +completeShXzeros (_ :!$? ssh) = 0 :$? completeShXzeros ssh ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh') ixAppend ZIX idx' = idx' ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx' ixAppend (i :.? idx) idx' = i :.? ixAppend idx idx' +shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh') +shAppend ZSX sh' = sh' +shAppend (n :$@ sh) sh' = n :$@ shAppend sh sh' +shAppend (n :$? sh) sh' = n :$? shAppend sh sh' + ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh' ixDrop sh ZIX = sh ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx -ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh') -ssxAppend ZSX sh' = sh' -ssxAppend (n :$@ sh) sh' = n :$@ ssxAppend sh sh' -ssxAppend (() :$? sh) sh' = () :$? ssxAppend sh sh' +ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') +ssxAppend ZKSX sh' = sh' +ssxAppend (n :!$@ sh) sh' = n :!$@ ssxAppend sh sh' +ssxAppend (() :!$? sh) sh' = () :!$? ssxAppend sh sh' -shapeSize :: IIxX sh -> Int -shapeSize ZIX = 1 -shapeSize (n :.@ sh) = n * shapeSize sh -shapeSize (n :.? sh) = n * shapeSize sh +shapeSize :: IShX sh -> Int +shapeSize ZSX = 1 +shapeSize (n :$@ sh) = fromSNat' n * shapeSize sh +shapeSize (n :$? sh) = n * shapeSize sh -- | This may fail if @sh@ has @Nothing@s in it. -ssxToShape' :: StaticShapeX sh -> Maybe (IIxX sh) -ssxToShape' ZSX = Just ZIX -ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) :.@) <$> ssxToShape' sh -ssxToShape' (_ :$? _) = Nothing +ssxToShape' :: StaticShX sh -> Maybe (IShX sh) +ssxToShape' ZKSX = Just ZSX +ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh +ssxToShape' (_ :!$? _) = Nothing -fromLinearIdx :: IIxX sh -> Int -> IIxX sh +fromLinearIdx :: IShX sh -> Int -> IIxX sh fromLinearIdx = \sh i -> case go sh i of (idx, 0) -> idx _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ " in array of shape " ++ show sh ++ ")" where -- returns (index in subarray, remaining index in enclosing array) - go :: IIxX sh -> Int -> (IIxX sh, Int) - go ZIX i = (ZIX, i) - go (n :.@ sh) i = + go :: IShX sh -> Int -> (IIxX sh, Int) + go ZSX i = (ZIX, i) + go (n :$@ sh) i = let (idx, i') = go sh i - (upi, locali) = i' `quotRem` n + (upi, locali) = i' `quotRem` fromSNat' n in (locali :.@ idx, upi) - go (n :.? sh) i = + go (n :$? sh) i = let (idx, i') = go sh i (upi, locali) = i' `quotRem` n in (locali :.? idx, upi) -toLinearIdx :: IIxX sh -> IIxX sh -> Int +toLinearIdx :: IShX sh -> IIxX sh -> Int toLinearIdx = \sh i -> fst (go sh i) where -- returns (index in subarray, size of subarray) - go :: IIxX sh -> IIxX sh -> (Int, Int) - go ZIX ZIX = (0, 1) - go (n :.@ sh) (i :.@ ix) = + go :: IShX sh -> IIxX sh -> (Int, Int) + go ZSX ZIX = (0, 1) + go (n :$@ sh) (i :.@ ix) = let (lidx, sz) = go sh ix - in (sz * i + lidx, n * sz) - go (n :.? sh) (i :.? ix) = + in (sz * i + lidx, fromSNat' n * sz) + go (n :$? sh) (i :.? ix) = let (lidx, sz) = go sh ix in (sz * i + lidx, n * sz) -enumShape :: IIxX sh -> [IIxX sh] +enumShape :: IShX sh -> [IIxX sh] enumShape = \sh -> go sh id [] where - go :: IIxX sh -> (IIxX sh -> a) -> [a] -> [a] - go ZIX f = (f ZIX :) - go (n :.@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. n-1]] - go (n :.? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]] - -shapeLshape :: IIxX sh -> S.ShapeL -shapeLshape ZIX = [] -shapeLshape (n :.@ sh) = n : shapeLshape sh -shapeLshape (n :.? sh) = n : shapeLshape sh - -ssxLength :: StaticShapeX sh -> Int -ssxLength ZSX = 0 -ssxLength (_ :$@ ssh) = 1 + ssxLength ssh -ssxLength (_ :$? ssh) = 1 + ssxLength ssh - -ssxIotaFrom :: Int -> StaticShapeX sh -> [Int] -ssxIotaFrom _ ZSX = [] -ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh -ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh - -lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2 + go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] + go ZSX f = (f ZIX :) + go (n :$@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. fromSNat' n - 1]] + go (n :$? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]] + +shapeLshape :: IShX sh -> S.ShapeL +shapeLshape ZSX = [] +shapeLshape (n :$@ sh) = fromSNat' n : shapeLshape sh +shapeLshape (n :$? sh) = n : shapeLshape sh + +ssxLength :: StaticShX sh -> Int +ssxLength ZKSX = 0 +ssxLength (_ :!$@ ssh) = 1 + ssxLength ssh +ssxLength (_ :!$? ssh) = 1 + ssxLength ssh + +ssxIotaFrom :: Int -> StaticShX sh -> [Int] +ssxIotaFrom _ ZKSX = [] +ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh + +lemRankApp :: StaticShX sh1 -> StaticShX sh2 -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2) lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this -lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2 +lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1)) lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this -lemKnownINatRank :: IIxX sh -> Dict KnownINat (Rank sh) -lemKnownINatRank ZIX = Dict -lemKnownINatRank (_ :.@ sh) | Dict <- lemKnownINatRank sh = Dict -lemKnownINatRank (_ :.? sh) | Dict <- lemKnownINatRank sh = Dict +lemKnownINatRank :: IShX sh -> Dict KnownINat (Rank sh) +lemKnownINatRank ZSX = Dict +lemKnownINatRank (_ :$@ sh) | Dict <- lemKnownINatRank sh = Dict +lemKnownINatRank (_ :$? sh) | Dict <- lemKnownINatRank sh = Dict -lemKnownINatRankSSX :: StaticShapeX sh -> Dict KnownINat (Rank sh) -lemKnownINatRankSSX ZSX = Dict -lemKnownINatRankSSX (_ :$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict -lemKnownINatRankSSX (_ :$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict +lemKnownINatRankSSX :: StaticShX sh -> Dict KnownINat (Rank sh) +lemKnownINatRankSSX ZKSX = Dict +lemKnownINatRankSSX (_ :!$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict +lemKnownINatRankSSX (_ :!$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict -lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh -lemKnownShapeX ZSX = Dict -lemKnownShapeX (GHC_SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict -lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict +lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh +lemKnownShapeX ZKSX = Dict +lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict +lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict -lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2) -lemAppKnownShapeX ZSX ssh' = lemKnownShapeX ssh' -lemAppKnownShapeX (GHC_SNat :$@ ssh) ssh' +lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2) +lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh' +lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict -lemAppKnownShapeX (() :$? ssh) ssh' +lemAppKnownShapeX (() :!$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict -shape :: forall sh a. KnownShapeX sh => XArray sh a -> IIxX sh +shape :: forall sh a. KnownShapeX sh => XArray sh a -> IShX sh shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) where - go :: StaticShapeX sh' -> [Int] -> IIxX sh' - go ZSX [] = ZIX - go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) :.@ go ssh l - go (() :$? ssh) (n : l) = n :.? go ssh l + go :: StaticShX sh' -> [Int] -> IShX sh' + go ZKSX [] = ZSX + go (n :!$@ ssh) (_ : l) = n :$@ go ssh l + go (() :!$? ssh) (n : l) = n :$? go ssh l go _ _ = error "Invalid shapeL" -fromVector :: forall sh a. Storable a => IIxX sh -> VS.Vector a -> XArray sh a +fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownINatRank sh , Dict <- knownNatFromINat (Proxy @(Rank sh)) @@ -235,16 +267,16 @@ scalar = XArray . S.scalar unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a -constant :: forall sh a. Storable a => IIxX sh -> a -> XArray sh a +constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a constant sh x | Dict <- lemKnownINatRank sh , Dict <- knownNatFromINat (Proxy @(Rank sh)) = XArray (S.constant (shapeLshape sh) x) -generate :: Storable a => IIxX sh -> (IIxX sh -> a) -> XArray sh a +generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) --- generateM :: (Monad m, Storable a) => IIxX sh -> (IIxX sh -> m a) -> m (XArray sh a) +-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) -- generateM sh f | Dict <- lemKnownINatRank sh = -- XArray . S.fromVector (shapeLshape sh) -- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) @@ -274,7 +306,7 @@ append (XArray a) (XArray b) rerank :: forall sh sh1 sh2 a b. (Storable a, Storable b) - => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 + => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f (XArray arr) @@ -294,14 +326,14 @@ rerank ssh ssh1 ssh2 f (XArray arr) rerankTop :: forall sh sh1 sh2 a b. (Storable a, Storable b) - => StaticShapeX sh1 -> StaticShapeX sh2 -> StaticShapeX sh + => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh rerank2 :: forall sh sh1 sh2 a b c. (Storable a, Storable b, Storable c) - => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 + => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) @@ -327,7 +359,7 @@ transpose perm (XArray arr) = XArray (S.transpose perm arr) transpose2 :: forall sh1 sh2 a. - StaticShapeX sh1 -> StaticShapeX sh2 + StaticShX sh1 -> StaticShX sh2 -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a transpose2 ssh1 ssh2 (XArray arr) | Refl <- lemRankApp ssh1 ssh2 @@ -344,24 +376,24 @@ sumFull :: (Storable a, Num a) => XArray sh a -> a sumFull (XArray arr) = S.sumA arr sumInner :: forall sh sh' a. (Storable a, Num a) - => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a + => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' | Refl <- lemAppNil @sh - = rerank ssh ssh' ZSX (scalar . sumFull) + = rerank ssh ssh' ZKSX (scalar . sumFull) sumOuter :: forall sh sh' a. (Storable a, Num a) - => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a + => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' | Refl <- lemAppNil @sh = sumInner ssh' ssh . transpose2 ssh ssh' fromList :: forall n sh a. Storable a - => StaticShapeX (n : sh) -> [XArray sh a] -> XArray (n : sh) a + => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a fromList ssh l | Dict <- lemKnownINatRankSSX ssh , Dict <- knownNatFromINat (Proxy @(Rank (n : sh))) = case ssh of - m@GHC_SNat :$@ _ | natVal m /= fromIntegral (length l) -> + m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) -> error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ "does not match the type (" ++ show (natVal m) ++ ")" _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l))) @@ -369,5 +401,13 @@ fromList ssh l toList :: Storable a => XArray (n : sh) a -> [XArray sh a] toList (XArray arr) = coerce (ORB.toList (S.unravel arr)) +-- | Throws if the given shape is not, in fact, empty. +empty :: forall sh a. Storable a => IShX sh -> XArray sh a +empty sh + | Dict <- lemKnownINatRank sh + , Dict <- knownNatFromINat (Proxy @(Rank sh)) + = XArray (S.constant (shapeLshape sh) + (error "Data.Array.Mixed.empty: shape was not empty")) + slice :: [(Int, Int)] -> XArray sh a -> XArray sh a slice ivs (XArray arr) = XArray (S.slice ivs arr) -- cgit v1.2.3-70-g09d2