summaryrefslogtreecommitdiff
path: root/src/Fancy.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Fancy.hs')
-rw-r--r--src/Fancy.hs598
1 files changed, 0 insertions, 598 deletions
diff --git a/src/Fancy.hs b/src/Fancy.hs
deleted file mode 100644
index 7461c1f..0000000
--- a/src/Fancy.hs
+++ /dev/null
@@ -1,598 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DerivingVia #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE InstanceSigs #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-
-{-|
-TODO:
-* This module needs better structure with an Internal module and less public
- exports etc.
-
-* We should be more consistent in whether functions take a 'StaticShapeX'
- argument or a 'KnownShapeX' constraint.
-
--}
-
-module Fancy where
-
-import Control.Monad (forM_)
-import Control.Monad.ST
-import Data.Coerce (coerce, Coercible)
-import Data.Kind
-import Data.Proxy
-import Data.Type.Equality
-import qualified Data.Vector.Unboxed as VU
-import qualified Data.Vector.Unboxed.Mutable as VUM
-
-import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))
-import qualified Array as X
-import Nats
-
-
-type family Replicate n a where
- Replicate Z a = '[]
- Replicate (S n) a = a : Replicate n a
-
-type family MapJust l where
- MapJust '[] = '[]
- MapJust (x : xs) = Just x : MapJust xs
-
-lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing)
-lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))
- where
- go :: SNat m -> StaticShapeX (Replicate m Nothing)
- go SZ = SZX
- go (SS n) = () :$? go n
-
-lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n
-lemRankReplicate _ = go (knownNat @n)
- where
- go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m
- go SZ = Refl
- go (SS n) | Refl <- go n = Refl
-
-lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a
- -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a
-lemReplicatePlusApp _ _ _ = go (knownNat @n)
- where
- go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a
- go SZ = Refl
- go (SS n) | Refl <- go n = Refl
-
-
--- | Wrapper type used as a tag to attach instances on. The instances on arrays
--- of @'Primitive' a@ are more polymorphic than the direct instances for arrays
--- of scalars; this means that if @orthotope@ supports an element type @T@ that
--- this library does not (directly), it may just work if you use an array of
--- @'Primitive' T@ instead.
-newtype Primitive a = Primitive a
-
-
--- | Mixed arrays: some dimensions are size-typed, some are not. Distributes
--- over product-typed elements using a dat afamily so that the full array is
--- always in struct-of-arrays format.
---
--- Built on top of 'XArray' which is built on top of @orthotope@, meaning that
--- dimension permutations (e.g. 'transpose') are typically free.
-type Mixed :: [Maybe Nat] -> Type -> Type
-data family Mixed sh a
-
-newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a)
-
-newtype instance Mixed sh Int = M_Int (XArray sh Int)
-newtype instance Mixed sh Double = M_Double (XArray sh Double)
-newtype instance Mixed sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector)
--- etc.
-
-data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b)
--- etc.
-
-newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a)
-
-
--- | Internal helper data family mirrorring 'Mixed' that consists of mutable
--- vectors instead of 'XArray's.
-type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
-data family MixedVecs s sh a
-
-newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a)
-
-newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int)
-newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double)
-newtype instance MixedVecs s sh () = MV_Nil (VU.MVector s ()) -- no content, MVector optimises this
--- etc.
-
-data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b)
--- etc.
-
-data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a)
-
-
--- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or
--- 'Shaped' array. Note the polymorphic instance for 'GMixed' of @'Primitive'
--- a@; see the documentation for 'Primitive' for more details.
-class GMixed a where
- -- ====== PUBLIC METHODS ====== --
-
- mshape :: KnownShapeX sh => Mixed sh a -> IxX sh
- mindex :: Mixed sh a -> IxX sh -> a
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a
-
- mlift :: forall sh1 sh2. KnownShapeX sh2
- => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 a -> Mixed sh2 a
-
- -- ====== PRIVATE METHODS ====== --
- -- Remember I said that this module needed better management of exports?
-
- -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
- memptyArray :: IxX sh -> Mixed sh a
-
- -- | Return the size of the individual (SoA) arrays in this value. If @a@
- -- does not contain tuples, this coincides with the total number of scalars
- -- in the given value; if @a@ contains tuples, then it is some multiple of
- -- this number of scalars.
- mvecsNumElts :: a -> Int
-
- -- | Create uninitialised vectors for this array type, given the shape of
- -- this vector and an example for the contents. The shape must not have size
- -- zero; an error may be thrown otherwise.
- mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a)
-
- -- | Given the shape of this array, an index and a value, write the value at
- -- that index in the vectors.
- mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s ()
-
- -- | Given the shape of this array, an index and a value, write the value at
- -- that index in the vectors.
- mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
-
- -- | Given the shape of this array, finalise the vectors into 'XArray's.
- mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
-
-
--- Arrays of scalars are basically just arrays of scalars.
-instance VU.Unbox a => GMixed (Primitive a) where
- mshape (M_Primitive a) = X.shape a
- mindex (M_Primitive a) i = Primitive (X.index a i)
- mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i)
-
- mlift :: forall sh1 sh2.
- (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
- -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
- mlift f (M_Primitive a)
- | Refl <- X.lemAppNil @sh1
- , Refl <- X.lemAppNil @sh2
- = M_Primitive (f Proxy a)
-
- memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty"))
- mvecsNumElts _ = 1
- mvecsUnsafeNew sh _ = MV_Primitive <$> VUM.unsafeNew (X.shapeSize sh)
- mvecsWrite sh i (Primitive x) (MV_Primitive v) = VUM.write v (X.toLinearIdx sh i) x
-
- -- TODO: this use of toVector is suboptimal
- mvecsWritePartial
- :: forall sh' sh s. (KnownShapeX sh', VU.Unbox a)
- => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
- mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do
- let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr)))
- VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)
-
- mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v
-
--- What a blessing that orthotope's Array has "representational" role on the value type!
-deriving via Primitive Int instance GMixed Int
-deriving via Primitive Double instance GMixed Double
-deriving via Primitive () instance GMixed ()
-
--- Arrays of pairs are pairs of arrays.
-instance (GMixed a, GMixed b) => GMixed (a, b) where
- mshape (M_Tup2 a _) = mshape a
- mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
- mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
- mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b)
-
- memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
- mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y
- mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
- mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
- mvecsWrite sh i x a
- mvecsWrite sh i y b
- mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
- mvecsWritePartial sh i x a
- mvecsWritePartial sh i y b
- mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
-
--- Arrays of arrays are just arrays, but with more dimensions.
-instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where
- mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh
- mshape (M_Nest arr)
- | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
- = ixAppPrefix (knownShapeX @sh) (mshape arr)
- where
- ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1
- ixAppPrefix SZX _ = IZX
- ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx
- ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx
-
- mindex (M_Nest arr) i = mindexPartial arr i
-
- mindexPartial :: forall sh1 sh2.
- Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
- mindexPartial (M_Nest arr) i
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
-
- mlift :: forall sh1 sh2. KnownShapeX sh2
- => (forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b)
- -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
- mlift f (M_Nest arr)
- | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
- = M_Nest (mlift f' arr)
- where
- f' :: forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b
- f' _
- | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3)
- , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3)
- , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @sh3))
- = f (Proxy @(sh' ++ sh3))
-
- memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh'))))
-
- mvecsNumElts arr =
- let n = X.shapeSize (mshape arr)
- in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh')))
-
- mvecsUnsafeNew sh example
- | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example"
- | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example))
- (mindex example (X.zeroIdx (knownShapeX @sh')))
- where
- sh' = mshape example
-
- mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs
-
- mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
- => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
- -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
- -> ST s ()
- mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)
- | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
- , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
- = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs
-
- mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs
-
-
--- Public method. Turns out this doesn't have to be in the type class!
--- | Create an array given a size and a function that computes the element at a
--- given index.
-mgenerate :: forall sh a. (KnownShapeX sh, GMixed a) => IxX sh -> (IxX sh -> a) -> Mixed sh a
-mgenerate sh f
- -- TODO: Do we need this checkBounds check elsewhere as well?
- | not (checkBounds sh (knownShapeX @sh)) =
- error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh)
- -- We need to be very careful here to ensure that neither 'sh' nor
- -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty.
- | X.shapeSize sh == 0 = memptyArray sh
- | otherwise =
- let firstidx = X.zeroIdx' sh
- firstelem = f (X.zeroIdx' sh)
- in if mvecsNumElts firstelem == 0
- then memptyArray sh
- else runST $ do
- vecs <- mvecsUnsafeNew sh firstelem
- mvecsWrite sh firstidx firstelem vecs
- -- TODO: This is likely fine if @a@ is big, but if @a@ is a
- -- scalar this feels inefficient. Should improve this.
- forM_ (tail (X.enumShape sh)) $ \idx ->
- mvecsWrite sh idx (f idx) vecs
- mvecsFreeze sh vecs
- where
- checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool
- checkBounds IZX SZX = True
- checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh'
- checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh'
-
-
--- | Newtype around a 'Mixed' of 'Nothing's. This works like a rank-typed array
--- as in @orthotope@.
-type Ranked :: Nat -> Type -> Type
-newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
-
--- | Newtype around a 'Mixed' of 'Just's. This works like a shape-typed array
--- as in @orthotope@.
-type Shaped :: [Nat] -> Type -> Type
-newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
-
--- just unwrap the newtype and defer to the general instance for nested arrays
-newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a))
-newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a))
-
-newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a))
-newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a))
-
-
--- 'Ranked' and 'Shaped' can already be used at the top level of an array nest;
--- these instances allow them to also be used as elements of arrays, thus
--- making them first-class in the API.
-instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where
- mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr
- mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i)
-
- mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a)
- mindexPartial (M_Ranked arr) i
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $
- mindexPartial arr i
-
- mlift :: forall sh1 sh2. KnownShapeX sh2
- => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
- mlift f (M_Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $
- mlift f arr
-
- memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a)
- memptyArray i
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $
- memptyArray i
-
- mvecsNumElts (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- = mvecsNumElts arr
-
- mvecsUnsafeNew idx (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- = MV_Ranked <$> mvecsUnsafeNew idx arr
-
- mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s ()
- mvecsWrite sh idx (Ranked arr) vecs
- | Dict <- lemKnownReplicate (Proxy @n)
- = mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
-
- mvecsWritePartial :: forall sh sh' s. KnownShapeX sh'
- => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a)
- -> MixedVecs s (sh ++ sh') (Ranked n a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs
- | Dict <- lemKnownReplicate (Proxy @n)
- = mvecsWritePartial sh idx
- (coerce @(Mixed sh' (Ranked n a))
- @(Mixed sh' (Mixed (Replicate n Nothing) a))
- arr)
- (coerce @(MixedVecs s (sh ++ sh') (Ranked n a))
- @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a))
- vecs)
-
- mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a))
- mvecsFreeze sh vecs
- | Dict <- lemKnownReplicate (Proxy @n)
- = coerce @(Mixed sh (Mixed (Replicate n Nothing) a))
- @(Mixed sh (Ranked n a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh (Ranked n a))
- @(MixedVecs s sh (Mixed (Replicate n Nothing) a))
- vecs)
-
-
-data SShape sh where
- ShNil :: SShape '[]
- ShCons :: SNat n -> SShape sh -> SShape (n : sh)
-deriving instance Show (SShape sh)
-
-class KnownShape sh where knownShape :: SShape sh
-instance KnownShape '[] where knownShape = ShNil
-instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape
-
-lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh)
-lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh))
- where
- go :: SShape sh' -> StaticShapeX (MapJust sh')
- go ShNil = SZX
- go (ShCons n sh) = n :$@ go sh
-
-lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2
- -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2
-lemMapJustPlusApp _ _ = go (knownShape @sh1)
- where
- go :: SShape sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2
- go ShNil = Refl
- go (ShCons _ sh) | Refl <- go sh = Refl
-
-instance (KnownShape sh, GMixed a) => GMixed (Shaped sh a) where
- mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr
- mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i)
-
- mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IxX sh1 -> Mixed sh2 (Shaped sh a)
- mindexPartial (M_Shaped arr) i
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mindexPartial arr i
-
- mlift :: forall sh1 sh2. KnownShapeX sh2
- => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
- -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a)
- mlift f (M_Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $
- mlift f arr
-
- memptyArray :: forall sh'. IxX sh' -> Mixed sh' (Shaped sh a)
- memptyArray i
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $
- memptyArray i
-
- mvecsNumElts (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = mvecsNumElts arr
-
- mvecsUnsafeNew idx (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = MV_Shaped <$> mvecsUnsafeNew idx arr
-
- mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s ()
- mvecsWrite sh idx (Shaped arr) vecs
- | Dict <- lemKnownMapJust (Proxy @sh)
- = mvecsWrite sh idx arr
- (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
-
- mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
- => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Shaped sh a)
- -> MixedVecs s (sh1 ++ sh2) (Shaped sh a)
- -> ST s ()
- mvecsWritePartial sh idx arr vecs
- | Dict <- lemKnownMapJust (Proxy @sh)
- = mvecsWritePartial sh idx
- (coerce @(Mixed sh2 (Shaped sh a))
- @(Mixed sh2 (Mixed (MapJust sh) a))
- arr)
- (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a))
- @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a))
- vecs)
-
- mvecsFreeze :: forall sh' s. IxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a))
- mvecsFreeze sh vecs
- | Dict <- lemKnownMapJust (Proxy @sh)
- = coerce @(Mixed sh' (Mixed (MapJust sh) a))
- @(Mixed sh' (Shaped sh a))
- <$> mvecsFreeze sh
- (coerce @(MixedVecs s sh' (Shaped sh a))
- @(MixedVecs s sh' (Mixed (MapJust sh) a))
- vecs)
-
-
--- Utility function to satisfy the type checker sometimes
-rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a
-rewriteMixed Refl x = x
-
-
--- ====== API OF RANKED ARRAYS ====== --
-
--- | An index into a rank-typed array.
-type IxR :: Nat -> Type
-data IxR n where
- IZR :: IxR Z
- (:::) :: Int -> IxR n -> IxR (S n)
-
-ixCvtXR :: IxX sh -> IxR (X.Rank sh)
-ixCvtXR IZX = IZR
-ixCvtXR (n ::@ idx) = n ::: ixCvtXR idx
-ixCvtXR (n ::? idx) = n ::: ixCvtXR idx
-
-ixCvtRX :: IxR n -> IxX (Replicate n Nothing)
-ixCvtRX IZR = IZX
-ixCvtRX (n ::: idx) = n ::? ixCvtRX idx
-
-
-rshape :: forall n a. (KnownNat n, GMixed a) => Ranked n a -> IxR n
-rshape (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- , Refl <- lemRankReplicate (Proxy @n)
- = ixCvtXR (mshape arr)
-
-rindex :: GMixed a => Ranked n a -> IxR n -> a
-rindex (Ranked arr) idx = mindex arr (ixCvtRX idx)
-
-rindexPartial :: forall n m a. (KnownNat n, GMixed a) => Ranked (n + m) a -> IxR n -> Ranked m a
-rindexPartial (Ranked arr) idx =
- Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing)
- (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr)
- (ixCvtRX idx))
-
-rgenerate :: forall n a. (KnownNat n, GMixed a) => IxR n -> (IxR n -> a) -> Ranked n a
-rgenerate sh f
- | Dict <- lemKnownReplicate (Proxy @n)
- , Refl <- lemRankReplicate (Proxy @n)
- = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR))
-
-rlift :: forall n1 n2 a. (KnownNat n2, GMixed a)
- => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
- -> Ranked n1 a -> Ranked n2 a
-rlift f (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n2)
- = Ranked (mlift f arr)
-
-rsumOuter1 :: forall n a.
- (VU.Unbox a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a))
- => Ranked (S n) a -> Ranked n a
-rsumOuter1 (Ranked arr)
- | Dict <- lemKnownReplicate (Proxy @n)
- = Ranked
- . coerce @(XArray (Replicate n Nothing) a) @(Mixed (Replicate n Nothing) a)
- . X.sumOuter (() :$? SZX) (knownShapeX @(Replicate n Nothing))
- . coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a)
- $ arr
-
-
--- ====== API OF SHAPED ARRAYS ====== --
-
--- | An index into a shape-typed array.
-type IxS :: [Nat] -> Type
-data IxS sh where
- IZS :: IxS '[]
- (::$) :: Int -> IxS sh -> IxS (n : sh)
-
-cvtSShapeIxS :: SShape sh -> IxS sh
-cvtSShapeIxS ShNil = IZS
-cvtSShapeIxS (ShCons n sh) = fromIntegral (unSNat n) ::$ cvtSShapeIxS sh
-
-ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh
-ixCvtXS ShNil IZX = IZS
-ixCvtXS (ShCons _ sh) (n ::@ idx) = n ::$ ixCvtXS sh idx
-
-ixCvtSX :: IxS sh -> IxX (MapJust sh)
-ixCvtSX IZS = IZX
-ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh
-
-
-sshape :: forall sh a. (KnownShape sh, GMixed a) => Shaped sh a -> IxS sh
-sshape _ = cvtSShapeIxS (knownShape @sh)
-
-sindex :: GMixed a => Shaped sh a -> IxS sh -> a
-sindex (Shaped arr) idx = mindex arr (ixCvtSX idx)
-
-sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, GMixed a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a
-sindexPartial (Shaped arr) idx =
- Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2)
- (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr)
- (ixCvtSX idx))
-
-sgenerate :: forall sh a. (KnownShape sh, GMixed a) => IxS sh -> (IxS sh -> a) -> Shaped sh a
-sgenerate sh f
- | Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh)))
-
-slift :: forall sh1 sh2 a. (KnownShape sh2, GMixed a)
- => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b)
- -> Shaped sh1 a -> Shaped sh2 a
-slift f (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh2)
- = Shaped (mlift f arr)
-
-ssumOuter1 :: forall sh n a.
- (VU.Unbox a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a))
- => Shaped (n : sh) a -> Shaped sh a
-ssumOuter1 (Shaped arr)
- | Dict <- lemKnownMapJust (Proxy @sh)
- = Shaped
- . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a)
- . X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh))
- . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a)
- $ arr