{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} module Array where import qualified Data.Array.RankedU as U import Data.Kind import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Unboxed as VU import qualified GHC.TypeLits as GHC import Unsafe.Coerce (unsafeCoerce) import Nats type family l1 ++ l2 where '[] ++ l2 = l2 (x : xs) ++ l2 = x : xs ++ l2 lemAppNil :: l ++ '[] :~: l lemAppNil = unsafeCoerce Refl lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl type IxX :: [Maybe Nat] -> Type data IxX sh where IZX :: IxX '[] (::@) :: Int -> IxX sh -> IxX (Just n : sh) (::?) :: Int -> IxX sh -> IxX (Nothing : sh) deriving instance Show (IxX sh) type StaticShapeX :: [Maybe Nat] -> Type data StaticShapeX sh where SZX :: StaticShapeX '[] (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) deriving instance Show (StaticShapeX sh) type KnownShapeX :: [Maybe Nat] -> Constraint class KnownShapeX sh where knownShapeX :: StaticShapeX sh instance KnownShapeX '[] where knownShapeX = SZX instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where knownShapeX = knownNat :$@ knownShapeX instance KnownShapeX sh => KnownShapeX (Nothing : sh) where knownShapeX = () :$? knownShapeX type family Rank sh where Rank '[] = Z Rank (_ : sh) = S (Rank sh) type XArray :: [Maybe Nat] -> Type -> Type data XArray sh a = XArray (U.Array (GNat (Rank sh)) a) zeroIdx :: StaticShapeX sh -> IxX sh zeroIdx SZX = IZX zeroIdx (_ :$@ ssh) = 0 ::@ zeroIdx ssh zeroIdx (_ :$? ssh) = 0 ::? zeroIdx ssh zeroIdx' :: IxX sh -> IxX sh zeroIdx' IZX = IZX zeroIdx' (_ ::@ sh) = 0 ::@ zeroIdx' sh zeroIdx' (_ ::? sh) = 0 ::? zeroIdx' sh ixAppend :: IxX sh -> IxX sh' -> IxX (sh ++ sh') ixAppend IZX idx' = idx' ixAppend (i ::@ idx) idx' = i ::@ ixAppend idx idx' ixAppend (i ::? idx) idx' = i ::? ixAppend idx idx' ixDrop :: IxX (sh ++ sh') -> IxX sh -> IxX sh' ixDrop sh IZX = sh ixDrop (_ ::@ sh) (_ ::@ idx) = ixDrop sh idx ixDrop (_ ::? sh) (_ ::? idx) = ixDrop sh idx ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh') ssxAppend SZX idx' = idx' ssxAppend (n :$@ idx) idx' = n :$@ ssxAppend idx idx' ssxAppend (() :$? idx) idx' = () :$? ssxAppend idx idx' shapeSize :: IxX sh -> Int shapeSize IZX = 1 shapeSize (n ::@ sh) = n * shapeSize sh shapeSize (n ::? sh) = n * shapeSize sh fromLinearIdx :: IxX sh -> Int -> IxX sh fromLinearIdx = \sh i -> case go sh i of (idx, 0) -> idx _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ " in array of shape " ++ show sh ++ ")" where -- returns (index in subarray, remaining index in enclosing array) go :: IxX sh -> Int -> (IxX sh, Int) go IZX i = (IZX, i) go (n ::@ sh) i = let (idx, i') = go sh i (upi, locali) = i' `quotRem` n in (locali ::@ idx, upi) go (n ::? sh) i = let (idx, i') = go sh i (upi, locali) = i' `quotRem` n in (locali ::? idx, upi) toLinearIdx :: IxX sh -> IxX sh -> Int toLinearIdx = \sh i -> fst (go sh i) where -- returns (index in subarray, size of subarray) go :: IxX sh -> IxX sh -> (Int, Int) go IZX IZX = (0, 1) go (n ::@ sh) (i ::@ ix) = let (lidx, sz) = go sh ix in (sz * i + lidx, n * sz) go (n ::? sh) (i ::? ix) = let (lidx, sz) = go sh ix in (sz * i + lidx, n * sz) enumShape :: IxX sh -> [IxX sh] enumShape = \sh -> go 0 sh id [] where go :: Int -> IxX sh -> (IxX sh -> a) -> [a] -> [a] go _ IZX _ = id go i (n ::@ sh) f | i < n = go (i + 1) (n ::@ sh) f . go 0 sh (f . (i ::@)) | otherwise = id go i (n ::? sh) f | i < n = go (i + 1) (n ::? sh) f . go 0 sh (f . (i ::?)) | otherwise = id shapeLshape :: IxX sh -> U.ShapeL shapeLshape IZX = [] shapeLshape (n ::@ sh) = n : shapeLshape sh shapeLshape (n ::? sh) = n : shapeLshape sh ssxLength :: StaticShapeX sh -> Int ssxLength SZX = 0 ssxLength (_ :$@ ssh) = 1 + ssxLength ssh ssxLength (_ :$? ssh) = 1 + ssxLength ssh ssxIotaFrom :: Int -> StaticShapeX sh -> [Int] ssxIotaFrom _ SZX = [] ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2 -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank sh1) GHC.+ GNat (Rank sh2) lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2 -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank (sh2 ++ sh1)) lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this lemKnownNatRank :: IxX sh -> Dict KnownNat (Rank sh) lemKnownNatRank IZX = Dict lemKnownNatRank (_ ::@ sh) | Dict <- lemKnownNatRank sh = Dict lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict lemKnownNatRankSSX :: StaticShapeX sh -> Dict KnownNat (Rank sh) lemKnownNatRankSSX SZX = Dict lemKnownNatRankSSX (_ :$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict lemKnownNatRankSSX (_ :$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh lemKnownShapeX SZX = Dict lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2) lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh' lemAppKnownShapeX (n :$@ ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' , Dict <- snatKnown n = Dict lemAppKnownShapeX (() :$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr) where go :: StaticShapeX sh' -> [Int] -> IxX sh' go SZX [] = IZX go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l go (() :$? ssh) (n : l) = n ::? go ssh l go _ _ = error "Invalid shapeL" fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownNatRank sh , Dict <- gknownNat (Proxy @(Rank sh)) = XArray (U.fromVector (shapeLshape sh) v) toVector :: U.Unbox a => XArray sh a -> VU.Vector a toVector (XArray arr) = U.toVector arr scalar :: U.Unbox a => a -> XArray '[] a scalar = XArray . U.scalar unScalar :: U.Unbox a => XArray '[] a -> a unScalar (XArray a) = U.unScalar a generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh) -- generateM :: (Monad m, U.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a) -- generateM sh f | Dict <- lemKnownNatRank sh = -- XArray . U.fromVector (shapeLshape sh) -- <$> VU.generateM (shapeSize sh) (f . fromLinearIdx sh) indexPartial :: U.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a indexPartial (XArray arr) IZX = XArray arr indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (U.index arr i)) idx indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (U.index arr i)) idx index :: forall sh a. U.Unbox a => XArray sh a -> IxX sh -> a index xarr i | Refl <- lemAppNil @sh = let XArray arr' = indexPartial xarr i :: XArray '[] a in U.unScalar arr' append :: forall sh a. (KnownShapeX sh, U.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a append (XArray a) (XArray b) | Dict <- lemKnownNatRankSSX (knownShapeX @sh) , Dict <- gknownNat (Proxy @(Rank sh)) = XArray (U.append a b) rerank :: forall sh sh1 sh2 a b. (U.Unbox a, U.Unbox b) => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f (XArray arr) | Dict <- lemKnownNatRankSSX ssh , Dict <- gknownNat (Proxy @(Rank sh)) , Dict <- lemKnownNatRankSSX ssh2 , Dict <- gknownNat (Proxy @(Rank sh2)) , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) (\a -> unXArray (f (XArray a))) arr) where unXArray (XArray a) = a rerank2 :: forall sh sh1 sh2 a b c. (U.Unbox a, U.Unbox b, U.Unbox c) => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) | Dict <- lemKnownNatRankSSX ssh , Dict <- gknownNat (Proxy @(Rank sh)) , Dict <- lemKnownNatRankSSX ssh2 , Dict <- gknownNat (Proxy @(Rank sh2)) , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough = XArray (U.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) (\a b -> unXArray (f (XArray a) (XArray b))) arr1 arr2) where unXArray (XArray a) = a -- | The list argument gives indices into the original dimension list. transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a transpose perm (XArray arr) | Dict <- lemKnownNatRankSSX (knownShapeX @sh) , Dict <- gknownNat (Proxy @(Rank sh)) = XArray (U.transpose perm arr) transpose2 :: forall sh1 sh2 a. StaticShapeX sh1 -> StaticShapeX sh2 -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a transpose2 ssh1 ssh2 (XArray arr) | Refl <- lemRankApp ssh1 ssh2 , Refl <- lemRankApp ssh2 ssh1 , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) , Dict <- gknownNat (Proxy @(Rank (sh1 ++ sh2))) , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) , Dict <- gknownNat (Proxy @(Rank (sh2 ++ sh1))) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 = XArray (U.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) sumFull :: (U.Unbox a, Num a) => XArray sh a -> a sumFull (XArray arr) = U.sumA arr sumInner :: forall sh sh' a. (U.Unbox a, Num a) => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' | Refl <- lemAppNil @sh = rerank ssh ssh' SZX (scalar . sumFull) sumOuter :: forall sh sh' a. (U.Unbox a, Num a) => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' | Refl <- lemAppNil @sh = sumInner ssh' ssh . transpose2 ssh ssh'