diff options
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | cabal.project | 2 | ||||
| -rw-r--r-- | ox-arrays.cabal | 19 | ||||
| -rw-r--r-- | src/Array.hs | 195 | ||||
| -rw-r--r-- | src/Fancy.hs | 237 | 
5 files changed, 454 insertions, 0 deletions
| diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c33954f --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +dist-newstyle/ diff --git a/cabal.project b/cabal.project new file mode 100644 index 0000000..a13761a --- /dev/null +++ b/cabal.project @@ -0,0 +1,2 @@ +packages: . +with-compiler: ghc-9.6.4 diff --git a/ox-arrays.cabal b/ox-arrays.cabal new file mode 100644 index 0000000..aea0a94 --- /dev/null +++ b/ox-arrays.cabal @@ -0,0 +1,19 @@ +cabal-version:   3.0 +name:            ox-arrays +version:         0.1.0.0 +author:          Tom Smeding +license:         BSD-3-Clause +build-type:      Simple + +library +  exposed-modules: +    Array +    Fancy +  build-depends: +    base >=4.18, +    ghc-typelits-knownnat, +    orthotope, +    vector +  hs-source-dirs: src +  default-language: Haskell2010 +  ghc-options: -Wall diff --git a/src/Array.hs b/src/Array.hs new file mode 100644 index 0000000..25db19e --- /dev/null +++ b/src/Array.hs @@ -0,0 +1,195 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Array where + +import qualified Data.Array.RankedU as U +import Data.Kind +import Data.Proxy +import Data.Type.Equality +import qualified Data.Vector.Unboxed as VU +import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) + + +data Dict c a where +  Dict :: c a => Dict c a + +type family l1 ++ l2 where +  '[] ++ l2 = l2 +  (x : xs) ++ l2 = x : xs ++ l2 + +lemAppNil :: l ++ '[] :~: l +lemAppNil = unsafeCoerce Refl + +lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) +lemAppAssoc _ _ _ = unsafeCoerce Refl + + +type IxX :: [Maybe Nat] -> Type +data IxX sh where +  IZX :: IxX '[] +  (::@) :: Int -> IxX sh -> IxX (Just n : sh) +  (::?) :: Int -> IxX sh -> IxX (Nothing : sh) +deriving instance Show (IxX sh) + +type StaticShapeX :: [Maybe Nat] -> Type +data StaticShapeX sh where +  SZX :: StaticShapeX '[] +  (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) +  (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) + +type KnownShapeX :: [Maybe Nat] -> Constraint +class KnownShapeX sh where +  knownShapeX :: StaticShapeX sh +instance KnownShapeX '[] where +  knownShapeX = SZX +instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where +  knownShapeX = natSing @n :$@ knownShapeX +instance KnownShapeX sh => KnownShapeX (Nothing : sh) where +  knownShapeX = () :$? knownShapeX + +type family Rank sh where +  Rank '[] = 0 +  Rank (_ : sh) = 1 + Rank sh + +type XArray :: [Maybe Nat] -> Type -> Type +data XArray sh a = XArray (U.Array (Rank sh) a) + +zeroIdx :: StaticShapeX sh -> IxX sh +zeroIdx SZX = IZX +zeroIdx (_ :$@ ssh) = 0 ::@ zeroIdx ssh +zeroIdx (_ :$? ssh) = 0 ::? zeroIdx ssh + +zeroIdx' :: IxX sh -> IxX sh +zeroIdx' IZX = IZX +zeroIdx' (_ ::@ sh) = 0 ::@ zeroIdx' sh +zeroIdx' (_ ::? sh) = 0 ::? zeroIdx' sh + +ixAppend :: IxX sh -> IxX sh' -> IxX (sh ++ sh') +ixAppend IZX idx' = idx' +ixAppend (i ::@ idx) idx' = i ::@ ixAppend idx idx' +ixAppend (i ::? idx) idx' = i ::? ixAppend idx idx' + +ixDrop :: IxX (sh ++ sh') -> IxX sh -> IxX sh' +ixDrop sh IZX = sh +ixDrop (_ ::@ sh) (_ ::@ idx) = ixDrop sh idx +ixDrop (_ ::? sh) (_ ::? idx) = ixDrop sh idx + +ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh') +ssxAppend SZX idx' = idx' +ssxAppend (n :$@ idx) idx' = n :$@ ssxAppend idx idx' +ssxAppend (() :$? idx) idx' = () :$? ssxAppend idx idx' + +shapeSize :: IxX sh -> Int +shapeSize IZX = 1 +shapeSize (n ::@ sh) = n * shapeSize sh +shapeSize (n ::? sh) = n * shapeSize sh + +fromLinearIdx :: IxX sh -> Int -> IxX sh +fromLinearIdx = \sh i -> case go sh i of +  (idx, 0) -> idx +  _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ +               " in array of shape " ++ show sh ++ ")" +  where +    -- returns (index in subarray, remaining index in enclosing array) +    go :: IxX sh -> Int -> (IxX sh, Int) +    go IZX i = (IZX, i) +    go (n ::@ sh) i = +      let (idx, i') = go sh i +          (upi, locali) = i' `quotRem` n +      in (locali ::@ idx, upi) +    go (n ::? sh) i = +      let (idx, i') = go sh i +          (upi, locali) = i' `quotRem` n +      in (locali ::? idx, upi) + +toLinearIdx :: IxX sh -> IxX sh -> Int +toLinearIdx = \sh i -> fst (go sh i) +  where +    -- returns (index in subarray, size of subarray) +    go :: IxX sh -> IxX sh -> (Int, Int) +    go IZX IZX = (0, 1) +    go (n ::@ sh) (i ::@ ix) = +      let (lidx, sz) = go sh ix +      in (sz * i + lidx, n * sz) +    go (n ::? sh) (i ::? ix) = +      let (lidx, sz) = go sh ix +      in (sz * i + lidx, n * sz) + +enumShape :: IxX sh -> [IxX sh] +enumShape = \sh -> go 0 sh id [] +  where +    go :: Int -> IxX sh -> (IxX sh -> a) -> [a] -> [a] +    go _ IZX _ = id +    go i (n ::@ sh) f +      | i < n = go (i + 1) (n ::@ sh) f . go 0 sh (f . (i ::@)) +      | otherwise = id +    go i (n ::? sh) f +      | i < n = go (i + 1) (n ::? sh) f . go 0 sh (f . (i ::?)) +      | otherwise = id + +shapeLshape :: IxX sh -> U.ShapeL +shapeLshape IZX = [] +shapeLshape (n ::@ sh) = n : shapeLshape sh +shapeLshape (n ::? sh) = n : shapeLshape sh + +lemKnownNatRank :: IxX sh -> Dict KnownNat (Rank sh) +lemKnownNatRank IZX = Dict +lemKnownNatRank (_ ::@ sh) | Dict <- lemKnownNatRank sh = Dict +lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict + +lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh +lemKnownShapeX SZX = Dict +lemKnownShapeX (SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict +lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict +lemKnownShapeX (_ :$@ _) = error "SNat does not have a COMPLETE pragma" + +lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2) +lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh' +lemAppKnownShapeX (SNat :$@ ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict +lemAppKnownShapeX (() :$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict +lemAppKnownShapeX (_ :$@ _) _ = error "SNat does not have a COMPLETE pragma" + +shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh +shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr) +  where +    go :: StaticShapeX sh' -> [Int] -> IxX sh' +    go SZX [] = IZX +    go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) ::@ go ssh l +    go (() :$? ssh) (n : l) = n ::? go ssh l +    go _ _ = error "Invalid shapeL" + +fromVector :: U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a  +fromVector sh v | Dict <- lemKnownNatRank sh = XArray (U.fromVector (shapeLshape sh) v) + +toVector :: U.Unbox a => XArray sh a -> VU.Vector a +toVector (XArray arr) = U.toVector arr + +generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a +generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh) + +-- generateM :: (Monad m, U.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a) +-- generateM sh f | Dict <- lemKnownNatRank sh = +--   XArray . U.fromVector (shapeLshape sh) +--     <$> VU.generateM (shapeSize sh) (f . fromLinearIdx sh) + +indexPartial :: U.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a +indexPartial (XArray arr) IZX = XArray arr +indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (U.index arr i)) idx +indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (U.index arr i)) idx + +index :: forall sh a. U.Unbox a => XArray sh a -> IxX sh -> a +index xarr i +  | Refl <- lemAppNil @sh +  = let XArray arr' = indexPartial xarr i :: XArray '[] a +    in U.unScalar arr' diff --git a/src/Fancy.hs b/src/Fancy.hs new file mode 100644 index 0000000..821073e --- /dev/null +++ b/src/Fancy.hs @@ -0,0 +1,237 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE FlexibleInstances #-} +module Fancy where + +import Control.Monad (forM_) +import Control.Monad.ST +import Data.Kind +import Data.Proxy +import Data.Type.Equality +import qualified Data.Vector.Unboxed as VU +import qualified Data.Vector.Unboxed.Mutable as VUM +import GHC.TypeLits + +import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++)) +import qualified Array as X + + +type family Replicate n a where +  Replicate 0 a = '[] +  Replicate n a = a : Replicate (n - 1) a + +type family MapJust l where +  MapJust '[] = '[] +  MapJust (x : xs) = Just x : MapJust xs + + +type Mixed :: [Maybe Nat] -> Type -> Type +data family Mixed sh a + +newtype instance Mixed sh Int = M_Int (XArray sh Int) +newtype instance Mixed sh Double = M_Double (XArray sh Double) +-- etc. + +newtype instance Mixed sh () = M_Nil (IxX sh)  -- store the shape +data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b) +data instance Mixed sh (a, b, c) = M_Tup3 (Mixed sh a) (Mixed sh b) (Mixed sh c) +data instance Mixed sh (a, b, c, d) = M_Tup4 (Mixed sh a) (Mixed sh b) (Mixed sh c) (Mixed sh d) + +newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a) + + +type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type +data family MixedVecs s sh a + +newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int) +newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double) +-- etc. + +data instance MixedVecs s sh () = MV_Nil +data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b) +data instance MixedVecs s sh (a, b, c) = MV_Tup3 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c) +data instance MixedVecs s sh (a, b, c, d) = MV_Tup4 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c) (MixedVecs s sh d) + +data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a) + + +class GMixed a where +  mshape :: KnownShapeX sh => Mixed sh a -> IxX sh +  mindex :: Mixed sh a -> IxX sh -> a +  mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a + +  -- | Create an empty array. The given shape must have size zero; this may or may not be checked. +  memptyArray :: IxX sh -> Mixed sh a + +  -- | Return the size of the individual (SoA) arrays in this value. If @a@ +  -- does not contain tuples, this coincides with the total number of scalars +  -- in the given value; if @a@ contains tuples, then it is some multiple of +  -- this number of scalars. +  mvecsNumElts :: a -> Int + +  -- | Create uninitialised vectors for this array type, given the shape of +  -- this vector and an example for the contents. The shape must not have size +  -- zero; an error may be thrown otherwise. +  mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a) + +  -- | Given the shape of this array, an index and a value, write the value at +  -- that index in the vectors. +  mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s () + +  -- | Given the shape of this array, an index and a value, write the value at +  -- that index in the vectors. +  mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + +  -- | Given the shape of this array, finalise the vectors into 'XArray's. +  mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + +-- TODO: this use of toVector is suboptimal +mvecsWritePartialPrimitive +  :: forall sh' sh a s. (KnownShapeX sh', VU.Unbox a) +  => IxX (sh ++ sh') -> IxX sh -> XArray sh' a -> VU.MVector s a -> ST s () +mvecsWritePartialPrimitive sh i arr v = do +  let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr))) +  VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) + +instance GMixed Int where +  mshape (M_Int a) = X.shape a +  mindex (M_Int a) i = X.index a i +  mindexPartial (M_Int a) i = M_Int (X.indexPartial a i) +  memptyArray sh = M_Int (X.generate sh (error "memptyArray Int: shape was not empty")) + +  mvecsNumElts _ = 1 +  mvecsUnsafeNew sh _ = MV_Int <$> VUM.unsafeNew (X.shapeSize sh) +  mvecsWrite sh i x (MV_Int v) = VUM.write v (X.toLinearIdx sh i) x +  mvecsWritePartial sh i (M_Int @sh' arr) (MV_Int v) = mvecsWritePartialPrimitive @sh' sh i arr v +  mvecsFreeze sh (MV_Int v) = M_Int . X.fromVector sh <$> VU.freeze v + +instance GMixed Double where +  mshape (M_Double a) = X.shape a +  mindex (M_Double a) i = X.index a i +  mindexPartial (M_Double a) i = M_Double (X.indexPartial a i) +  memptyArray sh = M_Double (X.generate sh (error "memptyArray Double: shape was not empty")) + +  mvecsNumElts _ = 1 +  mvecsUnsafeNew sh _ = MV_Double <$> VUM.unsafeNew (X.shapeSize sh) +  mvecsWrite sh i x (MV_Double v) = VUM.write v (X.toLinearIdx sh i) x +  mvecsWritePartial sh i (M_Double @sh' arr) (MV_Double v) = mvecsWritePartialPrimitive @sh' sh i arr v +  mvecsFreeze sh (MV_Double v) = M_Double . X.fromVector sh <$> VU.freeze v + +instance GMixed () where +  mshape (M_Nil sh) = sh +  mindex _ _ = () +  mindexPartial = \(M_Nil sh) i -> M_Nil (X.ixDrop sh i) +  memptyArray sh = M_Nil sh + +  mvecsNumElts _ = 1 +  mvecsUnsafeNew _ _ = return MV_Nil +  mvecsWrite _ _ _ _ = return () +  mvecsWritePartial _ _ _ _ = return () +  mvecsFreeze sh _ = return (M_Nil sh) + +instance (GMixed a, GMixed b) => GMixed (a, b) where +  mshape (M_Tup2 a _) = mshape a +  mindex (M_Tup2 a b) i = (mindex a i, mindex b i) +  mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) +  memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) + +  mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y +  mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y +  mvecsWrite sh i (x, y) (MV_Tup2 a b) = do +    mvecsWrite sh i x a +    mvecsWrite sh i y b +  mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do +    mvecsWritePartial sh i x a +    mvecsWritePartial sh i y b +  mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b + +instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where +  -- TODO: this is quadratic in the nesting level +  mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh +  mshape (M_Nest arr) +    | X.Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') +    = ixAppPrefix (knownShapeX @sh) (mshape arr) +    where +      ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1 +      ixAppPrefix SZX _ = IZX +      ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx +      ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx + +  mindex (M_Nest arr) i = mindexPartial arr i + +  mindexPartial :: forall sh1 sh2. +                   Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) +  mindexPartial (M_Nest arr) i +    | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') +    = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + +  memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh')))) + +  mvecsNumElts arr = +    let n = X.shapeSize (mshape arr) +    in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh'))) + +  mvecsUnsafeNew sh example +    | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example" +    | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) +                                                 (mindex example (X.zeroIdx (knownShapeX @sh'))) +    where +      sh' = mshape example + +  mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs + +  mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 +                    => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) +                    -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) +                    -> ST s () +  mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) +    | X.Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) +    , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') +    = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs + +  mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs + +mgenerate :: GMixed a => IxX sh -> (IxX sh -> a) -> Mixed sh a +mgenerate sh f +  -- We need to be very careful here to ensure that neither 'sh' nor +  -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty. +  | X.shapeSize sh == 0 = memptyArray sh +  | otherwise = +      let firstidx = X.zeroIdx' sh +          firstelem = f (X.zeroIdx' sh) +      in if mvecsNumElts firstelem == 0 +           then memptyArray sh +           else runST $ do +                  vecs <- mvecsUnsafeNew sh firstelem +                  mvecsWrite sh firstidx firstelem vecs +                  forM_ (tail (X.enumShape sh)) $ \idx -> +                    mvecsWrite sh idx (f idx) vecs +                  mvecsFreeze sh vecs + + +type Ranked :: Nat -> Type -> Type +newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) + +type Shaped :: [Nat] -> Type -> Type +newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) + + +type IxR :: Nat -> Type +data IxR n where +  IZR :: IxR 0 +  (:::) :: Int -> IxR n -> IxR (n + 1) + +type IxS :: [Nat] -> Type +data IxS sh where +  IZS :: IxS '[] +  (::$) :: Int -> IxS sh -> IxS (n : sh) + + | 
