diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2024-04-03 12:37:35 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-04-03 12:37:35 +0200 |
commit | 92902c4f66db111b439f3b7eba9de50ad7c73f7b (patch) | |
tree | 27f12853825b7dd13d4bc8040dd2be6781deb635 /src/Data/Array/Mixed.hs | |
parent | 264c8e601f49cebed9280f0da2e73f380bb5be52 (diff) |
Reorganise, documentation
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r-- | src/Data/Array/Mixed.hs | 314 |
1 files changed, 314 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs new file mode 100644 index 0000000..e1e2d5a --- /dev/null +++ b/src/Data/Array/Mixed.hs @@ -0,0 +1,314 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed where + +import qualified Data.Array.RankedU as U +import Data.Kind +import Data.Proxy +import Data.Type.Equality +import qualified Data.Vector.Unboxed as VU +import qualified GHC.TypeLits as GHC +import Unsafe.Coerce (unsafeCoerce) + +import Data.Nat + + +-- | Type-level list append. +type family l1 ++ l2 where + '[] ++ l2 = l2 + (x : xs) ++ l2 = x : xs ++ l2 + +lemAppNil :: l ++ '[] :~: l +lemAppNil = unsafeCoerce Refl + +lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) +lemAppAssoc _ _ _ = unsafeCoerce Refl + + +type IxX :: [Maybe Nat] -> Type +data IxX sh where + IZX :: IxX '[] + (::@) :: Int -> IxX sh -> IxX (Just n : sh) + (::?) :: Int -> IxX sh -> IxX (Nothing : sh) +deriving instance Show (IxX sh) + +-- | The part of a shape that is statically known. +type StaticShapeX :: [Maybe Nat] -> Type +data StaticShapeX sh where + SZX :: StaticShapeX '[] + (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) + (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) +deriving instance Show (StaticShapeX sh) + +-- | Evidence for the static part of a shape. +type KnownShapeX :: [Maybe Nat] -> Constraint +class KnownShapeX sh where + knownShapeX :: StaticShapeX sh +instance KnownShapeX '[] where + knownShapeX = SZX +instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where + knownShapeX = knownNat :$@ knownShapeX +instance KnownShapeX sh => KnownShapeX (Nothing : sh) where + knownShapeX = () :$? knownShapeX + +type family Rank sh where + Rank '[] = Z + Rank (_ : sh) = S (Rank sh) + +type XArray :: [Maybe Nat] -> Type -> Type +data XArray sh a = XArray (U.Array (GNat (Rank sh)) a) + +zeroIdx :: StaticShapeX sh -> IxX sh +zeroIdx SZX = IZX +zeroIdx (_ :$@ ssh) = 0 ::@ zeroIdx ssh +zeroIdx (_ :$? ssh) = 0 ::? zeroIdx ssh + +zeroIdx' :: IxX sh -> IxX sh +zeroIdx' IZX = IZX +zeroIdx' (_ ::@ sh) = 0 ::@ zeroIdx' sh +zeroIdx' (_ ::? sh) = 0 ::? zeroIdx' sh + +ixAppend :: IxX sh -> IxX sh' -> IxX (sh ++ sh') +ixAppend IZX idx' = idx' +ixAppend (i ::@ idx) idx' = i ::@ ixAppend idx idx' +ixAppend (i ::? idx) idx' = i ::? ixAppend idx idx' + +ixDrop :: IxX (sh ++ sh') -> IxX sh -> IxX sh' +ixDrop sh IZX = sh +ixDrop (_ ::@ sh) (_ ::@ idx) = ixDrop sh idx +ixDrop (_ ::? sh) (_ ::? idx) = ixDrop sh idx + +ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh') +ssxAppend SZX idx' = idx' +ssxAppend (n :$@ idx) idx' = n :$@ ssxAppend idx idx' +ssxAppend (() :$? idx) idx' = () :$? ssxAppend idx idx' + +shapeSize :: IxX sh -> Int +shapeSize IZX = 1 +shapeSize (n ::@ sh) = n * shapeSize sh +shapeSize (n ::? sh) = n * shapeSize sh + +fromLinearIdx :: IxX sh -> Int -> IxX sh +fromLinearIdx = \sh i -> case go sh i of + (idx, 0) -> idx + _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" + where + -- returns (index in subarray, remaining index in enclosing array) + go :: IxX sh -> Int -> (IxX sh, Int) + go IZX i = (IZX, i) + go (n ::@ sh) i = + let (idx, i') = go sh i + (upi, locali) = i' `quotRem` n + in (locali ::@ idx, upi) + go (n ::? sh) i = + let (idx, i') = go sh i + (upi, locali) = i' `quotRem` n + in (locali ::? idx, upi) + +toLinearIdx :: IxX sh -> IxX sh -> Int +toLinearIdx = \sh i -> fst (go sh i) + where + -- returns (index in subarray, size of subarray) + go :: IxX sh -> IxX sh -> (Int, Int) + go IZX IZX = (0, 1) + go (n ::@ sh) (i ::@ ix) = + let (lidx, sz) = go sh ix + in (sz * i + lidx, n * sz) + go (n ::? sh) (i ::? ix) = + let (lidx, sz) = go sh ix + in (sz * i + lidx, n * sz) + +enumShape :: IxX sh -> [IxX sh] +enumShape = \sh -> go 0 sh id [] + where + go :: Int -> IxX sh -> (IxX sh -> a) -> [a] -> [a] + go _ IZX _ = id + go i (n ::@ sh) f + | i < n = go (i + 1) (n ::@ sh) f . go 0 sh (f . (i ::@)) + | otherwise = id + go i (n ::? sh) f + | i < n = go (i + 1) (n ::? sh) f . go 0 sh (f . (i ::?)) + | otherwise = id + +shapeLshape :: IxX sh -> U.ShapeL +shapeLshape IZX = [] +shapeLshape (n ::@ sh) = n : shapeLshape sh +shapeLshape (n ::? sh) = n : shapeLshape sh + +ssxLength :: StaticShapeX sh -> Int +ssxLength SZX = 0 +ssxLength (_ :$@ ssh) = 1 + ssxLength ssh +ssxLength (_ :$? ssh) = 1 + ssxLength ssh + +ssxIotaFrom :: Int -> StaticShapeX sh -> [Int] +ssxIotaFrom _ SZX = [] +ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh + +lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2 + -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank sh1) GHC.+ GNat (Rank sh2) +lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this + +lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2 + -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank (sh2 ++ sh1)) +lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this + +lemKnownNatRank :: IxX sh -> Dict KnownNat (Rank sh) +lemKnownNatRank IZX = Dict +lemKnownNatRank (_ ::@ sh) | Dict <- lemKnownNatRank sh = Dict +lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict + +lemKnownNatRankSSX :: StaticShapeX sh -> Dict KnownNat (Rank sh) +lemKnownNatRankSSX SZX = Dict +lemKnownNatRankSSX (_ :$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict +lemKnownNatRankSSX (_ :$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict + +lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh +lemKnownShapeX SZX = Dict +lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict +lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict + +lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2) +lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh' +lemAppKnownShapeX (n :$@ ssh) ssh' + | Dict <- lemAppKnownShapeX ssh ssh' + , Dict <- snatKnown n + = Dict +lemAppKnownShapeX (() :$? ssh) ssh' + | Dict <- lemAppKnownShapeX ssh ssh' + = Dict + +shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh +shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr) + where + go :: StaticShapeX sh' -> [Int] -> IxX sh' + go SZX [] = IZX + go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l + go (() :$? ssh) (n : l) = n ::? go ssh l + go _ _ = error "Invalid shapeL" + +fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a +fromVector sh v + | Dict <- lemKnownNatRank sh + , Dict <- gknownNat (Proxy @(Rank sh)) + = XArray (U.fromVector (shapeLshape sh) v) + +toVector :: U.Unbox a => XArray sh a -> VU.Vector a +toVector (XArray arr) = U.toVector arr + +scalar :: U.Unbox a => a -> XArray '[] a +scalar = XArray . U.scalar + +unScalar :: U.Unbox a => XArray '[] a -> a +unScalar (XArray a) = U.unScalar a + +generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a +generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh) + +-- generateM :: (Monad m, U.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a) +-- generateM sh f | Dict <- lemKnownNatRank sh = +-- XArray . U.fromVector (shapeLshape sh) +-- <$> VU.generateM (shapeSize sh) (f . fromLinearIdx sh) + +indexPartial :: U.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a +indexPartial (XArray arr) IZX = XArray arr +indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (U.index arr i)) idx +indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (U.index arr i)) idx + +index :: forall sh a. U.Unbox a => XArray sh a -> IxX sh -> a +index xarr i + | Refl <- lemAppNil @sh + = let XArray arr' = indexPartial xarr i :: XArray '[] a + in U.unScalar arr' + +append :: forall sh a. (KnownShapeX sh, U.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a +append (XArray a) (XArray b) + | Dict <- lemKnownNatRankSSX (knownShapeX @sh) + , Dict <- gknownNat (Proxy @(Rank sh)) + = XArray (U.append a b) + +rerank :: forall sh sh1 sh2 a b. + (U.Unbox a, U.Unbox b) + => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b +rerank ssh ssh1 ssh2 f (XArray arr) + | Dict <- lemKnownNatRankSSX ssh + , Dict <- gknownNat (Proxy @(Rank sh)) + , Dict <- lemKnownNatRankSSX ssh2 + , Dict <- gknownNat (Proxy @(Rank sh2)) + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the + , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough + = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) + (\a -> unXArray (f (XArray a))) + arr) + where + unXArray (XArray a) = a + +rerank2 :: forall sh sh1 sh2 a b c. + (U.Unbox a, U.Unbox b, U.Unbox c) + => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 + -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c +rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) + | Dict <- lemKnownNatRankSSX ssh + , Dict <- gknownNat (Proxy @(Rank sh)) + , Dict <- lemKnownNatRankSSX ssh2 + , Dict <- gknownNat (Proxy @(Rank sh2)) + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the + , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough + = XArray (U.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) + (\a b -> unXArray (f (XArray a) (XArray b))) + arr1 arr2) + where + unXArray (XArray a) = a + +-- | The list argument gives indices into the original dimension list. +transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a +transpose perm (XArray arr) + | Dict <- lemKnownNatRankSSX (knownShapeX @sh) + , Dict <- gknownNat (Proxy @(Rank sh)) + = XArray (U.transpose perm arr) + +transpose2 :: forall sh1 sh2 a. + StaticShapeX sh1 -> StaticShapeX sh2 + -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a +transpose2 ssh1 ssh2 (XArray arr) + | Refl <- lemRankApp ssh1 ssh2 + , Refl <- lemRankApp ssh2 ssh1 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) + , Dict <- gknownNat (Proxy @(Rank (sh1 ++ sh2))) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) + , Dict <- gknownNat (Proxy @(Rank (sh2 ++ sh1))) + , Refl <- lemRankAppComm ssh1 ssh2 + , let n1 = ssxLength ssh1 + = XArray (U.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) + +sumFull :: (U.Unbox a, Num a) => XArray sh a -> a +sumFull (XArray arr) = U.sumA arr + +sumInner :: forall sh sh' a. (U.Unbox a, Num a) + => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a +sumInner ssh ssh' + | Refl <- lemAppNil @sh + = rerank ssh ssh' SZX (scalar . sumFull) + +sumOuter :: forall sh sh' a. (U.Unbox a, Num a) + => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a +sumOuter ssh ssh' + | Refl <- lemAppNil @sh + = sumInner ssh' ssh . transpose2 ssh ssh' |