{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module Data.Array.Mixed where import qualified Data.Array.RankedS as S import qualified Data.Array.Ranked as ORB import Data.Coerce import Data.Kind import Data.Proxy import Data.Type.Equality import qualified Data.Vector.Storable as VS import Foreign.Storable (Storable) import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) import Data.INat -- | The 'SNat' pattern synonym is complete, but it doesn't have a -- @COMPLETE@ pragma. This copy of it does. pattern GHC_SNat :: () => KnownNat n => SNat n pattern GHC_SNat = SNat {-# COMPLETE GHC_SNat #-} -- | Type-level list append. type family l1 ++ l2 where '[] ++ l2 = l2 (x : xs) ++ l2 = x : xs ++ l2 lemAppNil :: l ++ '[] :~: l lemAppNil = unsafeCoerce Refl lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) lemAppAssoc _ _ _ = unsafeCoerce Refl -- TODO: ListX? But if so, why is StaticShapeX not defined as a newtype -- over IxX (so that we can make IxX and StaticShapeX a newtype over ListX)? type IxX :: [Maybe Nat] -> Type -> Type data IxX sh i where ZIX :: IxX '[] i (:.@) :: forall n sh i. i -> IxX sh i -> IxX (Just n : sh) i (:.?) :: forall sh i. i -> IxX sh i -> IxX (Nothing : sh) i deriving instance Show i => Show (IxX sh i) deriving instance Eq i => Eq (IxX sh i) deriving instance Ord i => Ord (IxX sh i) infixr 3 :.@ infixr 3 :.? type IIxX sh = IxX sh Int -- | The part of a shape that is statically known. type StaticShapeX :: [Maybe Nat] -> Type data StaticShapeX sh where ZSX :: StaticShapeX '[] (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) deriving instance Show (StaticShapeX sh) infixr 3 :$@ infixr 3 :$? -- | Evidence for the static part of a shape. type KnownShapeX :: [Maybe Nat] -> Constraint class KnownShapeX sh where knownShapeX :: StaticShapeX sh instance KnownShapeX '[] where knownShapeX = ZSX instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where knownShapeX = natSing :$@ knownShapeX instance KnownShapeX sh => KnownShapeX (Nothing : sh) where knownShapeX = () :$? knownShapeX type family Rank sh where Rank '[] = Z Rank (_ : sh) = S (Rank sh) type XArray :: [Maybe Nat] -> Type -> Type newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a) deriving (Show) zeroIxX :: StaticShapeX sh -> IIxX sh zeroIxX ZSX = ZIX zeroIxX (_ :$@ ssh) = 0 :.@ zeroIxX ssh zeroIxX (_ :$? ssh) = 0 :.? zeroIxX ssh zeroIxX' :: IIxX sh -> IIxX sh zeroIxX' ZIX = ZIX zeroIxX' (_ :.@ sh) = 0 :.@ zeroIxX' sh zeroIxX' (_ :.? sh) = 0 :.? zeroIxX' sh ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh') ixAppend ZIX idx' = idx' ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx' ixAppend (i :.? idx) idx' = i :.? ixAppend idx idx' ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh' ixDrop sh ZIX = sh ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh') ssxAppend ZSX sh' = sh' ssxAppend (n :$@ sh) sh' = n :$@ ssxAppend sh sh' ssxAppend (() :$? sh) sh' = () :$? ssxAppend sh sh' shapeSize :: IIxX sh -> Int shapeSize ZIX = 1 shapeSize (n :.@ sh) = n * shapeSize sh shapeSize (n :.? sh) = n * shapeSize sh -- | This may fail if @sh@ has @Nothing@s in it. ssxToShape' :: StaticShapeX sh -> Maybe (IIxX sh) ssxToShape' ZSX = Just ZIX ssxToShape' (n :$@ sh) = (fromIntegral (fromSNat n) :.@) <$> ssxToShape' sh ssxToShape' (_ :$? _) = Nothing fromLinearIdx :: IIxX sh -> Int -> IIxX sh fromLinearIdx = \sh i -> case go sh i of (idx, 0) -> idx _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ " in array of shape " ++ show sh ++ ")" where -- returns (index in subarray, remaining index in enclosing array) go :: IIxX sh -> Int -> (IIxX sh, Int) go ZIX i = (ZIX, i) go (n :.@ sh) i = let (idx, i') = go sh i (upi, locali) = i' `quotRem` n in (locali :.@ idx, upi) go (n :.? sh) i = let (idx, i') = go sh i (upi, locali) = i' `quotRem` n in (locali :.? idx, upi) toLinearIdx :: IIxX sh -> IIxX sh -> Int toLinearIdx = \sh i -> fst (go sh i) where -- returns (index in subarray, size of subarray) go :: IIxX sh -> IIxX sh -> (Int, Int) go ZIX ZIX = (0, 1) go (n :.@ sh) (i :.@ ix) = let (lidx, sz) = go sh ix in (sz * i + lidx, n * sz) go (n :.? sh) (i :.? ix) = let (lidx, sz) = go sh ix in (sz * i + lidx, n * sz) enumShape :: IIxX sh -> [IIxX sh] enumShape = \sh -> go sh id [] where go :: IIxX sh -> (IIxX sh -> a) -> [a] -> [a] go ZIX f = (f ZIX :) go (n :.@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. n-1]] go (n :.? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]] shapeLshape :: IIxX sh -> S.ShapeL shapeLshape ZIX = [] shapeLshape (n :.@ sh) = n : shapeLshape sh shapeLshape (n :.? sh) = n : shapeLshape sh ssxLength :: StaticShapeX sh -> Int ssxLength ZSX = 0 ssxLength (_ :$@ ssh) = 1 + ssxLength ssh ssxLength (_ :$? ssh) = 1 + ssxLength ssh ssxIotaFrom :: Int -> StaticShapeX sh -> [Int] ssxIotaFrom _ ZSX = [] ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2 -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2) lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2 -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1)) lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this lemKnownINatRank :: IIxX sh -> Dict KnownINat (Rank sh) lemKnownINatRank ZIX = Dict lemKnownINatRank (_ :.@ sh) | Dict <- lemKnownINatRank sh = Dict lemKnownINatRank (_ :.? sh) | Dict <- lemKnownINatRank sh = Dict lemKnownINatRankSSX :: StaticShapeX sh -> Dict KnownINat (Rank sh) lemKnownINatRankSSX ZSX = Dict lemKnownINatRankSSX (_ :$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict lemKnownINatRankSSX (_ :$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh lemKnownShapeX ZSX = Dict lemKnownShapeX (GHC_SNat :$@ ssh) | Dict <- lemKnownShapeX ssh = Dict lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2) lemAppKnownShapeX ZSX ssh' = lemKnownShapeX ssh' lemAppKnownShapeX (GHC_SNat :$@ ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict lemAppKnownShapeX (() :$? ssh) ssh' | Dict <- lemAppKnownShapeX ssh ssh' = Dict shape :: forall sh a. KnownShapeX sh => XArray sh a -> IIxX sh shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) where go :: StaticShapeX sh' -> [Int] -> IIxX sh' go ZSX [] = ZIX go (n :$@ ssh) (_ : l) = fromIntegral (fromSNat n) :.@ go ssh l go (() :$? ssh) (n : l) = n :.? go ssh l go _ _ = error "Invalid shapeL" fromVector :: forall sh a. Storable a => IIxX sh -> VS.Vector a -> XArray sh a fromVector sh v | Dict <- lemKnownINatRank sh , Dict <- knownNatFromINat (Proxy @(Rank sh)) = XArray (S.fromVector (shapeLshape sh) v) toVector :: Storable a => XArray sh a -> VS.Vector a toVector (XArray arr) = S.toVector arr scalar :: Storable a => a -> XArray '[] a scalar = XArray . S.scalar unScalar :: Storable a => XArray '[] a -> a unScalar (XArray a) = S.unScalar a constant :: forall sh a. Storable a => IIxX sh -> a -> XArray sh a constant sh x | Dict <- lemKnownINatRank sh , Dict <- knownNatFromINat (Proxy @(Rank sh)) = XArray (S.constant (shapeLshape sh) x) generate :: Storable a => IIxX sh -> (IIxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) -- generateM :: (Monad m, Storable a) => IIxX sh -> (IIxX sh -> m a) -> m (XArray sh a) -- generateM sh f | Dict <- lemKnownINatRank sh = -- XArray . S.fromVector (shapeLshape sh) -- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a indexPartial (XArray arr) ZIX = XArray arr indexPartial (XArray arr) (i :.@ idx) = indexPartial (XArray (S.index arr i)) idx indexPartial (XArray arr) (i :.? idx) = indexPartial (XArray (S.index arr i)) idx index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a index xarr i | Refl <- lemAppNil @sh = let XArray arr' = indexPartial xarr i :: XArray '[] a in S.unScalar arr' type family AddMaybe n m where AddMaybe Nothing _ = Nothing AddMaybe (Just _) Nothing = Nothing AddMaybe (Just n) (Just m) = Just (n + m) append :: forall n m sh a. (KnownShapeX sh, Storable a) => XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a append (XArray a) (XArray b) | Dict <- lemKnownINatRankSSX (knownShapeX @sh) , Dict <- knownNatFromINat (Proxy @(Rank sh)) = XArray (S.append a b) rerank :: forall sh sh1 sh2 a b. (Storable a, Storable b) => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f (XArray arr) | Dict <- lemKnownINatRankSSX ssh , Dict <- knownNatFromINat (Proxy @(Rank sh)) , Dict <- lemKnownINatRankSSX ssh2 , Dict <- knownNatFromINat (Proxy @(Rank sh2)) , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough = XArray (S.rerank @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2)) (\a -> unXArray (f (XArray a))) arr) where unXArray (XArray a) = a rerankTop :: forall sh sh1 sh2 a b. (Storable a, Storable b) => StaticShapeX sh1 -> StaticShapeX sh2 -> StaticShapeX sh -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh rerank2 :: forall sh sh1 sh2 a b c. (Storable a, Storable b, Storable c) => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) | Dict <- lemKnownINatRankSSX ssh , Dict <- knownNatFromINat (Proxy @(Rank sh)) , Dict <- lemKnownINatRankSSX ssh2 , Dict <- knownNatFromINat (Proxy @(Rank sh2)) , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough = XArray (S.rerank2 @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2)) (\a b -> unXArray (f (XArray a) (XArray b))) arr1 arr2) where unXArray (XArray a) = a -- | The list argument gives indices into the original dimension list. transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a transpose perm (XArray arr) | Dict <- lemKnownINatRankSSX (knownShapeX @sh) , Dict <- knownNatFromINat (Proxy @(Rank sh)) = XArray (S.transpose perm arr) transpose2 :: forall sh1 sh2 a. StaticShapeX sh1 -> StaticShapeX sh2 -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a transpose2 ssh1 ssh2 (XArray arr) | Refl <- lemRankApp ssh1 ssh2 , Refl <- lemRankApp ssh2 ssh1 , Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh2) , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh2))) , Dict <- lemKnownINatRankSSX (ssxAppend ssh2 ssh1) , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh1))) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) sumFull :: (Storable a, Num a) => XArray sh a -> a sumFull (XArray arr) = S.sumA arr sumInner :: forall sh sh' a. (Storable a, Num a) => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a sumInner ssh ssh' | Refl <- lemAppNil @sh = rerank ssh ssh' ZSX (scalar . sumFull) sumOuter :: forall sh sh' a. (Storable a, Num a) => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a sumOuter ssh ssh' | Refl <- lemAppNil @sh = sumInner ssh' ssh . transpose2 ssh ssh' fromList :: forall n sh a. Storable a => StaticShapeX (n : sh) -> [XArray sh a] -> XArray (n : sh) a fromList ssh l | Dict <- lemKnownINatRankSSX ssh , Dict <- knownNatFromINat (Proxy @(Rank (n : sh))) = case ssh of m@GHC_SNat :$@ _ | natVal m /= fromIntegral (length l) -> error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ "does not match the type (" ++ show (natVal m) ++ ")" _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l))) toList :: Storable a => XArray (n : sh) a -> [XArray sh a] toList (XArray arr) = coerce (ORB.toList (S.unravel arr)) slice :: [(Int, Int)] -> XArray sh a -> XArray sh a slice ivs (XArray arr) = XArray (S.slice ivs arr)