diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-03-26 23:55:18 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-03-26 23:55:18 +0100 |
commit | 4918bbe4c5b560917c3cb53619838ead1ea53b9e (patch) | |
tree | 0f702a20b1802065d701e677a8dd853881239394 /src/Fancy.hs |
Initial
Diffstat (limited to 'src/Fancy.hs')
-rw-r--r-- | src/Fancy.hs | 237 |
1 files changed, 237 insertions, 0 deletions
diff --git a/src/Fancy.hs b/src/Fancy.hs new file mode 100644 index 0000000..821073e --- /dev/null +++ b/src/Fancy.hs @@ -0,0 +1,237 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE FlexibleInstances #-} +module Fancy where + +import Control.Monad (forM_) +import Control.Monad.ST +import Data.Kind +import Data.Proxy +import Data.Type.Equality +import qualified Data.Vector.Unboxed as VU +import qualified Data.Vector.Unboxed.Mutable as VUM +import GHC.TypeLits + +import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++)) +import qualified Array as X + + +type family Replicate n a where + Replicate 0 a = '[] + Replicate n a = a : Replicate (n - 1) a + +type family MapJust l where + MapJust '[] = '[] + MapJust (x : xs) = Just x : MapJust xs + + +type Mixed :: [Maybe Nat] -> Type -> Type +data family Mixed sh a + +newtype instance Mixed sh Int = M_Int (XArray sh Int) +newtype instance Mixed sh Double = M_Double (XArray sh Double) +-- etc. + +newtype instance Mixed sh () = M_Nil (IxX sh) -- store the shape +data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b) +data instance Mixed sh (a, b, c) = M_Tup3 (Mixed sh a) (Mixed sh b) (Mixed sh c) +data instance Mixed sh (a, b, c, d) = M_Tup4 (Mixed sh a) (Mixed sh b) (Mixed sh c) (Mixed sh d) + +newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a) + + +type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type +data family MixedVecs s sh a + +newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int) +newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double) +-- etc. + +data instance MixedVecs s sh () = MV_Nil +data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b) +data instance MixedVecs s sh (a, b, c) = MV_Tup3 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c) +data instance MixedVecs s sh (a, b, c, d) = MV_Tup4 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c) (MixedVecs s sh d) + +data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a) + + +class GMixed a where + mshape :: KnownShapeX sh => Mixed sh a -> IxX sh + mindex :: Mixed sh a -> IxX sh -> a + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a + + -- | Create an empty array. The given shape must have size zero; this may or may not be checked. + memptyArray :: IxX sh -> Mixed sh a + + -- | Return the size of the individual (SoA) arrays in this value. If @a@ + -- does not contain tuples, this coincides with the total number of scalars + -- in the given value; if @a@ contains tuples, then it is some multiple of + -- this number of scalars. + mvecsNumElts :: a -> Int + + -- | Create uninitialised vectors for this array type, given the shape of + -- this vector and an example for the contents. The shape must not have size + -- zero; an error may be thrown otherwise. + mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a) + + -- | Given the shape of this array, an index and a value, write the value at + -- that index in the vectors. + mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s () + + -- | Given the shape of this array, an index and a value, write the value at + -- that index in the vectors. + mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + + -- | Given the shape of this array, finalise the vectors into 'XArray's. + mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + +-- TODO: this use of toVector is suboptimal +mvecsWritePartialPrimitive + :: forall sh' sh a s. (KnownShapeX sh', VU.Unbox a) + => IxX (sh ++ sh') -> IxX sh -> XArray sh' a -> VU.MVector s a -> ST s () +mvecsWritePartialPrimitive sh i arr v = do + let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr))) + VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) + +instance GMixed Int where + mshape (M_Int a) = X.shape a + mindex (M_Int a) i = X.index a i + mindexPartial (M_Int a) i = M_Int (X.indexPartial a i) + memptyArray sh = M_Int (X.generate sh (error "memptyArray Int: shape was not empty")) + + mvecsNumElts _ = 1 + mvecsUnsafeNew sh _ = MV_Int <$> VUM.unsafeNew (X.shapeSize sh) + mvecsWrite sh i x (MV_Int v) = VUM.write v (X.toLinearIdx sh i) x + mvecsWritePartial sh i (M_Int @sh' arr) (MV_Int v) = mvecsWritePartialPrimitive @sh' sh i arr v + mvecsFreeze sh (MV_Int v) = M_Int . X.fromVector sh <$> VU.freeze v + +instance GMixed Double where + mshape (M_Double a) = X.shape a + mindex (M_Double a) i = X.index a i + mindexPartial (M_Double a) i = M_Double (X.indexPartial a i) + memptyArray sh = M_Double (X.generate sh (error "memptyArray Double: shape was not empty")) + + mvecsNumElts _ = 1 + mvecsUnsafeNew sh _ = MV_Double <$> VUM.unsafeNew (X.shapeSize sh) + mvecsWrite sh i x (MV_Double v) = VUM.write v (X.toLinearIdx sh i) x + mvecsWritePartial sh i (M_Double @sh' arr) (MV_Double v) = mvecsWritePartialPrimitive @sh' sh i arr v + mvecsFreeze sh (MV_Double v) = M_Double . X.fromVector sh <$> VU.freeze v + +instance GMixed () where + mshape (M_Nil sh) = sh + mindex _ _ = () + mindexPartial = \(M_Nil sh) i -> M_Nil (X.ixDrop sh i) + memptyArray sh = M_Nil sh + + mvecsNumElts _ = 1 + mvecsUnsafeNew _ _ = return MV_Nil + mvecsWrite _ _ _ _ = return () + mvecsWritePartial _ _ _ _ = return () + mvecsFreeze sh _ = return (M_Nil sh) + +instance (GMixed a, GMixed b) => GMixed (a, b) where + mshape (M_Tup2 a _) = mshape a + mindex (M_Tup2 a b) i = (mindex a i, mindex b i) + mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) + memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) + + mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y + mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsWrite sh i (x, y) (MV_Tup2 a b) = do + mvecsWrite sh i x a + mvecsWrite sh i y b + mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do + mvecsWritePartial sh i x a + mvecsWritePartial sh i y b + mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b + +instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where + -- TODO: this is quadratic in the nesting level + mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh + mshape (M_Nest arr) + | X.Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') + = ixAppPrefix (knownShapeX @sh) (mshape arr) + where + ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1 + ixAppPrefix SZX _ = IZX + ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx + ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx + + mindex (M_Nest arr) i = mindexPartial arr i + + mindexPartial :: forall sh1 sh2. + Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) + mindexPartial (M_Nest arr) i + | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + + memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh')))) + + mvecsNumElts arr = + let n = X.shapeSize (mshape arr) + in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh'))) + + mvecsUnsafeNew sh example + | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example" + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) + (mindex example (X.zeroIdx (knownShapeX @sh'))) + where + sh' = mshape example + + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs + + mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 + => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () + mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) + | X.Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) + , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs + + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs + +mgenerate :: GMixed a => IxX sh -> (IxX sh -> a) -> Mixed sh a +mgenerate sh f + -- We need to be very careful here to ensure that neither 'sh' nor + -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty. + | X.shapeSize sh == 0 = memptyArray sh + | otherwise = + let firstidx = X.zeroIdx' sh + firstelem = f (X.zeroIdx' sh) + in if mvecsNumElts firstelem == 0 + then memptyArray sh + else runST $ do + vecs <- mvecsUnsafeNew sh firstelem + mvecsWrite sh firstidx firstelem vecs + forM_ (tail (X.enumShape sh)) $ \idx -> + mvecsWrite sh idx (f idx) vecs + mvecsFreeze sh vecs + + +type Ranked :: Nat -> Type -> Type +newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) + +type Shaped :: [Nat] -> Type -> Type +newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) + + +type IxR :: Nat -> Type +data IxR n where + IZR :: IxR 0 + (:::) :: Int -> IxR n -> IxR (n + 1) + +type IxS :: [Nat] -> Type +data IxS sh where + IZS :: IxS '[] + (::$) :: Int -> IxS sh -> IxS (n : sh) + + |