diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Array.hs | 103 | ||||
-rw-r--r-- | src/Fancy.hs | 288 |
2 files changed, 342 insertions, 49 deletions
diff --git a/src/Array.hs b/src/Array.hs index 693df05..cbf04fc 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -8,6 +8,8 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} module Array where import qualified Data.Array.RankedU as U @@ -15,6 +17,7 @@ import Data.Kind import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Unboxed as VU +import qualified GHC.TypeLits as GHC import Unsafe.Coerce (unsafeCoerce) import Nats @@ -140,6 +143,24 @@ shapeLshape IZX = [] shapeLshape (n ::@ sh) = n : shapeLshape sh shapeLshape (n ::? sh) = n : shapeLshape sh +ssxLength :: StaticShapeX sh -> Int +ssxLength SZX = 0 +ssxLength (_ :$@ ssh) = 1 + ssxLength ssh +ssxLength (_ :$? ssh) = 1 + ssxLength ssh + +ssxIotaFrom :: Int -> StaticShapeX sh -> [Int] +ssxIotaFrom _ SZX = [] +ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh + +lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2 + -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank sh1) GHC.+ GNat (Rank sh2) +lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this + +lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2 + -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank (sh2 ++ sh1)) +lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this + lemKnownNatRank :: IxX sh -> Dict KnownNat (Rank sh) lemKnownNatRank IZX = Dict lemKnownNatRank (_ ::@ sh) | Dict <- lemKnownNatRank sh = Dict @@ -183,6 +204,12 @@ fromVector sh v toVector :: U.Unbox a => XArray sh a -> VU.Vector a toVector (XArray arr) = U.toVector arr +scalar :: U.Unbox a => a -> XArray '[] a +scalar = XArray . U.scalar + +unScalar :: U.Unbox a => XArray '[] a -> a +unScalar (XArray a) = U.unScalar a + generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh) @@ -207,3 +234,79 @@ append (XArray a) (XArray b) | Dict <- lemKnownNatRankSSX (knownShapeX @sh) , Dict <- gknownNat (Proxy @(Rank sh)) = XArray (U.append a b) + +rerank :: forall sh sh1 sh2 a b. + (U.Unbox a, U.Unbox b) + => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b +rerank ssh ssh1 ssh2 f (XArray arr) + | Dict <- lemKnownNatRankSSX ssh + , Dict <- gknownNat (Proxy @(Rank sh)) + , Dict <- lemKnownNatRankSSX ssh2 + , Dict <- gknownNat (Proxy @(Rank sh2)) + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the + , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough + = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) + (\a -> unXArray (f (XArray a))) + arr) + where + unXArray (XArray a) = a + +rerank2 :: forall sh sh1 sh2 a b c. + (U.Unbox a, U.Unbox b, U.Unbox c) + => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 + -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c +rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) + | Dict <- lemKnownNatRankSSX ssh + , Dict <- gknownNat (Proxy @(Rank sh)) + , Dict <- lemKnownNatRankSSX ssh2 + , Dict <- gknownNat (Proxy @(Rank sh2)) + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the + , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough + = XArray (U.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) + (\a b -> unXArray (f (XArray a) (XArray b))) + arr1 arr2) + where + unXArray (XArray a) = a + +-- | The list argument gives indices into the original dimension list. +transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a +transpose perm (XArray arr) + | Dict <- lemKnownNatRankSSX (knownShapeX @sh) + , Dict <- gknownNat (Proxy @(Rank sh)) + = XArray (U.transpose perm arr) + +transpose2 :: forall sh1 sh2 a. + StaticShapeX sh1 -> StaticShapeX sh2 + -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a +transpose2 ssh1 ssh2 (XArray arr) + | Refl <- lemRankApp ssh1 ssh2 + , Refl <- lemRankApp ssh2 ssh1 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) + , Dict <- gknownNat (Proxy @(Rank (sh1 ++ sh2))) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) + , Dict <- gknownNat (Proxy @(Rank (sh2 ++ sh1))) + , Refl <- lemRankAppComm ssh1 ssh2 + , let n1 = ssxLength ssh1 + = XArray (U.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) + +sumFull :: (U.Unbox a, Num a) => XArray sh a -> a +sumFull (XArray arr) = U.sumA arr + +sumInner :: forall sh sh' a. (U.Unbox a, Num a) + => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a +sumInner ssh ssh' + | Refl <- lemAppNil @sh + = rerank ssh ssh' SZX (scalar . sumFull) + +sumOuter :: forall sh sh' a. (U.Unbox a, Num a) + => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a +sumOuter ssh ssh' + | Refl <- lemAppNil @sh + = sumInner ssh' ssh . transpose2 ssh ssh' diff --git a/src/Fancy.hs b/src/Fancy.hs index e8192aa..7461c1f 100644 --- a/src/Fancy.hs +++ b/src/Fancy.hs @@ -1,8 +1,10 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -10,15 +12,25 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} + +{-| +TODO: +* This module needs better structure with an Internal module and less public + exports etc. + +* We should be more consistent in whether functions take a 'StaticShapeX' + argument or a 'KnownShapeX' constraint. + +-} + module Fancy where import Control.Monad (forM_) import Control.Monad.ST -import Data.Coerce (coerce) +import Data.Coerce (coerce, Coercible) import Data.Kind import Data.Proxy import Data.Type.Equality -import Data.Type.Ord import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as VUM @@ -35,9 +47,6 @@ type family MapJust l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs -lemCompareFalse1 :: (0 < n, 1 > n) => Proxy n -> a -lemCompareFalse1 = error "Incoherence" - lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n)) where @@ -45,10 +54,36 @@ lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n)) go SZ = SZX go (SS n) = () :$? go n +lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = go (knownNat @n) + where + go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m + go SZ = Refl + go (SS n) | Refl <- go n = Refl --- Wrapper type used as a tag to attach instances on. +lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a + -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +lemReplicatePlusApp _ _ _ = go (knownNat @n) + where + go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a + go SZ = Refl + go (SS n) | Refl <- go n = Refl + + +-- | Wrapper type used as a tag to attach instances on. The instances on arrays +-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays +-- of scalars; this means that if @orthotope@ supports an element type @T@ that +-- this library does not (directly), it may just work if you use an array of +-- @'Primitive' T@ instead. newtype Primitive a = Primitive a + +-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes +-- over product-typed elements using a dat afamily so that the full array is +-- always in struct-of-arrays format. +-- +-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that +-- dimension permutations (e.g. 'transpose') are typically free. type Mixed :: [Maybe Nat] -> Type -> Type data family Mixed sh a @@ -60,12 +95,13 @@ newtype instance Mixed sh () = M_Nil (XArray sh ()) -- no content, orthotope op -- etc. data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b) -data instance Mixed sh (a, b, c) = M_Tup3 (Mixed sh a) (Mixed sh b) (Mixed sh c) -data instance Mixed sh (a, b, c, d) = M_Tup4 (Mixed sh a) (Mixed sh b) (Mixed sh c) (Mixed sh d) +-- etc. newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a) +-- | Internal helper data family mirrorring 'Mixed' that consists of mutable +-- vectors instead of 'XArray's. type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type data family MixedVecs s sh a @@ -77,13 +113,17 @@ newtype instance MixedVecs s sh () = MV_Nil (VU.MVector s ()) -- no content, MV -- etc. data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b) -data instance MixedVecs s sh (a, b, c) = MV_Tup3 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c) -data instance MixedVecs s sh (a, b, c, d) = MV_Tup4 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c) (MixedVecs s sh d) +-- etc. data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a) +-- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or +-- 'Shaped' array. Note the polymorphic instance for 'GMixed' of @'Primitive' +-- a@; see the documentation for 'Primitive' for more details. class GMixed a where + -- ====== PUBLIC METHODS ====== -- + mshape :: KnownShapeX sh => Mixed sh a -> IxX sh mindex :: Mixed sh a -> IxX sh -> a mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a @@ -92,6 +132,9 @@ class GMixed a where => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) -> Mixed sh1 a -> Mixed sh2 a + -- ====== PRIVATE METHODS ====== -- + -- Remember I said that this module needed better management of exports? + -- | Create an empty array. The given shape must have size zero; this may or may not be checked. memptyArray :: IxX sh -> Mixed sh a @@ -118,6 +161,7 @@ class GMixed a where mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) +-- Arrays of scalars are basically just arrays of scalars. instance VU.Unbox a => GMixed (Primitive a) where mshape (M_Primitive a) = X.shape a mindex (M_Primitive a) i = Primitive (X.index a i) @@ -146,10 +190,12 @@ instance VU.Unbox a => GMixed (Primitive a) where mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v +-- What a blessing that orthotope's Array has "representational" role on the value type! deriving via Primitive Int instance GMixed Int deriving via Primitive Double instance GMixed Double deriving via Primitive () instance GMixed () +-- Arrays of pairs are pairs of arrays. instance (GMixed a, GMixed b) => GMixed (a, b) where mshape (M_Tup2 a _) = mshape a mindex (M_Tup2 a b) i = (mindex a i, mindex b i) @@ -167,6 +213,7 @@ instance (GMixed a, GMixed b) => GMixed (a, b) where mvecsWritePartial sh i y b mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b +-- Arrays of arrays are just arrays, but with more dimensions. instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh mshape (M_Nest arr) @@ -226,8 +273,13 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs + +-- Public method. Turns out this doesn't have to be in the type class! +-- | Create an array given a size and a function that computes the element at a +-- given index. mgenerate :: forall sh a. (KnownShapeX sh, GMixed a) => IxX sh -> (IxX sh -> a) -> Mixed sh a mgenerate sh f + -- TODO: Do we need this checkBounds check elsewhere as well? | not (checkBounds sh (knownShapeX @sh)) = error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) -- We need to be very careful here to ensure that neither 'sh' nor @@ -241,6 +293,8 @@ mgenerate sh f else runST $ do vecs <- mvecsUnsafeNew sh firstelem mvecsWrite sh firstidx firstelem vecs + -- TODO: This is likely fine if @a@ is big, but if @a@ is a + -- scalar this feels inefficient. Should improve this. forM_ (tail (X.enumShape sh)) $ \idx -> mvecsWrite sh idx (f idx) vecs mvecsFreeze sh vecs @@ -251,19 +305,27 @@ mgenerate sh f checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' +-- | Newtype around a 'Mixed' of 'Nothing's. This works like a rank-typed array +-- as in @orthotope@. type Ranked :: Nat -> Type -> Type newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) +-- | Newtype around a 'Mixed' of 'Just's. This works like a shape-typed array +-- as in @orthotope@. type Shaped :: [Nat] -> Type -> Type newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) -newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) -newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) +newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a)) -newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) -newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) +newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) +newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a)) +-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; +-- these instances allow them to also be used as elements of arrays, thus +-- making them first-class in the API. instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) @@ -271,7 +333,7 @@ instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a) mindexPartial (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh' (Mixed (Replicate n 'Nothing) a)) @(Mixed sh' (Ranked n a)) $ + = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ mindexPartial arr i mlift :: forall sh1 sh2. KnownShapeX sh2 @@ -279,13 +341,13 @@ instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) mlift f (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh2 (Mixed (Replicate n 'Nothing) a)) @(Mixed sh2 (Ranked n a)) $ + = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ mlift f arr memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a) memptyArray i | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh (Mixed (Replicate n 'Nothing) a)) @(Mixed sh (Ranked n a)) $ + = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ memptyArray i mvecsNumElts (Ranked arr) @@ -337,42 +399,106 @@ class KnownShape sh where knownShape :: SShape sh instance KnownShape '[] where knownShape = ShNil instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape --- instance (KnownShape sh, GMixed a) => GMixed (Shaped sh a) where +lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh) +lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) + where + go :: SShape sh' -> StaticShapeX (MapJust sh') + go ShNil = SZX + go (ShCons n sh) = n :$@ go sh + +lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2 + -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 +lemMapJustPlusApp _ _ = go (knownShape @sh1) + where + go :: SShape sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2 + go ShNil = Refl + go (ShCons _ sh) | Refl <- go sh = Refl + +instance (KnownShape sh, GMixed a) => GMixed (Shaped sh a) where + mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr + mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i) + + mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IxX sh1 -> Mixed sh2 (Shaped sh a) + mindexPartial (M_Shaped arr) i + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mindexPartial arr i + + mlift :: forall sh1 sh2. KnownShapeX sh2 + => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) + mlift f (M_Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mlift f arr + + memptyArray :: forall sh'. IxX sh' -> Mixed sh' (Shaped sh a) + memptyArray i + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ + memptyArray i + + mvecsNumElts (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = mvecsNumElts arr + + mvecsUnsafeNew idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsUnsafeNew idx arr + + mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWrite sh idx (Shaped arr) vecs + | Dict <- lemKnownMapJust (Proxy @sh) + = mvecsWrite sh idx arr + (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + + mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 + => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartial sh idx arr vecs + | Dict <- lemKnownMapJust (Proxy @sh) + = mvecsWritePartial sh idx + (coerce @(Mixed sh2 (Shaped sh a)) + @(Mixed sh2 (Mixed (MapJust sh) a)) + arr) + (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) + @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) + vecs) + + mvecsFreeze :: forall sh' s. IxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsFreeze sh vecs + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + + +-- Utility function to satisfy the type checker sometimes +rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a +rewriteMixed Refl x = x +-- ====== API OF RANKED ARRAYS ====== -- + +-- | An index into a rank-typed array. type IxR :: Nat -> Type data IxR n where IZR :: IxR Z (:::) :: Int -> IxR n -> IxR (S n) -type IxS :: [Nat] -> Type -data IxS sh where - IZS :: IxS '[] - (::$) :: Int -> IxS sh -> IxS (n : sh) - ixCvtXR :: IxX sh -> IxR (X.Rank sh) ixCvtXR IZX = IZR -ixCvtXR (n ::@ sh) = n ::: ixCvtXR sh -ixCvtXR (n ::? sh) = n ::: ixCvtXR sh +ixCvtXR (n ::@ idx) = n ::: ixCvtXR idx +ixCvtXR (n ::? idx) = n ::: ixCvtXR idx ixCvtRX :: IxR n -> IxX (Replicate n Nothing) ixCvtRX IZR = IZX -ixCvtRX (n ::: sh) = n ::? ixCvtRX sh - -lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate _ = go (knownNat @n) - where - go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m - go SZ = Refl - go (SS n) | Refl <- go n = Refl - -lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a - -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a -lemReplicatePlusApp _ _ _ = go (knownNat @n) - where - go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a - go SZ = Refl - go (SS n) | Refl <- go n = Refl +ixCvtRX (n ::: idx) = n ::? ixCvtRX idx rshape :: forall n a. (KnownNat n, GMixed a) => Ranked n a -> IxR n @@ -384,15 +510,11 @@ rshape (Ranked arr) rindex :: GMixed a => Ranked n a -> IxR n -> a rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) -rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a -rewriteMixed Refl x = x - rindexPartial :: forall n m a. (KnownNat n, GMixed a) => Ranked (n + m) a -> IxR n -> Ranked m a -rindexPartial (Ranked arr) idx - | Refl <- lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing) - = Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) - (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) - (ixCvtRX idx)) +rindexPartial (Ranked arr) idx = + Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) + (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) + (ixCvtRX idx)) rgenerate :: forall n a. (KnownNat n, GMixed a) => IxR n -> (IxR n -> a) -> Ranked n a rgenerate sh f @@ -406,3 +528,71 @@ rlift :: forall n1 n2 a. (KnownNat n2, GMixed a) rlift f (Ranked arr) | Dict <- lemKnownReplicate (Proxy @n2) = Ranked (mlift f arr) + +rsumOuter1 :: forall n a. + (VU.Unbox a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a)) + => Ranked (S n) a -> Ranked n a +rsumOuter1 (Ranked arr) + | Dict <- lemKnownReplicate (Proxy @n) + = Ranked + . coerce @(XArray (Replicate n Nothing) a) @(Mixed (Replicate n Nothing) a) + . X.sumOuter (() :$? SZX) (knownShapeX @(Replicate n Nothing)) + . coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a) + $ arr + + +-- ====== API OF SHAPED ARRAYS ====== -- + +-- | An index into a shape-typed array. +type IxS :: [Nat] -> Type +data IxS sh where + IZS :: IxS '[] + (::$) :: Int -> IxS sh -> IxS (n : sh) + +cvtSShapeIxS :: SShape sh -> IxS sh +cvtSShapeIxS ShNil = IZS +cvtSShapeIxS (ShCons n sh) = fromIntegral (unSNat n) ::$ cvtSShapeIxS sh + +ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh +ixCvtXS ShNil IZX = IZS +ixCvtXS (ShCons _ sh) (n ::@ idx) = n ::$ ixCvtXS sh idx + +ixCvtSX :: IxS sh -> IxX (MapJust sh) +ixCvtSX IZS = IZX +ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh + + +sshape :: forall sh a. (KnownShape sh, GMixed a) => Shaped sh a -> IxS sh +sshape _ = cvtSShapeIxS (knownShape @sh) + +sindex :: GMixed a => Shaped sh a -> IxS sh -> a +sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) + +sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, GMixed a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a +sindexPartial (Shaped arr) idx = + Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) + (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) + (ixCvtSX idx)) + +sgenerate :: forall sh a. (KnownShape sh, GMixed a) => IxS sh -> (IxS sh -> a) -> Shaped sh a +sgenerate sh f + | Dict <- lemKnownMapJust (Proxy @sh) + = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh))) + +slift :: forall sh1 sh2 a. (KnownShape sh2, GMixed a) + => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) + -> Shaped sh1 a -> Shaped sh2 a +slift f (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh2) + = Shaped (mlift f arr) + +ssumOuter1 :: forall sh n a. + (VU.Unbox a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) + => Shaped (n : sh) a -> Shaped sh a +ssumOuter1 (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = Shaped + . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a) + . X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh)) + . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a) + $ arr |