diff options
| author | Tom Smeding <t.j.smeding@uu.nl> | 2024-03-28 17:30:54 +0100 | 
|---|---|---|
| committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-03-28 17:30:54 +0100 | 
| commit | 373799f36417c086510847e4d6689bd68582e658 (patch) | |
| tree | 72ac398279c2e8bc2162ffe77f6f054b4e0d2f77 | |
| parent | ae113c0249f3fe8be7df345081b1b51451cd3fdf (diff) | |
Various improvements
| -rw-r--r-- | src/Array.hs | 11 | ||||
| -rw-r--r-- | src/Fancy.hs | 113 | 
2 files changed, 83 insertions, 41 deletions
| diff --git a/src/Array.hs b/src/Array.hs index 29806d4..693df05 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -145,6 +145,11 @@ lemKnownNatRank IZX = Dict  lemKnownNatRank (_ ::@ sh) | Dict <- lemKnownNatRank sh = Dict  lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict +lemKnownNatRankSSX :: StaticShapeX sh -> Dict KnownNat (Rank sh) +lemKnownNatRankSSX SZX = Dict +lemKnownNatRankSSX (_ :$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict +lemKnownNatRankSSX (_ :$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict +  lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh  lemKnownShapeX SZX = Dict  lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict @@ -196,3 +201,9 @@ index xarr i    | Refl <- lemAppNil @sh    = let XArray arr' = indexPartial xarr i :: XArray '[] a      in U.unScalar arr' + +append :: forall sh a. (KnownShapeX sh, U.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a +append (XArray a) (XArray b) +  | Dict <- lemKnownNatRankSSX (knownShapeX @sh) +  , Dict <- gknownNat (Proxy @(Rank sh)) +  = XArray (U.append a b) diff --git a/src/Fancy.hs b/src/Fancy.hs index 6b6d8d4..e8192aa 100644 --- a/src/Fancy.hs +++ b/src/Fancy.hs @@ -1,7 +1,9 @@  {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingVia #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE InstanceSigs #-}  {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE StandaloneKindSignatures #-} @@ -44,14 +46,19 @@ lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n))      go (SS n) = () :$? go n +-- Wrapper type used as a tag to attach instances on. +newtype Primitive a = Primitive a +  type Mixed :: [Maybe Nat] -> Type -> Type  data family Mixed sh a +newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a) +  newtype instance Mixed sh Int = M_Int (XArray sh Int)  newtype instance Mixed sh Double = M_Double (XArray sh Double) +newtype instance Mixed sh () = M_Nil (XArray sh ())  -- no content, orthotope optimises this (via Vector)  -- etc. -newtype instance Mixed sh () = M_Nil (IxX sh)  -- store the shape  data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b)  data instance Mixed sh (a, b, c) = M_Tup3 (Mixed sh a) (Mixed sh b) (Mixed sh c)  data instance Mixed sh (a, b, c, d) = M_Tup4 (Mixed sh a) (Mixed sh b) (Mixed sh c) (Mixed sh d) @@ -62,11 +69,13 @@ newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a)  type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type  data family MixedVecs s sh a +newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a) +  newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int)  newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double) +newtype instance MixedVecs s sh () = MV_Nil (VU.MVector s ())  -- no content, MVector optimises this  -- etc. -data instance MixedVecs s sh () = MV_Nil  data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b)  data instance MixedVecs s sh (a, b, c) = MV_Tup3 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c)  data instance MixedVecs s sh (a, b, c, d) = MV_Tup4 (MixedVecs s sh a) (MixedVecs s sh b) (MixedVecs s sh c) (MixedVecs s sh d) @@ -79,6 +88,10 @@ class GMixed a where    mindex :: Mixed sh a -> IxX sh -> a    mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a +  mlift :: forall sh1 sh2. KnownShapeX sh2 +        => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) +        -> Mixed sh1 a -> Mixed sh2 a +    -- | Create an empty array. The given shape must have size zero; this may or may not be checked.    memptyArray :: IxX sh -> Mixed sh a @@ -104,56 +117,46 @@ class GMixed a where    -- | Given the shape of this array, finalise the vectors into 'XArray's.    mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) --- TODO: this use of toVector is suboptimal -mvecsWritePartialPrimitive -  :: forall sh' sh a s. (KnownShapeX sh', VU.Unbox a) -  => IxX (sh ++ sh') -> IxX sh -> XArray sh' a -> VU.MVector s a -> ST s () -mvecsWritePartialPrimitive sh i arr v = do -  let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr))) -  VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) -instance GMixed Int where -  mshape (M_Int a) = X.shape a -  mindex (M_Int a) i = X.index a i -  mindexPartial (M_Int a) i = M_Int (X.indexPartial a i) -  memptyArray sh = M_Int (X.generate sh (error "memptyArray Int: shape was not empty")) +instance VU.Unbox a => GMixed (Primitive a) where +  mshape (M_Primitive a) = X.shape a +  mindex (M_Primitive a) i = Primitive (X.index a i) +  mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) -  mvecsNumElts _ = 1 -  mvecsUnsafeNew sh _ = MV_Int <$> VUM.unsafeNew (X.shapeSize sh) -  mvecsWrite sh i x (MV_Int v) = VUM.write v (X.toLinearIdx sh i) x -  mvecsWritePartial sh i (M_Int @sh' arr) (MV_Int v) = mvecsWritePartialPrimitive @sh' sh i arr v -  mvecsFreeze sh (MV_Int v) = M_Int . X.fromVector sh <$> VU.freeze v - -instance GMixed Double where -  mshape (M_Double a) = X.shape a -  mindex (M_Double a) i = X.index a i -  mindexPartial (M_Double a) i = M_Double (X.indexPartial a i) -  memptyArray sh = M_Double (X.generate sh (error "memptyArray Double: shape was not empty")) +  mlift :: forall sh1 sh2. +           (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) +        -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) +  mlift f (M_Primitive a) +    | Refl <- X.lemAppNil @sh1 +    , Refl <- X.lemAppNil @sh2 +    = M_Primitive (f Proxy a) +  memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty"))    mvecsNumElts _ = 1 -  mvecsUnsafeNew sh _ = MV_Double <$> VUM.unsafeNew (X.shapeSize sh) -  mvecsWrite sh i x (MV_Double v) = VUM.write v (X.toLinearIdx sh i) x -  mvecsWritePartial sh i (M_Double @sh' arr) (MV_Double v) = mvecsWritePartialPrimitive @sh' sh i arr v -  mvecsFreeze sh (MV_Double v) = M_Double . X.fromVector sh <$> VU.freeze v +  mvecsUnsafeNew sh _ = MV_Primitive <$> VUM.unsafeNew (X.shapeSize sh) +  mvecsWrite sh i (Primitive x) (MV_Primitive v) = VUM.write v (X.toLinearIdx sh i) x -instance GMixed () where -  mshape (M_Nil sh) = sh -  mindex _ _ = () -  mindexPartial = \(M_Nil sh) i -> M_Nil (X.ixDrop sh i) -  memptyArray sh = M_Nil sh +  -- TODO: this use of toVector is suboptimal +  mvecsWritePartial +    :: forall sh' sh s. (KnownShapeX sh', VU.Unbox a) +    => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () +  mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do +    let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr))) +    VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) -  mvecsNumElts _ = 1 -  mvecsUnsafeNew _ _ = return MV_Nil -  mvecsWrite _ _ _ _ = return () -  mvecsWritePartial _ _ _ _ = return () -  mvecsFreeze sh _ = return (M_Nil sh) +  mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v + +deriving via Primitive Int instance GMixed Int +deriving via Primitive Double instance GMixed Double +deriving via Primitive () instance GMixed ()  instance (GMixed a, GMixed b) => GMixed (a, b) where    mshape (M_Tup2 a _) = mshape a    mindex (M_Tup2 a b) i = (mindex a i, mindex b i)    mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) -  memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) +  mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b) +  memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh)    mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y    mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y    mvecsWrite sh i (x, y) (MV_Tup2 a b) = do @@ -165,7 +168,6 @@ instance (GMixed a, GMixed b) => GMixed (a, b) where    mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b  instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where -  -- TODO: this is quadratic in the nesting level    mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh    mshape (M_Nest arr)      | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') @@ -184,6 +186,20 @@ instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where      | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh')      = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) +  mlift :: forall sh1 sh2. KnownShapeX sh2 +        => (forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b) +        -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) +  mlift f (M_Nest arr) +    | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) +    = M_Nest (mlift f' arr) +    where +      f' :: forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b +      f' _ +        | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3) +        , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3) +        , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @sh3)) +        = f (Proxy @(sh' ++ sh3)) +    memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh'))))    mvecsNumElts arr = @@ -258,6 +274,14 @@ instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where      = coerce @(Mixed sh' (Mixed (Replicate n 'Nothing) a)) @(Mixed sh' (Ranked n a)) $          mindexPartial arr i +  mlift :: forall sh1 sh2. KnownShapeX sh2 +        => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) +        -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) +  mlift f (M_Ranked arr) +    | Dict <- lemKnownReplicate (Proxy @n) +    = coerce @(Mixed sh2 (Mixed (Replicate n 'Nothing) a)) @(Mixed sh2 (Ranked n a)) $ +        mlift f arr +    memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a)    memptyArray i      | Dict <- lemKnownReplicate (Proxy @n) @@ -375,3 +399,10 @@ rgenerate sh f    | Dict <- lemKnownReplicate (Proxy @n)    , Refl <- lemRankReplicate (Proxy @n)    = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) + +rlift :: forall n1 n2 a. (KnownNat n2, GMixed a) +      => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) +      -> Ranked n1 a -> Ranked n2 a +rlift f (Ranked arr) +  | Dict <- lemKnownReplicate (Proxy @n2) +  = Ranked (mlift f arr) | 
