summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-03-26 23:55:18 +0100
committerTom Smeding <tom@tomsmeding.com>2024-03-26 23:55:18 +0100
commit4918bbe4c5b560917c3cb53619838ead1ea53b9e (patch)
tree0f702a20b1802065d701e677a8dd853881239394 /src
Initial
Diffstat (limited to 'src')
-rw-r--r--src/Array.hs195
-rw-r--r--src/Fancy.hs237
2 files changed, 432 insertions, 0 deletions
diff --git a/src/Array.hs b/src/Array.hs
new file mode 100644
index 0000000..25db19e
--- /dev/null
+++ b/src/Array.hs
@@ -0,0 +1,195 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Array where
+
+import qualified Data.Array.RankedU as U
+import Data.Kind
+import Data.Proxy
+import Data.Type.Equality
+import qualified Data.Vector.Unboxed as VU
+import GHC.TypeLits
+import Unsafe.Coerce (unsafeCoerce)
+
+
+data Dict c a where
+ Dict :: c a => Dict c a
+
+type family l1 ++ l2 where
+ '[] ++ l2 = l2
+ (x : xs) ++ l2 = x : xs ++ l2
+
+lemAppNil :: l ++ '[] :~: l
+lemAppNil = unsafeCoerce Refl
+
+lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
+lemAppAssoc _ _ _ = unsafeCoerce Refl
+
+
+type IxX :: [Maybe Nat] -> Type
+data IxX sh where
+ IZX :: IxX '[]
+ (::@) :: Int -> IxX sh -> IxX (Just n : sh)
+ (::?) :: Int -> IxX sh -> IxX (Nothing : sh)
+deriving instance Show (IxX sh)
+
+type StaticShapeX :: [Maybe Nat] -> Type
+data StaticShapeX sh where
+ SZX :: StaticShapeX '[]
+ (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh)
+ (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh)
+
+type KnownShapeX :: [Maybe Nat] -> Constraint
+class KnownShapeX sh where
+ knownShapeX :: StaticShapeX sh
+instance KnownShapeX '[] where
+ knownShapeX = SZX
+instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
+ knownShapeX = natSing @n :$@ knownShapeX
+instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
+ knownShapeX = () :$? knownShapeX
+
+type family Rank sh where
+ Rank '[] = 0
+ Rank (_ : sh) = 1 + Rank sh
+
+type XArray :: [Maybe Nat] -> Type -> Type
+data XArray sh a = XArray (U.Array (Rank sh) a)
+
+zeroIdx :: StaticShapeX sh -> IxX sh
+zeroIdx SZX = IZX
+zeroIdx (_ :$@ ssh) = 0 ::@ zeroIdx ssh
+zeroIdx (_ :$? ssh) = 0 ::? zeroIdx ssh
+
+zeroIdx' :: IxX sh -> IxX sh
+zeroIdx' IZX = IZX
+zeroIdx' (_ ::@ sh) = 0 ::@ zeroIdx' sh
+zeroIdx' (_ ::? sh) = 0 ::? zeroIdx' sh
+
+ixAppend :: IxX sh -> IxX sh' -> IxX (sh ++ sh')
+ixAppend IZX idx' = idx'
+ixAppend (i ::@ idx) idx' = i ::@ ixAppend idx idx'
+ixAppend (i ::? idx) idx' = i ::? ixAppend idx idx'
+
+ixDrop :: IxX (sh ++ sh') -> IxX sh -> IxX sh'
+ixDrop sh IZX = sh
+ixDrop (_ ::@ sh) (_ ::@ idx) = ixDrop sh idx
+ixDrop (_ ::? sh) (_ ::? idx) = ixDrop sh idx
+
+ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh')
+ssxAppend SZX idx' = idx'
+ssxAppend (n :$@ idx) idx' = n :$@ ssxAppend idx idx'
+ssxAppend (() :$? idx) idx' = () :$? ssxAppend idx idx'
+
+shapeSize :: IxX sh -> Int
+shapeSize IZX = 1
+shapeSize (n ::@ sh) = n * shapeSize sh
+shapeSize (n ::? sh) = n * shapeSize sh
+
+fromLinearIdx :: IxX sh -> Int -> IxX sh
+fromLinearIdx = \sh i -> case go sh i of
+ (idx, 0) -> idx
+ _ -> error $ "fromLinearIdx: out of range (" ++ show i ++
+ " in array of shape " ++ show sh ++ ")"
+ where
+ -- returns (index in subarray, remaining index in enclosing array)
+ go :: IxX sh -> Int -> (IxX sh, Int)
+ go IZX i = (IZX, i)
+ go (n ::@ sh) i =
+ let (idx, i') = go sh i
+ (upi, locali) = i' `quotRem` n
+ in (locali ::@ idx, upi)
+ go (n ::? sh) i =
+ let (idx, i') = go sh i
+ (upi, locali) = i' `quotRem` n
+ in (locali ::? idx, upi)
+
+toLinearIdx :: IxX sh -> IxX sh -> Int
+toLinearIdx = \sh i -> fst (go sh i)
+ where
+ -- returns (index in subarray, size of subarray)
+ go :: IxX sh -> IxX sh -> (Int, Int)
+ go IZX IZX = (0, 1)
+ go (n ::@ sh) (i ::@ ix) =
+ let (lidx, sz) = go sh ix
+ in (sz * i + lidx, n * sz)
+ go (n ::? sh) (i ::? ix) =
+ let (lidx, sz) = go sh ix
+ in (sz * i + lidx, n * sz)
+
+enumShape :: IxX sh -> [IxX sh]
+enumShape = \sh -> go 0 sh id []
+ where
+ go :: Int -> IxX sh -> (IxX sh -> a) -> [a] -> [a]
+ go _ IZX _ = id
+ go i (n ::@ sh) f
+ | i < n = go (i + 1) (n ::@ sh) f . go 0 sh (f . (i ::@))
+ | otherwise = id
+ go i (n ::? sh) f
+ | i < n = go (i + 1) (n ::? sh) f . go 0 sh (f . (i ::?))
+ | otherwise = id
+
+shapeLshape :: IxX sh -> U.ShapeL
+shapeLshape IZX = []
+shapeLshape (n ::@ sh) = n : shapeLshape sh
+shapeLshape (n ::? sh) = n : shapeLshape sh
+
+lemKnownNatRank :: IxX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRank IZX = Dict
+lemKnownNatRank (_ ::@ sh) | Dict <- lemKnownNatRank sh = Dict
+lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict
+
+lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh
+lemKnownShapeX SZX = Dict
+lemKnownShapeX (SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict
+lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict
+lemKnownShapeX (_ :$@ _) = error "SNat does not have a COMPLETE pragma"
+
+lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
+lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh'
+lemAppKnownShapeX (SNat :$@ ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict
+lemAppKnownShapeX (() :$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict
+lemAppKnownShapeX (_ :$@ _) _ = error "SNat does not have a COMPLETE pragma"
+
+shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh
+shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr)
+ where
+ go :: StaticShapeX sh' -> [Int] -> IxX sh'
+ go SZX [] = IZX
+ go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) ::@ go ssh l
+ go (() :$? ssh) (n : l) = n ::? go ssh l
+ go _ _ = error "Invalid shapeL"
+
+fromVector :: U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a
+fromVector sh v | Dict <- lemKnownNatRank sh = XArray (U.fromVector (shapeLshape sh) v)
+
+toVector :: U.Unbox a => XArray sh a -> VU.Vector a
+toVector (XArray arr) = U.toVector arr
+
+generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a
+generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh)
+
+-- generateM :: (Monad m, U.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a)
+-- generateM sh f | Dict <- lemKnownNatRank sh =
+-- XArray . U.fromVector (shapeLshape sh)
+-- <$> VU.generateM (shapeSize sh) (f . fromLinearIdx sh)
+
+indexPartial :: U.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a
+indexPartial (XArray arr) IZX = XArray arr
+indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (U.index arr i)) idx
+indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (U.index arr i)) idx
+
+index :: forall sh a. U.Unbox a => XArray sh a -> IxX sh -> a
+index xarr i
+ | Refl <- lemAppNil @sh
+ = let XArray arr' = indexPartial xarr i :: XArray '[] a
+ in U.unScalar arr'
diff --git a/src/Fancy.hs b/src/Fancy.hs
new file mode 100644
index 0000000..821073e
--- /dev/null
+++ b/src/Fancy.hs
@@ -0,0 +1,237 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE FlexibleInstances #-}
+module Fancy where
+
+import Control.Monad (forM_)
+import Control.Monad.ST
+import Data.Kind
+import Data.Proxy
+import Data.Type.Equality
+import qualified Data.Vector.Unboxed as VU
+import qualified Data.Vector.Unboxed.Mutable as VUM
+import GHC.TypeLits
+
+import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++))
+import qualified Array as X
+
+
+type family Replicate n a where
+ Replicate 0 a = '[]
+ Replicate n a = a : Replicate (n - 1) a
+
+type family MapJust l where
+ MapJust '[] = '[]
+ MapJust (x : xs) = Just x : MapJust xs
+
+
+type Mixed :: [Maybe Nat] -> Type -> Type
+data family Mixed sh a
+
+newtype instance Mixed sh Int = M_Int (XArray sh Int)
+newtype instance Mixed sh Double = M_Double (XArray sh Double)
+-- etc.
+
+newtype instance Mixed sh () = M_Nil (IxX sh) -- store the shape
+data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b)
+data instance Mixed sh (a, b, c) = M_Tup3 (Mixed sh a) (Mixed sh b) (Mixed sh c)
+data instance Mixed sh (a, b, c, d) = M_Tup4 (Mixed sh a) (Mixed sh b) (Mixed sh c) (Mixed sh d)
+
+newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a)
+
+
+type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type
+data family MixedVecs s sh a
+
+newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int)
+newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double)
+-- etc.
+
+data instance MixedVecs s sh () = MV_Nil
+data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b)
+data instance MixedVecs s sh (a, b, c) = MV_Tup3 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c)
+data instance MixedVecs s sh (a, b, c, d) = MV_Tup4 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c) (MixedVecs s sh d)
+
+data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a)
+
+
+class GMixed a where
+ mshape :: KnownShapeX sh => Mixed sh a -> IxX sh
+ mindex :: Mixed sh a -> IxX sh -> a
+ mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a
+
+ -- | Create an empty array. The given shape must have size zero; this may or may not be checked.
+ memptyArray :: IxX sh -> Mixed sh a
+
+ -- | Return the size of the individual (SoA) arrays in this value. If @a@
+ -- does not contain tuples, this coincides with the total number of scalars
+ -- in the given value; if @a@ contains tuples, then it is some multiple of
+ -- this number of scalars.
+ mvecsNumElts :: a -> Int
+
+ -- | Create uninitialised vectors for this array type, given the shape of
+ -- this vector and an example for the contents. The shape must not have size
+ -- zero; an error may be thrown otherwise.
+ mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a)
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s ()
+
+ -- | Given the shape of this array, an index and a value, write the value at
+ -- that index in the vectors.
+ mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s ()
+
+ -- | Given the shape of this array, finalise the vectors into 'XArray's.
+ mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a)
+
+-- TODO: this use of toVector is suboptimal
+mvecsWritePartialPrimitive
+ :: forall sh' sh a s. (KnownShapeX sh', VU.Unbox a)
+ => IxX (sh ++ sh') -> IxX sh -> XArray sh' a -> VU.MVector s a -> ST s ()
+mvecsWritePartialPrimitive sh i arr v = do
+ let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr)))
+ VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr)
+
+instance GMixed Int where
+ mshape (M_Int a) = X.shape a
+ mindex (M_Int a) i = X.index a i
+ mindexPartial (M_Int a) i = M_Int (X.indexPartial a i)
+ memptyArray sh = M_Int (X.generate sh (error "memptyArray Int: shape was not empty"))
+
+ mvecsNumElts _ = 1
+ mvecsUnsafeNew sh _ = MV_Int <$> VUM.unsafeNew (X.shapeSize sh)
+ mvecsWrite sh i x (MV_Int v) = VUM.write v (X.toLinearIdx sh i) x
+ mvecsWritePartial sh i (M_Int @sh' arr) (MV_Int v) = mvecsWritePartialPrimitive @sh' sh i arr v
+ mvecsFreeze sh (MV_Int v) = M_Int . X.fromVector sh <$> VU.freeze v
+
+instance GMixed Double where
+ mshape (M_Double a) = X.shape a
+ mindex (M_Double a) i = X.index a i
+ mindexPartial (M_Double a) i = M_Double (X.indexPartial a i)
+ memptyArray sh = M_Double (X.generate sh (error "memptyArray Double: shape was not empty"))
+
+ mvecsNumElts _ = 1
+ mvecsUnsafeNew sh _ = MV_Double <$> VUM.unsafeNew (X.shapeSize sh)
+ mvecsWrite sh i x (MV_Double v) = VUM.write v (X.toLinearIdx sh i) x
+ mvecsWritePartial sh i (M_Double @sh' arr) (MV_Double v) = mvecsWritePartialPrimitive @sh' sh i arr v
+ mvecsFreeze sh (MV_Double v) = M_Double . X.fromVector sh <$> VU.freeze v
+
+instance GMixed () where
+ mshape (M_Nil sh) = sh
+ mindex _ _ = ()
+ mindexPartial = \(M_Nil sh) i -> M_Nil (X.ixDrop sh i)
+ memptyArray sh = M_Nil sh
+
+ mvecsNumElts _ = 1
+ mvecsUnsafeNew _ _ = return MV_Nil
+ mvecsWrite _ _ _ _ = return ()
+ mvecsWritePartial _ _ _ _ = return ()
+ mvecsFreeze sh _ = return (M_Nil sh)
+
+instance (GMixed a, GMixed b) => GMixed (a, b) where
+ mshape (M_Tup2 a _) = mshape a
+ mindex (M_Tup2 a b) i = (mindex a i, mindex b i)
+ mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i)
+ memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)
+
+ mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y
+ mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y
+ mvecsWrite sh i (x, y) (MV_Tup2 a b) = do
+ mvecsWrite sh i x a
+ mvecsWrite sh i y b
+ mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do
+ mvecsWritePartial sh i x a
+ mvecsWritePartial sh i y b
+ mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b
+
+instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where
+ -- TODO: this is quadratic in the nesting level
+ mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh
+ mshape (M_Nest arr)
+ | X.Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh')
+ = ixAppPrefix (knownShapeX @sh) (mshape arr)
+ where
+ ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1
+ ixAppPrefix SZX _ = IZX
+ ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx
+ ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx
+
+ mindex (M_Nest arr) i = mindexPartial arr i
+
+ mindexPartial :: forall sh1 sh2.
+ Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
+ mindexPartial (M_Nest arr) i
+ | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i)
+
+ memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh'))))
+
+ mvecsNumElts arr =
+ let n = X.shapeSize (mshape arr)
+ in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh')))
+
+ mvecsUnsafeNew sh example
+ | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example"
+ | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example))
+ (mindex example (X.zeroIdx (knownShapeX @sh')))
+ where
+ sh' = mshape example
+
+ mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs
+
+ mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2
+ => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a)
+ -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a)
+ -> ST s ()
+ mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs)
+ | X.Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh'))
+ , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')
+ = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs
+
+ mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs
+
+mgenerate :: GMixed a => IxX sh -> (IxX sh -> a) -> Mixed sh a
+mgenerate sh f
+ -- We need to be very careful here to ensure that neither 'sh' nor
+ -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty.
+ | X.shapeSize sh == 0 = memptyArray sh
+ | otherwise =
+ let firstidx = X.zeroIdx' sh
+ firstelem = f (X.zeroIdx' sh)
+ in if mvecsNumElts firstelem == 0
+ then memptyArray sh
+ else runST $ do
+ vecs <- mvecsUnsafeNew sh firstelem
+ mvecsWrite sh firstidx firstelem vecs
+ forM_ (tail (X.enumShape sh)) $ \idx ->
+ mvecsWrite sh idx (f idx) vecs
+ mvecsFreeze sh vecs
+
+
+type Ranked :: Nat -> Type -> Type
+newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a)
+
+type Shaped :: [Nat] -> Type -> Type
+newtype Shaped sh a = Shaped (Mixed (MapJust sh) a)
+
+
+type IxR :: Nat -> Type
+data IxR n where
+ IZR :: IxR 0
+ (:::) :: Int -> IxR n -> IxR (n + 1)
+
+type IxS :: [Nat] -> Type
+data IxS sh where
+ IZS :: IxS '[]
+ (::$) :: Int -> IxS sh -> IxS (n : sh)
+
+