diff options
Diffstat (limited to 'src/Array.hs')
-rw-r--r-- | src/Array.hs | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/src/Array.hs b/src/Array.hs index 693df05..cbf04fc 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -8,6 +8,8 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} module Array where import qualified Data.Array.RankedU as U @@ -15,6 +17,7 @@ import Data.Kind import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Unboxed as VU +import qualified GHC.TypeLits as GHC import Unsafe.Coerce (unsafeCoerce) import Nats @@ -140,6 +143,24 @@ shapeLshape IZX = [] shapeLshape (n ::@ sh) = n : shapeLshape sh shapeLshape (n ::? sh) = n : shapeLshape sh +ssxLength :: StaticShapeX sh -> Int +ssxLength SZX = 0 +ssxLength (_ :$@ ssh) = 1 + ssxLength ssh +ssxLength (_ :$? ssh) = 1 + ssxLength ssh + +ssxIotaFrom :: Int -> StaticShapeX sh -> [Int] +ssxIotaFrom _ SZX = [] +ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh + +lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2 + -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank sh1) GHC.+ GNat (Rank sh2) +lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this + +lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2 + -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank (sh2 ++ sh1)) +lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this + lemKnownNatRank :: IxX sh -> Dict KnownNat (Rank sh) lemKnownNatRank IZX = Dict lemKnownNatRank (_ ::@ sh) | Dict <- lemKnownNatRank sh = Dict @@ -183,6 +204,12 @@ fromVector sh v toVector :: U.Unbox a => XArray sh a -> VU.Vector a toVector (XArray arr) = U.toVector arr +scalar :: U.Unbox a => a -> XArray '[] a +scalar = XArray . U.scalar + +unScalar :: U.Unbox a => XArray '[] a -> a +unScalar (XArray a) = U.unScalar a + generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh) @@ -207,3 +234,79 @@ append (XArray a) (XArray b) | Dict <- lemKnownNatRankSSX (knownShapeX @sh) , Dict <- gknownNat (Proxy @(Rank sh)) = XArray (U.append a b) + +rerank :: forall sh sh1 sh2 a b. + (U.Unbox a, U.Unbox b) + => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b +rerank ssh ssh1 ssh2 f (XArray arr) + | Dict <- lemKnownNatRankSSX ssh + , Dict <- gknownNat (Proxy @(Rank sh)) + , Dict <- lemKnownNatRankSSX ssh2 + , Dict <- gknownNat (Proxy @(Rank sh2)) + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the + , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough + = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) + (\a -> unXArray (f (XArray a))) + arr) + where + unXArray (XArray a) = a + +rerank2 :: forall sh sh1 sh2 a b c. + (U.Unbox a, U.Unbox b, U.Unbox c) + => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 + -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c +rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) + | Dict <- lemKnownNatRankSSX ssh + , Dict <- gknownNat (Proxy @(Rank sh)) + , Dict <- lemKnownNatRankSSX ssh2 + , Dict <- gknownNat (Proxy @(Rank sh2)) + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the + , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough + = XArray (U.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) + (\a b -> unXArray (f (XArray a) (XArray b))) + arr1 arr2) + where + unXArray (XArray a) = a + +-- | The list argument gives indices into the original dimension list. +transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a +transpose perm (XArray arr) + | Dict <- lemKnownNatRankSSX (knownShapeX @sh) + , Dict <- gknownNat (Proxy @(Rank sh)) + = XArray (U.transpose perm arr) + +transpose2 :: forall sh1 sh2 a. + StaticShapeX sh1 -> StaticShapeX sh2 + -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a +transpose2 ssh1 ssh2 (XArray arr) + | Refl <- lemRankApp ssh1 ssh2 + , Refl <- lemRankApp ssh2 ssh1 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) + , Dict <- gknownNat (Proxy @(Rank (sh1 ++ sh2))) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) + , Dict <- gknownNat (Proxy @(Rank (sh2 ++ sh1))) + , Refl <- lemRankAppComm ssh1 ssh2 + , let n1 = ssxLength ssh1 + = XArray (U.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) + +sumFull :: (U.Unbox a, Num a) => XArray sh a -> a +sumFull (XArray arr) = U.sumA arr + +sumInner :: forall sh sh' a. (U.Unbox a, Num a) + => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a +sumInner ssh ssh' + | Refl <- lemAppNil @sh + = rerank ssh ssh' SZX (scalar . sumFull) + +sumOuter :: forall sh sh' a. (U.Unbox a, Num a) + => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a +sumOuter ssh ssh' + | Refl <- lemAppNil @sh + = sumInner ssh' ssh . transpose2 ssh ssh' |