aboutsummaryrefslogtreecommitdiff
path: root/src/Array.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Array.hs')
-rw-r--r--src/Array.hs103
1 files changed, 103 insertions, 0 deletions
diff --git a/src/Array.hs b/src/Array.hs
index 693df05..cbf04fc 100644
--- a/src/Array.hs
+++ b/src/Array.hs
@@ -8,6 +8,8 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Array where
import qualified Data.Array.RankedU as U
@@ -15,6 +17,7 @@ import Data.Kind
import Data.Proxy
import Data.Type.Equality
import qualified Data.Vector.Unboxed as VU
+import qualified GHC.TypeLits as GHC
import Unsafe.Coerce (unsafeCoerce)
import Nats
@@ -140,6 +143,24 @@ shapeLshape IZX = []
shapeLshape (n ::@ sh) = n : shapeLshape sh
shapeLshape (n ::? sh) = n : shapeLshape sh
+ssxLength :: StaticShapeX sh -> Int
+ssxLength SZX = 0
+ssxLength (_ :$@ ssh) = 1 + ssxLength ssh
+ssxLength (_ :$? ssh) = 1 + ssxLength ssh
+
+ssxIotaFrom :: Int -> StaticShapeX sh -> [Int]
+ssxIotaFrom _ SZX = []
+ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh
+ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh
+
+lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2
+ -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank sh1) GHC.+ GNat (Rank sh2)
+lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this
+
+lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2
+ -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank (sh2 ++ sh1))
+lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
+
lemKnownNatRank :: IxX sh -> Dict KnownNat (Rank sh)
lemKnownNatRank IZX = Dict
lemKnownNatRank (_ ::@ sh) | Dict <- lemKnownNatRank sh = Dict
@@ -183,6 +204,12 @@ fromVector sh v
toVector :: U.Unbox a => XArray sh a -> VU.Vector a
toVector (XArray arr) = U.toVector arr
+scalar :: U.Unbox a => a -> XArray '[] a
+scalar = XArray . U.scalar
+
+unScalar :: U.Unbox a => XArray '[] a -> a
+unScalar (XArray a) = U.unScalar a
+
generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a
generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh)
@@ -207,3 +234,79 @@ append (XArray a) (XArray b)
| Dict <- lemKnownNatRankSSX (knownShapeX @sh)
, Dict <- gknownNat (Proxy @(Rank sh))
= XArray (U.append a b)
+
+rerank :: forall sh sh1 sh2 a b.
+ (U.Unbox a, U.Unbox b)
+ => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
+ -> (XArray sh1 a -> XArray sh2 b)
+ -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
+rerank ssh ssh1 ssh2 f (XArray arr)
+ | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- gknownNat (Proxy @(Rank sh))
+ , Dict <- lemKnownNatRankSSX ssh2
+ , Dict <- gknownNat (Proxy @(Rank sh2))
+ , Refl <- lemRankApp ssh ssh1
+ , Refl <- lemRankApp ssh ssh2
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
+ , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
+ = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
+ (\a -> unXArray (f (XArray a)))
+ arr)
+ where
+ unXArray (XArray a) = a
+
+rerank2 :: forall sh sh1 sh2 a b c.
+ (U.Unbox a, U.Unbox b, U.Unbox c)
+ => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
+ -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
+ -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
+rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
+ | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- gknownNat (Proxy @(Rank sh))
+ , Dict <- lemKnownNatRankSSX ssh2
+ , Dict <- gknownNat (Proxy @(Rank sh2))
+ , Refl <- lemRankApp ssh ssh1
+ , Refl <- lemRankApp ssh ssh2
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
+ , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
+ = XArray (U.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
+ (\a b -> unXArray (f (XArray a) (XArray b)))
+ arr1 arr2)
+ where
+ unXArray (XArray a) = a
+
+-- | The list argument gives indices into the original dimension list.
+transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
+transpose perm (XArray arr)
+ | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
+ , Dict <- gknownNat (Proxy @(Rank sh))
+ = XArray (U.transpose perm arr)
+
+transpose2 :: forall sh1 sh2 a.
+ StaticShapeX sh1 -> StaticShapeX sh2
+ -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a
+transpose2 ssh1 ssh2 (XArray arr)
+ | Refl <- lemRankApp ssh1 ssh2
+ , Refl <- lemRankApp ssh2 ssh1
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2)
+ , Dict <- gknownNat (Proxy @(Rank (sh1 ++ sh2)))
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
+ , Dict <- gknownNat (Proxy @(Rank (sh2 ++ sh1)))
+ , Refl <- lemRankAppComm ssh1 ssh2
+ , let n1 = ssxLength ssh1
+ = XArray (U.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
+
+sumFull :: (U.Unbox a, Num a) => XArray sh a -> a
+sumFull (XArray arr) = U.sumA arr
+
+sumInner :: forall sh sh' a. (U.Unbox a, Num a)
+ => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a
+sumInner ssh ssh'
+ | Refl <- lemAppNil @sh
+ = rerank ssh ssh' SZX (scalar . sumFull)
+
+sumOuter :: forall sh sh' a. (U.Unbox a, Num a)
+ => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a
+sumOuter ssh ssh'
+ | Refl <- lemAppNil @sh
+ = sumInner ssh' ssh . transpose2 ssh ssh'