{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} module Fancy where import Control.Monad (forM_) import Control.Monad.ST import Data.Kind import Data.Proxy import Data.Type.Equality import Data.Type.Ord import qualified Data.Vector.Unboxed as VU import qualified Data.Vector.Unboxed.Mutable as VUM import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++)) import qualified Array as X type family Replicate n a where Replicate 0 a = '[] Replicate n a = a : Replicate (n - 1) a type family MapJust l where MapJust '[] = '[] MapJust (x : xs) = Just x : MapJust xs lemCompareFalse1 :: (0 < n, 1 > n) => Proxy n -> a lemCompareFalse1 = error "Incoherence" lemKnownReplicate :: forall n. KnownNat n => Proxy n -> X.Dict KnownShapeX (Replicate n Nothing) lemKnownReplicate _ = X.lemKnownShapeX (go (natSing @n)) where go :: forall m. SNat m -> StaticShapeX (Replicate m Nothing) go SNat = case cmpNat (Proxy @1) (Proxy @m) of LTI | Refl <- (unsafeCoerce Refl :: Nothing : Replicate (m - 1) Nothing :~: Replicate m Nothing) -> () :$? go (SNat @(m - 1)) EQI -> () :$? SZX GTI -> case cmpNat (Proxy @0) (Proxy @m) of LTI -> lemCompareFalse1 (Proxy @m) EQI -> SZX GTI -> error "0 > natural" go _ = error "COMPLETE" type Mixed :: [Maybe Nat] -> Type -> Type data family Mixed sh a newtype instance Mixed sh Int = M_Int (XArray sh Int) newtype instance Mixed sh Double = M_Double (XArray sh Double) -- etc. newtype instance Mixed sh () = M_Nil (IxX sh) -- store the shape data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b) data instance Mixed sh (a, b, c) = M_Tup3 (Mixed sh a) (Mixed sh b) (Mixed sh c) data instance Mixed sh (a, b, c, d) = M_Tup4 (Mixed sh a) (Mixed sh b) (Mixed sh c) (Mixed sh d) newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a) type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type data family MixedVecs s sh a newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int) newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double) -- etc. data instance MixedVecs s sh () = MV_Nil data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b) data instance MixedVecs s sh (a, b, c) = MV_Tup3 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c) data instance MixedVecs s sh (a, b, c, d) = MV_Tup4 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c) (MixedVecs s sh d) data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a) class GMixed a where mshape :: KnownShapeX sh => Mixed sh a -> IxX sh mindex :: Mixed sh a -> IxX sh -> a mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a -- | Create an empty array. The given shape must have size zero; this may or may not be checked. memptyArray :: IxX sh -> Mixed sh a -- | Return the size of the individual (SoA) arrays in this value. If @a@ -- does not contain tuples, this coincides with the total number of scalars -- in the given value; if @a@ contains tuples, then it is some multiple of -- this number of scalars. mvecsNumElts :: a -> Int -- | Create uninitialised vectors for this array type, given the shape of -- this vector and an example for the contents. The shape must not have size -- zero; an error may be thrown otherwise. mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a) -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s () -- | Given the shape of this array, an index and a value, write the value at -- that index in the vectors. mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () -- | Given the shape of this array, finalise the vectors into 'XArray's. mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) -- TODO: this use of toVector is suboptimal mvecsWritePartialPrimitive :: forall sh' sh a s. (KnownShapeX sh', VU.Unbox a) => IxX (sh ++ sh') -> IxX sh -> XArray sh' a -> VU.MVector s a -> ST s () mvecsWritePartialPrimitive sh i arr v = do let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr))) VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) instance GMixed Int where mshape (M_Int a) = X.shape a mindex (M_Int a) i = X.index a i mindexPartial (M_Int a) i = M_Int (X.indexPartial a i) memptyArray sh = M_Int (X.generate sh (error "memptyArray Int: shape was not empty")) mvecsNumElts _ = 1 mvecsUnsafeNew sh _ = MV_Int <$> VUM.unsafeNew (X.shapeSize sh) mvecsWrite sh i x (MV_Int v) = VUM.write v (X.toLinearIdx sh i) x mvecsWritePartial sh i (M_Int @sh' arr) (MV_Int v) = mvecsWritePartialPrimitive @sh' sh i arr v mvecsFreeze sh (MV_Int v) = M_Int . X.fromVector sh <$> VU.freeze v instance GMixed Double where mshape (M_Double a) = X.shape a mindex (M_Double a) i = X.index a i mindexPartial (M_Double a) i = M_Double (X.indexPartial a i) memptyArray sh = M_Double (X.generate sh (error "memptyArray Double: shape was not empty")) mvecsNumElts _ = 1 mvecsUnsafeNew sh _ = MV_Double <$> VUM.unsafeNew (X.shapeSize sh) mvecsWrite sh i x (MV_Double v) = VUM.write v (X.toLinearIdx sh i) x mvecsWritePartial sh i (M_Double @sh' arr) (MV_Double v) = mvecsWritePartialPrimitive @sh' sh i arr v mvecsFreeze sh (MV_Double v) = M_Double . X.fromVector sh <$> VU.freeze v instance GMixed () where mshape (M_Nil sh) = sh mindex _ _ = () mindexPartial = \(M_Nil sh) i -> M_Nil (X.ixDrop sh i) memptyArray sh = M_Nil sh mvecsNumElts _ = 1 mvecsUnsafeNew _ _ = return MV_Nil mvecsWrite _ _ _ _ = return () mvecsWritePartial _ _ _ _ = return () mvecsFreeze sh _ = return (M_Nil sh) instance (GMixed a, GMixed b) => GMixed (a, b) where mshape (M_Tup2 a _) = mshape a mindex (M_Tup2 a b) i = (mindex a i, mindex b i) mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y mvecsWrite sh i (x, y) (MV_Tup2 a b) = do mvecsWrite sh i x a mvecsWrite sh i y b mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do mvecsWritePartial sh i x a mvecsWritePartial sh i y b mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where -- TODO: this is quadratic in the nesting level mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh mshape (M_Nest arr) | X.Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = ixAppPrefix (knownShapeX @sh) (mshape arr) where ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1 ixAppPrefix SZX _ = IZX ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx mindex (M_Nest arr) i = mindexPartial arr i mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) mindexPartial (M_Nest arr) i | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh')))) mvecsNumElts arr = let n = X.shapeSize (mshape arr) in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh'))) mvecsUnsafeNew sh example | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example" | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) (mindex example (X.zeroIdx (knownShapeX @sh'))) where sh' = mshape example mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) -> ST s () mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) | X.Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs mgenerate :: GMixed a => IxX sh -> (IxX sh -> a) -> Mixed sh a mgenerate sh f -- We need to be very careful here to ensure that neither 'sh' nor -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty. | X.shapeSize sh == 0 = memptyArray sh | otherwise = let firstidx = X.zeroIdx' sh firstelem = f (X.zeroIdx' sh) in if mvecsNumElts firstelem == 0 then memptyArray sh else runST $ do vecs <- mvecsUnsafeNew sh firstelem mvecsWrite sh firstidx firstelem vecs forM_ (tail (X.enumShape sh)) $ \idx -> mvecsWrite sh idx (f idx) vecs mvecsFreeze sh vecs type Ranked :: Nat -> Type -> Type newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) type Shaped :: [Nat] -> Type -> Type newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where mshape (M_Ranked arr) | X.Dict <- lemKnownReplicate (Proxy @n) = mshape arr type IxR :: Nat -> Type data IxR n where IZR :: IxR 0 (:::) :: Int -> IxR n -> IxR (n + 1) type IxS :: [Nat] -> Type data IxS sh where IZS :: IxS '[] (::$) :: Int -> IxS sh -> IxS (n : sh)