diff options
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r-- | src/Data/Array/Mixed.hs | 266 |
1 files changed, 216 insertions, 50 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 0351beb..ce18431 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} @@ -7,12 +8,16 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed where @@ -21,14 +26,18 @@ import qualified Data.Array.Ranked as ORB import Data.Coerce import Data.Kind import Data.Proxy +import Data.Type.Bool import Data.Type.Equality import qualified Data.Vector.Storable as VS import Foreign.Storable (Storable) +import GHC.TypeError import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) -import Data.INat +-- | Evidence for the constraint @c a@. +data Dict c a where + Dict :: c a => Dict c a -- | The 'SNat' pattern synonym is complete, but it doesn't have a -- @COMPLETE@ pragma. This copy of it does. @@ -39,6 +48,28 @@ pattern GHC_SNat = SNat fromSNat' :: SNat n -> Int fromSNat' = fromIntegral . fromSNat +pattern SZ :: () => (n ~ 0) => SNat n +pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) + where SZ = SNat + +pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 +pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) + where SS = snatSucc + +{-# COMPLETE SZ, SS #-} + +snatSucc :: SNat n -> SNat (n + 1) +snatSucc SNat = SNat + +data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) +snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) +snatPred snp1 = + withKnownNat snp1 $ + case cmpNat (Proxy @1) (Proxy @np1) of + LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + GTI -> Nothing + -- | Type-level list append. type family l1 ++ l2 where @@ -51,6 +82,11 @@ lemAppNil = unsafeCoerce Refl lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl +type family Replicate n a where + Replicate 0 a = '[] + Replicate n a = a : Replicate (n - 1) a + + type IxX :: [Maybe Nat] -> Type -> Type data IxX sh i where ZIX :: IxX '[] i @@ -103,11 +139,11 @@ instance KnownShapeX sh => KnownShapeX (Nothing : sh) where knownShapeX = () :!$? knownShapeX type family Rank sh where - Rank '[] = Z - Rank (_ : sh) = S (Rank sh) + Rank '[] = 0 + Rank (_ : sh) = 1 + Rank sh type XArray :: [Maybe Nat] -> Type -> Type -newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a) +newtype XArray sh a = XArray (S.Array (Rank sh) a) deriving (Show) zeroIxX :: StaticShX sh -> IIxX sh @@ -157,6 +193,15 @@ ssxToShape' ZKSX = Just ZSX ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh ssxToShape' (_ :!$? _) = Nothing +lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc = unsafeCoerce Refl + +ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) +ssxReplicate SZ = ZKSX +ssxReplicate (SS (n :: SNat n')) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n' + = () :!$? ssxReplicate n + fromLinearIdx :: IShX sh -> Int -> IIxX sh fromLinearIdx = \sh i -> case go sh i of (idx, 0) -> idx @@ -211,23 +256,28 @@ ssxIotaFrom _ ZKSX = [] ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh +staticShapeFrom :: IShX sh -> StaticShX sh +staticShapeFrom ZSX = ZKSX +staticShapeFrom (n :$@ sh) = n :!$@ staticShapeFrom sh +staticShapeFrom (_ :$? sh) = () :!$? staticShapeFrom sh + lemRankApp :: StaticShX sh1 -> StaticShX sh2 - -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2) + -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 - -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1)) + -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this -lemKnownINatRank :: IShX sh -> Dict KnownINat (Rank sh) -lemKnownINatRank ZSX = Dict -lemKnownINatRank (_ :$@ sh) | Dict <- lemKnownINatRank sh = Dict -lemKnownINatRank (_ :$? sh) | Dict <- lemKnownINatRank sh = Dict +lemKnownNatRank :: IShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRank ZSX = Dict +lemKnownNatRank (_ :$@ sh) | Dict <- lemKnownNatRank sh = Dict +lemKnownNatRank (_ :$? sh) | Dict <- lemKnownNatRank sh = Dict -lemKnownINatRankSSX :: StaticShX sh -> Dict KnownINat (Rank sh) -lemKnownINatRankSSX ZKSX = Dict -lemKnownINatRankSSX (_ :!$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict -lemKnownINatRankSSX (_ :!$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict +lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRankSSX ZKSX = Dict +lemKnownNatRankSSX (_ :!$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict +lemKnownNatRankSSX (_ :!$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh lemKnownShapeX ZKSX = Dict @@ -254,8 +304,7 @@ shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v - | Dict <- lemKnownINatRank sh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) + | Dict <- lemKnownNatRank sh = XArray (S.fromVector (shapeLshape sh) v) toVector :: Storable a => XArray sh a -> VS.Vector a @@ -269,15 +318,14 @@ unScalar (XArray a) = S.unScalar a constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a constant sh x - | Dict <- lemKnownINatRank sh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) + | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) x) generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) -- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) --- generateM sh f | Dict <- lemKnownINatRank sh = +-- generateM sh f | Dict <- lemKnownNatRank sh = -- XArray . S.fromVector (shapeLshape sh) -- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) @@ -300,8 +348,7 @@ type family AddMaybe n m where append :: forall n m sh a. (KnownShapeX sh, Storable a) => XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a append (XArray a) (XArray b) - | Dict <- lemKnownINatRankSSX (knownShapeX @sh) - , Dict <- knownNatFromINat (Proxy @(Rank sh)) + | Dict <- lemKnownNatRankSSX (knownShapeX @sh) = XArray (S.append a b) rerank :: forall sh sh1 sh2 a b. @@ -310,21 +357,18 @@ rerank :: forall sh sh1 sh2 a b. -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f (XArray arr) - | Dict <- lemKnownINatRankSSX ssh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - , Dict <- lemKnownINatRankSSX ssh2 - , Dict <- knownNatFromINat (Proxy @(Rank sh2)) + | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (S.rerank @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2)) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough + = XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) (\a -> unXArray (f (XArray a))) arr) where unXArray (XArray a) = a -rerankTop :: forall sh sh1 sh2 a b. +rerankTop :: forall sh1 sh2 sh a b. (Storable a, Storable b) => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh -> (XArray sh1 a -> XArray sh2 b) @@ -337,26 +381,135 @@ rerank2 :: forall sh sh1 sh2 a b c. -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) - | Dict <- lemKnownINatRankSSX ssh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - , Dict <- lemKnownINatRankSSX ssh2 - , Dict <- knownNatFromINat (Proxy @(Rank sh2)) + | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (S.rerank2 @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2)) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- should be redundant but the solver is not clever enough + = XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) (\a b -> unXArray (f (XArray a) (XArray b))) arr1 arr2) where unXArray (XArray a) = a +type family Elem x l where + Elem x '[] = 'False + Elem x (x : _) = 'True + Elem x (_ : ys) = Elem x ys + +type family AllElem' as bs where + AllElem' '[] bs = 'True + AllElem' (a : as) bs = Elem a bs && AllElem' as bs + +type AllElem as bs = Assert (AllElem' as bs) + (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) + +type family Count i n where + Count n n = '[] + Count i n = i : Count (i + 1) n + +type Permutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) + +type family Index i sh where + Index 0 (n : sh) = n + Index i (_ : sh) = Index (i - 1) sh + +type family Permute is sh where + Permute '[] sh = '[] + Permute (i : is) sh = Index i sh : Permute is sh + +type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh + +data HList f list where + HNil :: HList f '[] + HCons :: f a -> HList f l -> HList f (a : l) +infixr 5 `HCons` + +foldHList :: Monoid m => (forall a. f a -> m) -> HList f list -> m +foldHList _ HNil = mempty +foldHList f (x `HCons` l) = f x <> foldHList f l + +class KnownNatList l where makeNatList :: HList SNat l +instance KnownNatList '[] where makeNatList = HNil +instance (KnownNat n, KnownNatList l) => KnownNatList (n : l) where makeNatList = natSing `HCons` makeNatList + +type family TakeLen ref l where + TakeLen '[] l = '[] + TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs + +type family DropLen ref l where + DropLen '[] l = l + DropLen (_ : ref) (_ : xs) = DropLen ref xs + +lemRankPermute :: Proxy sh -> HList SNat is -> Rank (Permute is sh) :~: Rank is +lemRankPermute _ HNil = Refl +lemRankPermute p (_ `HCons` is) | Refl <- lemRankPermute p is = Refl + +lemRankDropLen :: forall is sh. (Rank is <= Rank sh) + => StaticShX sh -> HList SNat is -> Rank (DropLen is sh) :~: Rank sh - Rank is +lemRankDropLen ZKSX HNil = Refl +lemRankDropLen (_ :!$@ sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!$? sh) (_ `HCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!$@ _) HNil = Refl +lemRankDropLen (_ :!$? _) HNil = Refl +lemRankDropLen ZKSX (_ `HCons` _) = error "1 <= 0" + +lemIndexSucc :: Proxy i -> Proxy a -> Proxy l -> Index (i + 1) (a : l) :~: Index i l +lemIndexSucc _ _ _ = unsafeCoerce Refl + +ssxTakeLen :: HList SNat is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLen HNil _ = ZKSX +ssxTakeLen (_ `HCons` is) (n :!$@ sh) = n :!$@ ssxTakeLen is sh +ssxTakeLen (_ `HCons` is) (n :!$? sh) = n :!$? ssxTakeLen is sh +ssxTakeLen (_ `HCons` _) ZKSX = error "Permutation longer than shape" + +ssxDropLen :: HList SNat is -> StaticShX sh -> StaticShX (DropLen is sh) +ssxDropLen HNil sh = sh +ssxDropLen (_ `HCons` is) (_ :!$@ sh) = ssxDropLen is sh +ssxDropLen (_ `HCons` is) (_ :!$? sh) = ssxDropLen is sh +ssxDropLen (_ `HCons` _) ZKSX = error "Permutation longer than shape" + +ssxPermute :: HList SNat is -> StaticShX sh -> StaticShX (Permute is sh) +ssxPermute HNil _ = ZKSX +ssxPermute (i `HCons` (is :: HList SNat is')) (sh :: StaticShX sh) = ssxIndex (Proxy @is') (Proxy @sh) i sh (ssxPermute is sh) + +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> StaticShX (Permute is shT) -> StaticShX (Index i sh : Permute is shT) +ssxIndex _ _ SZ (n :!$@ _) rest = n :!$@ rest +ssxIndex _ _ SZ (n :!$? _) rest = n :!$? rest +ssxIndex p pT (SS (i :: SNat i')) ((_ :: SNat n) :!$@ (sh :: StaticShX sh')) rest + | Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @sh') + = ssxIndex p pT i sh rest +ssxIndex p pT (SS (i :: SNat i')) (() :!$? (sh :: StaticShX sh')) rest + | Refl <- lemIndexSucc (Proxy @i') (Proxy @Nothing) (Proxy @sh') + = ssxIndex p pT i sh rest +ssxIndex _ _ _ ZKSX _ = error "Index into empty shape" + -- | The list argument gives indices into the original dimension list. -transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a +transpose :: forall is sh a. (Permutation is, Rank is <= Rank sh, KnownShapeX sh) + => HList SNat is + -> XArray sh a + -> XArray (PermutePrefix is sh) a transpose perm (XArray arr) - | Dict <- lemKnownINatRankSSX (knownShapeX @sh) - , Dict <- knownNatFromINat (Proxy @(Rank sh)) + | Dict <- lemKnownNatRankSSX (knownShapeX @sh) + , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm (knownShapeX @sh))) (ssxDropLen perm (knownShapeX @sh)) + , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm + , Refl <- lemRankDropLen (knownShapeX @sh) perm + = let perm' = foldHList (\sn -> [fromSNat' sn]) perm :: [Int] + in XArray (S.transpose perm' arr) + +-- | The list argument gives indices into the original dimension list. +-- +-- The permutation (the list) must have length <= @n@. If it is longer, this +-- function throws. +transposeUntyped :: forall n sh a. + SNat n -> StaticShX sh -> [Int] + -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a +transposeUntyped sn ssh perm (XArray arr) + | length perm <= fromSNat' sn + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) = XArray (S.transpose perm arr) + | otherwise + = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" transpose2 :: forall sh1 sh2 a. StaticShX sh1 -> StaticShX sh2 @@ -364,10 +517,8 @@ transpose2 :: forall sh1 sh2 a. transpose2 ssh1 ssh2 (XArray arr) | Refl <- lemRankApp ssh1 ssh2 , Refl <- lemRankApp ssh2 ssh1 - , Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh2) - , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh2))) - , Dict <- lemKnownINatRankSSX (ssxAppend ssh2 ssh1) - , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh1))) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) @@ -390,13 +541,12 @@ sumOuter ssh ssh' fromList1 :: forall n sh a. Storable a => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a fromList1 ssh l - | Dict <- lemKnownINatRankSSX ssh - , Dict <- knownNatFromINat (Proxy @(Rank (n : sh))) + | Dict <- lemKnownNatRankSSX ssh = case ssh of m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) -> error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ "does not match the type (" ++ show (natVal m) ++ ")" - _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l))) + _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a] toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr)) @@ -404,13 +554,29 @@ toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr)) -- | Throws if the given shape is not, in fact, empty. empty :: forall sh a. Storable a => IShX sh -> XArray sh a empty sh - | Dict <- lemKnownINatRank sh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) + | Dict <- lemKnownNatRank sh = XArray (S.constant (shapeLshape sh) (error "Data.Array.Mixed.empty: shape was not empty")) -slice :: [(Int, Int)] -> XArray sh a -> XArray sh a -slice ivs (XArray arr) = XArray (S.slice ivs arr) +slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a +slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr) + +sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a +sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr) rev1 :: XArray (n : sh) a -> XArray (n : sh) a rev1 (XArray arr) = XArray (S.rev [0] arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a +reshape ssh1 sh2 (XArray arr) + | Dict <- lemKnownNatRankSSX ssh1 + , Dict <- lemKnownNatRank sh2 + = XArray (S.reshape (shapeLshape sh2) arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +reshapePartial ssh1 ssh' sh2 (XArray arr) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') + , Dict <- lemKnownNatRankSSX (ssxAppend (staticShapeFrom sh2) ssh') + = XArray (S.reshape (shapeLshape sh2 ++ drop (length sh2) (S.shapeL arr)) arr) |