summaryrefslogtreecommitdiff
path: root/src/Data/Array/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Mixed.hs')
-rw-r--r--src/Data/Array/Mixed.hs314
1 files changed, 314 insertions, 0 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs
new file mode 100644
index 0000000..e1e2d5a
--- /dev/null
+++ b/src/Data/Array/Mixed.hs
@@ -0,0 +1,314 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
+module Data.Array.Mixed where
+
+import qualified Data.Array.RankedU as U
+import Data.Kind
+import Data.Proxy
+import Data.Type.Equality
+import qualified Data.Vector.Unboxed as VU
+import qualified GHC.TypeLits as GHC
+import Unsafe.Coerce (unsafeCoerce)
+
+import Data.Nat
+
+
+-- | Type-level list append.
+type family l1 ++ l2 where
+ '[] ++ l2 = l2
+ (x : xs) ++ l2 = x : xs ++ l2
+
+lemAppNil :: l ++ '[] :~: l
+lemAppNil = unsafeCoerce Refl
+
+lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c)
+lemAppAssoc _ _ _ = unsafeCoerce Refl
+
+
+type IxX :: [Maybe Nat] -> Type
+data IxX sh where
+ IZX :: IxX '[]
+ (::@) :: Int -> IxX sh -> IxX (Just n : sh)
+ (::?) :: Int -> IxX sh -> IxX (Nothing : sh)
+deriving instance Show (IxX sh)
+
+-- | The part of a shape that is statically known.
+type StaticShapeX :: [Maybe Nat] -> Type
+data StaticShapeX sh where
+ SZX :: StaticShapeX '[]
+ (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh)
+ (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh)
+deriving instance Show (StaticShapeX sh)
+
+-- | Evidence for the static part of a shape.
+type KnownShapeX :: [Maybe Nat] -> Constraint
+class KnownShapeX sh where
+ knownShapeX :: StaticShapeX sh
+instance KnownShapeX '[] where
+ knownShapeX = SZX
+instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where
+ knownShapeX = knownNat :$@ knownShapeX
+instance KnownShapeX sh => KnownShapeX (Nothing : sh) where
+ knownShapeX = () :$? knownShapeX
+
+type family Rank sh where
+ Rank '[] = Z
+ Rank (_ : sh) = S (Rank sh)
+
+type XArray :: [Maybe Nat] -> Type -> Type
+data XArray sh a = XArray (U.Array (GNat (Rank sh)) a)
+
+zeroIdx :: StaticShapeX sh -> IxX sh
+zeroIdx SZX = IZX
+zeroIdx (_ :$@ ssh) = 0 ::@ zeroIdx ssh
+zeroIdx (_ :$? ssh) = 0 ::? zeroIdx ssh
+
+zeroIdx' :: IxX sh -> IxX sh
+zeroIdx' IZX = IZX
+zeroIdx' (_ ::@ sh) = 0 ::@ zeroIdx' sh
+zeroIdx' (_ ::? sh) = 0 ::? zeroIdx' sh
+
+ixAppend :: IxX sh -> IxX sh' -> IxX (sh ++ sh')
+ixAppend IZX idx' = idx'
+ixAppend (i ::@ idx) idx' = i ::@ ixAppend idx idx'
+ixAppend (i ::? idx) idx' = i ::? ixAppend idx idx'
+
+ixDrop :: IxX (sh ++ sh') -> IxX sh -> IxX sh'
+ixDrop sh IZX = sh
+ixDrop (_ ::@ sh) (_ ::@ idx) = ixDrop sh idx
+ixDrop (_ ::? sh) (_ ::? idx) = ixDrop sh idx
+
+ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh')
+ssxAppend SZX idx' = idx'
+ssxAppend (n :$@ idx) idx' = n :$@ ssxAppend idx idx'
+ssxAppend (() :$? idx) idx' = () :$? ssxAppend idx idx'
+
+shapeSize :: IxX sh -> Int
+shapeSize IZX = 1
+shapeSize (n ::@ sh) = n * shapeSize sh
+shapeSize (n ::? sh) = n * shapeSize sh
+
+fromLinearIdx :: IxX sh -> Int -> IxX sh
+fromLinearIdx = \sh i -> case go sh i of
+ (idx, 0) -> idx
+ _ -> error $ "fromLinearIdx: out of range (" ++ show i ++
+ " in array of shape " ++ show sh ++ ")"
+ where
+ -- returns (index in subarray, remaining index in enclosing array)
+ go :: IxX sh -> Int -> (IxX sh, Int)
+ go IZX i = (IZX, i)
+ go (n ::@ sh) i =
+ let (idx, i') = go sh i
+ (upi, locali) = i' `quotRem` n
+ in (locali ::@ idx, upi)
+ go (n ::? sh) i =
+ let (idx, i') = go sh i
+ (upi, locali) = i' `quotRem` n
+ in (locali ::? idx, upi)
+
+toLinearIdx :: IxX sh -> IxX sh -> Int
+toLinearIdx = \sh i -> fst (go sh i)
+ where
+ -- returns (index in subarray, size of subarray)
+ go :: IxX sh -> IxX sh -> (Int, Int)
+ go IZX IZX = (0, 1)
+ go (n ::@ sh) (i ::@ ix) =
+ let (lidx, sz) = go sh ix
+ in (sz * i + lidx, n * sz)
+ go (n ::? sh) (i ::? ix) =
+ let (lidx, sz) = go sh ix
+ in (sz * i + lidx, n * sz)
+
+enumShape :: IxX sh -> [IxX sh]
+enumShape = \sh -> go 0 sh id []
+ where
+ go :: Int -> IxX sh -> (IxX sh -> a) -> [a] -> [a]
+ go _ IZX _ = id
+ go i (n ::@ sh) f
+ | i < n = go (i + 1) (n ::@ sh) f . go 0 sh (f . (i ::@))
+ | otherwise = id
+ go i (n ::? sh) f
+ | i < n = go (i + 1) (n ::? sh) f . go 0 sh (f . (i ::?))
+ | otherwise = id
+
+shapeLshape :: IxX sh -> U.ShapeL
+shapeLshape IZX = []
+shapeLshape (n ::@ sh) = n : shapeLshape sh
+shapeLshape (n ::? sh) = n : shapeLshape sh
+
+ssxLength :: StaticShapeX sh -> Int
+ssxLength SZX = 0
+ssxLength (_ :$@ ssh) = 1 + ssxLength ssh
+ssxLength (_ :$? ssh) = 1 + ssxLength ssh
+
+ssxIotaFrom :: Int -> StaticShapeX sh -> [Int]
+ssxIotaFrom _ SZX = []
+ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh
+ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh
+
+lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2
+ -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank sh1) GHC.+ GNat (Rank sh2)
+lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this
+
+lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2
+ -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank (sh2 ++ sh1))
+lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this
+
+lemKnownNatRank :: IxX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRank IZX = Dict
+lemKnownNatRank (_ ::@ sh) | Dict <- lemKnownNatRank sh = Dict
+lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict
+
+lemKnownNatRankSSX :: StaticShapeX sh -> Dict KnownNat (Rank sh)
+lemKnownNatRankSSX SZX = Dict
+lemKnownNatRankSSX (_ :$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
+lemKnownNatRankSSX (_ :$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict
+
+lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh
+lemKnownShapeX SZX = Dict
+lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict
+lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict
+
+lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2)
+lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh'
+lemAppKnownShapeX (n :$@ ssh) ssh'
+ | Dict <- lemAppKnownShapeX ssh ssh'
+ , Dict <- snatKnown n
+ = Dict
+lemAppKnownShapeX (() :$? ssh) ssh'
+ | Dict <- lemAppKnownShapeX ssh ssh'
+ = Dict
+
+shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh
+shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr)
+ where
+ go :: StaticShapeX sh' -> [Int] -> IxX sh'
+ go SZX [] = IZX
+ go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l
+ go (() :$? ssh) (n : l) = n ::? go ssh l
+ go _ _ = error "Invalid shapeL"
+
+fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a
+fromVector sh v
+ | Dict <- lemKnownNatRank sh
+ , Dict <- gknownNat (Proxy @(Rank sh))
+ = XArray (U.fromVector (shapeLshape sh) v)
+
+toVector :: U.Unbox a => XArray sh a -> VU.Vector a
+toVector (XArray arr) = U.toVector arr
+
+scalar :: U.Unbox a => a -> XArray '[] a
+scalar = XArray . U.scalar
+
+unScalar :: U.Unbox a => XArray '[] a -> a
+unScalar (XArray a) = U.unScalar a
+
+generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a
+generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh)
+
+-- generateM :: (Monad m, U.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a)
+-- generateM sh f | Dict <- lemKnownNatRank sh =
+-- XArray . U.fromVector (shapeLshape sh)
+-- <$> VU.generateM (shapeSize sh) (f . fromLinearIdx sh)
+
+indexPartial :: U.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a
+indexPartial (XArray arr) IZX = XArray arr
+indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (U.index arr i)) idx
+indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (U.index arr i)) idx
+
+index :: forall sh a. U.Unbox a => XArray sh a -> IxX sh -> a
+index xarr i
+ | Refl <- lemAppNil @sh
+ = let XArray arr' = indexPartial xarr i :: XArray '[] a
+ in U.unScalar arr'
+
+append :: forall sh a. (KnownShapeX sh, U.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a
+append (XArray a) (XArray b)
+ | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
+ , Dict <- gknownNat (Proxy @(Rank sh))
+ = XArray (U.append a b)
+
+rerank :: forall sh sh1 sh2 a b.
+ (U.Unbox a, U.Unbox b)
+ => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
+ -> (XArray sh1 a -> XArray sh2 b)
+ -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b
+rerank ssh ssh1 ssh2 f (XArray arr)
+ | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- gknownNat (Proxy @(Rank sh))
+ , Dict <- lemKnownNatRankSSX ssh2
+ , Dict <- gknownNat (Proxy @(Rank sh2))
+ , Refl <- lemRankApp ssh ssh1
+ , Refl <- lemRankApp ssh ssh2
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
+ , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
+ = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
+ (\a -> unXArray (f (XArray a)))
+ arr)
+ where
+ unXArray (XArray a) = a
+
+rerank2 :: forall sh sh1 sh2 a b c.
+ (U.Unbox a, U.Unbox b, U.Unbox c)
+ => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2
+ -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c)
+ -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c
+rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2)
+ | Dict <- lemKnownNatRankSSX ssh
+ , Dict <- gknownNat (Proxy @(Rank sh))
+ , Dict <- lemKnownNatRankSSX ssh2
+ , Dict <- gknownNat (Proxy @(Rank sh2))
+ , Refl <- lemRankApp ssh ssh1
+ , Refl <- lemRankApp ssh ssh2
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the
+ , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough
+ = XArray (U.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2))
+ (\a b -> unXArray (f (XArray a) (XArray b)))
+ arr1 arr2)
+ where
+ unXArray (XArray a) = a
+
+-- | The list argument gives indices into the original dimension list.
+transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a
+transpose perm (XArray arr)
+ | Dict <- lemKnownNatRankSSX (knownShapeX @sh)
+ , Dict <- gknownNat (Proxy @(Rank sh))
+ = XArray (U.transpose perm arr)
+
+transpose2 :: forall sh1 sh2 a.
+ StaticShapeX sh1 -> StaticShapeX sh2
+ -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a
+transpose2 ssh1 ssh2 (XArray arr)
+ | Refl <- lemRankApp ssh1 ssh2
+ , Refl <- lemRankApp ssh2 ssh1
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2)
+ , Dict <- gknownNat (Proxy @(Rank (sh1 ++ sh2)))
+ , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1)
+ , Dict <- gknownNat (Proxy @(Rank (sh2 ++ sh1)))
+ , Refl <- lemRankAppComm ssh1 ssh2
+ , let n1 = ssxLength ssh1
+ = XArray (U.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr)
+
+sumFull :: (U.Unbox a, Num a) => XArray sh a -> a
+sumFull (XArray arr) = U.sumA arr
+
+sumInner :: forall sh sh' a. (U.Unbox a, Num a)
+ => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a
+sumInner ssh ssh'
+ | Refl <- lemAppNil @sh
+ = rerank ssh ssh' SZX (scalar . sumFull)
+
+sumOuter :: forall sh sh' a. (U.Unbox a, Num a)
+ => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a
+sumOuter ssh ssh'
+ | Refl <- lemAppNil @sh
+ = sumInner ssh' ssh . transpose2 ssh ssh'