diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-05-13 22:47:42 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-05-13 22:47:42 +0200 |
commit | e4e23a33f77d250af1e9b6614cf249128ba1510a (patch) | |
tree | 34bb40910003749becbaf8005a7b7ca62024fff2 /src | |
parent | 7c9865354442326d55094087ad6a74b6e96341fb (diff) |
Shape/index hygiene
Diffstat (limited to 'src')
-rw-r--r-- | src/Data/Array/Mixed.hs | 246 | ||||
-rw-r--r-- | src/Data/Array/Nested.hs | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 239 | ||||
-rw-r--r-- | src/Data/INat.hs | 1 |
4 files changed, 257 insertions, 231 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 47027cb..94b7cdf 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -1,6 +1,10 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -32,6 +36,9 @@ pattern GHC_SNat :: () => KnownNat n => SNat n pattern GHC_SNat = SNat {-# COMPLETE GHC_SNat #-} +fromSNat' :: SNat n -> Int +fromSNat' = fromIntegral . fromSNat + -- | Type-level list append. type family l1 ++ l2 where @@ -44,9 +51,6 @@ lemAppNil = unsafeCoerce Refl lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl --- TODO: ListX? But if so, why is StaticShapeX not defined as a newtype --- over IxX (so that we can make IxX and StaticShapeX a newtype over ListX)? - type IxX :: [Maybe Nat] -> Type -> Type data IxX sh i where ZIX :: IxX '[] i @@ -55,31 +59,48 @@ data IxX sh i where deriving instance Show i => Show (IxX sh i) deriving instance Eq i => Eq (IxX sh i) deriving instance Ord i => Ord (IxX sh i) +deriving instance Functor (IxX sh) +deriving instance Foldable (IxX sh) infixr 3 :.@ infixr 3 :.? type IIxX sh = IxX sh Int --- | The part of a shape that is statically known. -type StaticShapeX :: [Maybe Nat] -> Type -data StaticShapeX sh where - ZSX :: StaticShapeX '[] - (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) - (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) -deriving instance Show (StaticShapeX sh) +type ShX :: [Maybe Nat] -> Type -> Type +data ShX sh i where + ZSX :: ShX '[] i + (:$@) :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i + (:$?) :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i +deriving instance Show i => Show (ShX sh i) +deriving instance Eq i => Eq (ShX sh i) +deriving instance Ord i => Ord (ShX sh i) +deriving instance Functor (ShX sh) +deriving instance Foldable (ShX sh) infixr 3 :$@ infixr 3 :$? +type IShX sh = ShX sh Int + +-- | The part of a shape that is statically known. +type StaticShX :: [Maybe Nat] -> Type +data StaticShX sh where + ZKSX :: StaticShX '[] + (:!$@) :: SNat n -> StaticShX sh -> StaticShX (Just n : sh) + (:!$?) :: () -> StaticShX sh -> StaticShX (Nothing : sh) +deriving instance Show (StaticShX sh) +infixr 3 :!$@ +infixr 3 :!$? + -- | Evidence for the static part of a shape. type KnownShapeX :: [Maybe Nat] -> Constraint class KnownShapeX sh where - knownShapeX :: StaticShapeX sh + knownShapeX :: StaticShX sh instance KnownShapeX '[] where - knownShapeX = ZSX + knownShapeX = ZKSX instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where - knownShapeX = natSing :$@ knownShapeX + knownShapeX = natSing :!$@ knownShapeX instance KnownShapeX sh => KnownShapeX (Nothing : sh) where - knownShapeX = () :$? knownShapeX + knownShapeX = () :!$? knownShapeX type family Rank sh where Rank '[] = Z @@ -89,138 +110,149 @@ type XArray :: [Maybe Nat] -> Type -> Type newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a) deriving (Show) -zeroIxX :: StaticShapeX sh -> IIxX sh -zeroIxX ZSX = ZIX -zeroIxX (_ :$@ ssh) = 0 :.@ zeroIxX ssh -zeroIxX (_ :$? ssh) = 0 :.? zeroIxX ssh +zeroIxX :: StaticShX sh -> IIxX sh +zeroIxX ZKSX = ZIX +zeroIxX (_ :!$@ ssh) = 0 :.@ zeroIxX ssh +zeroIxX (_ :!$? ssh) = 0 :.? zeroIxX ssh + +zeroIxX' :: IShX sh -> IIxX sh +zeroIxX' ZSX = ZIX +zeroIxX' (_ :$@ sh) = 0 :.@ zeroIxX' sh +zeroIxX' (_ :$? sh) = 0 :.? zeroIxX' sh -zeroIxX' :: IIxX sh -> IIxX sh -zeroIxX' ZIX = ZIX -zeroIxX' (_ :.@ sh) = 0 :.@ zeroIxX' sh -zeroIxX' (_ :.? sh) = 0 :.? zeroIxX' sh +-- This is a weird operation, so it has a long name +completeShXzeros :: StaticShX sh -> IShX sh +completeShXzeros ZKSX = ZSX +completeShXzeros (n :!$@ ssh) = n :$@ completeShXzeros ssh +completeShXzeros (_ :!$? ssh) = 0 :$? completeShXzeros ssh ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh') ixAppend ZIX idx' = idx' ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx' ixAppend (i :.? idx) idx' = i :.? ixAppend idx idx' +shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh') +shAppend ZSX sh' = sh' +shAppend (n :$@ sh) sh' = n :$@ shAppend sh sh' +shAppend (n :$? sh) sh' = n :$? shAppend sh sh' + ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh' ixDrop sh ZIX = sh ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx -ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh') -ssxAppend ZSX sh' = sh' -ssxAppend (n :$@ sh) sh' = n :$@ ssxAppend sh sh' -ssxAppend (() :$? sh) sh' = () :$? ssxAppend sh sh' +ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') +ssxAppend ZKSX sh' = sh' +ssxAppend (n :!$@ sh) sh' = n :!$@ ssxAppend sh sh' +ssxAppend (() :!$? sh) sh' = () :!$? ssxAppend sh sh' -shapeSize :: IIxX sh -> Int -shapeSize ZIX = 1 -shapeSize (n :.@ sh) = n * shapeSize sh -shapeSize (n :.? sh) = n * shapeSize sh +shapeSize :: IShX sh -> Int +shapeSize ZSX = 1 +shapeSize (n :$@ sh) = fromSNat' n * shapeSize sh +shapeSize (n :$? sh) = n * shapeSize sh -- | This may fail if @sh@ has @Nothing@s in it. -ssxToShape' :: StaticShapeX sh -> Maybe (IIxX sh) -ssxToShape' ZSX = Just ZIX -ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) :.@) <$> ssxToShape' sh -ssxToShape' (_ :$? _) = Nothing +ssxToShape' :: StaticShX sh -> Maybe (IShX sh) +ssxToShape' ZKSX = Just ZSX +ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh +ssxToShape' (_ :!$? _) = Nothing -fromLinearIdx :: IIxX sh -> Int -> IIxX sh +fromLinearIdx :: IShX sh -> Int -> IIxX sh fromLinearIdx = \sh i -> case go sh i of (idx, 0) -> idx _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ " in array of shape " ++ show sh ++ ")" where -- returns (index in subarray, remaining index in enclosing array) - go :: IIxX sh -> Int -> (IIxX sh, Int) - go ZIX i = (ZIX, i) - go (n :.@ sh) i = + go :: IShX sh -> Int -> (IIxX sh, Int) + go ZSX i = (ZIX, i) + go (n :$@ sh) i = let (idx, i') = go sh i - (upi, locali) = i' `quotRem` n + (upi, locali) = i' `quotRem` fromSNat' n in (locali :.@ idx, upi) - go (n :.? sh) i = + go (n :$? sh) i = let (idx, i') = go sh i (upi, locali) = i' `quotRem` n in (locali :.? idx, upi) -toLinearIdx :: IIxX sh -> IIxX sh -> Int +toLinearIdx :: IShX sh -> IIxX sh -> Int toLinearIdx = \sh i -> fst (go sh i) where -- returns (index in subarray, size of subarray) - go :: IIxX sh -> IIxX sh -> (Int, Int) - go ZIX ZIX = (0, 1) - go (n :.@ sh) (i :.@ ix) = + go :: IShX sh -> IIxX sh -> (Int, Int) + go ZSX ZIX = (0, 1) + go (n :$@ sh) (i :.@ ix) = let (lidx, sz) = go sh ix - in (sz * i + lidx, n * sz) - go (n :.? sh) (i :.? ix) = + in (sz * i + lidx, fromSNat' n * sz) + go (n :$? sh) (i :.? ix) = let (lidx, sz) = go sh ix in (sz * i + lidx, n * sz) -enumShape :: IIxX sh -> [IIxX sh] +enumShape :: IShX sh -> [IIxX sh] enumShape = \sh -> go sh id [] where - go :: IIxX sh -> (IIxX sh -> a) -> [a] -> [a] - go ZIX f = (f ZIX :) - go (n :.@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. n-1]] - go (n :.? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]] - -shapeLshape :: IIxX sh -> S.ShapeL -shapeLshape ZIX = [] -shapeLshape (n :.@ sh) = n : shapeLshape sh -shapeLshape (n :.? sh) = n : shapeLshape sh - -ssxLength :: StaticShapeX sh -> Int -ssxLength ZSX = 0 -ssxLength (_ :$@ ssh) = 1 + ssxLength ssh -ssxLength (_ :$? ssh) = 1 + ssxLength ssh - -ssxIotaFrom :: Int -> StaticShapeX sh -> [Int] -ssxIotaFrom _ ZSX = [] -ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh -ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh - -lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2 + go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] + go ZSX f = (f ZIX :) + go (n :$@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. fromSNat' n - 1]] + go (n :$? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]] + +shapeLshape :: IShX sh -> S.ShapeL +shapeLshape ZSX = [] +shapeLshape (n :$@ sh) = fromSNat' n : shapeLshape sh +shapeLshape (n :$? sh) = n : shapeLshape sh + +ssxLength :: StaticShX sh -> Int +ssxLength ZKSX = 0 +ssxLength (_ :!$@ ssh) = 1 + ssxLength ssh +ssxLength (_ :!$? ssh) = 1 + ssxLength ssh + +ssxIotaFrom :: Int -> StaticShX sh -> [Int] +ssxIotaFrom _ ZKSX = [] +ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh + +lemRankApp :: StaticShX sh1 -> StaticShX sh2 -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2) lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this -lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2 +lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1)) lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this -lemKnownINatRank :: IIxX sh -> Dict KnownINat (Rank sh) -lemKnownINatRank ZIX = Dict -lemKnownINatRank (_ :.@ sh) | Dict <- lemKnownINatRank sh = Dict -lemKnownINatRank (_ :.? sh) | Dict <- lemKnownINatRank sh = Dict +lemKnownINatRank :: IShX sh -> Dict KnownINat (Rank sh) +lemKnownINatRank ZSX = Dict +lemKnownINatRank (_ :$@ sh) | Dict <- lemKnownINatRank sh = Dict +lemKnownINatRank (_ :$? sh) | Dict <- lemKnownINatRank sh = Dict -lemKnownINatRankSSX :: StaticShapeX sh -> Dict KnownINat (Rank sh) -lemKnownINatRankSSX ZSX = Dict -lemKnownINatRankSSX (_ :$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict -lemKnownINatRankSSX (_ :$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict +lemKnownINatRankSSX :: StaticShX sh -> Dict KnownINat (Rank sh) +lemKnownINatRankSSX ZKSX = Dict +lemKnownINatRankSSX (_ :!$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict +lemKnownINatRankSSX (_ :!$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict -lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh -lemKnownShapeX ZSX = Dict -lemKnownShapeX (GHC_SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict -lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict +lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh +lemKnownShapeX ZKSX = Dict +lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict +lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict -lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2) -lemAppKnownShapeX ZSX ssh' = lemKnownShapeX ssh' -lemAppKnownShapeX (GHC_SNat :$@ ssh) ssh' +lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2) +lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh' +lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict -lemAppKnownShapeX (() :$? ssh) ssh' +lemAppKnownShapeX (() :!$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict -shape :: forall sh a. KnownShapeX sh => XArray sh a -> IIxX sh +shape :: forall sh a. KnownShapeX sh => XArray sh a -> IShX sh shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) where - go :: StaticShapeX sh' -> [Int] -> IIxX sh' - go ZSX [] = ZIX - go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) :.@ go ssh l - go (() :$? ssh) (n : l) = n :.? go ssh l + go :: StaticShX sh' -> [Int] -> IShX sh' + go ZKSX [] = ZSX + go (n :!$@ ssh) (_ : l) = n :$@ go ssh l + go (() :!$? ssh) (n : l) = n :$? go ssh l go _ _ = error "Invalid shapeL" -fromVector :: forall sh a. Storable a => IIxX sh -> VS.Vector a -> XArray sh a +fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownINatRank sh , Dict <- knownNatFromINat (Proxy @(Rank sh)) @@ -235,16 +267,16 @@ scalar = XArray . S.scalar unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a -constant :: forall sh a. Storable a => IIxX sh -> a -> XArray sh a +constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a constant sh x | Dict <- lemKnownINatRank sh , Dict <- knownNatFromINat (Proxy @(Rank sh)) = XArray (S.constant (shapeLshape sh) x) -generate :: Storable a => IIxX sh -> (IIxX sh -> a) -> XArray sh a +generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) --- generateM :: (Monad m, Storable a) => IIxX sh -> (IIxX sh -> m a) -> m (XArray sh a) +-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) -- generateM sh f | Dict <- lemKnownINatRank sh = -- XArray . S.fromVector (shapeLshape sh) -- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) @@ -274,7 +306,7 @@ append (XArray a) (XArray b) rerank :: forall sh sh1 sh2 a b. (Storable a, Storable b) - => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 + => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f (XArray arr) @@ -294,14 +326,14 @@ rerank ssh ssh1 ssh2 f (XArray arr) rerankTop :: forall sh sh1 sh2 a b. (Storable a, Storable b) - => StaticShapeX sh1 -> StaticShapeX sh2 -> StaticShapeX sh + => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh rerank2 :: forall sh sh1 sh2 a b c. (Storable a, Storable b, Storable c) - => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 + => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) @@ -327,7 +359,7 @@ transpose perm (XArray arr) = XArray (S.transpose perm arr) transpose2 :: forall sh1 sh2 a. - StaticShapeX sh1 -> StaticShapeX sh2 + StaticShX sh1 -> StaticShX sh2 -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a transpose2 ssh1 ssh2 (XArray arr) | Refl <- lemRankApp ssh1 ssh2 @@ -344,24 +376,24 @@ sumFull :: (Storable a, Num a) => XArray sh a -> a sumFull (XArray arr) = S.sumA arr sumInner :: forall sh sh' a. (Storable a, Num a) - => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a + => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' | Refl <- lemAppNil @sh - = rerank ssh ssh' ZSX (scalar . sumFull) + = rerank ssh ssh' ZKSX (scalar . sumFull) sumOuter :: forall sh sh' a. (Storable a, Num a) - => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a + => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' | Refl <- lemAppNil @sh = sumInner ssh' ssh . transpose2 ssh ssh' fromList :: forall n sh a. Storable a - => StaticShapeX (n : sh) -> [XArray sh a] -> XArray (n : sh) a + => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a fromList ssh l | Dict <- lemKnownINatRankSSX ssh , Dict <- knownNatFromINat (Proxy @(Rank (n : sh))) = case ssh of - m@GHC_SNat :$@ _ | natVal m /= fromIntegral (length l) -> + m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) -> error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ "does not match the type (" ++ show (natVal m) ++ ")" _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l))) @@ -369,5 +401,13 @@ fromList ssh l toList :: Storable a => XArray (n : sh) a -> [XArray sh a] toList (XArray arr) = coerce (ORB.toList (S.unravel arr)) +-- | Throws if the given shape is not, in fact, empty. +empty :: forall sh a. Storable a => IShX sh -> XArray sh a +empty sh + | Dict <- lemKnownINatRank sh + , Dict <- knownNatFromINat (Proxy @(Rank sh)) + = XArray (S.constant (shapeLshape sh) + (error "Data.Array.Mixed.empty: shape was not empty")) + slice :: [(Int, Int)] -> XArray sh a -> XArray sh a slice ivs (XArray arr) = XArray (S.slice ivs arr) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index 370e30a..145598e 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -28,7 +28,7 @@ module Data.Array.Nested ( -- * Mixed arrays Mixed, IxX(..), IIxX, - KnownShapeX(..), StaticShapeX(..), + KnownShapeX(..), StaticShX(..), mgenerate, mtranspose, mappend, mfromVector, munScalar, mconstant, mfromList1, mtoList1, mslice, diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs index 49ed7cb..2f1e79e 100644 --- a/src/Data/Array/Nested/Internal.hs +++ b/src/Data/Array/Nested/Internal.hs @@ -8,12 +8,12 @@ {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} @@ -23,7 +23,7 @@ {-| TODO: -* We should be more consistent in whether functions take a 'StaticShapeX' +* We should be more consistent in whether functions take a 'StaticShX' argument or a 'KnownShapeX' constraint. * Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point @@ -51,7 +51,7 @@ import qualified Data.Vector.Storable.Mutable as VSM import Foreign.Storable (Storable) import GHC.TypeLits -import Data.Array.Mixed (XArray, IxX(..), IIxX, KnownShapeX(..), StaticShapeX(..), type (++), pattern GHC_SNat) +import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat) import qualified Data.Array.Mixed as X import Data.INat @@ -100,9 +100,9 @@ type family MapJust l where lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n)) where - go :: SINat m -> StaticShapeX (Replicate m Nothing) - go SZ = ZSX - go (SS n) = () :$? go n + go :: SINat m -> StaticShX (Replicate m Nothing) + go SZ = ZKSX + go (SS n) = () :!$? go n lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n lemRankReplicate _ = go (inatSing @n) @@ -119,10 +119,10 @@ lemReplicatePlusApp _ _ _ = go (inatSing @n) go SZ = Refl go (SS n) | Refl <- go n = Refl -ixAppSplit :: Proxy sh' -> StaticShapeX sh -> IIxX (sh ++ sh') -> (IIxX sh, IIxX sh') -ixAppSplit _ ZSX idx = (ZIX, idx) -ixAppSplit p (_ :$@ ssh) (i :.@ idx) = first (i :.@) (ixAppSplit p ssh idx) -ixAppSplit p (_ :$? ssh) (i :.? idx) = first (i :.?) (ixAppSplit p ssh idx) +shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') +shAppSplit _ ZKSX idx = (ZSX, idx) +shAppSplit p (_ :!$@ ssh) (i :$@ idx) = first (i :$@) (shAppSplit p ssh idx) +shAppSplit p (_ :!$? ssh) (i :$? idx) = first (i :$?) (shAppSplit p ssh idx) -- | Wrapper type used as a tag to attach instances on. The instances on arrays @@ -184,7 +184,7 @@ newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MV data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) -- etc. -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IIxX sh2) !(MixedVecs s (sh1 ++ sh2) a) +data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) -- | Tree giving the shape of every array component. @@ -196,9 +196,9 @@ type family ShapeTree a where ShapeTree () = () ShapeTree (a, b) = (ShapeTree a, ShapeTree b) - ShapeTree (Mixed sh a) = (IIxX sh, ShapeTree a) - ShapeTree (Ranked n a) = (IIxR n, ShapeTree a) - ShapeTree (Shaped sh a) = (IIxS sh, ShapeTree a) + ShapeTree (Mixed sh a) = (IShX sh, ShapeTree a) + ShapeTree (Ranked n a) = (IShR n, ShapeTree a) + ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) -- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or @@ -207,7 +207,7 @@ type family ShapeTree a where class Elt a where -- ====== PUBLIC METHODS ====== -- - mshape :: KnownShapeX sh => Mixed sh a -> IIxX sh + mshape :: KnownShapeX sh => Mixed sh a -> IShX sh mindex :: Mixed sh a -> IIxX sh -> a mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a mscalar :: a -> Mixed '[] a @@ -240,15 +240,12 @@ class Elt a where -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a -- ====== PRIVATE METHODS ====== -- - -- Remember I said that this module needed better management of exports? -- | Create an empty array. The given shape must have size zero; this may or may not be checked. - memptyArray :: IIxX sh -> Mixed sh a + memptyArray :: IShX sh -> Mixed sh a mshapeTree :: a -> ShapeTree a - mshapeTreeZero :: Proxy a -> ShapeTree a - mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool @@ -257,20 +254,20 @@ class Elt a where -- | Create uninitialised vectors for this array type, given the shape of -- this vector and an example for the contents. - mvecsUnsafeNew :: IIxX sh -> a -> ST s (MixedVecs s sh a) + mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. - mvecsWrite :: IIxX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. - mvecsWritePartial :: KnownShapeX sh' => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + mvecsWritePartial :: KnownShapeX sh' => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () -- | Given the shape of this array, finalise the vectors into 'XArray's. - mvecsFreeze :: IIxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) -- Arrays of scalars are basically just arrays of scalars. @@ -299,9 +296,8 @@ instance Storable a => Elt (Primitive a) where , Refl <- X.lemAppNil @sh3 = M_Primitive (f Proxy a b) - memptyArray sh = M_Primitive (X.generate sh (error $ "memptyArray Int: shape was not empty (" ++ show sh ++ ")")) + memptyArray sh = M_Primitive (X.empty sh) mshapeTree _ = () - mshapeTreeZero _ = () mshapeTreeEq _ () () = True mshapeTreeEmpty _ () = False mshowShapeTree _ () = "()" @@ -312,7 +308,7 @@ instance Storable a => Elt (Primitive a) where -- TODO: this use of toVector is suboptimal mvecsWritePartial :: forall sh' sh s. KnownShapeX sh' - => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' (X.shape arr))) VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) @@ -338,7 +334,6 @@ instance (Elt a, Elt b) => Elt (a, b) where memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) mshapeTree (x, y) = (mshapeTree x, mshapeTree y) - mshapeTreeZero _ = (mshapeTreeZero (Proxy @a), mshapeTreeZero (Proxy @b)) mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" @@ -357,10 +352,10 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where -- TODO: this is quadratic in the nesting depth because it repeatedly -- truncates the shape vector to one a little shorter. Fix with a -- moverlongShape method, a prefix of which is mshape. - mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IIxX sh + mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IShX sh mshape (M_Nest arr) | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') - = fst (ixAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr)) + = fst (shAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr)) mindex (M_Nest arr) i = mindexPartial arr i @@ -410,12 +405,10 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) = f (Proxy @(sh' ++ shT)) - memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIxX (knownShapeX @sh')))) + memptyArray sh = M_Nest (memptyArray (X.shAppend sh (X.completeShXzeros (knownShapeX @sh')))) mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (knownShapeX @sh')))) - mshapeTreeZero _ = (X.zeroIxX (knownShapeX @sh'), mshapeTreeZero (Proxy @a)) - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t @@ -424,32 +417,26 @@ instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where mvecsUnsafeNew sh example | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh (mshape example)) (mindex example (X.zeroIxX (knownShapeX @sh'))) where sh' = mshape example - mvecsNewEmpty _ = MV_Nest (X.zeroIxX (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a) + mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a) - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -> ST s () mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs - - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs + = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.shAppend sh12 sh') idx arr vecs + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.shAppend sh sh') vecs --- | Check whether a given shape corresponds on the statically-known components of the shape. -checkBounds :: IIxX sh' -> StaticShapeX sh' -> Bool -checkBounds ZIX ZSX = True -checkBounds (n :.@ sh') (n' :$@ ssh') = n == fromIntegral (fromSNat n') && checkBounds sh' ssh' -checkBounds (_ :.? sh') (() :$? ssh') = checkBounds sh' ssh' -- | Create an array given a size and a function that computes the element at a -- given index. @@ -468,31 +455,25 @@ checkBounds (_ :.? sh') (() :$? ssh') = checkBounds sh' ssh' -- the entire hierarchy (after distributing out tuples) must be a rectangular -- array. The type of 'mgenerate' allows this requirement to be broken very -- easily, hence the runtime check. -mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IIxX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f - -- TODO: Do we need this checkBounds check elsewhere as well? - | not (checkBounds sh (knownShapeX @sh)) = - error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) - -- If the shape is empty, there is no first element, so we should not try to - -- generate it. - | X.shapeSize sh == 0 = memptyArray sh - | otherwise = - let firstidx = X.zeroIxX' sh - firstelem = f (X.zeroIxX' sh) - shapetree = mshapeTree firstelem - in if mshapeTreeEmpty (Proxy @a) shapetree - then memptyArray sh - else runST $ do - vecs <- mvecsUnsafeNew sh firstelem - mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. - forM_ (tail (X.enumShape sh)) $ \idx -> do - let val = f idx - when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ - error "Data.Array.Nested mgenerate: generated values do not have equal shapes" - mvecsWrite sh idx val vecs - mvecsFreeze sh vecs +mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgenerate sh f = case X.enumShape sh of + [] -> memptyArray sh + firstidx : restidxs -> + let firstelem = f (X.zeroIxX' sh) + shapetree = mshapeTree firstelem + in if mshapeTreeEmpty (Proxy @a) shapetree + then memptyArray sh + else runST $ do + vecs <- mvecsUnsafeNew sh firstelem + mvecsWrite sh firstidx firstelem vecs + -- TODO: This is likely fine if @a@ is big, but if @a@ is a + -- scalar this array copying inefficient. Should improve this. + forM_ restidxs $ \idx -> do + let val = f idx + when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ + error "Data.Array.Nested mgenerate: generated values do not have equal shapes" + mvecsWrite sh idx val vecs + mvecsFreeze sh vecs mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a mtranspose perm = @@ -506,12 +487,8 @@ mappend = mlift2 go => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append -mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> VS.Vector a -> Mixed sh (Primitive a) -mfromVector sh v - | not (checkBounds sh (knownShapeX @sh)) = - error $ "mfromVector: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) - | otherwise = - M_Primitive (X.fromVector sh v) +mfromVector :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVector sh v = M_Primitive (X.fromVector sh v) mfromList1 :: (KnownShapeX '[n], Elt a) => NonEmpty a -> Mixed '[n] a mfromList1 = mfromList . fmap mscalar @@ -522,17 +499,13 @@ mtoList1 = map munScalar . mtoList munScalar :: Elt a => Mixed '[] a -> a munScalar arr = mindex arr ZIX -mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IIxX sh -> a -> Mixed sh (Primitive a) -mconstantP sh x - | not (checkBounds sh (knownShapeX @sh)) = - error $ "mconstant: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) - | otherwise = - M_Primitive (X.constant sh x) +mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> a -> Mixed sh (Primitive a) +mconstantP sh x = M_Primitive (X.constant sh x) -- | This 'Coercible' constraint holds for the scalar types for which 'Mixed' -- is defined. mconstant :: forall sh a. (KnownShapeX sh, Storable a, Coercible (Mixed sh (Primitive a)) (Mixed sh a)) - => IIxX sh -> a -> Mixed sh a + => IShX sh -> a -> Mixed sh a mconstant sh x = coerce (mconstantP sh x) mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a @@ -648,7 +621,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where = coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ mlift2 f arr1 arr2 - memptyArray :: forall sh. IIxX sh -> Mixed sh (Ranked n a) + memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a) memptyArray i | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ @@ -657,9 +630,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where mshapeTree (Ranked arr) | Refl <- lemRankReplicate (Proxy @n) , Dict <- lemKnownReplicate (Proxy @n) - = first ixCvtXR (mshapeTree arr) - - mshapeTreeZero _ = (zeroIxR (inatSing @n), mshapeTreeZero (Proxy @a)) + = first shCvtXR (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -675,7 +646,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where | Dict <- lemKnownReplicate (Proxy @n) = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) - mvecsWrite :: forall sh s. IIxX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () mvecsWrite sh idx (Ranked arr) vecs | Dict <- lemKnownReplicate (Proxy @n) = mvecsWrite sh idx arr @@ -683,7 +654,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where vecs) mvecsWritePartial :: forall sh sh' s. KnownShapeX sh' - => IIxX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) + => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) -> MixedVecs s (sh ++ sh') (Ranked n a) -> ST s () mvecsWritePartial sh idx arr vecs @@ -696,7 +667,7 @@ instance (Elt a, KnownINat n) => Elt (Ranked n a) where @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) vecs) - mvecsFreeze :: forall sh s. IIxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) mvecsFreeze sh vecs | Dict <- lemKnownReplicate (Proxy @n) = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @@ -712,6 +683,8 @@ data ShS sh where ZSS :: ShS '[] (:$$) :: forall n sh. SNat n -> ShS sh -> ShS (n : sh) deriving instance Show (ShS sh) +deriving instance Eq (ShS sh) +deriving instance Ord (ShS sh) infixr 3 :$$ -- | A statically-known shape of a shape-typed array. @@ -726,9 +699,9 @@ sshapeKnown (GHC_SNat :$$ sh) | Dict <- sshapeKnown sh = Dict lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh) lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) where - go :: ShS sh' -> StaticShapeX (MapJust sh') - go ZSS = ZSX - go (n :$$ sh) = n :$@ go sh + go :: ShS sh' -> StaticShX (MapJust sh') + go ZSS = ZKSX + go (n :$$ sh) = n :!$@ go sh lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2 -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 @@ -777,7 +750,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where = coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ mlift2 f arr1 arr2 - memptyArray :: forall sh'. IIxX sh' -> Mixed sh' (Shaped sh a) + memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) memptyArray i | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ @@ -785,9 +758,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where mshapeTree (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) - = first (ixCvtXS (knownShape @sh)) (mshapeTree arr) - - mshapeTreeZero _ = (zeroIxS (knownShape @sh), mshapeTreeZero (Proxy @a)) + = first (shCvtXS (knownShape @sh)) (mshapeTree arr) mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 @@ -803,7 +774,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where | Dict <- lemKnownMapJust (Proxy @sh) = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) - mvecsWrite :: forall sh' s. IIxX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () mvecsWrite sh idx (Shaped arr) vecs | Dict <- lemKnownMapJust (Proxy @sh) = mvecsWrite sh idx arr @@ -811,7 +782,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where vecs) mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IIxX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) -> ST s () mvecsWritePartial sh idx arr vecs @@ -824,7 +795,7 @@ instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) vecs) - mvecsFreeze :: forall sh' s. IIxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) mvecsFreeze sh vecs | Dict <- lemKnownMapJust (Proxy @sh) = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @@ -927,6 +898,8 @@ newtype ShR n i = ShR (ListR n i) deriving (Show, Eq, Ord) deriving newtype (Functor, Foldable) +type IShR n = ShR n Int + pattern ZSR :: forall n i. () => n ~ Z => ShR n i pattern ZSR = ShR ZR @@ -957,20 +930,29 @@ ixCvtXR ZIX = ZIR ixCvtXR (n :.@ idx) = n :.: ixCvtXR idx ixCvtXR (n :.? idx) = n :.: ixCvtXR idx +shCvtXR :: IShX sh -> IShR (X.Rank sh) +shCvtXR ZSX = ZSR +shCvtXR (n :$@ idx) = X.fromSNat' n :$: shCvtXR idx +shCvtXR (n :$? idx) = n :$: shCvtXR idx + ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) ixCvtRX ZIR = ZIX ixCvtRX (n :.: idx) = n :.? ixCvtRX idx -shapeSizeR :: IIxR n -> Int -shapeSizeR ZIR = 1 -shapeSizeR (n :.: sh) = n * shapeSizeR sh +shCvtRX :: IShR n -> IShX (Replicate n Nothing) +shCvtRX ZSR = ZSX +shCvtRX (n :$: idx) = n :$? shCvtRX idx + +shapeSizeR :: IShR n -> Int +shapeSizeR ZSR = 1 +shapeSizeR (n :$: sh) = n * shapeSizeR sh -rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IIxR n +rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IShR n rshape (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) , Refl <- lemRankReplicate (Proxy @n) - = ixCvtXR (mshape arr) + = shCvtXR (mshape arr) rindex :: Elt a => Ranked n a -> IIxR n -> a rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) @@ -983,12 +965,12 @@ rindexPartial (Ranked arr) idx = -- | __WARNING__: All values returned from the function must have equal shape. -- See the documentation of 'mgenerate' for more details. -rgenerate :: forall n a. Elt a => IIxR n -> (IIxR n -> a) -> Ranked n a +rgenerate :: forall n a. Elt a => IShR n -> (IIxR n -> a) -> Ranked n a rgenerate sh f - | Dict <- knownIxR sh + | Dict <- knownShR sh , Dict <- lemKnownReplicate (Proxy @n) , Refl <- lemRankReplicate (Proxy @n) - = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) + = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) -- | See the documentation of 'mlift'. rlift :: forall n1 n2 a. (KnownINat n2, Elt a) @@ -1005,7 +987,7 @@ rsumOuter1 (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = Ranked . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a)) - . X.sumOuter (() :$? ZSX) (knownShapeX @(Replicate n Nothing)) + . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing)) . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a) $ arr @@ -1021,10 +1003,10 @@ rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend rscalar :: Elt a => a -> Ranked I0 a rscalar x = Ranked (mscalar x) -rfromVector :: forall n a. (KnownINat n, Storable a) => IIxR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVector :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a) rfromVector sh v | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mfromVector (ixCvtRX sh) v) + = Ranked (mfromVector (shCvtRX sh) v) rfromList :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a rfromList l @@ -1043,13 +1025,13 @@ rtoList1 = map runScalar . rtoList runScalar :: Elt a => Ranked I0 a -> a runScalar arr = rindex arr ZIR -rconstantP :: forall n a. (KnownINat n, Storable a) => IIxR n -> a -> Ranked n (Primitive a) +rconstantP :: forall n a. (KnownINat n, Storable a) => IShR n -> a -> Ranked n (Primitive a) rconstantP sh x | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mconstantP (ixCvtRX sh) x) + = Ranked (mconstantP (shCvtRX sh) x) rconstant :: forall n a. (KnownINat n, Storable a, Coercible (Mixed (Replicate n Nothing) (Primitive a)) (Mixed (Replicate n Nothing) a)) - => IIxR n -> a -> Ranked n a + => IShR n -> a -> Ranked n a rconstant sh x = coerce (rconstantP sh x) rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a @@ -1141,27 +1123,30 @@ zeroIxS :: ShS sh -> IIxS sh zeroIxS ZSS = ZIS zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh --- TODO: this function should not exist perhaps -cvtShSIxS :: ShS sh -> IIxS sh -cvtShSIxS ZSS = ZIS -cvtShSIxS (n :$$ sh) = fromIntegral (fromSNat n) :.$ cvtShSIxS sh - ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh ixCvtXS ZSS ZIX = ZIS ixCvtXS (_ :$$ sh) (n :.@ idx) = n :.$ ixCvtXS sh idx +shCvtXS :: ShS sh -> IShX (MapJust sh) -> ShS sh +shCvtXS ZSS ZSX = ZSS +shCvtXS (_ :$$ sh) (n :$@ idx) = n :$$ shCvtXS sh idx + ixCvtSX :: IIxS sh -> IIxX (MapJust sh) ixCvtSX ZIS = ZIX ixCvtSX (n :.$ sh) = n :.@ ixCvtSX sh -shapeSizeS :: IIxS sh -> Int -shapeSizeS ZIS = 1 -shapeSizeS (n :.$ sh) = n * shapeSizeS sh +shCvtSX :: ShS sh -> IShX (MapJust sh) +shCvtSX ZSS = ZSX +shCvtSX (n :$$ sh) = n :$@ shCvtSX sh + +shapeSizeS :: ShS sh -> Int +shapeSizeS ZSS = 1 +shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh -- | This does not touch the passed array, all information comes from 'KnownShape'. -sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IIxS sh -sshape _ = cvtShSIxS (knownShape @sh) +sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> ShS sh +sshape _ = knownShape @sh sindex :: Elt a => Shaped sh a -> IIxS sh -> a sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) @@ -1177,7 +1162,7 @@ sindexPartial (Shaped arr) idx = sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IIxS sh -> a) -> Shaped sh a sgenerate f | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mgenerate (ixCvtSX (cvtShSIxS (knownShape @sh))) (f . ixCvtXS (knownShape @sh))) + = Shaped (mgenerate (shCvtSX (knownShape @sh)) (f . ixCvtXS (knownShape @sh))) -- | See the documentation of 'mlift'. slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) @@ -1194,7 +1179,7 @@ ssumOuter1 (Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = Shaped . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a)) - . X.sumOuter (natSing @n :$@ ZSX) (knownShapeX @(MapJust sh)) + . X.sumOuter (natSing @n :!$@ ZKSX) (knownShapeX @(MapJust sh)) . coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a) $ arr @@ -1213,7 +1198,7 @@ sscalar x = Shaped (mscalar x) sfromVector :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a) sfromVector v | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mfromVector (ixCvtSX (cvtShSIxS (knownShape @sh))) v) + = Shaped (mfromVector (shCvtSX (knownShape @sh)) v) sfromList :: forall n sh a. (KnownNat n, KnownShape sh, Elt a) => NonEmpty (Shaped sh a) -> Shaped (n : sh) a @@ -1236,7 +1221,7 @@ sunScalar arr = sindex arr ZIS sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a) sconstantP x | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mconstantP (ixCvtSX (cvtShSIxS (knownShape @sh))) x) + = Shaped (mconstantP (shCvtSX (knownShape @sh)) x) sconstant :: forall sh a. (KnownShape sh, Storable a, Coercible (Mixed (MapJust sh) (Primitive a)) (Mixed (MapJust sh) a)) => a -> Shaped sh a diff --git a/src/Data/INat.hs b/src/Data/INat.hs index 2d65c53..af8f18b 100644 --- a/src/Data/INat.hs +++ b/src/Data/INat.hs @@ -4,6 +4,7 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} |