summaryrefslogtreecommitdiff
path: root/src/Fancy.hs
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-03-28 17:30:54 +0100
committerTom Smeding <t.j.smeding@uu.nl>2024-03-28 17:30:54 +0100
commit373799f36417c086510847e4d6689bd68582e658 (patch)
tree72ac398279c2e8bc2162ffe77f6f054b4e0d2f77 /src/Fancy.hs
parentae113c0249f3fe8be7df345081b1b51451cd3fdf (diff)
Various improvements
Diffstat (limited to 'src/Fancy.hs')
-rw-r--r--src/Fancy.hs113
1 files changed, 72 insertions, 41 deletions
diff --git a/src/Fancy.hs b/src/Fancy.hs
index 6b6d8d4..e8192aa 100644
--- a/src/Fancy.hs
+++ b/src/Fancy.hs
@@ -1,7 +1,9 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
@@ -44,14 +46,19 @@ lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))
go (SS n) = () :$? go n
+-- Wrapper type used as a tag to attach instances on.
+newtype Primitive a = Primitive a
+
type Mixed :: [Maybe Nat] -> Type -> Type
data family Mixed sh a
+newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a)
+
newtype instance Mixed sh Int = M_Int (XArray sh Int)
newtype instance Mixed sh Double = M_Double (XArray sh Double)
+newtype instance Mixed sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector)
-- etc.
-newtype instance Mixed sh () = M_Nil (IxX sh) -- store the shape
data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b)
data instance Mixed sh (a, b, c) = M_Tup3 (Mixed sh a) (Mixed sh b) (Mixed sh c)
data instance Mixed sh (a, b, c, d) = M_Tup4 (Mixed sh a) (Mixed sh b) (Mixed sh c) (Mixed sh d)
@@ -62,11 +69,13 @@ newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a)
type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
data family MixedVecs s sh a
+newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a)
+
newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int)
newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double)
+newtype instance MixedVecs s sh () = MV_Nil (VU.MVector s ()) -- no content, MVector optimises this
-- etc.
-data instance MixedVecs s sh () = MV_Nil
data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b)
data instance MixedVecs s sh (a, b, c) = MV_Tup3 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c)
data instance MixedVecs s sh (a, b, c, d) = MV_Tup4 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c) (MixedVecs s sh d)
@@ -79,6 +88,10 @@ class GMixed a where
mindex :: Mixed sh a -> IxX sh -> a
mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a
+ mlift :: forall sh1 sh2. KnownShapeX sh2
+ => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 a -> Mixed sh2 a
+
-- | Create an empty array. The given shape must have size zero; this may or may not be checked.
memptyArray :: IxX sh -> Mixed sh a
@@ -104,56 +117,46 @@ class GMixed a where
-- | Given the shape of this array, finalise the vectors into 'XArray's.
mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
--- TODO: this use of toVector is suboptimal
-mvecsWritePartialPrimitive
- :: forall sh' sh a s. (KnownShapeX sh', VU.Unbox a)
- => IxX (sh ++ sh') -> IxX sh -> XArray sh' a -> VU.MVector s a -> ST s ()
-mvecsWritePartialPrimitive sh i arr v = do
- let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr)))
- VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)
-instance GMixed Int where
- mshape (M_Int a) = X.shape a
- mindex (M_Int a) i = X.index a i
- mindexPartial (M_Int a) i = M_Int (X.indexPartial a i)
- memptyArray sh = M_Int (X.generate sh (error "memptyArray Int: shape was not empty"))
+instance VU.Unbox a => GMixed (Primitive a) where
+ mshape (M_Primitive a) = X.shape a
+ mindex (M_Primitive a) i = Primitive (X.index a i)
+ mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i)
- mvecsNumElts _ = 1
- mvecsUnsafeNew sh _ = MV_Int <$> VUM.unsafeNew (X.shapeSize sh)
- mvecsWrite sh i x (MV_Int v) = VUM.write v (X.toLinearIdx sh i) x
- mvecsWritePartial sh i (M_Int @sh' arr) (MV_Int v) = mvecsWritePartialPrimitive @sh' sh i arr v
- mvecsFreeze sh (MV_Int v) = M_Int . X.fromVector sh <$> VU.freeze v
-
-instance GMixed Double where
- mshape (M_Double a) = X.shape a
- mindex (M_Double a) i = X.index a i
- mindexPartial (M_Double a) i = M_Double (X.indexPartial a i)
- memptyArray sh = M_Double (X.generate sh (error "memptyArray Double: shape was not empty"))
+ mlift :: forall sh1 sh2.
+ (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a)
+ -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a)
+ mlift f (M_Primitive a)
+ | Refl <- X.lemAppNil @sh1
+ , Refl <- X.lemAppNil @sh2
+ = M_Primitive (f Proxy a)
+ memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty"))
mvecsNumElts _ = 1
- mvecsUnsafeNew sh _ = MV_Double <$> VUM.unsafeNew (X.shapeSize sh)
- mvecsWrite sh i x (MV_Double v) = VUM.write v (X.toLinearIdx sh i) x
- mvecsWritePartial sh i (M_Double @sh' arr) (MV_Double v) = mvecsWritePartialPrimitive @sh' sh i arr v
- mvecsFreeze sh (MV_Double v) = M_Double . X.fromVector sh <$> VU.freeze v
+ mvecsUnsafeNew sh _ = MV_Primitive <$> VUM.unsafeNew (X.shapeSize sh)
+ mvecsWrite sh i (Primitive x) (MV_Primitive v) = VUM.write v (X.toLinearIdx sh i) x
-instance GMixed () where
- mshape (M_Nil sh) = sh
- mindex _ _ = ()
- mindexPartial = \(M_Nil sh) i -> M_Nil (X.ixDrop sh i)
- memptyArray sh = M_Nil sh
+ -- TODO: this use of toVector is suboptimal
+ mvecsWritePartial
+ :: forall sh' sh s. (KnownShapeX sh', VU.Unbox a)
+ => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s ()
+ mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do
+ let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr)))
+ VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)
- mvecsNumElts _ = 1
- mvecsUnsafeNew _ _ = return MV_Nil
- mvecsWrite _ _ _ _ = return ()
- mvecsWritePartial _ _ _ _ = return ()
- mvecsFreeze sh _ = return (M_Nil sh)
+ mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v
+
+deriving via Primitive Int instance GMixed Int
+deriving via Primitive Double instance GMixed Double
+deriving via Primitive () instance GMixed ()
instance (GMixed a, GMixed b) => GMixed (a, b) where
mshape (M_Tup2 a _) = mshape a
mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
- memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
+ mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b)
+ memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y
mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
@@ -165,7 +168,6 @@ instance (GMixed a, GMixed b) => GMixed (a, b) where
mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where
- -- TODO: this is quadratic in the nesting level
mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh
mshape (M_Nest arr)
| Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
@@ -184,6 +186,20 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where
| Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
= M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+ mlift :: forall sh1 sh2. KnownShapeX sh2
+ => (forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b)
+ -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a)
+ mlift f (M_Nest arr)
+ | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
+ = M_Nest (mlift f' arr)
+ where
+ f' :: forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b
+ f' _
+ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3)
+ , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3)
+ , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @sh3))
+ = f (Proxy @(sh' ++ sh3))
+
memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh'))))
mvecsNumElts arr =
@@ -258,6 +274,14 @@ instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where
= coerce @(Mixed sh' (Mixed (Replicate n 'Nothing) a)) @(Mixed sh' (Ranked n a)) $
mindexPartial arr i
+ mlift :: forall sh1 sh2. KnownShapeX sh2
+ => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b)
+ -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a)
+ mlift f (M_Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n)
+ = coerce @(Mixed sh2 (Mixed (Replicate n 'Nothing) a)) @(Mixed sh2 (Ranked n a)) $
+ mlift f arr
+
memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a)
memptyArray i
| Dict <- lemKnownReplicate (Proxy @n)
@@ -375,3 +399,10 @@ rgenerate sh f
| Dict <- lemKnownReplicate (Proxy @n)
, Refl <- lemRankReplicate (Proxy @n)
= Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR))
+
+rlift :: forall n1 n2 a. (KnownNat n2, GMixed a)
+ => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b)
+ -> Ranked n1 a -> Ranked n2 a
+rlift f (Ranked arr)
+ | Dict <- lemKnownReplicate (Proxy @n2)
+ = Ranked (mlift f arr)