diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-14 23:30:53 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-14 23:30:53 +0200 |
commit | 43ddff2e7f1e9f4d8855f573384e26b63d34f697 (patch) | |
tree | 86b6989d4a1b935fd6f6338b4699d8e5c0083a2c /src/Data/Array/Mixed.hs | |
parent | 77ab86ede90938fa43f7f9988ac10f7026440a1c (diff) |
WIP GHC nats
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r-- | src/Data/Array/Mixed.hs | 93 |
1 files changed, 39 insertions, 54 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index d782e9f..672b832 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -13,6 +13,7 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed where @@ -27,8 +28,10 @@ import Foreign.Storable (Storable) import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) -import Data.INat +-- | Evidence for the constraint @c a@. +data Dict c a where + Dict :: c a => Dict c a -- | The 'SNat' pattern synonym is complete, but it doesn't have a -- @COMPLETE@ pragma. This copy of it does. @@ -103,11 +106,11 @@ instance KnownShapeX sh => KnownShapeX (Nothing : sh) where knownShapeX = () :!$? knownShapeX type family Rank sh where - Rank '[] = Z - Rank (_ : sh) = S (Rank sh) + Rank '[] = 0 + Rank (_ : sh) = 1 + Rank sh type XArray :: [Maybe Nat] -> Type -> Type -newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a) +newtype XArray sh a = XArray (S.Array (Rank sh) a) deriving (Show) zeroIxX :: StaticShX sh -> IIxX sh @@ -217,22 +220,22 @@ staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh lemRankApp :: StaticShX sh1 -> StaticShX sh2 - -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2) + -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 - -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1)) + -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this -lemKnownINatRank :: IShX sh -> Dict KnownINat (Rank sh) -lemKnownINatRank ZSX = Dict -lemKnownINatRank (_ :$@ sh) | Dict <- lemKnownINatRank sh = Dict -lemKnownINatRank (_ :$? sh) | Dict <- lemKnownINatRank sh = Dict +lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRank ZSX = Dict +lemKnownNatRank (_ :$@ sh) | Dict <- lemKnownNatRank sh = Dict +lemKnownNatRank (_ :$? sh) | Dict <- lemKnownNatRank sh = Dict -lemKnownINatRankSSX :: StaticShX sh -> Dict KnownINat (Rank sh) -lemKnownINatRankSSX ZKSX = Dict -lemKnownINatRankSSX (_ :!$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict -lemKnownINatRankSSX (_ :!$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict +lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRankSSX ZKSX = Dict +lemKnownNatRankSSX (_ :!$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict +lemKnownNatRankSSX (_ :!$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh lemKnownShapeX ZKSX = Dict @@ -259,8 +262,7 @@ shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v - | Dict <- lemKnownINatRank sh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) + | Dict <- lemKnownNatRank sh = XArray (S.fromVector (shapeLshape sh) v) toVector :: Storable a => XArray sh a -> VS.Vector a @@ -274,15 +276,14 @@ unScalar (XArray a) = S.unScalar a constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a constant sh x - | Dict <- lemKnownINatRank sh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) + | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) x) generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) -- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) --- generateM sh f | Dict <- lemKnownINatRank sh = +-- generateM sh f | Dict <- lemKnownNatRank sh = -- XArray . S.fromVector (shapeLshape sh) -- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) @@ -305,8 +306,7 @@ type family AddMaybe n m where append :: forall n m sh a. (KnownShapeX sh, Storable a) => XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a append (XArray a) (XArray b) - | Dict <- lemKnownINatRankSSX (knownShapeX @sh) - , Dict <- knownNatFromINat (Proxy @(Rank sh)) + | Dict <- lemKnownNatRankSSX (knownShapeX @sh) = XArray (S.append a b) rerank :: forall sh sh1 sh2 a b. @@ -315,15 +315,12 @@ rerank :: forall sh sh1 sh2 a b. -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f (XArray arr) - | Dict <- lemKnownINatRankSSX ssh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - , Dict <- lemKnownINatRankSSX ssh2 - , Dict <- knownNatFromINat (Proxy @(Rank sh2)) + | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (S.rerank @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2)) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough + = XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) (\a -> unXArray (f (XArray a))) arr) where @@ -342,15 +339,12 @@ rerank2 :: forall sh sh1 sh2 a b c. -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) - | Dict <- lemKnownINatRankSSX ssh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - , Dict <- lemKnownINatRankSSX ssh2 - , Dict <- knownNatFromINat (Proxy @(Rank sh2)) + | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (S.rerank2 @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2)) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the + = XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) (\a b -> unXArray (f (XArray a) (XArray b))) arr1 arr2) where @@ -359,8 +353,7 @@ rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) -- | The list argument gives indices into the original dimension list. transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a transpose perm (XArray arr) - | Dict <- lemKnownINatRankSSX (knownShapeX @sh) - , Dict <- knownNatFromINat (Proxy @(Rank sh)) + | Dict <- lemKnownNatRankSSX (knownShapeX @sh) = XArray (S.transpose perm arr) transpose2 :: forall sh1 sh2 a. @@ -369,10 +362,8 @@ transpose2 :: forall sh1 sh2 a. transpose2 ssh1 ssh2 (XArray arr) | Refl <- lemRankApp ssh1 ssh2 , Refl <- lemRankApp ssh2 ssh1 - , Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh2) - , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh2))) - , Dict <- lemKnownINatRankSSX (ssxAppend ssh2 ssh1) - , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh1))) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) @@ -395,13 +386,12 @@ sumOuter ssh ssh' fromList1 :: forall n sh a. Storable a => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a fromList1 ssh l - | Dict <- lemKnownINatRankSSX ssh - , Dict <- knownNatFromINat (Proxy @(Rank (n : sh))) + | Dict <- lemKnownNatRankSSX ssh = case ssh of m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) -> error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ "does not match the type (" ++ show (natVal m) ++ ")" - _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l))) + _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a] toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr)) @@ -409,8 +399,7 @@ toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr)) -- | Throws if the given shape is not, in fact, empty. empty :: forall sh a. Storable a => IShX sh -> XArray sh a empty sh - | Dict <- lemKnownINatRank sh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) + | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) (error "Data.Array.Mixed.empty: shape was not empty")) @@ -423,17 +412,13 @@ rev1 (XArray arr) = XArray (S.rev [0] arr) -- | Throws if the given array and the target shape do not have the same number of elements. reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a reshape ssh1 sh2 (XArray arr) - | Dict <- lemKnownINatRankSSX ssh1 - , Dict <- knownNatFromINat (Proxy @(Rank sh1)) - , Dict <- lemKnownINatRank sh2 - , Dict <- knownNatFromINat (Proxy @(Rank sh2)) + | Dict <- lemKnownNatRankSSX ssh1 + , Dict <- lemKnownNatRank sh2 = XArray (S.reshape (shapeLshape sh2) arr) -- | Throws if the given array and the target shape do not have the same number of elements. reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a reshapePartial ssh1 ssh' sh2 (XArray arr) - | Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh') - , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh'))) - , Dict <- lemKnownINatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') - , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh'))) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') + , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') = XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr) |