diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Data/Array/Mixed.hs | 416 | ||||
-rw-r--r-- | src/Data/Array/Nested.hs | 128 | ||||
-rw-r--r-- | src/Data/Array/Nested/Convert.hs | 333 | ||||
-rw-r--r-- | src/Data/Array/Nested/Internal.hs | 1294 | ||||
-rw-r--r-- | src/Data/Array/Nested/Lemmas.hs | 162 | ||||
-rw-r--r-- | src/Data/Array/Nested/Mixed.hs | 936 | ||||
-rw-r--r-- | src/Data/Array/Nested/Mixed/Shape.hs | 644 | ||||
-rw-r--r-- | src/Data/Array/Nested/Permutation.hs | 283 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked.hs | 323 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Base.hs | 268 | ||||
-rw-r--r-- | src/Data/Array/Nested/Ranked/Shape.hs | 369 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped.hs | 272 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped/Base.hs | 255 | ||||
-rw-r--r-- | src/Data/Array/Nested/Shaped/Shape.hs | 425 | ||||
-rw-r--r-- | src/Data/Array/Nested/Trace.hs | 72 | ||||
-rw-r--r-- | src/Data/Array/Nested/Trace/TH.hs | 98 | ||||
-rw-r--r-- | src/Data/Array/Nested/Types.hs | 152 | ||||
-rw-r--r-- | src/Data/Array/Strided/Orthotope.hs | 43 | ||||
-rw-r--r-- | src/Data/Array/XArray.hs | 348 | ||||
-rw-r--r-- | src/Data/Bag.hs | 18 | ||||
-rw-r--r-- | src/Data/INat.hs | 121 |
21 files changed, 5102 insertions, 1858 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs deleted file mode 100644 index 0351beb..0000000 --- a/src/Data/Array/Mixed.hs +++ /dev/null @@ -1,416 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.Array.Mixed where - -import qualified Data.Array.RankedS as S -import qualified Data.Array.Ranked as ORB -import Data.Coerce -import Data.Kind -import Data.Proxy -import Data.Type.Equality -import qualified Data.Vector.Storable as VS -import Foreign.Storable (Storable) -import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) - -import Data.INat - - --- | The 'SNat' pattern synonym is complete, but it doesn't have a --- @COMPLETE@ pragma. This copy of it does. -pattern GHC_SNat :: () => KnownNat n => SNat n -pattern GHC_SNat = SNat -{-# COMPLETE GHC_SNat #-} - -fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat - - --- | Type-level list append. -type family l1 ++ l2 where - '[] ++ l2 = l2 - (x : xs) ++ l2 = x : xs ++ l2 - -lemAppNil :: l ++ '[] :~: l -lemAppNil = unsafeCoerce Refl - -lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) -lemAppAssoc _ _ _ = unsafeCoerce Refl - -type IxX :: [Maybe Nat] -> Type -> Type -data IxX sh i where - ZIX :: IxX '[] i - (:.@) :: forall n sh i. i -> IxX sh i -> IxX (Just n : sh) i - (:.?) :: forall sh i. i -> IxX sh i -> IxX (Nothing : sh) i -deriving instance Show i => Show (IxX sh i) -deriving instance Eq i => Eq (IxX sh i) -deriving instance Ord i => Ord (IxX sh i) -deriving instance Functor (IxX sh) -deriving instance Foldable (IxX sh) -infixr 3 :.@ -infixr 3 :.? - -type IIxX sh = IxX sh Int - -type ShX :: [Maybe Nat] -> Type -> Type -data ShX sh i where - ZSX :: ShX '[] i - (:$@) :: forall n sh i. SNat n -> ShX sh i -> ShX (Just n : sh) i - (:$?) :: forall sh i. i -> ShX sh i -> ShX (Nothing : sh) i -deriving instance Show i => Show (ShX sh i) -deriving instance Eq i => Eq (ShX sh i) -deriving instance Ord i => Ord (ShX sh i) -deriving instance Functor (ShX sh) -deriving instance Foldable (ShX sh) -infixr 3 :$@ -infixr 3 :$? - -type IShX sh = ShX sh Int - --- | The part of a shape that is statically known. -type StaticShX :: [Maybe Nat] -> Type -data StaticShX sh where - ZKSX :: StaticShX '[] - (:!$@) :: SNat n -> StaticShX sh -> StaticShX (Just n : sh) - (:!$?) :: () -> StaticShX sh -> StaticShX (Nothing : sh) -deriving instance Show (StaticShX sh) -infixr 3 :!$@ -infixr 3 :!$? - --- | Evidence for the static part of a shape. -type KnownShapeX :: [Maybe Nat] -> Constraint -class KnownShapeX sh where - knownShapeX :: StaticShX sh -instance KnownShapeX '[] where - knownShapeX = ZKSX -instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where - knownShapeX = natSing :!$@ knownShapeX -instance KnownShapeX sh => KnownShapeX (Nothing : sh) where - knownShapeX = () :!$? knownShapeX - -type family Rank sh where - Rank '[] = Z - Rank (_ : sh) = S (Rank sh) - -type XArray :: [Maybe Nat] -> Type -> Type -newtype XArray sh a = XArray (S.Array (FromINat (Rank sh)) a) - deriving (Show) - -zeroIxX :: StaticShX sh -> IIxX sh -zeroIxX ZKSX = ZIX -zeroIxX (_ :!$@ ssh) = 0 :.@ zeroIxX ssh -zeroIxX (_ :!$? ssh) = 0 :.? zeroIxX ssh - -zeroIxX' :: IShX sh -> IIxX sh -zeroIxX' ZSX = ZIX -zeroIxX' (_ :$@ sh) = 0 :.@ zeroIxX' sh -zeroIxX' (_ :$? sh) = 0 :.? zeroIxX' sh - --- This is a weird operation, so it has a long name -completeShXzeros :: StaticShX sh -> IShX sh -completeShXzeros ZKSX = ZSX -completeShXzeros (n :!$@ ssh) = n :$@ completeShXzeros ssh -completeShXzeros (_ :!$? ssh) = 0 :$? completeShXzeros ssh - -ixAppend :: IIxX sh -> IIxX sh' -> IIxX (sh ++ sh') -ixAppend ZIX idx' = idx' -ixAppend (i :.@ idx) idx' = i :.@ ixAppend idx idx' -ixAppend (i :.? idx) idx' = i :.? ixAppend idx idx' - -shAppend :: IShX sh -> IShX sh' -> IShX (sh ++ sh') -shAppend ZSX sh' = sh' -shAppend (n :$@ sh) sh' = n :$@ shAppend sh sh' -shAppend (n :$? sh) sh' = n :$? shAppend sh sh' - -ixDrop :: IIxX (sh ++ sh') -> IIxX sh -> IIxX sh' -ixDrop sh ZIX = sh -ixDrop (_ :.@ sh) (_ :.@ idx) = ixDrop sh idx -ixDrop (_ :.? sh) (_ :.? idx) = ixDrop sh idx - -ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') -ssxAppend ZKSX sh' = sh' -ssxAppend (n :!$@ sh) sh' = n :!$@ ssxAppend sh sh' -ssxAppend (() :!$? sh) sh' = () :!$? ssxAppend sh sh' - -shapeSize :: IShX sh -> Int -shapeSize ZSX = 1 -shapeSize (n :$@ sh) = fromSNat' n * shapeSize sh -shapeSize (n :$? sh) = n * shapeSize sh - --- | This may fail if @sh@ has @Nothing@s in it. -ssxToShape' :: StaticShX sh -> Maybe (IShX sh) -ssxToShape' ZKSX = Just ZSX -ssxToShape' (n :!$@ sh) = (n :$@) <$> ssxToShape' sh -ssxToShape' (_ :!$? _) = Nothing - -fromLinearIdx :: IShX sh -> Int -> IIxX sh -fromLinearIdx = \sh i -> case go sh i of - (idx, 0) -> idx - _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" - where - -- returns (index in subarray, remaining index in enclosing array) - go :: IShX sh -> Int -> (IIxX sh, Int) - go ZSX i = (ZIX, i) - go (n :$@ sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` fromSNat' n - in (locali :.@ idx, upi) - go (n :$? sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` n - in (locali :.? idx, upi) - -toLinearIdx :: IShX sh -> IIxX sh -> Int -toLinearIdx = \sh i -> fst (go sh i) - where - -- returns (index in subarray, size of subarray) - go :: IShX sh -> IIxX sh -> (Int, Int) - go ZSX ZIX = (0, 1) - go (n :$@ sh) (i :.@ ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, fromSNat' n * sz) - go (n :$? sh) (i :.? ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, n * sz) - -enumShape :: IShX sh -> [IIxX sh] -enumShape = \sh -> go sh id [] - where - go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] - go ZSX f = (f ZIX :) - go (n :$@ sh) f = foldr (.) id [go sh (f . (i :.@)) | i <- [0 .. fromSNat' n - 1]] - go (n :$? sh) f = foldr (.) id [go sh (f . (i :.?)) | i <- [0 .. n-1]] - -shapeLshape :: IShX sh -> S.ShapeL -shapeLshape ZSX = [] -shapeLshape (n :$@ sh) = fromSNat' n : shapeLshape sh -shapeLshape (n :$? sh) = n : shapeLshape sh - -ssxLength :: StaticShX sh -> Int -ssxLength ZKSX = 0 -ssxLength (_ :!$@ ssh) = 1 + ssxLength ssh -ssxLength (_ :!$? ssh) = 1 + ssxLength ssh - -ssxIotaFrom :: Int -> StaticShX sh -> [Int] -ssxIotaFrom _ ZKSX = [] -ssxIotaFrom i (_ :!$@ ssh) = i : ssxIotaFrom (i+1) ssh -ssxIotaFrom i (_ :!$? ssh) = i : ssxIotaFrom (i+1) ssh - -lemRankApp :: StaticShX sh1 -> StaticShX sh2 - -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank sh1) + FromINat (Rank sh2) -lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this - -lemRankAppComm :: StaticShX sh1 -> StaticShX sh2 - -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1)) -lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this - -lemKnownINatRank :: IShX sh -> Dict KnownINat (Rank sh) -lemKnownINatRank ZSX = Dict -lemKnownINatRank (_ :$@ sh) | Dict <- lemKnownINatRank sh = Dict -lemKnownINatRank (_ :$? sh) | Dict <- lemKnownINatRank sh = Dict - -lemKnownINatRankSSX :: StaticShX sh -> Dict KnownINat (Rank sh) -lemKnownINatRankSSX ZKSX = Dict -lemKnownINatRankSSX (_ :!$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict -lemKnownINatRankSSX (_ :!$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict - -lemKnownShapeX :: StaticShX sh -> Dict KnownShapeX sh -lemKnownShapeX ZKSX = Dict -lemKnownShapeX (GHC_SNat :!$@ ssh) | Dict <- lemKnownShapeX ssh = Dict -lemKnownShapeX (() :!$? ssh) | Dict <- lemKnownShapeX ssh = Dict - -lemAppKnownShapeX :: StaticShX sh1 -> StaticShX sh2 -> Dict KnownShapeX (sh1 ++ sh2) -lemAppKnownShapeX ZKSX ssh' = lemKnownShapeX ssh' -lemAppKnownShapeX (GHC_SNat :!$@ ssh) ssh' - | Dict <- lemAppKnownShapeX ssh ssh' - = Dict -lemAppKnownShapeX (() :!$? ssh) ssh' - | Dict <- lemAppKnownShapeX ssh ssh' - = Dict - -shape :: forall sh a. KnownShapeX sh => XArray sh a -> IShX sh -shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) - where - go :: StaticShX sh' -> [Int] -> IShX sh' - go ZKSX [] = ZSX - go (n :!$@ ssh) (_ : l) = n :$@ go ssh l - go (() :!$? ssh) (n : l) = n :$? go ssh l - go _ _ = error "Invalid shapeL" - -fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a -fromVector sh v - | Dict <- lemKnownINatRank sh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - = XArray (S.fromVector (shapeLshape sh) v) - -toVector :: Storable a => XArray sh a -> VS.Vector a -toVector (XArray arr) = S.toVector arr - -scalar :: Storable a => a -> XArray '[] a -scalar = XArray . S.scalar - -unScalar :: Storable a => XArray '[] a -> a -unScalar (XArray a) = S.unScalar a - -constant :: forall sh a. Storable a => IShX sh -> a -> XArray sh a -constant sh x - | Dict <- lemKnownINatRank sh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - = XArray (S.constant (shapeLshape sh) x) - -generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a -generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) - --- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) --- generateM sh f | Dict <- lemKnownINatRank sh = --- XArray . S.fromVector (shapeLshape sh) --- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) - -indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a -indexPartial (XArray arr) ZIX = XArray arr -indexPartial (XArray arr) (i :.@ idx) = indexPartial (XArray (S.index arr i)) idx -indexPartial (XArray arr) (i :.? idx) = indexPartial (XArray (S.index arr i)) idx - -index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a -index xarr i - | Refl <- lemAppNil @sh - = let XArray arr' = indexPartial xarr i :: XArray '[] a - in S.unScalar arr' - -type family AddMaybe n m where - AddMaybe Nothing _ = Nothing - AddMaybe (Just _) Nothing = Nothing - AddMaybe (Just n) (Just m) = Just (n + m) - -append :: forall n m sh a. (KnownShapeX sh, Storable a) - => XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a -append (XArray a) (XArray b) - | Dict <- lemKnownINatRankSSX (knownShapeX @sh) - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - = XArray (S.append a b) - -rerank :: forall sh sh1 sh2 a b. - (Storable a, Storable b) - => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 - -> (XArray sh1 a -> XArray sh2 b) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b -rerank ssh ssh1 ssh2 f (XArray arr) - | Dict <- lemKnownINatRankSSX ssh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - , Dict <- lemKnownINatRankSSX ssh2 - , Dict <- knownNatFromINat (Proxy @(Rank sh2)) - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (S.rerank @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2)) - (\a -> unXArray (f (XArray a))) - arr) - where - unXArray (XArray a) = a - -rerankTop :: forall sh sh1 sh2 a b. - (Storable a, Storable b) - => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh - -> (XArray sh1 a -> XArray sh2 b) - -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b -rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh - -rerank2 :: forall sh sh1 sh2 a b c. - (Storable a, Storable b, Storable c) - => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 - -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c -rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) - | Dict <- lemKnownINatRankSSX ssh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - , Dict <- lemKnownINatRankSSX ssh2 - , Dict <- knownNatFromINat (Proxy @(Rank sh2)) - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (S.rerank2 @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2)) - (\a b -> unXArray (f (XArray a) (XArray b))) - arr1 arr2) - where - unXArray (XArray a) = a - --- | The list argument gives indices into the original dimension list. -transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a -transpose perm (XArray arr) - | Dict <- lemKnownINatRankSSX (knownShapeX @sh) - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - = XArray (S.transpose perm arr) - -transpose2 :: forall sh1 sh2 a. - StaticShX sh1 -> StaticShX sh2 - -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a -transpose2 ssh1 ssh2 (XArray arr) - | Refl <- lemRankApp ssh1 ssh2 - , Refl <- lemRankApp ssh2 ssh1 - , Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh2) - , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh2))) - , Dict <- lemKnownINatRankSSX (ssxAppend ssh2 ssh1) - , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh1))) - , Refl <- lemRankAppComm ssh1 ssh2 - , let n1 = ssxLength ssh1 - = XArray (S.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) - -sumFull :: (Storable a, Num a) => XArray sh a -> a -sumFull (XArray arr) = S.sumA arr - -sumInner :: forall sh sh' a. (Storable a, Num a) - => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a -sumInner ssh ssh' - | Refl <- lemAppNil @sh - = rerank ssh ssh' ZKSX (scalar . sumFull) - -sumOuter :: forall sh sh' a. (Storable a, Num a) - => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a -sumOuter ssh ssh' - | Refl <- lemAppNil @sh - = sumInner ssh' ssh . transpose2 ssh ssh' - -fromList1 :: forall n sh a. Storable a - => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a -fromList1 ssh l - | Dict <- lemKnownINatRankSSX ssh - , Dict <- knownNatFromINat (Proxy @(Rank (n : sh))) - = case ssh of - m@GHC_SNat :!$@ _ | natVal m /= fromIntegral (length l) -> - error $ "Data.Array.Mixed.fromList: length of list (" ++ show (length l) ++ ")" ++ - "does not match the type (" ++ show (natVal m) ++ ")" - _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (FromINat (Rank sh)) a] l))) - -toList1 :: Storable a => XArray (n : sh) a -> [XArray sh a] -toList1 (XArray arr) = coerce (ORB.toList (S.unravel arr)) - --- | Throws if the given shape is not, in fact, empty. -empty :: forall sh a. Storable a => IShX sh -> XArray sh a -empty sh - | Dict <- lemKnownINatRank sh - , Dict <- knownNatFromINat (Proxy @(Rank sh)) - = XArray (S.constant (shapeLshape sh) - (error "Data.Array.Mixed.empty: shape was not empty")) - -slice :: [(Int, Int)] -> XArray sh a -> XArray sh a -slice ivs (XArray arr) = XArray (S.slice ivs arr) - -rev1 :: XArray (n : sh) a -> XArray (n : sh) a -rev1 (XArray arr) = XArray (S.rev [0] arr) diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs index ec5f0b5..c3635e9 100644 --- a/src/Data/Array/Nested.hs +++ b/src/Data/Array/Nested.hs @@ -2,52 +2,126 @@ {-# LANGUAGE PatternSynonyms #-} module Data.Array.Nested ( -- * Ranked arrays - Ranked, - ListR(ZR, (:::)), knownListR, - IxR(.., ZIR, (:.:)), IIxR, knownIxR, - ShR(.., ZSR, (:$:)), knownShR, - rshape, rindex, rindexPartial, rgenerate, rsumOuter1, - rtranspose, rappend, rscalar, rfromVector, rtoVector, runScalar, - rconstant, rfromList, rfromList1, rtoList, rtoList1, - rslice, rrev1, + Ranked(Ranked), + ListR(ZR, (:::)), + IxR(.., ZIR, (:.:)), IIxR, + ShR(.., ZSR, (:$:)), IShR, + rshape, rrank, rsize, rindex, rindexPartial, rgenerate, rsumOuter1, rsumAllPrim, + rtranspose, rappend, rconcat, rscalar, rfromVector, rtoVector, runScalar, + remptyArray, + rrerank, + rreplicate, rreplicateScal, + rfromList1, rfromListOuter, rfromListLinear, rfromListPrim, rfromListPrimLinear, + rtoList, rtoListOuter, rtoListLinear, + rslice, rrev1, rreshape, rflatten, riota, + rminIndexPrim, rmaxIndexPrim, rdot1Inner, rdot, + rnest, runNest, rzip, runzip, -- ** Lifting orthotope operations to 'Ranked' arrays - rlift, + rlift, rlift2, + -- ** Conversions + rtoXArrayPrim, rfromXArrayPrim, + rtoMixed, rcastToMixed, rcastToShaped, + rfromOrthotope, rtoOrthotope, + -- ** Additional arithmetic operations + -- + -- $integralRealFloat + rquotArray, rremArray, ratan2Array, -- * Shaped arrays - Shaped, + Shaped(Shaped), ListS(ZS, (::$)), IxS(.., ZIS, (:.$)), IIxS, - ShS(..), KnownShape(..), - sshape, sindex, sindexPartial, sgenerate, ssumOuter1, + ShS(.., ZSS, (:$$)), KnownShS(..), + sshape, srank, ssize, sindex, sindexPartial, sgenerate, ssumOuter1, ssumAllPrim, stranspose, sappend, sscalar, sfromVector, stoVector, sunScalar, - sconstant, sfromList, sfromList1, stoList, stoList1, - sslice, srev1, + -- TODO: sconcat? What should its type be? + semptyArray, + srerank, + sreplicate, sreplicateScal, + sfromList1, sfromListOuter, sfromListLinear, sfromListPrim, sfromListPrimLinear, + stoList, stoListOuter, stoListLinear, + sslice, srev1, sreshape, sflatten, siota, + sminIndexPrim, smaxIndexPrim, sdot1Inner, sdot, + snest, sunNest, szip, sunzip, -- ** Lifting orthotope operations to 'Shaped' arrays - slift, + slift, slift2, + -- ** Conversions + stoXArrayPrim, sfromXArrayPrim, + stoMixed, scastToMixed, stoRanked, + sfromOrthotope, stoOrthotope, + -- ** Additional arithmetic operations + -- + -- $integralRealFloat + squotArray, sremArray, satan2Array, -- * Mixed arrays Mixed, - IxX(..), IIxX, - KnownShapeX(..), StaticShX(..), - mgenerate, mtranspose, mappend, mfromVector, mtoVector, munScalar, - mconstant, mfromList, mtoList, mslice, mrev1, + ListX(ZX, (::%)), + IxX(.., ZIX, (:.%)), IIxX, + ShX(.., ZSX, (:$%)), KnownShX(..), IShX, + StaticShX(.., ZKX, (:!%)), + SMayNat(..), + mshape, mrank, msize, mindex, mindexPartial, mgenerate, msumOuter1, msumAllPrim, + mtranspose, mappend, mconcat, mscalar, mfromVector, mtoVector, munScalar, + memptyArray, + mrerank, + mreplicate, mreplicateScal, + mfromList1, mfromListOuter, mfromListLinear, mfromListPrim, mfromListPrimLinear, + mtoList, mtoListOuter, mtoListLinear, + mslice, mrev1, mreshape, mflatten, miota, + mminIndexPrim, mmaxIndexPrim, mdot1Inner, mdot, + mnest, munNest, mzip, munzip, + -- ** Lifting orthotope operations to 'Mixed' arrays + mlift, mlift2, + -- ** Conversions + mtoXArrayPrim, mfromXArrayPrim, + mcast, + mcastToShaped, mtoRanked, + convert, Conversion(..), + -- ** Additional arithmetic operations + -- + -- $integralRealFloat + mquotArray, mremArray, matan2Array, -- * Array elements - Elt(mshape, mindex, mindexPartial, mscalar, mfromList1, mtoList1, mlift, mlift2), + Elt, PrimElt, Primitive(..), - - -- * Inductive natural numbers - module Data.INat, + KnownElt, -- * Further utilities / re-exports type (++), Storable, + SNat, pattern SNat, + pattern SZ, pattern SS, + Perm(..), + IsPermutation, + KnownPerm(..), + NumElt, IntElt, FloatElt, + Rank, Product, + Replicate, + MapJust, ) where -import Prelude hiding (mappend) +import Prelude hiding (mappend, mconcat) -import Data.Array.Mixed -import Data.Array.Nested.Internal -import Data.INat +import Data.Array.Nested.Convert +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Ranked +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Shaped +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith import Foreign.Storable +import GHC.TypeLits + +-- $integralRealFloat +-- +-- These functions are separate top-level functions, and not exposed in +-- instances for 'RealFloat' and 'Integral', because those classes include a +-- variety of other functions that make no sense for arrays. +-- This problem already occurs with 'fromInteger', 'fromRational' and 'pi', but +-- having 'Num', 'Fractional' and 'Floating' available is just too useful. diff --git a/src/Data/Array/Nested/Convert.hs b/src/Data/Array/Nested/Convert.hs new file mode 100644 index 0000000..2438f68 --- /dev/null +++ b/src/Data/Array/Nested/Convert.hs @@ -0,0 +1,333 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +module Data.Array.Nested.Convert ( + -- * Shape\/index\/list casting functions + -- ** To ranked + ixrFromIxS, ixrFromIxX, shrFromShS, shrFromShX, shrFromShX2, + listrCast, ixrCast, shrCast, + -- ** To shaped + ixsFromIxR, ixsFromIxR', ixsFromIxX, ixsFromIxX', withShsFromShR, shsFromShX, withShsFromShX, shsFromSSX, + ixsCast, + -- ** To mixed + ixxFromIxR, ixxFromIxS, shxFromShR, shxFromShS, + ixxCast, shxCast, shxCast', + + -- * Array conversions + convert, + Conversion(..), + + -- * Special cases of array conversions + -- + -- | These functions can all be implemented using 'convert' in some way, + -- but some have fewer constraints. + rtoMixed, rcastToMixed, rcastToShaped, + stoMixed, scastToMixed, stoRanked, + mcast, mcastToShaped, mtoRanked, +) where + +import Control.Category +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked.Base +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Shaped.Base +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types + +-- * Shape or index or list casting functions + +-- * To ranked + +ixrFromIxS :: IxS sh i -> IxR (Rank sh) i +ixrFromIxS ZIS = ZIR +ixrFromIxS (i :.$ ix) = i :.: ixrFromIxS ix + +ixrFromIxX :: IxX sh i -> IxR (Rank sh) i +ixrFromIxX ZIX = ZIR +ixrFromIxX (n :.% idx) = n :.: ixrFromIxX idx + +shrFromShS :: ShS sh -> IShR (Rank sh) +shrFromShS ZSS = ZSR +shrFromShS (n :$$ sh) = fromSNat' n :$: shrFromShS sh + +-- shrFromShX re-exported +-- shrFromShX2 re-exported +-- listrCast re-exported +-- ixrCast re-exported +-- shrCast re-exported + +-- * To shaped + +-- TODO: these take a ShS because there are KnownNats inside IxS. + +ixsFromIxR :: ShS sh -> IxR (Rank sh) i -> IxS sh i +ixsFromIxR ZSS ZIR = ZIS +ixsFromIxR (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR sh idx +ixsFromIxR _ _ = error "unreachable" + +-- | Performs a runtime check that @n@ matches @Rank sh@. Equivalent to the +-- following, but more efficient: +-- +-- > ixsFromIxR' sh idx = ixsFromIxR sh (ixrCast (shsRank sh) idx) +ixsFromIxR' :: ShS sh -> IxR n i -> IxS sh i +ixsFromIxR' ZSS ZIR = ZIS +ixsFromIxR' (_ :$$ sh) (n :.: idx) = n :.$ ixsFromIxR' sh idx +ixsFromIxR' _ _ = error "ixsFromIxR': index rank does not match shape rank" + +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsFromIxX :: ShS sh -> IxX (MapJust sh) i -> IxS sh i +ixsFromIxX ZSS ZIX = ZIS +ixsFromIxX (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX sh idx + +-- | Performs a runtime check that @Rank sh'@ match @Rank sh@. Equivalent to +-- the following, but more efficient: +-- +-- > ixsFromIxX' sh idx = ixsFromIxX sh (ixxCast (shxFromShS sh) idx) +ixsFromIxX' :: ShS sh -> IxX sh' i -> IxS sh i +ixsFromIxX' ZSS ZIX = ZIS +ixsFromIxX' (_ :$$ sh) (n :.% idx) = n :.$ ixsFromIxX' sh idx +ixsFromIxX' _ _ = error "ixsFromIxX': index rank does not match shape rank" + +-- | Produce an existential 'ShS' from an 'IShR'. +withShsFromShR :: IShR n -> (forall sh. Rank sh ~ n => ShS sh -> r) -> r +withShsFromShR ZSR k = k ZSS +withShsFromShR (n :$: sh) k = + withShsFromShR sh $ \sh' -> + withSomeSNat (fromIntegral @Int @Integer n) $ \case + Just sn@SNat -> k (sn :$$ sh') + Nothing -> error $ "withShsFromShR: negative dimension size (" ++ show n ++ ")" + +-- shsFromShX re-exported + +-- | Produce an existential 'ShS' from an 'IShX'. If you already know that +-- @sh'@ is @MapJust@ of something, use 'shsFromShX' instead. +withShsFromShX :: IShX sh' -> (forall sh. Rank sh ~ Rank sh' => ShS sh -> r) -> r +withShsFromShX ZSX k = k ZSS +withShsFromShX (SKnown sn@SNat :$% sh) k = + withShsFromShX sh $ \sh' -> + k (sn :$$ sh') +withShsFromShX (SUnknown n :$% sh) k = + withShsFromShX sh $ \sh' -> + withSomeSNat (fromIntegral @Int @Integer n) $ \case + Just sn@SNat -> k (sn :$$ sh') + Nothing -> error $ "withShsFromShX: negative SUnknown dimension size (" ++ show n ++ ")" + +shsFromSSX :: StaticShX (MapJust sh) -> ShS sh +shsFromSSX = shsFromShX Prelude.. shxFromSSX + +-- ixsCast re-exported + +-- * To mixed + +ixxFromIxR :: IxR n i -> IxX (Replicate n Nothing) i +ixxFromIxR ZIR = ZIX +ixxFromIxR (n :.: (idx :: IxR m i)) = + castWith (subst2 @IxX @i (lemReplicateSucc @(Nothing @Nat) @m)) + (n :.% ixxFromIxR idx) + +ixxFromIxS :: IxS sh i -> IxX (MapJust sh) i +ixxFromIxS ZIS = ZIX +ixxFromIxS (n :.$ sh) = n :.% ixxFromIxS sh + +shxFromShR :: ShR n i -> ShX (Replicate n Nothing) i +shxFromShR ZSR = ZSX +shxFromShR (n :$: (idx :: ShR m i)) = + castWith (subst2 @ShX @i (lemReplicateSucc @(Nothing @Nat) @m)) + (SUnknown n :$% shxFromShR idx) + +shxFromShS :: ShS sh -> IShX (MapJust sh) +shxFromShS ZSS = ZSX +shxFromShS (n :$$ sh) = SKnown n :$% shxFromShS sh + +-- ixxCast re-exported +-- shxCast re-exported +-- shxCast' re-exported + + +-- * Array conversions + +-- | The constructors that perform runtime shape checking are marked with a +-- tick (@'@): 'ConvXS'' and 'ConvXX''. For the other constructors, the types +-- ensure that the shapes are already compatible. To convert between 'Ranked' +-- and 'Shaped', go via 'Mixed'. +-- +-- The guiding principle behind 'Conversion' is that it should represent the +-- array restructurings, or perhaps re-presentations, that do not change the +-- underlying 'XArray's. This leads to the inclusion of some operations that do +-- not look like simple conversions (casts) at first glance, like 'ConvZip'. +-- +-- /Note/: Haddock gleefully renames type variables in constructors so that +-- they match the data type head as much as possible. See the source for a more +-- readable presentation of this data type. +data Conversion a b where + ConvId :: Conversion a a + ConvCmp :: Conversion b c -> Conversion a b -> Conversion a c + + ConvRX :: Conversion (Ranked n a) (Mixed (Replicate n Nothing) a) + ConvSX :: Conversion (Shaped sh a) (Mixed (MapJust sh) a) + + ConvXR :: Elt a + => Conversion (Mixed sh a) (Ranked (Rank sh) a) + ConvXS :: Conversion (Mixed (MapJust sh) a) (Shaped sh a) + ConvXS' :: (Rank sh ~ Rank sh', Elt a) + => ShS sh' + -> Conversion (Mixed sh a) (Shaped sh' a) + + ConvXX' :: (Rank sh ~ Rank sh', Elt a) + => StaticShX sh' + -> Conversion (Mixed sh a) (Mixed sh' a) + + ConvRR :: Conversion a b + -> Conversion (Ranked n a) (Ranked n b) + ConvSS :: Conversion a b + -> Conversion (Shaped sh a) (Shaped sh b) + ConvXX :: Conversion a b + -> Conversion (Mixed sh a) (Mixed sh b) + ConvT2 :: Conversion a a' + -> Conversion b b' + -> Conversion (a, b) (a', b') + + Conv0X :: Elt a + => Conversion a (Mixed '[] a) + ConvX0 :: Conversion (Mixed '[] a) a + + ConvNest :: Elt a => StaticShX sh + -> Conversion (Mixed (sh ++ sh') a) (Mixed sh (Mixed sh' a)) + ConvUnnest :: Conversion (Mixed sh (Mixed sh' a)) (Mixed (sh ++ sh') a) + + ConvZip :: (Elt a, Elt b) + => Conversion (Mixed sh a, Mixed sh b) (Mixed sh (a, b)) + ConvUnzip :: (Elt a, Elt b) + => Conversion (Mixed sh (a, b)) (Mixed sh a, Mixed sh b) +deriving instance Show (Conversion a b) + +instance Category Conversion where + id = ConvId + (.) = ConvCmp + +convert :: (Elt a, Elt b) => Conversion a b -> a -> b +convert = \c x -> munScalar (go c (mscalar x)) + where + -- The 'esh' is the extension shape: the conversion happens under a whole + -- bunch of additional dimensions that it does not touch. These dimensions + -- are 'esh'. + -- The strategy is to unwind step-by-step to a large Mixed array, and to + -- perform the required checks and conversions when re-nesting back up. + go :: Conversion a b -> Mixed esh a -> Mixed esh b + go ConvId x = x + go (ConvCmp c1 c2) x = go c1 (go c2 x) + go ConvRX (M_Ranked x) = x + go ConvSX (M_Shaped x) = x + go (ConvXR @_ @sh) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqRepNo (Proxy @esh) (Proxy @sh) + = let ssx' = ssxAppend (ssxFromShX esh) + (ssxReplicate (shxRank (shxDropSSX @esh @sh (ssxFromShX esh) (mshape x)))) + in M_Ranked (M_Nest esh (mcast ssx' x)) + go ConvXS (M_Nest esh x) = M_Shaped (M_Nest esh x) + go (ConvXS' @sh @sh' sh') (M_Nest @esh esh x) + | Refl <- lemRankAppRankEqMapJust (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Shaped (M_Nest esh (mcast (ssxFromShX (shxAppend esh (shxFromShS sh'))) + x)) + go (ConvXX' @sh @sh' ssx) (M_Nest @esh esh x) + | Refl <- lemRankAppRankEq (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh $ mcast (ssxFromShX esh `ssxAppend` ssx) x + go (ConvRR c) (M_Ranked (M_Nest esh x)) = M_Ranked (M_Nest esh (go c x)) + go (ConvSS c) (M_Shaped (M_Nest esh x)) = M_Shaped (M_Nest esh (go c x)) + go (ConvXX c) (M_Nest esh x) = M_Nest esh (go c x) + go (ConvT2 c1 c2) (M_Tup2 x1 x2) = M_Tup2 (go c1 x1) (go c2 x2) + go Conv0X (x :: Mixed esh a) + | Refl <- lemAppNil @esh + = M_Nest (mshape x) x + go ConvX0 (M_Nest @esh _ x) + | Refl <- lemAppNil @esh + = x + go (ConvNest @_ @sh @sh' ssh) (M_Nest @esh esh x) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh (M_Nest (shxTakeSSX (Proxy @sh') (ssxFromShX esh `ssxAppend` ssh) (mshape x)) x) + go (ConvUnnest @sh @sh') (M_Nest @esh esh (M_Nest _ x)) + | Refl <- lemAppAssoc (Proxy @esh) (Proxy @sh) (Proxy @sh') + = M_Nest esh x + go ConvZip x = + -- no need to check that the two esh's are equal because they were zipped previously + let (M_Nest esh x1, M_Nest _ x2) = munzip x + in M_Nest esh (mzip x1 x2) + go ConvUnzip (M_Nest esh x) = + let (x1, x2) = munzip x + in mzip (M_Nest esh x1) (M_Nest esh x2) + + lemRankAppRankEq :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ sh') + lemRankAppRankEq _ _ _ = unsafeCoerceRefl + + lemRankAppRankEqRepNo :: Proxy esh -> Proxy sh + -> Rank (esh ++ sh) :~: Rank (esh ++ Replicate (Rank sh) Nothing) + lemRankAppRankEqRepNo _ _ = unsafeCoerceRefl + + lemRankAppRankEqMapJust :: Rank sh ~ Rank sh' + => Proxy esh -> Proxy sh -> Proxy sh' + -> Rank (esh ++ sh) :~: Rank (esh ++ MapJust sh') + lemRankAppRankEqMapJust _ _ _ = unsafeCoerceRefl + + +-- * Special cases of array conversions + +mcast :: forall sh1 sh2 a. (Rank sh1 ~ Rank sh2, Elt a) + => StaticShX sh2 -> Mixed sh1 a -> Mixed sh2 a +mcast ssh2 arr + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + = mcastPartial (ssxFromShX (mshape arr)) ssh2 (Proxy @'[]) arr + +mtoRanked :: forall sh a. Elt a => Mixed sh a -> Ranked (Rank sh) a +mtoRanked = convert ConvXR + +rtoMixed :: forall n a. Ranked n a -> Mixed (Replicate n Nothing) a +rtoMixed (Ranked arr) = arr + +-- | A more weakly-typed version of 'rtoMixed' that does a runtime shape +-- compatibility check. +rcastToMixed :: (Rank sh ~ n, Elt a) => StaticShX sh -> Ranked n a -> Mixed sh a +rcastToMixed sshx rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank rarr) + = mcast sshx arr + +mcastToShaped :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => ShS sh' -> Mixed sh a -> Shaped sh' a +mcastToShaped targetsh = convert (ConvXS' targetsh) + +stoMixed :: forall sh a. Shaped sh a -> Mixed (MapJust sh) a +stoMixed (Shaped arr) = arr + +-- | A more weakly-typed version of 'stoMixed' that does a runtime shape +-- compatibility check. +scastToMixed :: forall sh sh' a. (Elt a, Rank sh ~ Rank sh') + => StaticShX sh' -> Shaped sh a -> Mixed sh' a +scastToMixed sshx sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mcast sshx arr + +stoRanked :: Elt a => Shaped sh a -> Ranked (Rank sh) a +stoRanked sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + = mtoRanked arr + +rcastToShaped :: Elt a => Ranked (Rank sh) a -> ShS sh -> Shaped sh a +rcastToShaped (Ranked arr) targetsh + | Refl <- lemRankReplicate (shxRank (shxFromShS targetsh)) + , Refl <- lemRankMapJust targetsh + = mcastToShaped targetsh arr diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs deleted file mode 100644 index 350eb6f..0000000 --- a/src/Data/Array/Nested/Internal.hs +++ /dev/null @@ -1,1294 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} - -{-| -TODO: -* We should be more consistent in whether functions take a 'StaticShX' - argument or a 'KnownShapeX' constraint. - -* Document the choice of using 'INat' for ranks and 'Nat' for shapes. Point - being that we need to do induction over the former, but the latter need to be - able to get large. - --} - -module Data.Array.Nested.Internal where - -import Prelude hiding (mappend) - -import Control.Monad (forM_, when) -import Control.Monad.ST -import qualified Data.Array.RankedS as S -import Data.Bifunctor (first) -import Data.Coerce (coerce, Coercible) -import Data.Foldable (toList) -import Data.Kind -import Data.List.NonEmpty (NonEmpty) -import Data.Proxy -import Data.Type.Equality -import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VSM -import Foreign.Storable (Storable) -import GHC.TypeLits - -import Data.Array.Mixed (XArray, IxX(..), IIxX, ShX(..), IShX, KnownShapeX(..), StaticShX(..), type (++), pattern GHC_SNat) -import qualified Data.Array.Mixed as X -import Data.INat - - --- Invariant in the API --- ==================== --- --- In the underlying XArray, there is some shape for elements of an empty --- array. For example, for this array: --- --- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) --- rshape arr == 0 :.: 0 :.: 0 :.: ZIR --- --- the two underlying XArrays have a shape, and those shapes might be anything. --- The invariant is that these element shapes are unobservable in the API. --- (This is possible because you ought to not be able to get to such an element --- without indexing out of bounds.) --- --- Note, though, that the converse situation may arise: the outer array might --- be nonempty but then the inner arrays might. This is fine, an invariant only --- applies if the _outer_ array is empty. --- --- TODO: can we enforce that the elements of an empty (nested) array have --- all-zero shape? --- -> no, because mlift and also any kind of internals probing from outsiders - - --- Primitive element types --- ======================= --- --- There are a few primitive element types; arrays containing elements of such --- type are a newtype over an XArray, which it itself a newtype over a Vector. --- Unfortunately, the setup of the library requires us to list these primitive --- element types multiple times; to aid in extending the list, all these lists --- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. - - -type family Replicate n a where - Replicate Z a = '[] - Replicate (S n) a = a : Replicate n a - -type family MapJust l where - MapJust '[] = '[] - MapJust (x : xs) = Just x : MapJust xs - -lemKnownReplicate :: forall n. KnownINat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) -lemKnownReplicate _ = X.lemKnownShapeX (go (inatSing @n)) - where - go :: SINat m -> StaticShX (Replicate m Nothing) - go SZ = ZKSX - go (SS n) = () :!$? go n - -lemRankReplicate :: forall n. KnownINat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate _ = go (inatSing @n) - where - go :: SINat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m - go SZ = Refl - go (SS n) | Refl <- go n = Refl - -lemReplicatePlusApp :: forall n m a. KnownINat n => Proxy n -> Proxy m -> Proxy a - -> Replicate (n +! m) a :~: Replicate n a ++ Replicate m a -lemReplicatePlusApp _ _ _ = go (inatSing @n) - where - go :: SINat n' -> Replicate (n' +! m) a :~: Replicate n' a ++ Replicate m a - go SZ = Refl - go (SS n) | Refl <- go n = Refl - -shAppSplit :: Proxy sh' -> StaticShX sh -> IShX (sh ++ sh') -> (IShX sh, IShX sh') -shAppSplit _ ZKSX idx = (ZSX, idx) -shAppSplit p (_ :!$@ ssh) (i :$@ idx) = first (i :$@) (shAppSplit p ssh idx) -shAppSplit p (_ :!$? ssh) (i :$? idx) = first (i :$?) (shAppSplit p ssh idx) - - --- | Wrapper type used as a tag to attach instances on. The instances on arrays --- of @'Primitive' a@ are more polymorphic than the direct instances for arrays --- of scalars; this means that if @orthotope@ supports an element type @T@ that --- this library does not (directly), it may just work if you use an array of --- @'Primitive' T@ instead. -newtype Primitive a = Primitive a - --- | Element types that are primitive; arrays of these types are just a newtype --- wrapper over an array. -class PrimElt a where - fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a - toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) - - default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a - fromPrimitive = coerce - - default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) - toPrimitive = coerce - --- [PRIMITIVE ELEMENT TYPES LIST] -instance PrimElt Int -instance PrimElt Double -instance PrimElt () - - --- | Mixed arrays: some dimensions are size-typed, some are not. Distributes --- over product-typed elements using a data family so that the full array is --- always in struct-of-arrays format. --- --- Built on top of 'XArray' which is built on top of @orthotope@, meaning that --- dimension permutations (e.g. 'mtranspose') are typically free. --- --- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type --- class. -type Mixed :: [Maybe Nat] -> Type -> Type -data family Mixed sh a --- NOTE: When opening up the Mixed abstraction, you might see dimension sizes --- that you're not supposed to see. In particular, you might see (nonempty) --- sizes of the elements of an empty array, which is information that should --- ostensibly not exist; the full array is still empty. - -newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a) - deriving (Show) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance Mixed sh Int = M_Int (XArray sh Int) - deriving (Show) -newtype instance Mixed sh Double = M_Double (XArray sh Double) - deriving (Show) -newtype instance Mixed sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector) - deriving (Show) --- etc. - -data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) -deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) --- etc. - -newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a) -deriving instance Show (Mixed (sh1 ++ sh2) a) => Show (Mixed sh1 (Mixed sh2 a)) - - --- | Internal helper data family mirroring 'Mixed' that consists of mutable --- vectors instead of 'XArray's. -type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type -data family MixedVecs s sh a - -newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) - --- [PRIMITIVE ELEMENT TYPES LIST] -newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) -newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) -newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this --- etc. - -data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) --- etc. - -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) - - --- | Tree giving the shape of every array component. -type family ShapeTree a where - ShapeTree (Primitive _) = () - -- [PRIMITIVE ELEMENT TYPES LIST] - ShapeTree Int = () - ShapeTree Double = () - ShapeTree () = () - - ShapeTree (a, b) = (ShapeTree a, ShapeTree b) - ShapeTree (Mixed sh a) = (IShX sh, ShapeTree a) - ShapeTree (Ranked n a) = (IShR n, ShapeTree a) - ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) - - --- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or --- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' --- a@; see the documentation for 'Primitive' for more details. -class Elt a where - -- ====== PUBLIC METHODS ====== -- - - mshape :: KnownShapeX sh => Mixed sh a -> IShX sh - mindex :: Mixed sh a -> IIxX sh -> a - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a - mscalar :: a -> Mixed '[] a - - -- | All arrays in the list, even subarrays inside @a@, must have the same - -- shape; if they do not, a runtime error will be thrown. See the - -- documentation of 'mgenerate' for more information about this restriction. - -- Furthermore, the length of the list must correspond with @n@: if @n@ is - -- @Just m@ and @m@ does not equal the length of the list, a runtime error is - -- thrown. - -- - -- If you want a single-dimensional array from your list, map 'mscalar' - -- first. - mfromList1 :: forall n sh. KnownShapeX (n : sh) => NonEmpty (Mixed sh a) -> Mixed (n : sh) a - - mtoList1 :: Mixed (n : sh) a -> [Mixed sh a] - - -- | Note: this library makes no particular guarantees about the shapes of - -- arrays "inside" an empty array. With 'mlift' and 'mlift2' you can see the - -- full 'XArray' and as such you can distinguish different empty arrays by - -- the "shapes" of their elements. This information is meaningless, so you - -- should not use it. - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a - - -- | See the documentation for 'mlift'. - mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a - - -- ====== PRIVATE METHODS ====== -- - - -- | Create an empty array. The given shape must have size zero; this may or may not be checked. - memptyArray :: IShX sh -> Mixed sh a - - mshapeTree :: a -> ShapeTree a - - mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool - - mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool - - mshowShapeTree :: Proxy a -> ShapeTree a -> String - - -- | Create uninitialised vectors for this array type, given the shape of - -- this vector and an example for the contents. - mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) - - mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWritePartial :: KnownShapeX sh' => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () - - -- | Given the shape of this array, finalise the vectors into 'XArray's. - mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) - - --- Arrays of scalars are basically just arrays of scalars. -instance Storable a => Elt (Primitive a) where - mshape (M_Primitive a) = X.shape a - mindex (M_Primitive a) i = Primitive (X.index a i) - mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) - mscalar (Primitive x) = M_Primitive (X.scalar x) - mfromList1 l = M_Primitive (X.fromList1 knownShapeX (coerce (toList l))) - mtoList1 (M_Primitive arr) = coerce (X.toList1 arr) - - mlift :: forall sh1 sh2. - (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) - mlift f (M_Primitive a) - | Refl <- X.lemAppNil @sh1 - , Refl <- X.lemAppNil @sh2 - = M_Primitive (f Proxy a) - - mlift2 :: forall sh1 sh2 sh3. - (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) - mlift2 f (M_Primitive a) (M_Primitive b) - | Refl <- X.lemAppNil @sh1 - , Refl <- X.lemAppNil @sh2 - , Refl <- X.lemAppNil @sh3 - = M_Primitive (f Proxy a b) - - memptyArray sh = M_Primitive (X.empty sh) - mshapeTree _ = () - mshapeTreeEq _ () () = True - mshapeTreeEmpty _ () = False - mshowShapeTree _ () = "()" - mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (X.shapeSize sh) - mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (X.toLinearIdx sh i) x - - -- TODO: this use of toVector is suboptimal - mvecsWritePartial - :: forall sh' sh s. KnownShapeX sh' - => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do - let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIxX' (X.shape arr))) - VS.copy (VSM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) - - mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VS.freeze v - --- [PRIMITIVE ELEMENT TYPES LIST] -deriving via Primitive Int instance Elt Int -deriving via Primitive Double instance Elt Double -deriving via Primitive () instance Elt () - --- Arrays of pairs are pairs of arrays. -instance (Elt a, Elt b) => Elt (a, b) where - mshape (M_Tup2 a _) = mshape a - mindex (M_Tup2 a b) i = (mindex a i, mindex b i) - mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) - mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) - mfromList1 l = M_Tup2 (mfromList1 ((\(M_Tup2 x _) -> x) <$> l)) - (mfromList1 ((\(M_Tup2 _ y) -> y) <$> l)) - mtoList1 (M_Tup2 a b) = zipWith M_Tup2 (mtoList1 a) (mtoList1 b) - mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b) - mlift2 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 f a x) (mlift2 f b y) - - memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) - mshapeTree (x, y) = (mshapeTree x, mshapeTree y) - mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' - mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 - mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" - mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y - mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b - mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b - --- Arrays of arrays are just arrays, but with more dimensions. -instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where - -- TODO: this is quadratic in the nesting depth because it repeatedly - -- truncates the shape vector to one a little shorter. Fix with a - -- moverlongShape method, a prefix of which is mshape. - mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IShX sh - mshape (M_Nest arr) - | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') - = fst (shAppSplit (Proxy @sh') (knownShapeX @sh) (mshape arr)) - - mindex (M_Nest arr) i = mindexPartial arr i - - mindexPartial :: forall sh1 sh2. - Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - mindexPartial (M_Nest arr) i - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) - - mscalar = M_Nest - - mfromList1 :: forall n sh. KnownShapeX (n : sh) - => NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (n : sh) (Mixed sh' a) - mfromList1 l - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @(n : sh)) (knownShapeX @sh')) - = M_Nest (mfromList1 (coerce l)) - - mtoList1 (M_Nest arr) = coerce (mtoList1 arr) - - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) - mlift f (M_Nest arr) - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) - = M_Nest (mlift f' arr) - where - f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b - f' _ - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) - = f (Proxy @(sh' ++ shT)) - - mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) - => (forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) - mlift2 f (M_Nest arr1) (M_Nest arr2) - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) - , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh3) (knownShapeX @sh')) - = M_Nest (mlift2 f' arr1 arr2) - where - f' :: forall shT b. (KnownShapeX shT, Storable b) => Proxy shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b - f' _ - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) - , Refl <- X.lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) - , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @shT)) - = f (Proxy @(sh' ++ shT)) - - memptyArray sh = M_Nest (memptyArray (X.shAppend sh (X.completeShXzeros (knownShapeX @sh')))) - - mshapeTree arr = (mshape arr, mshapeTree (mindex arr (X.zeroIxX (knownShapeX @sh')))) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = X.shapeSize sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - mvecsUnsafeNew sh example - | X.shapeSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.shAppend sh (mshape example)) - (mindex example (X.zeroIxX (knownShapeX @sh'))) - where - sh' = mshape example - - mvecsNewEmpty _ = MV_Nest (X.completeShXzeros (knownShapeX @sh')) <$> mvecsNewEmpty (Proxy @a) - - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.shAppend sh sh') idx val vecs - - mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) - , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.shAppend sh12 sh') idx arr vecs - - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.shAppend sh sh') vecs - - --- | Create an array given a size and a function that computes the element at a --- given index. --- --- __WARNING__: It is required that every @a@ returned by the argument to --- 'mgenerate' has the same shape. For example, the following will throw a --- runtime error: --- --- > foo :: Mixed [Nothing] (Mixed [Nothing] Double) --- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) -> --- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) -> --- > ... --- --- because the size of the inner 'mgenerate' is not always the same (it depends --- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so --- the entire hierarchy (after distributing out tuples) must be a rectangular --- array. The type of 'mgenerate' allows this requirement to be broken very --- easily, hence the runtime check. -mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IShX sh -> (IIxX sh -> a) -> Mixed sh a -mgenerate sh f = case X.enumShape sh of - [] -> memptyArray sh - firstidx : restidxs -> - let firstelem = f (X.zeroIxX' sh) - shapetree = mshapeTree firstelem - in if mshapeTreeEmpty (Proxy @a) shapetree - then memptyArray sh - else runST $ do - vecs <- mvecsUnsafeNew sh firstelem - mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this array copying inefficient. Should improve this. - forM_ restidxs $ \idx -> do - let val = f idx - when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ - error "Data.Array.Nested mgenerate: generated values do not have equal shapes" - mvecsWrite sh idx val vecs - mvecsFreeze sh vecs - -mtranspose :: forall sh a. (KnownShapeX sh, Elt a) => [Int] -> Mixed sh a -> Mixed sh a -mtranspose perm = - mlift (\(Proxy @sh') -> X.rerankTop (knownShapeX @sh) (knownShapeX @sh) (knownShapeX @sh') - (X.transpose perm)) - -mappend :: forall n m sh a. (KnownShapeX sh, KnownShapeX (n : sh), KnownShapeX (m : sh), KnownShapeX (X.AddMaybe n m : sh), Elt a) - => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (X.AddMaybe n m : sh) a -mappend = mlift2 go - where go :: forall sh' b. (KnownShapeX sh', Storable b) - => Proxy sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (X.AddMaybe n m : sh ++ sh') b - go Proxy | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') = X.append - -mfromVectorP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) -mfromVectorP sh v = M_Primitive (X.fromVector sh v) - -mfromVector :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a) => IShX sh -> VS.Vector a -> Mixed sh a -mfromVector sh v = fromPrimitive (mfromVectorP sh v) - -mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a -mtoVectorP (M_Primitive v) = X.toVector v - -mtoVector :: (Storable a, PrimElt a) => Mixed sh a -> VS.Vector a -mtoVector arr = mtoVectorP (coerce toPrimitive arr) - -mfromList :: (KnownShapeX '[n], Elt a) => NonEmpty a -> Mixed '[n] a -mfromList = mfromList1 . fmap mscalar - -mtoList :: Elt a => Mixed '[n] a -> [a] -mtoList = map munScalar . mtoList1 - -munScalar :: Elt a => Mixed '[] a -> a -munScalar arr = mindex arr ZIX - -mconstantP :: forall sh a. (KnownShapeX sh, Storable a) => IShX sh -> a -> Mixed sh (Primitive a) -mconstantP sh x = M_Primitive (X.constant sh x) - -mconstant :: forall sh a. (KnownShapeX sh, Storable a, PrimElt a) - => IShX sh -> a -> Mixed sh a -mconstant sh x = fromPrimitive (mconstantP sh x) - -mslice :: (KnownShapeX sh, Elt a) => [(Int, Int)] -> Mixed sh a -> Mixed sh a -mslice ivs = mlift $ \_ -> X.slice ivs - -mrev1 :: (KnownShapeX (n : sh), Elt a) => Mixed (n : sh) a -> Mixed (n : sh) a -mrev1 = mlift $ \_ -> X.rev1 - -mliftPrim :: (KnownShapeX sh, Storable a) - => (a -> a) - -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -mliftPrim f (M_Primitive (X.XArray arr)) = M_Primitive (X.XArray (S.mapA f arr)) - -mliftPrim2 :: (KnownShapeX sh, Storable a) - => (a -> a -> a) - -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -> Mixed sh (Primitive a) -mliftPrim2 f (M_Primitive (X.XArray arr1)) (M_Primitive (X.XArray arr2)) = - M_Primitive (X.XArray (S.zipWithA f arr1 arr2)) - -instance (KnownShapeX sh, Storable a, Num a) => Num (Mixed sh (Primitive a)) where - (+) = mliftPrim2 (+) - (-) = mliftPrim2 (-) - (*) = mliftPrim2 (*) - negate = mliftPrim negate - abs = mliftPrim abs - signum = mliftPrim signum - fromInteger n = - case X.ssxToShape' (knownShapeX @sh) of - Just sh -> M_Primitive (X.constant sh (fromInteger n)) - Nothing -> error "Data.Array.Nested.fromIntegral: \ - \Unknown components in shape, use explicit mconstant" - --- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types) -deriving via Mixed sh (Primitive Int) instance KnownShapeX sh => Num (Mixed sh Int) -deriving via Mixed sh (Primitive Double) instance KnownShapeX sh => Num (Mixed sh Double) - - --- | A rank-typed array: the number of dimensions of the array (its /rank/) is --- represented on the type level as a 'INat'. --- --- Valid elements of a ranked arrays are described by the 'Elt' type class. --- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are --- supported (and are represented as a single, flattened, struct-of-arrays --- array internally). --- --- Note that this 'INat' is not a "GHC.TypeLits" natural, because we want a --- type-level natural that supports induction. --- --- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. -type Ranked :: INat -> Type -> Type -newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) -deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) - --- | A shape-typed array: the full shape of the array (the sizes of its --- dimensions) is represented on the type level as a list of 'Nat's. Note that --- these are "GHC.TypeLits" naturals, because we do not need induction over --- them and we want very large arrays to be possible. --- --- Like for 'Ranked', the valid elements are described by the 'Elt' type class, --- and 'Shaped' itself is again an instance of 'Elt' as well. --- --- 'Shaped' is a newtype around a 'Mixed' of 'Just's. -type Shaped :: [Nat] -> Type -> Type -newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) -deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) -deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a)) -newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a)) -deriving instance Show (Mixed sh (Mixed (MapJust sh' ) a)) => Show (Mixed sh (Shaped sh' a)) - -newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) -newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a)) - - --- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; --- these instances allow them to also be used as elements of arrays, thus --- making them first-class in the API. -instance (Elt a, KnownINat n) => Elt (Ranked n a) where - mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr - mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) - - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) - mindexPartial (M_Ranked arr) i - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ - mindexPartial arr i - - mscalar (Ranked x) = M_Ranked (M_Nest x) - - mfromList1 :: forall m sh. KnownShapeX (m : sh) - => NonEmpty (Mixed sh (Ranked n a)) -> Mixed (m : sh) (Ranked n a) - mfromList1 l - | Dict <- lemKnownReplicate (Proxy @n) - = M_Ranked (mfromList1 (coerce l)) - - mtoList1 :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] - mtoList1 (M_Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoList1 arr) - - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) - mlift f (M_Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ - mlift f arr - - mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) - mlift2 f (M_Ranked arr1) (M_Ranked arr2) - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ - mlift2 f arr1 arr2 - - memptyArray :: forall sh. IShX sh -> Mixed sh (Ranked n a) - memptyArray i - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ - memptyArray i - - mshapeTree (Ranked arr) - | Refl <- lemRankReplicate (Proxy @n) - , Dict <- lemKnownReplicate (Proxy @n) - = first shCvtXR (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shapeSizeR sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - mvecsUnsafeNew idx (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = MV_Ranked <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownReplicate (Proxy @n) - = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) - - mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs - | Dict <- lemKnownReplicate (Proxy @n) - = mvecsWrite sh idx arr - (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsWritePartial :: forall sh sh' s. KnownShapeX sh' - => IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs - | Dict <- lemKnownReplicate (Proxy @n) - = mvecsWritePartial sh idx - (coerce @(Mixed sh' (Ranked n a)) - @(Mixed sh' (Mixed (Replicate n Nothing) a)) - arr) - (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) - @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) - mvecsFreeze sh vecs - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) - @(Mixed sh (Ranked n a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh (Ranked n a)) - @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - - --- | The shape of a shape-typed array given as a list of 'SNat' values. -data ShS sh where - ZSS :: ShS '[] - (:$$) :: forall n sh. SNat n -> ShS sh -> ShS (n : sh) -deriving instance Show (ShS sh) -deriving instance Eq (ShS sh) -deriving instance Ord (ShS sh) -infixr 3 :$$ - --- | A statically-known shape of a shape-typed array. -class KnownShape sh where knownShape :: ShS sh -instance KnownShape '[] where knownShape = ZSS -instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = natSing :$$ knownShape - -sshapeKnown :: ShS sh -> Dict KnownShape sh -sshapeKnown ZSS = Dict -sshapeKnown (GHC_SNat :$$ sh) | Dict <- sshapeKnown sh = Dict - -lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh) -lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) - where - go :: ShS sh' -> StaticShX (MapJust sh') - go ZSS = ZKSX - go (n :$$ sh) = n :!$@ go sh - -lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2 - -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemMapJustPlusApp _ _ = go (knownShape @sh1) - where - go :: ShS sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2 - go ZSS = Refl - go (_ :$$ sh) | Refl <- go sh = Refl - -instance (Elt a, KnownShape sh) => Elt (Shaped sh a) where - mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr - mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i) - - mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - mindexPartial (M_Shaped arr) i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mindexPartial arr i - - mscalar (Shaped x) = M_Shaped (M_Nest x) - - mfromList1 :: forall n sh'. KnownShapeX (n : sh') - => NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (n : sh') (Shaped sh a) - mfromList1 l - | Dict <- lemKnownMapJust (Proxy @sh) - = M_Shaped (mfromList1 (coerce l)) - - mtoList1 :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] - mtoList1 (M_Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoList1 arr) - - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) - mlift f (M_Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mlift f arr - - mlift2 :: forall sh1 sh2 sh3. (KnownShapeX sh2, KnownShapeX sh3) - => (forall sh' b. (KnownShapeX sh', Storable b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) - mlift2 f (M_Shaped arr1) (M_Shaped arr2) - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ - mlift2 f arr1 arr2 - - memptyArray :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) - memptyArray i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ - memptyArray i - - mshapeTree (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = first (shCvtXS (knownShape @sh)) (mshapeTree arr) - - mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 - - mshapeTreeEmpty _ (sh, t) = shapeSizeS sh == 0 && mshapeTreeEmpty (Proxy @a) t - - mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" - - mvecsUnsafeNew idx (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsUnsafeNew idx arr - - mvecsNewEmpty _ - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) - - mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs - | Dict <- lemKnownMapJust (Proxy @sh) - = mvecsWrite sh idx arr - (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - - mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs - | Dict <- lemKnownMapJust (Proxy @sh) - = mvecsWritePartial sh idx - (coerce @(Mixed sh2 (Shaped sh a)) - @(Mixed sh2 (Mixed (MapJust sh) a)) - arr) - (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) - @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) - vecs) - - mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) - mvecsFreeze sh vecs - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) - @(Mixed sh' (Shaped sh a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh' (Shaped sh a)) - @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - - --- Utility functions to satisfy the type checker sometimes - -rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a -rewriteMixed Refl x = x - - --- ====== API OF RANKED ARRAYS ====== -- - -arithPromoteRanked :: forall n a. KnownINat n - => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a) - -> Ranked n a -> Ranked n a -arithPromoteRanked | Dict <- lemKnownReplicate (Proxy @n) = coerce - -arithPromoteRanked2 :: forall n a. KnownINat n - => (forall sh. KnownShapeX sh => Mixed sh a -> Mixed sh a -> Mixed sh a) - -> Ranked n a -> Ranked n a -> Ranked n a -arithPromoteRanked2 | Dict <- lemKnownReplicate (Proxy @n) = coerce - -instance (KnownINat n, Storable a, Num a) => Num (Ranked n (Primitive a)) where - (+) = arithPromoteRanked2 (+) - (-) = arithPromoteRanked2 (-) - (*) = arithPromoteRanked2 (*) - negate = arithPromoteRanked negate - abs = arithPromoteRanked abs - signum = arithPromoteRanked signum - fromInteger n = case inatSing @n of - SZ -> Ranked (M_Primitive (X.scalar (fromInteger n))) - SS _ -> error "Data.Array.Nested.fromIntegral(Ranked): \ - \Rank non-zero, use explicit mconstant" - --- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types) -deriving via Ranked n (Primitive Int) instance KnownINat n => Num (Ranked n Int) -deriving via Ranked n (Primitive Double) instance KnownINat n => Num (Ranked n Double) - -type role ListR nominal representational -type ListR :: INat -> Type -> Type -data ListR n i where - ZR :: ListR Z i - (:::) :: forall n {i}. i -> ListR n i -> ListR (S n) i -deriving instance Show i => Show (ListR n i) -deriving instance Eq i => Eq (ListR n i) -deriving instance Ord i => Ord (ListR n i) -deriving instance Functor (ListR n) -infixr 3 ::: - -instance Foldable (ListR n) where - foldr f z l = foldr f z (listRToList l) - -listRToList :: ListR n i -> [i] -listRToList ZR = [] -listRToList (i ::: is) = i : listRToList is - -knownListR :: ListR n i -> Dict KnownINat n -knownListR ZR = Dict -knownListR (_ ::: l) | Dict <- knownListR l = Dict - --- | An index into a rank-typed array. -type role IxR nominal representational -type IxR :: INat -> Type -> Type -newtype IxR n i = IxR (ListR n i) - deriving (Show, Eq, Ord) - deriving newtype (Functor, Foldable) - -pattern ZIR :: forall n i. () => n ~ Z => IxR n i -pattern ZIR = IxR ZR - -pattern (:.:) - :: forall {n1} {i}. - forall n. (S n ~ n1) - => i -> IxR n i -> IxR n1 i -pattern i :.: sh <- (unconsIxR -> Just (UnconsIxRRes sh i)) - where i :.: IxR sh = IxR (i ::: sh) -{-# COMPLETE ZIR, (:.:) #-} -infixr 3 :.: - -data UnconsIxRRes i n1 = - forall n. ((S n) ~ n1) => UnconsIxRRes (IxR n i) i -unconsIxR :: IxR n1 i -> Maybe (UnconsIxRRes i n1) -unconsIxR (IxR (i ::: sh')) = Just (UnconsIxRRes (IxR sh') i) -unconsIxR (IxR ZR) = Nothing - -type IIxR n = IxR n Int - -knownIxR :: IxR n i -> Dict KnownINat n -knownIxR (IxR sh) = knownListR sh - -type role ShR nominal representational -type ShR :: INat -> Type -> Type -newtype ShR n i = ShR (ListR n i) - deriving (Show, Eq, Ord) - deriving newtype (Functor, Foldable) - -type IShR n = ShR n Int - -pattern ZSR :: forall n i. () => n ~ Z => ShR n i -pattern ZSR = ShR ZR - -pattern (:$:) - :: forall {n1} {i}. - forall n. (S n ~ n1) - => i -> ShR n i -> ShR n1 i -pattern i :$: sh <- (unconsShR -> Just (UnconsShRRes sh i)) - where i :$: (ShR sh) = ShR (i ::: sh) -{-# COMPLETE ZSR, (:$:) #-} -infixr 3 :$: - -data UnconsShRRes i n1 = - forall n. S n ~ n1 => UnconsShRRes (ShR n i) i -unconsShR :: ShR n1 i -> Maybe (UnconsShRRes i n1) -unconsShR (ShR (i ::: sh')) = Just (UnconsShRRes (ShR sh') i) -unconsShR (ShR ZR) = Nothing - -knownShR :: ShR n i -> Dict KnownINat n -knownShR (ShR sh) = knownListR sh - -zeroIxR :: SINat n -> IIxR n -zeroIxR SZ = ZIR -zeroIxR (SS n) = 0 :.: zeroIxR n - -ixCvtXR :: IIxX sh -> IIxR (X.Rank sh) -ixCvtXR ZIX = ZIR -ixCvtXR (n :.@ idx) = n :.: ixCvtXR idx -ixCvtXR (n :.? idx) = n :.: ixCvtXR idx - -shCvtXR :: IShX sh -> IShR (X.Rank sh) -shCvtXR ZSX = ZSR -shCvtXR (n :$@ idx) = X.fromSNat' n :$: shCvtXR idx -shCvtXR (n :$? idx) = n :$: shCvtXR idx - -ixCvtRX :: IIxR n -> IIxX (Replicate n Nothing) -ixCvtRX ZIR = ZIX -ixCvtRX (n :.: idx) = n :.? ixCvtRX idx - -shCvtRX :: IShR n -> IShX (Replicate n Nothing) -shCvtRX ZSR = ZSX -shCvtRX (n :$: idx) = n :$? shCvtRX idx - -shapeSizeR :: IShR n -> Int -shapeSizeR ZSR = 1 -shapeSizeR (n :$: sh) = n * shapeSizeR sh - - -rshape :: forall n a. (KnownINat n, Elt a) => Ranked n a -> IShR n -rshape (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemRankReplicate (Proxy @n) - = shCvtXR (mshape arr) - -rindex :: Elt a => Ranked n a -> IIxR n -> a -rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) - -rindexPartial :: forall n m a. (KnownINat n, Elt a) => Ranked (n +! m) a -> IIxR n -> Ranked m a -rindexPartial (Ranked arr) idx = - Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) - (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) - (ixCvtRX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -rgenerate :: forall n a. Elt a => IShR n -> (IIxR n -> a) -> Ranked n a -rgenerate sh f - | Dict <- knownShR sh - , Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemRankReplicate (Proxy @n) - = Ranked (mgenerate (shCvtRX sh) (f . ixCvtXR)) - --- | See the documentation of 'mlift'. -rlift :: forall n1 n2 a. (KnownINat n2, Elt a) - => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) - -> Ranked n1 a -> Ranked n2 a -rlift f (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n2) - = Ranked (mlift f arr) - -rsumOuter1P :: forall n a. - (Storable a, Num a, KnownINat n) - => Ranked (S n) (Primitive a) -> Ranked n (Primitive a) -rsumOuter1P (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = Ranked - . coerce @(XArray (Replicate n 'Nothing) a) @(Mixed (Replicate n 'Nothing) (Primitive a)) - . X.sumOuter (() :!$? ZKSX) (knownShapeX @(Replicate n Nothing)) - . coerce @(Mixed (Replicate (S n) Nothing) (Primitive a)) @(XArray (Replicate (S n) Nothing) a) - $ arr - -rsumOuter1 :: forall n a. - (Storable a, Num a, PrimElt a, KnownINat n) - => Ranked (S n) a -> Ranked n a -rsumOuter1 = coerce fromPrimitive . rsumOuter1P @n @a . coerce toPrimitive - -rtranspose :: forall n a. (KnownINat n, Elt a) => [Int] -> Ranked n a -> Ranked n a -rtranspose perm (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mtranspose perm arr) - -rappend :: forall n a. (KnownINat n, Elt a) - => Ranked (S n) a -> Ranked (S n) a -> Ranked (S n) a -rappend | Dict <- lemKnownReplicate (Proxy @n) = coerce mappend - -rscalar :: Elt a => a -> Ranked I0 a -rscalar x = Ranked (mscalar x) - -rfromVectorP :: forall n a. (KnownINat n, Storable a) => IShR n -> VS.Vector a -> Ranked n (Primitive a) -rfromVectorP sh v - | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mfromVectorP (shCvtRX sh) v) - -rfromVector :: forall n a. (KnownINat n, Storable a, PrimElt a) => IShR n -> VS.Vector a -> Ranked n a -rfromVector sh v = coerce fromPrimitive (rfromVectorP sh v) - -rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a -rtoVectorP = coerce mtoVectorP - -rtoVector :: (Storable a, PrimElt a) => Ranked n a -> VS.Vector a -rtoVector = coerce mtoVector - -rfromList1 :: forall n a. (KnownINat n, Elt a) => NonEmpty (Ranked n a) -> Ranked (S n) a -rfromList1 l - | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mfromList1 (coerce l)) - -rfromList :: Elt a => NonEmpty a -> Ranked I1 a -rfromList = Ranked . mfromList1 . fmap mscalar - -rtoList :: Elt a => Ranked (S n) a -> [Ranked n a] -rtoList (Ranked arr) = coerce (mtoList1 arr) - -rtoList1 :: Elt a => Ranked I1 a -> [a] -rtoList1 = map runScalar . rtoList - -runScalar :: Elt a => Ranked I0 a -> a -runScalar arr = rindex arr ZIR - -rconstantP :: forall n a. (KnownINat n, Storable a) => IShR n -> a -> Ranked n (Primitive a) -rconstantP sh x - | Dict <- lemKnownReplicate (Proxy @n) - = Ranked (mconstantP (shCvtRX sh) x) - -rconstant :: forall n a. (KnownINat n, Storable a, PrimElt a) - => IShR n -> a -> Ranked n a -rconstant sh x = coerce fromPrimitive (rconstantP sh x) - -rslice :: (KnownINat n, Elt a) => [(Int, Int)] -> Ranked n a -> Ranked n a -rslice ivs = rlift $ \_ -> X.slice ivs - -rrev1 :: (KnownINat n, Elt a) => Ranked (S n) a -> Ranked (S n) a -rrev1 = rlift $ \_ -> X.rev1 - - --- ====== API OF SHAPED ARRAYS ====== -- - -arithPromoteShaped :: forall sh a. KnownShape sh - => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a) - -> Shaped sh a -> Shaped sh a -arithPromoteShaped | Dict <- lemKnownMapJust (Proxy @sh) = coerce - -arithPromoteShaped2 :: forall sh a. KnownShape sh - => (forall shx. KnownShapeX shx => Mixed shx a -> Mixed shx a -> Mixed shx a) - -> Shaped sh a -> Shaped sh a -> Shaped sh a -arithPromoteShaped2 | Dict <- lemKnownMapJust (Proxy @sh) = coerce - -instance (KnownShape sh, Storable a, Num a) => Num (Shaped sh (Primitive a)) where - (+) = arithPromoteShaped2 (+) - (-) = arithPromoteShaped2 (-) - (*) = arithPromoteShaped2 (*) - negate = arithPromoteShaped negate - abs = arithPromoteShaped abs - signum = arithPromoteShaped signum - fromInteger n = sconstantP (fromInteger n) - --- [PRIMITIVE ELEMENT TYPES LIST] (really, a partial list of just the numeric types) -deriving via Shaped sh (Primitive Int) instance KnownShape sh => Num (Shaped sh Int) -deriving via Shaped sh (Primitive Double) instance KnownShape sh => Num (Shaped sh Double) - -type role ListS nominal representational -type ListS :: [Nat] -> Type -> Type -data ListS sh i where - ZS :: ListS '[] i - (::$) :: forall n sh {i}. i -> ListS sh i -> ListS (n : sh) i -deriving instance Show i => Show (ListS sh i) -deriving instance Eq i => Eq (ListS sh i) -deriving instance Ord i => Ord (ListS sh i) -deriving instance Functor (ListS sh) -infixr 3 ::$ - -instance Foldable (ListS sh) where - foldr f z l = foldr f z (listSToList l) - -listSToList :: ListS sh i -> [i] -listSToList ZS = [] -listSToList (i ::$ is) = i : listSToList is - --- | An index into a shape-typed array. --- --- For convenience, this contains regular 'Int's instead of bounded integers --- (traditionally called \"@Fin@\"). Note that because the shape of a --- shape-typed array is known statically, you can also retrieve the array shape --- from a 'KnownShape' dictionary. -type role IxS nominal representational -type IxS :: [Nat] -> Type -> Type -newtype IxS sh i = IxS (ListS sh i) - deriving (Show, Eq, Ord) - deriving newtype (Functor, Foldable) - -pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i -pattern ZIS = IxS ZS - -pattern (:.$) - :: forall {sh1} {i}. - forall n sh. (n : sh ~ sh1) - => i -> IxS sh i -> IxS sh1 i -pattern i :.$ shl <- (unconsIxS -> Just (UnconsIxSRes shl i)) - where i :.$ IxS shl = IxS (i ::$ shl) -{-# COMPLETE ZIS, (:.$) #-} -infixr 3 :.$ - -data UnconsIxSRes i sh1 = - forall n sh. (n : sh ~ sh1) => UnconsIxSRes (IxS sh i) i -unconsIxS :: IxS sh1 i -> Maybe (UnconsIxSRes i sh1) -unconsIxS (IxS (i ::$ shl')) = Just (UnconsIxSRes (IxS shl') i) -unconsIxS (IxS ZS) = Nothing - -type IIxS sh = IxS sh Int - -data UnconsShSRes sh1 = - forall n sh. (n : sh ~ sh1) => UnconsShSRes (ShS sh) (SNat n) -unconsShS :: ShS sh1 -> Maybe (UnconsShSRes sh1) -unconsShS (i :$$ shl') = Just (UnconsShSRes shl' i) -unconsShS ZSS = Nothing - -zeroIxS :: ShS sh -> IIxS sh -zeroIxS ZSS = ZIS -zeroIxS (_ :$$ sh) = 0 :.$ zeroIxS sh - -ixCvtXS :: ShS sh -> IIxX (MapJust sh) -> IIxS sh -ixCvtXS ZSS ZIX = ZIS -ixCvtXS (_ :$$ sh) (n :.@ idx) = n :.$ ixCvtXS sh idx - -shCvtXS :: ShS sh -> IShX (MapJust sh) -> ShS sh -shCvtXS ZSS ZSX = ZSS -shCvtXS (_ :$$ sh) (n :$@ idx) = n :$$ shCvtXS sh idx - -ixCvtSX :: IIxS sh -> IIxX (MapJust sh) -ixCvtSX ZIS = ZIX -ixCvtSX (n :.$ sh) = n :.@ ixCvtSX sh - -shCvtSX :: ShS sh -> IShX (MapJust sh) -shCvtSX ZSS = ZSX -shCvtSX (n :$$ sh) = n :$@ shCvtSX sh - -shapeSizeS :: ShS sh -> Int -shapeSizeS ZSS = 1 -shapeSizeS (n :$$ sh) = X.fromSNat' n * shapeSizeS sh - - --- | This does not touch the passed array, all information comes from 'KnownShape'. -sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> ShS sh -sshape _ = knownShape @sh - -sindex :: Elt a => Shaped sh a -> IIxS sh -> a -sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) - -sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a -sindexPartial (Shaped arr) idx = - Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) - (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) - (ixCvtSX idx)) - --- | __WARNING__: All values returned from the function must have equal shape. --- See the documentation of 'mgenerate' for more details. -sgenerate :: forall sh a. (KnownShape sh, Elt a) => (IIxS sh -> a) -> Shaped sh a -sgenerate f - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mgenerate (shCvtSX (knownShape @sh)) (f . ixCvtXS (knownShape @sh))) - --- | See the documentation of 'mlift'. -slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) - => (forall sh' b. KnownShapeX sh' => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) - -> Shaped sh1 a -> Shaped sh2 a -slift f (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh2) - = Shaped (mlift f arr) - -ssumOuter1P :: forall sh n a. - (Storable a, Num a, KnownNat n, KnownShape sh) - => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) -ssumOuter1P (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped - . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) (Primitive a)) - . X.sumOuter (natSing @n :!$@ ZKSX) (knownShapeX @(MapJust sh)) - . coerce @(Mixed (Just n : MapJust sh) (Primitive a)) @(XArray (Just n : MapJust sh) a) - $ arr - -ssumOuter1 :: forall sh n a. - (Storable a, Num a, PrimElt a, KnownNat n, KnownShape sh) - => Shaped (n : sh) a -> Shaped sh a -ssumOuter1 = coerce fromPrimitive . ssumOuter1P @sh @n @a . coerce toPrimitive - -stranspose :: forall sh a. (KnownShape sh, Elt a) => [Int] -> Shaped sh a -> Shaped sh a -stranspose perm (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mtranspose perm arr) - -sappend :: forall n m sh a. (KnownNat n, KnownNat m, KnownShape sh, Elt a) - => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a -sappend | Dict <- lemKnownMapJust (Proxy @sh) = coerce mappend - -sscalar :: Elt a => a -> Shaped '[] a -sscalar x = Shaped (mscalar x) - -sfromVectorP :: forall sh a. (KnownShape sh, Storable a) => VS.Vector a -> Shaped sh (Primitive a) -sfromVectorP v - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mfromVectorP (shCvtSX (knownShape @sh)) v) - -sfromVector :: forall sh a. (KnownShape sh, Storable a, PrimElt a) => VS.Vector a -> Shaped sh a -sfromVector v = coerce fromPrimitive (sfromVectorP @sh @a v) - -stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a -stoVectorP = coerce mtoVectorP - -stoVector :: (Storable a, PrimElt a) => Shaped sh a -> VS.Vector a -stoVector = coerce mtoVector - -sfromList1 :: forall n sh a. (KnownNat n, KnownShape sh, Elt a) - => NonEmpty (Shaped sh a) -> Shaped (n : sh) a -sfromList1 l - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mfromList1 (coerce l)) - -sfromList :: (KnownNat n, Elt a) => NonEmpty a -> Shaped '[n] a -sfromList = Shaped . mfromList1 . fmap mscalar - -stoList :: Elt a => Shaped (n : sh) a -> [Shaped sh a] -stoList (Shaped arr) = coerce (mtoList1 arr) - -stoList1 :: Elt a => Shaped '[n] a -> [a] -stoList1 = map sunScalar . stoList - -sunScalar :: Elt a => Shaped '[] a -> a -sunScalar arr = sindex arr ZIS - -sconstantP :: forall sh a. (KnownShape sh, Storable a) => a -> Shaped sh (Primitive a) -sconstantP x - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mconstantP (shCvtSX (knownShape @sh)) x) - -sconstant :: forall sh a. (KnownShape sh, Storable a, PrimElt a) - => a -> Shaped sh a -sconstant x = coerce fromPrimitive (sconstantP @sh x) - -sslice :: (KnownShape sh, Elt a) => [(Int, Int)] -> Shaped sh a -> Shaped sh a -sslice ivs = slift $ \_ -> X.slice ivs - -srev1 :: (KnownNat n, KnownShape sh, Elt a) => Shaped (n : sh) a -> Shaped (n : sh) a -srev1 = slift $ \_ -> X.rev1 diff --git a/src/Data/Array/Nested/Lemmas.hs b/src/Data/Array/Nested/Lemmas.hs new file mode 100644 index 0000000..8cac298 --- /dev/null +++ b/src/Data/Array/Nested/Lemmas.hs @@ -0,0 +1,162 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Lemmas where + +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits + +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types + + +-- * Lemmas about numbers and lists + +-- ** Nat + +lemLeqSuccSucc :: k + 1 <= n => Proxy k -> Proxy n -> (k <=? n - 1) :~: True +lemLeqSuccSucc _ _ = unsafeCoerceRefl + +lemLeqPlus :: n <= m => Proxy n -> Proxy m -> Proxy k -> (n <=? (m + k)) :~: 'True +lemLeqPlus _ _ _ = Refl + +-- ** Append + +lemAppNil :: l ++ '[] :~: l +lemAppNil = unsafeCoerceRefl + +lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) +lemAppAssoc _ _ _ = unsafeCoerceRefl + +lemAppLeft :: Proxy l -> a :~: b -> a ++ l :~: b ++ l +lemAppLeft _ Refl = Refl + +-- ** Simple type families + +lemReplicatePlusApp :: forall n m a. SNat n -> Proxy m -> Proxy a + -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +lemReplicatePlusApp sn _ _ = go sn + where + go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a + go SZ = Refl + go (SS (n :: SNat n'm1)) + | Refl <- lemReplicateSucc @a @n'm1 + , Refl <- go n + = sym (lemReplicateSucc @a @(n'm1 + m)) + +lemDropLenApp :: Rank l1 <= Rank l2 + => Proxy l1 -> Proxy l2 -> Proxy rest + -> DropLen l1 l2 ++ rest :~: DropLen l1 (l2 ++ rest) +lemDropLenApp _ _ _ = unsafeCoerceRefl + +lemTakeLenApp :: Rank l1 <= Rank l2 + => Proxy l1 -> Proxy l2 -> Proxy rest + -> TakeLen l1 l2 :~: TakeLen l1 (l2 ++ rest) +lemTakeLenApp _ _ _ = unsafeCoerceRefl + +lemInitApp :: Proxy l -> Proxy x -> Init (l ++ '[x]) :~: l +lemInitApp _ _ = unsafeCoerceRefl + +lemLastApp :: Proxy l -> Proxy x -> Last (l ++ '[x]) :~: x +lemLastApp _ _ = unsafeCoerceRefl + + +-- ** KnownNat + +lemKnownNatSucc :: KnownNat n => Dict KnownNat (n + 1) +lemKnownNatSucc = Dict + +lemKnownNatRank :: ShX sh i -> Dict KnownNat (Rank sh) +lemKnownNatRank ZSX = Dict +lemKnownNatRank (_ :$% sh) | Dict <- lemKnownNatRank sh = Dict + +lemKnownNatRankSSX :: StaticShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRankSSX ZKX = Dict +lemKnownNatRankSSX (_ :!% ssh) | Dict <- lemKnownNatRankSSX ssh = Dict + + +-- * Lemmas about shapes + +-- ** Known shapes + +lemKnownReplicate :: SNat n -> Dict KnownShX (Replicate n Nothing) +lemKnownReplicate sn = lemKnownShX (ssxFromSNat sn) + +lemKnownShX :: StaticShX sh -> Dict KnownShX sh +lemKnownShX ZKX = Dict +lemKnownShX (SKnown SNat :!% ssh) | Dict <- lemKnownShX ssh = Dict +lemKnownShX (SUnknown () :!% ssh) | Dict <- lemKnownShX ssh = Dict + +lemKnownMapJust :: forall sh. KnownShS sh => Proxy sh -> Dict KnownShX (MapJust sh) +lemKnownMapJust _ = lemKnownShX (go (knownShS @sh)) + where + go :: ShS sh' -> StaticShX (MapJust sh') + go ZSS = ZKX + go (n :$$ sh) = SKnown n :!% go sh + +-- ** Rank + +lemRankApp :: forall sh1 sh2. + StaticShX sh1 -> StaticShX sh2 + -> Rank (sh1 ++ sh2) :~: Rank sh1 + Rank sh2 +lemRankApp ZKX _ = Refl +lemRankApp (_ :!% (ssh1 :: StaticShX sh1T)) ssh2 + = lem (Proxy @(Rank sh1T)) Proxy Proxy $ + sym (lemRankApp ssh1 ssh2) + where + lem :: proxy a -> proxy b -> proxy c + -> (a + b :~: c) + -> c + 1 :~: (a + 1 + b) + lem _ _ _ Refl = Refl + +lemRankAppComm :: proxy sh1 -> proxy sh2 + -> Rank (sh1 ++ sh2) :~: Rank (sh2 ++ sh1) +lemRankAppComm _ _ = unsafeCoerceRefl + +lemRankReplicate :: proxy n -> Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = unsafeCoerceRefl + +lemRankMapJust :: ShS sh -> Rank (MapJust sh) :~: Rank sh +lemRankMapJust ZSS = Refl +lemRankMapJust (_ :$$ sh') | Refl <- lemRankMapJust sh' = Refl + +-- ** Related to MapJust and/or Permutation + +lemTakeLenMapJust :: Perm is -> ShS sh -> TakeLen is (MapJust sh) :~: MapJust (TakeLen is sh) +lemTakeLenMapJust PNil _ = Refl +lemTakeLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemTakeLenMapJust is sh = Refl +lemTakeLenMapJust (_ `PCons` _) ZSS = error "TakeLen of empty" + +lemDropLenMapJust :: Perm is -> ShS sh -> DropLen is (MapJust sh) :~: MapJust (DropLen is sh) +lemDropLenMapJust PNil _ = Refl +lemDropLenMapJust (_ `PCons` is) (_ :$$ sh) | Refl <- lemDropLenMapJust is sh = Refl +lemDropLenMapJust (_ `PCons` _) ZSS = error "DropLen of empty" + +lemIndexMapJust :: SNat i -> ShS sh -> Index i (MapJust sh) :~: Just (Index i sh) +lemIndexMapJust SZ (_ :$$ _) = Refl +lemIndexMapJust (SS (i :: SNat i')) ((_ :: SNat n) :$$ (sh :: ShS sh')) + | Refl <- lemIndexMapJust i sh + , Refl <- lemIndexSucc (Proxy @i') (Proxy @(Just n)) (Proxy @(MapJust sh')) + , Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = Refl +lemIndexMapJust _ ZSS = error "Index of empty" + +lemPermuteMapJust :: Perm is -> ShS sh -> Permute is (MapJust sh) :~: MapJust (Permute is sh) +lemPermuteMapJust PNil _ = Refl +lemPermuteMapJust (i `PCons` is) sh + | Refl <- lemPermuteMapJust is sh + , Refl <- lemIndexMapJust i sh + = Refl + +lemMapJustApp :: ShS sh1 -> Proxy sh2 + -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 +lemMapJustApp ZSS _ = Refl +lemMapJustApp (_ :$$ sh) p | Refl <- lemMapJustApp sh p = Refl diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs new file mode 100644 index 0000000..144230e --- /dev/null +++ b/src/Data/Array/Nested/Mixed.hs @@ -0,0 +1,936 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +module Data.Array.Nested.Mixed where + +import Prelude hiding (mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad (forM_, when) +import Control.Monad.ST +import Data.Array.RankedS qualified as S +import Data.Bifunctor (bimap) +import Data.Coerce +import Data.Foldable (toList) +import Data.Int +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty(..)) +import Data.List.NonEmpty qualified as NE +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Data.Vector.Storable.Mutable qualified as VSM +import Foreign.C.Types (CInt) +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Types +import Data.Array.Strided.Orthotope +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X +import Data.Bag + + +-- TODO: +-- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a +-- rminIndex1 :: Ranked (n + 1) a -> Ranked n Int +-- gather/scatter-like things (most generally, the higher-order variants: accelerate's backpermute/permute) +-- After benchmarking: matmul and matvec + + + +-- Invariant in the API +-- ==================== +-- +-- In the underlying XArray, there is some shape for elements of an empty +-- array. For example, for this array: +-- +-- arr :: Ranked I3 (Ranked I2 Int, Ranked I1 Float) +-- rshape arr == 0 :.: 0 :.: 0 :.: ZIR +-- +-- the two underlying XArrays have a shape, and those shapes might be anything. +-- The invariant is that these element shapes are unobservable in the API. +-- (This is possible because you ought to not be able to get to such an element +-- without indexing out of bounds.) +-- +-- Note, though, that the converse situation may arise: the outer array might +-- be nonempty but then the inner arrays might. This is fine, an invariant only +-- applies if the _outer_ array is empty. +-- +-- TODO: can we enforce that the elements of an empty (nested) array have +-- all-zero shape? +-- -> no, because mlift and also any kind of internals probing from outsiders + + +-- Primitive element types +-- ======================= +-- +-- There are a few primitive element types; arrays containing elements of such +-- type are a newtype over an XArray, which it itself a newtype over a Vector. +-- Unfortunately, the setup of the library requires us to list these primitive +-- element types multiple times; to aid in extending the list, all these lists +-- have been marked with [PRIMITIVE ELEMENT TYPES LIST]. + + +-- | Wrapper type used as a tag to attach instances on. The instances on arrays +-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays +-- of scalars; this means that if @orthotope@ supports an element type @T@ that +-- this library does not (directly), it may just work if you use an array of +-- @'Primitive' T@ instead. +newtype Primitive a = Primitive a + deriving (Show) + +-- | Element types that are primitive; arrays of these types are just a newtype +-- wrapper over an array. +class (Storable a, Elt a) => PrimElt a where + fromPrimitive :: Mixed sh (Primitive a) -> Mixed sh a + toPrimitive :: Mixed sh a -> Mixed sh (Primitive a) + + default fromPrimitive :: Coercible (Mixed sh a) (Mixed sh (Primitive a)) => Mixed sh (Primitive a) -> Mixed sh a + fromPrimitive = coerce + + default toPrimitive :: Coercible (Mixed sh (Primitive a)) (Mixed sh a) => Mixed sh a -> Mixed sh (Primitive a) + toPrimitive = coerce + +-- [PRIMITIVE ELEMENT TYPES LIST] +instance PrimElt Bool +instance PrimElt Int +instance PrimElt Int64 +instance PrimElt Int32 +instance PrimElt CInt +instance PrimElt Float +instance PrimElt Double +instance PrimElt () + + +-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes +-- over product-typed elements using a data family so that the full array is +-- always in struct-of-arrays format. +-- +-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that +-- dimension permutations (e.g. 'mtranspose') are typically free. +-- +-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type +-- class. +type Mixed :: [Maybe Nat] -> Type -> Type +data family Mixed sh a +-- NOTE: When opening up the Mixed abstraction, you might see dimension sizes +-- that you're not supposed to see. In particular, you might see (nonempty) +-- sizes of the elements of an empty array, which is information that should +-- ostensibly not exist; the full array is still empty. + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +#define ANDSHOW , Show +#else +#define ANDSHOW +#endif + +data instance Mixed sh (Primitive a) = M_Primitive !(IShX sh) !(XArray sh a) + deriving (Eq, Ord, Generic ANDSHOW) + +-- [PRIMITIVE ELEMENT TYPES LIST] +newtype instance Mixed sh Bool = M_Bool (Mixed sh (Primitive Bool)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Int = M_Int (Mixed sh (Primitive Int)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Int64 = M_Int64 (Mixed sh (Primitive Int64)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Int32 = M_Int32 (Mixed sh (Primitive Int32)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh CInt = M_CInt (Mixed sh (Primitive CInt)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Float = M_Float (Mixed sh (Primitive Float)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh Double = M_Double (Mixed sh (Primitive Double)) deriving (Eq, Ord, Generic ANDSHOW) +newtype instance Mixed sh () = M_Nil (Mixed sh (Primitive ())) deriving (Eq, Ord, Generic ANDSHOW) -- no content, orthotope optimises this (via Vector) +-- etc. + +data instance Mixed sh (a, b) = M_Tup2 !(Mixed sh a) !(Mixed sh b) deriving (Generic) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance (Show (Mixed sh a), Show (Mixed sh b)) => Show (Mixed sh (a, b)) +#endif +-- etc., larger tuples (perhaps use generics to allow arbitrary product types) + +deriving instance (Eq (Mixed sh a), Eq (Mixed sh b)) => Eq (Mixed sh (a, b)) +deriving instance (Ord (Mixed sh a), Ord (Mixed sh b)) => Ord (Mixed sh (a, b)) + +data instance Mixed sh1 (Mixed sh2 a) = M_Nest !(IShX sh1) !(Mixed (sh1 ++ sh2) a) deriving (Generic) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance (Show (Mixed (sh1 ++ sh2) a)) => Show (Mixed sh1 (Mixed sh2 a)) +#endif + +deriving instance Eq (Mixed (sh1 ++ sh2) a) => Eq (Mixed sh1 (Mixed sh2 a)) +deriving instance Ord (Mixed (sh1 ++ sh2) a) => Ord (Mixed sh1 (Mixed sh2 a)) + + +-- | Internal helper data family mirroring 'Mixed' that consists of mutable +-- vectors instead of 'XArray's. +type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type +data family MixedVecs s sh a + +newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VS.MVector s a) + +-- [PRIMITIVE ELEMENT TYPES LIST] +newtype instance MixedVecs s sh Bool = MV_Bool (VS.MVector s Bool) +newtype instance MixedVecs s sh Int = MV_Int (VS.MVector s Int) +newtype instance MixedVecs s sh Int64 = MV_Int64 (VS.MVector s Int64) +newtype instance MixedVecs s sh Int32 = MV_Int32 (VS.MVector s Int32) +newtype instance MixedVecs s sh CInt = MV_CInt (VS.MVector s CInt) +newtype instance MixedVecs s sh Double = MV_Double (VS.MVector s Double) +newtype instance MixedVecs s sh Float = MV_Float (VS.MVector s Float) +newtype instance MixedVecs s sh () = MV_Nil (VS.MVector s ()) -- no content, MVector optimises this +-- etc. + +data instance MixedVecs s sh (a, b) = MV_Tup2 !(MixedVecs s sh a) !(MixedVecs s sh b) +-- etc. + +data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest !(IShX sh2) !(MixedVecs s (sh1 ++ sh2) a) + + +showsMixedArray :: (Show a, Elt a) + => String -- ^ fromList prefix: e.g. @rfromListLinear [2,3]@ + -> String -- ^ replicate prefix: e.g. @rreplicate [2,3]@ + -> Int -> Mixed sh a -> ShowS +showsMixedArray fromlistPrefix replicatePrefix d arr = + showParen (d > 10) $ + -- TODO: to avoid ambiguity, we should type-apply the shape to mfromListLinear here + case mtoListLinear arr of + hd : _ : _ + | all (all (== 0) . take (shxLength (mshape arr))) (marrayStrides arr) -> + showString replicatePrefix . showString " " . showsPrec 11 hd + _ -> + showString fromlistPrefix . showString " " . shows (mtoListLinear arr) + +#ifndef OXAR_DEFAULT_SHOW_INSTANCES +instance (Show a, Elt a) => Show (Mixed sh a) where + showsPrec d arr = + let sh = show (shxToList (mshape arr)) + in showsMixedArray ("mfromListLinear " ++ sh) ("mreplicate " ++ sh) d arr +#endif + +instance Elt a => NFData (Mixed sh a) where + rnf = mrnf + + +mliftNumElt1 :: (PrimElt a, PrimElt b) + => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b) + -> Mixed sh a -> Mixed sh b +mliftNumElt1 f (toPrimitive -> M_Primitive sh (XArray arr)) = fromPrimitive $ M_Primitive sh (XArray (f (shxRank sh) arr)) + +mliftNumElt2 :: (PrimElt a, PrimElt b, PrimElt c) + => (SNat (Rank sh) -> S.Array (Rank sh) a -> S.Array (Rank sh) b -> S.Array (Rank sh) c) + -> Mixed sh a -> Mixed sh b -> Mixed sh c +mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_Primitive sh2 (XArray arr2)) + | sh1 == sh2 = fromPrimitive $ M_Primitive sh1 (XArray (f (shxRank sh1) arr1 arr2)) + | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 + +instance (NumElt a, PrimElt a) => Num (Mixed sh a) where + (+) = mliftNumElt2 (liftO2 . numEltAdd) + (-) = mliftNumElt2 (liftO2 . numEltSub) + (*) = mliftNumElt2 (liftO2 . numEltMul) + negate = mliftNumElt1 (liftO1 . numEltNeg) + abs = mliftNumElt1 (liftO1 . numEltAbs) + signum = mliftNumElt1 (liftO1 . numEltSignum) + -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS + fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where + fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" + recip = mliftNumElt1 (liftO1 . floatEltRecip) + (/) = mliftNumElt2 (liftO2 . floatEltDiv) + +instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" + exp = mliftNumElt1 (liftO1 . floatEltExp) + log = mliftNumElt1 (liftO1 . floatEltLog) + sqrt = mliftNumElt1 (liftO1 . floatEltSqrt) + + (**) = mliftNumElt2 (liftO2 . floatEltPow) + logBase = mliftNumElt2 (liftO2 . floatEltLogbase) + + sin = mliftNumElt1 (liftO1 . floatEltSin) + cos = mliftNumElt1 (liftO1 . floatEltCos) + tan = mliftNumElt1 (liftO1 . floatEltTan) + asin = mliftNumElt1 (liftO1 . floatEltAsin) + acos = mliftNumElt1 (liftO1 . floatEltAcos) + atan = mliftNumElt1 (liftO1 . floatEltAtan) + sinh = mliftNumElt1 (liftO1 . floatEltSinh) + cosh = mliftNumElt1 (liftO1 . floatEltCosh) + tanh = mliftNumElt1 (liftO1 . floatEltTanh) + asinh = mliftNumElt1 (liftO1 . floatEltAsinh) + acosh = mliftNumElt1 (liftO1 . floatEltAcosh) + atanh = mliftNumElt1 (liftO1 . floatEltAtanh) + log1p = mliftNumElt1 (liftO1 . floatEltLog1p) + expm1 = mliftNumElt1 (liftO1 . floatEltExpm1) + log1pexp = mliftNumElt1 (liftO1 . floatEltLog1pexp) + log1mexp = mliftNumElt1 (liftO1 . floatEltLog1mexp) + +mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a +mquotArray = mliftNumElt2 (liftO2 . intEltQuot) +mremArray = mliftNumElt2 (liftO2 . intEltRem) + +matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a +matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) + +-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or +-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' +-- a@; see the documentation for 'Primitive' for more details. +class Elt a where + -- ====== PUBLIC METHODS ====== -- + + mshape :: Mixed sh a -> IShX sh + mindex :: Mixed sh a -> IIxX sh -> a + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IIxX sh -> Mixed sh' a + mscalar :: a -> Mixed '[] a + + -- | All arrays in the list, even subarrays inside @a@, must have the same + -- shape; if they do not, a runtime error will be thrown. See the + -- documentation of 'mgenerate' for more information about this restriction. + -- Furthermore, the length of the list must correspond with @n@: if @n@ is + -- @Just m@ and @m@ does not equal the length of the list, a runtime error is + -- thrown. + -- + -- Consider also 'mfromListPrim', which can avoid intermediate arrays. + mfromListOuter :: forall sh. NonEmpty (Mixed sh a) -> Mixed (Nothing : sh) a + + mtoListOuter :: Mixed (n : sh) a -> [Mixed sh a] + + -- | Note: this library makes no particular guarantees about the shapes of + -- arrays "inside" an empty array. With 'mlift', 'mlift2' and 'mliftL' you can see the + -- full 'XArray' and as such you can distinguish different empty arrays by + -- the "shapes" of their elements. This information is meaningless, so you + -- should not use it. + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 a -> Mixed sh2 a + + -- | See the documentation for 'mlift'. + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 a -> Mixed sh2 a -> Mixed sh3 a + + -- TODO: mliftL is currently unused. + -- | All arrays in the input must have equal shapes, including subarrays + -- inside their elements. + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 a) -> NonEmpty (Mixed sh2 a) + + mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') a -> Mixed (sh2 ++ sh') a + + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh a -> Mixed (PermutePrefix is sh) a + + -- | All arrays in the input must have equal shapes, including subarrays + -- inside their elements. + mconcat :: NonEmpty (Mixed (Nothing : sh) a) -> Mixed (Nothing : sh) a + + mrnf :: Mixed sh a -> () + + -- ====== PRIVATE METHODS ====== -- + + -- | Tree giving the shape of every array component. + type ShapeTree a + + mshapeTree :: a -> ShapeTree a + + mshapeTreeEq :: Proxy a -> ShapeTree a -> ShapeTree a -> Bool + + mshapeTreeEmpty :: Proxy a -> ShapeTree a -> Bool + + mshowShapeTree :: Proxy a -> ShapeTree a -> String + + -- | Returns the stride vector of each underlying component array making up + -- this mixed array. + marrayStrides :: Mixed sh a -> Bag [Int] + + -- | Given the shape of this array, an index and a value, write the value at + -- that index in the vectors. + mvecsWrite :: IShX sh -> IIxX sh -> a -> MixedVecs s sh a -> ST s () + + -- | Given the shape of this array, an index and a value, write the value at + -- that index in the vectors. + mvecsWritePartial :: IShX (sh ++ sh') -> IIxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + + -- | Given the shape of this array, finalise the vectors into 'XArray's. + mvecsFreeze :: IShX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + + +-- | Element types for which we have evidence of the (static part of the) shape +-- in a type class constraint. Compare the instance contexts of the instances +-- of this class with those of 'Elt': some instances have an additional +-- "known-shape" constraint. +-- +-- This class is (currently) only required for `memptyArray` and 'mgenerate'. +class Elt a => KnownElt a where + -- | Create an empty array. The given shape must have size zero; this may or may not be checked. + memptyArrayUnsafe :: IShX sh -> Mixed sh a + + -- | Create uninitialised vectors for this array type, given the shape of + -- this vector and an example for the contents. + mvecsUnsafeNew :: IShX sh -> a -> ST s (MixedVecs s sh a) + + mvecsNewEmpty :: Proxy a -> ST s (MixedVecs s sh a) + + +-- Arrays of scalars are basically just arrays of scalars. +instance Storable a => Elt (Primitive a) where + mshape (M_Primitive sh _) = sh + mindex (M_Primitive _ a) i = Primitive (X.index a i) + mindexPartial (M_Primitive sh a) i = M_Primitive (shxDropIx i sh) (X.indexPartial a i) + mscalar (Primitive x) = M_Primitive ZSX (X.scalar x) + mfromListOuter l@(arr1 :| _) = + let sh = SUnknown (length l) :$% mshape arr1 + in M_Primitive sh (X.fromListOuter (ssxFromShX sh) (map (\(M_Primitive _ a) -> a) (toList l))) + mtoListOuter (M_Primitive sh arr) = map (M_Primitive (shxTail sh)) (X.toListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) + -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) + mlift ssh2 f (M_Primitive _ a) + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + , let result = f ZKX a + = M_Primitive (X.shape ssh2 result) result + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (StaticShX '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a -> XArray (sh3 ++ '[]) a) + -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) -> Mixed sh3 (Primitive a) + mlift2 ssh3 f (M_Primitive _ a) (M_Primitive _ b) + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + , Refl <- lemAppNil @sh3 + , let result = f ZKX a b + = M_Primitive (X.shape ssh3 result) result + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Primitive a)) -> NonEmpty (Mixed sh2 (Primitive a)) + mliftL ssh2 f l + | Refl <- lemAppNil @sh1 + , Refl <- lemAppNil @sh2 + = fmap (\arr -> M_Primitive (X.shape ssh2 arr) arr) $ + f ZKX (fmap (\(M_Primitive _ arr) -> arr) l) + + mcastPartial :: forall sh1 sh2 sh'. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> StaticShX sh2 -> Proxy sh' -> Mixed (sh1 ++ sh') (Primitive a) -> Mixed (sh2 ++ sh') (Primitive a) + mcastPartial ssh1 ssh2 _ (M_Primitive sh1' arr) = + let (sh1, sh') = shxSplitApp (Proxy @sh') ssh1 sh1' + sh2 = shxCast' ssh2 sh1 + in M_Primitive (shxAppend sh2 sh') (X.cast ssh1 sh2 (ssxFromShX sh') arr) + + mtranspose perm (M_Primitive sh arr) = + M_Primitive (shxPermutePrefix perm sh) + (X.transpose (ssxFromShX sh) perm arr) + + mconcat :: forall sh. NonEmpty (Mixed (Nothing : sh) (Primitive a)) -> Mixed (Nothing : sh) (Primitive a) + mconcat l@(M_Primitive (_ :$% sh) _ :| _) = + let result = X.concat (ssxFromShX sh) (fmap (\(M_Primitive _ arr) -> arr) l) + in M_Primitive (X.shape (SUnknown () :!% ssxFromShX sh) result) result + + mrnf (M_Primitive sh a) = rnf sh `seq` rnf a + + type ShapeTree (Primitive a) = () + mshapeTree _ = () + mshapeTreeEq _ () () = True + mshapeTreeEmpty _ () = False + mshowShapeTree _ () = "()" + marrayStrides (M_Primitive _ arr) = BOne (X.arrayStrides arr) + mvecsWrite sh i (Primitive x) (MV_Primitive v) = VSM.write v (ixxToLinear sh i) x + + -- TODO: this use of toVector is suboptimal + mvecsWritePartial + :: forall sh' sh s. + IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + mvecsWritePartial sh i (M_Primitive sh' arr) (MV_Primitive v) = do + let arrsh = X.shape (ssxFromShX sh') arr + offset = ixxToLinear sh (ixxAppend i (ixxZero' arrsh)) + VS.copy (VSM.slice offset (shxSize arrsh) v) (X.toVector arr) + + mvecsFreeze sh (MV_Primitive v) = M_Primitive sh . X.fromVector sh <$> VS.freeze v + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving via Primitive Bool instance Elt Bool +deriving via Primitive Int instance Elt Int +deriving via Primitive Int64 instance Elt Int64 +deriving via Primitive Int32 instance Elt Int32 +deriving via Primitive CInt instance Elt CInt +deriving via Primitive Double instance Elt Double +deriving via Primitive Float instance Elt Float +deriving via Primitive () instance Elt () + +instance Storable a => KnownElt (Primitive a) where + memptyArrayUnsafe sh = M_Primitive sh (X.empty sh) + mvecsUnsafeNew sh _ = MV_Primitive <$> VSM.unsafeNew (shxSize sh) + mvecsNewEmpty _ = MV_Primitive <$> VSM.unsafeNew 0 + +-- [PRIMITIVE ELEMENT TYPES LIST] +deriving via Primitive Bool instance KnownElt Bool +deriving via Primitive Int instance KnownElt Int +deriving via Primitive Int64 instance KnownElt Int64 +deriving via Primitive Int32 instance KnownElt Int32 +deriving via Primitive CInt instance KnownElt CInt +deriving via Primitive Double instance KnownElt Double +deriving via Primitive Float instance KnownElt Float +deriving via Primitive () instance KnownElt () + +-- Arrays of pairs are pairs of arrays. +instance (Elt a, Elt b) => Elt (a, b) where + mshape (M_Tup2 a _) = mshape a + mindex (M_Tup2 a b) i = (mindex a i, mindex b i) + mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) + mscalar (x, y) = M_Tup2 (mscalar x) (mscalar y) + mfromListOuter l = + M_Tup2 (mfromListOuter ((\(M_Tup2 x _) -> x) <$> l)) + (mfromListOuter ((\(M_Tup2 _ y) -> y) <$> l)) + mtoListOuter (M_Tup2 a b) = zipWith M_Tup2 (mtoListOuter a) (mtoListOuter b) + mlift ssh2 f (M_Tup2 a b) = M_Tup2 (mlift ssh2 f a) (mlift ssh2 f b) + mlift2 ssh3 f (M_Tup2 a b) (M_Tup2 x y) = M_Tup2 (mlift2 ssh3 f a x) (mlift2 ssh3 f b y) + mliftL ssh2 f = + let unzipT2l [] = ([], []) + unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) + unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) + in uncurry (NE.zipWith M_Tup2) . bimap (mliftL ssh2 f) (mliftL ssh2 f) . unzipT2 + + mcastPartial ssh1 sh2 psh' (M_Tup2 a b) = + M_Tup2 (mcastPartial ssh1 sh2 psh' a) (mcastPartial ssh1 sh2 psh' b) + + mtranspose perm (M_Tup2 a b) = M_Tup2 (mtranspose perm a) (mtranspose perm b) + mconcat = + let unzipT2l [] = ([], []) + unzipT2l (M_Tup2 a b : l) = let (l1, l2) = unzipT2l l in (a : l1, b : l2) + unzipT2 (M_Tup2 a b :| l) = let (l1, l2) = unzipT2l l in (a :| l1, b :| l2) + in uncurry M_Tup2 . bimap mconcat mconcat . unzipT2 + + mrnf (M_Tup2 a b) = mrnf a `seq` mrnf b + + type ShapeTree (a, b) = (ShapeTree a, ShapeTree b) + mshapeTree (x, y) = (mshapeTree x, mshapeTree y) + mshapeTreeEq _ (t1, t2) (t1', t2') = mshapeTreeEq (Proxy @a) t1 t1' && mshapeTreeEq (Proxy @b) t2 t2' + mshapeTreeEmpty _ (t1, t2) = mshapeTreeEmpty (Proxy @a) t1 && mshapeTreeEmpty (Proxy @b) t2 + mshowShapeTree _ (t1, t2) = "(" ++ mshowShapeTree (Proxy @a) t1 ++ ", " ++ mshowShapeTree (Proxy @b) t2 ++ ")" + marrayStrides (M_Tup2 a b) = marrayStrides a <> marrayStrides b + mvecsWrite sh i (x, y) (MV_Tup2 a b) = do + mvecsWrite sh i x a + mvecsWrite sh i y b + mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do + mvecsWritePartial sh i x a + mvecsWritePartial sh i y b + mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b + +instance (KnownElt a, KnownElt b) => KnownElt (a, b) where + memptyArrayUnsafe sh = M_Tup2 (memptyArrayUnsafe sh) (memptyArrayUnsafe sh) + mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsNewEmpty _ = MV_Tup2 <$> mvecsNewEmpty (Proxy @a) <*> mvecsNewEmpty (Proxy @b) + +-- Arrays of arrays are just arrays, but with more dimensions. +instance Elt a => Elt (Mixed sh' a) where + -- TODO: this is quadratic in the nesting depth because it repeatedly + -- truncates the shape vector to one a little shorter. Fix with a + -- moverlongShape method, a prefix of which is mshape. + mshape :: forall sh. Mixed sh (Mixed sh' a) -> IShX sh + mshape (M_Nest sh arr) + = fst (shxSplitApp (Proxy @sh') (ssxFromShX sh) (mshape arr)) + + mindex :: Mixed sh (Mixed sh' a) -> IIxX sh -> Mixed sh' a + mindex (M_Nest _ arr) = mindexPartial arr + + mindexPartial :: forall sh1 sh2. + Mixed (sh1 ++ sh2) (Mixed sh' a) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + mindexPartial (M_Nest sh arr) i + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = M_Nest (shxDropIx i sh) (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + + mscalar = M_Nest ZSX + + mfromListOuter :: forall sh. NonEmpty (Mixed sh (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) + mfromListOuter l@(arr :| _) = + M_Nest (SUnknown (length l) :$% mshape arr) + (mfromListOuter ((\(M_Nest _ a) -> a) <$> l)) + + mtoListOuter (M_Nest sh arr) = map (M_Nest (shxTail sh)) (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b) + -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) + mlift ssh2 f (M_Nest sh1 arr) = + let result = mlift (ssxAppend ssh2 ssh') f' arr + (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape result) + in M_Nest sh2 result + where + ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr))) + + f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b + f' sshT + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall shT b. Storable b => StaticShX shT -> XArray (sh1 ++ shT) b -> XArray (sh2 ++ shT) b -> XArray (sh3 ++ shT) b) + -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) -> Mixed sh3 (Mixed sh' a) + mlift2 ssh3 f (M_Nest sh1 arr1) (M_Nest _ arr2) = + let result = mlift2 (ssxAppend ssh3 ssh') f' arr1 arr2 + (sh3, _) = shxSplitApp (Proxy @sh') ssh3 (mshape result) + in M_Nest sh3 result + where + ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) + + f' :: forall shT b. Storable b => StaticShX shT -> XArray ((sh1 ++ sh') ++ shT) b -> XArray ((sh2 ++ sh') ++ shT) b -> XArray ((sh3 ++ sh') ++ shT) b + f' sshT + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh3) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray (sh1 ++ shT) b) -> NonEmpty (XArray (sh2 ++ shT) b)) + -> NonEmpty (Mixed sh1 (Mixed sh' a)) -> NonEmpty (Mixed sh2 (Mixed sh' a)) + mliftL ssh2 f l@(M_Nest sh1 arr1 :| _) = + let result = mliftL (ssxAppend ssh2 ssh') f' (fmap (\(M_Nest _ arr) -> arr) l) + (sh2, _) = shxSplitApp (Proxy @sh') ssh2 (mshape (NE.head result)) + in fmap (M_Nest sh2) result + where + ssh' = ssxFromShX (snd (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape arr1))) + + f' :: forall shT b. Storable b => StaticShX shT -> NonEmpty (XArray ((sh1 ++ sh') ++ shT) b) -> NonEmpty (XArray ((sh2 ++ sh') ++ shT) b) + f' sshT + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @shT) + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @shT) + = f (ssxAppend ssh' sshT) + + mcastPartial :: forall sh1 sh2 shT. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> StaticShX sh2 -> Proxy shT -> Mixed (sh1 ++ shT) (Mixed sh' a) -> Mixed (sh2 ++ shT) (Mixed sh' a) + mcastPartial ssh1 ssh2 _ (M_Nest sh1T arr) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @shT) (Proxy @sh') + , Refl <- lemAppAssoc (Proxy @sh2) (Proxy @shT) (Proxy @sh') + = let (sh1, shT) = shxSplitApp (Proxy @shT) ssh1 sh1T + sh2 = shxCast' ssh2 sh1 + in M_Nest (shxAppend sh2 shT) (mcastPartial ssh1 ssh2 (Proxy @(shT ++ sh')) arr) + + mtranspose :: forall is sh. (IsPermutation is, Rank is <= Rank sh) + => Perm is -> Mixed sh (Mixed sh' a) + -> Mixed (PermutePrefix is sh) (Mixed sh' a) + mtranspose perm (M_Nest sh arr) + | let sh' = shxDropSh @sh @sh' sh (mshape arr) + , Refl <- lemRankApp (ssxFromShX sh) (ssxFromShX sh') + , Refl <- lemLeqPlus (Proxy @(Rank is)) (Proxy @(Rank sh)) (Proxy @(Rank sh')) + , Refl <- lemAppAssoc (Proxy @(Permute is (TakeLen is (sh ++ sh')))) (Proxy @(DropLen is sh)) (Proxy @sh') + , Refl <- lemDropLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') + , Refl <- lemTakeLenApp (Proxy @is) (Proxy @sh) (Proxy @sh') + = M_Nest (shxPermutePrefix perm sh) + (mtranspose perm arr) + + mconcat :: NonEmpty (Mixed (Nothing : sh) (Mixed sh' a)) -> Mixed (Nothing : sh) (Mixed sh' a) + mconcat l@(M_Nest sh1 _ :| _) = + let result = mconcat (fmap (\(M_Nest _ arr) -> arr) l) + in M_Nest (fst (shxSplitApp (Proxy @sh') (ssxFromShX sh1) (mshape result))) result + + mrnf (M_Nest sh arr) = rnf sh `seq` mrnf arr + + type ShapeTree (Mixed sh' a) = (IShX sh', ShapeTree a) + + mshapeTree :: Mixed sh' a -> ShapeTree (Mixed sh' a) + mshapeTree arr = (mshape arr, mshapeTree (mindex arr (ixxZero (ssxFromShX (mshape arr))))) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shxSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Nest _ arr) = marrayStrides arr + + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (shxAppend sh sh') idx val vecs + + mvecsWritePartial :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () + mvecsWritePartial sh12 idx (M_Nest _ arr) (MV_Nest sh' vecs) + | Refl <- lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = mvecsWritePartial (shxAppend sh12 sh') idx arr vecs + + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest sh <$> mvecsFreeze (shxAppend sh sh') vecs + +instance (KnownShX sh', KnownElt a) => KnownElt (Mixed sh' a) where + memptyArrayUnsafe sh = M_Nest sh (memptyArrayUnsafe (shxAppend sh (shxCompleteZeros (knownShX @sh')))) + + mvecsUnsafeNew sh example + | shxSize sh' == 0 = mvecsNewEmpty (Proxy @(Mixed sh' a)) + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (shxAppend sh sh') (mindex example (ixxZero (ssxFromShX sh'))) + where + sh' = mshape example + + mvecsNewEmpty _ = MV_Nest (shxCompleteZeros (knownShX @sh')) <$> mvecsNewEmpty (Proxy @a) + + +memptyArray :: KnownElt a => IShX sh -> Mixed (Just 0 : sh) a +memptyArray sh = memptyArrayUnsafe (SKnown SNat :$% sh) + +mrank :: Elt a => Mixed sh a -> SNat (Rank sh) +mrank = shxRank . mshape + +-- | The total number of elements in the array. +msize :: Elt a => Mixed sh a -> Int +msize = shxSize . mshape + +-- | Create an array given a size and a function that computes the element at a +-- given index. +-- +-- __WARNING__: It is required that every @a@ returned by the argument to +-- 'mgenerate' has the same shape. For example, the following will throw a +-- runtime error: +-- +-- > foo :: Mixed [Nothing] (Mixed [Nothing] Double) +-- > foo = mgenerate (10 :.: ZIR) $ \(i :.: ZIR) -> +-- > mgenerate (i :.: ZIR) $ \(j :.: ZIR) -> +-- > ... +-- +-- because the size of the inner 'mgenerate' is not always the same (it depends +-- on @i@). Nested arrays in @ox-arrays@ are always stored fully flattened, so +-- the entire hierarchy (after distributing out tuples) must be a rectangular +-- array. The type of 'mgenerate' allows this requirement to be broken very +-- easily, hence the runtime check. +mgenerate :: forall sh a. KnownElt a => IShX sh -> (IIxX sh -> a) -> Mixed sh a +mgenerate sh f = case shxEnum sh of + [] -> memptyArrayUnsafe sh + firstidx : restidxs -> + let firstelem = f (ixxZero' sh) + shapetree = mshapeTree firstelem + in if mshapeTreeEmpty (Proxy @a) shapetree + then memptyArrayUnsafe sh + else runST $ do + vecs <- mvecsUnsafeNew sh firstelem + mvecsWrite sh firstidx firstelem vecs + -- TODO: This is likely fine if @a@ is big, but if @a@ is a + -- scalar this array copying inefficient. Should improve this. + forM_ restidxs $ \idx -> do + let val = f idx + when (not (mshapeTreeEq (Proxy @a) (mshapeTree val) shapetree)) $ + error "Data.Array.Nested mgenerate: generated values do not have equal shapes" + mvecsWrite sh idx val vecs + mvecsFreeze sh vecs + +msumOuter1P :: forall sh n a. (Storable a, NumElt a) + => Mixed (n : sh) (Primitive a) -> Mixed sh (Primitive a) +msumOuter1P (M_Primitive (n :$% sh) arr) = + let nssh = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ZKX + in M_Primitive sh (X.sumOuter nssh (ssxFromShX sh) arr) + +msumOuter1 :: forall sh n a. (NumElt a, PrimElt a) + => Mixed (n : sh) a -> Mixed sh a +msumOuter1 = fromPrimitive . msumOuter1P @sh @n @a . toPrimitive + +msumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a +msumAllPrim (toPrimitive -> M_Primitive sh arr) = X.sumFull (ssxFromShX sh) arr + +mappend :: forall n m sh a. Elt a + => Mixed (n : sh) a -> Mixed (m : sh) a -> Mixed (AddMaybe n m : sh) a +mappend arr1 arr2 = mlift2 (snm :!% ssh) f arr1 arr2 + where + sn :$% sh = mshape arr1 + sm :$% _ = mshape arr2 + ssh = ssxFromShX sh + snm :: SMayNat () SNat (AddMaybe n m) + snm = case (sn, sm) of + (SUnknown{}, _) -> SUnknown () + (SKnown{}, SUnknown{}) -> SUnknown () + (SKnown n, SKnown m) -> SKnown (snatPlus n m) + + f :: forall sh' b. Storable b + => StaticShX sh' -> XArray (n : sh ++ sh') b -> XArray (m : sh ++ sh') b -> XArray (AddMaybe n m : sh ++ sh') b + f ssh' = X.append (ssxAppend ssh ssh') + +mfromVectorP :: forall sh a. Storable a => IShX sh -> VS.Vector a -> Mixed sh (Primitive a) +mfromVectorP sh v = M_Primitive sh (X.fromVector sh v) + +mfromVector :: forall sh a. PrimElt a => IShX sh -> VS.Vector a -> Mixed sh a +mfromVector sh v = fromPrimitive (mfromVectorP sh v) + +mtoVectorP :: Storable a => Mixed sh (Primitive a) -> VS.Vector a +mtoVectorP (M_Primitive _ v) = X.toVector v + +mtoVector :: PrimElt a => Mixed sh a -> VS.Vector a +mtoVector arr = mtoVectorP (toPrimitive arr) + +mfromList1 :: Elt a => NonEmpty a -> Mixed '[Nothing] a +mfromList1 = mfromListOuter . fmap mscalar -- TODO: optimise? + +-- This forall is there so that a simple type application can constrain the +-- shape, in case the user wants to use OverloadedLists for the shape. +mfromListLinear :: forall sh a. Elt a => IShX sh -> NonEmpty a -> Mixed sh a +mfromListLinear sh l = mreshape sh (mfromList1 l) + +mfromListPrim :: PrimElt a => [a] -> Mixed '[Nothing] a +mfromListPrim l = + let ssh = SUnknown () :!% ZKX + xarr = X.fromList1 ssh l + in fromPrimitive $ M_Primitive (X.shape ssh xarr) xarr + +mfromListPrimLinear :: PrimElt a => IShX sh -> [a] -> Mixed sh a +mfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in fromPrimitive $ M_Primitive sh (X.reshape (SUnknown () :!% ZKX) sh xarr) + +mtoList :: Elt a => Mixed '[n] a -> [a] +mtoList = map munScalar . mtoListOuter + +mtoListLinear :: Elt a => Mixed sh a -> [a] +mtoListLinear arr = map (mindex arr) (shxEnum (mshape arr)) -- TODO: optimise + +munScalar :: Elt a => Mixed '[] a -> a +munScalar arr = mindex arr ZIX + +mnest :: forall sh sh' a. Elt a => StaticShX sh -> Mixed (sh ++ sh') a -> Mixed sh (Mixed sh' a) +mnest ssh arr = M_Nest (fst (shxSplitApp (Proxy @sh') ssh (mshape arr))) arr + +munNest :: Mixed sh (Mixed sh' a) -> Mixed (sh ++ sh') a +munNest (M_Nest _ arr) = arr + +-- | The arguments must have equal shapes. If they do not, an error is raised. +mzip :: (Elt a, Elt b) => Mixed sh a -> Mixed sh b -> Mixed sh (a, b) +mzip a b + | Just Refl <- shxEqual (mshape a) (mshape b) = M_Tup2 a b + | otherwise = error "mzip: unequal shapes" + +munzip :: Mixed sh (a, b) -> (Mixed sh a, Mixed sh b) +munzip (M_Tup2 a b) = (a, b) + +mrerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => StaticShX sh -> IShX sh2 + -> (Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive b)) + -> Mixed (sh ++ sh1) (Primitive a) -> Mixed (sh ++ sh2) (Primitive b) +mrerankP ssh sh2 f (M_Primitive sh arr) = + let sh1 = shxDropSSX ssh sh + in M_Primitive (shxAppend (shxTakeSSX (Proxy @sh1) ssh sh) sh2) + (X.rerank ssh (ssxFromShX sh1) (ssxFromShX sh2) + (\a -> let M_Primitive _ r = f (M_Primitive sh1 a) in r) + arr) + +-- | See the caveats at 'Data.Array.XArray.rerank'. +mrerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => StaticShX sh -> IShX sh2 + -> (Mixed sh1 a -> Mixed sh2 b) + -> Mixed (sh ++ sh1) a -> Mixed (sh ++ sh2) b +mrerank ssh sh2 f (toPrimitive -> arr) = + fromPrimitive $ mrerankP ssh sh2 (toPrimitive . f . fromPrimitive) arr + +mreplicate :: forall sh sh' a. Elt a + => IShX sh -> Mixed sh' a -> Mixed (sh ++ sh') a +mreplicate sh arr = + let ssh' = ssxFromShX (mshape arr) + in mlift (ssxAppend (ssxFromShX sh) ssh') + (\(sshT :: StaticShX shT) -> + case lemAppAssoc (Proxy @sh) (Proxy @sh') (Proxy @shT) of + Refl -> X.replicate sh (ssxAppend ssh' sshT)) + arr + +mreplicateScalP :: forall sh a. Storable a => IShX sh -> a -> Mixed sh (Primitive a) +mreplicateScalP sh x = M_Primitive sh (X.replicateScal sh x) + +mreplicateScal :: forall sh a. PrimElt a + => IShX sh -> a -> Mixed sh a +mreplicateScal sh x = fromPrimitive (mreplicateScalP sh x) + +mslice :: Elt a => SNat i -> SNat n -> Mixed (Just (i + n + k) : sh) a -> Mixed (Just n : sh) a +mslice i n arr = + let _ :$% sh = mshape arr + in mlift (SKnown n :!% ssxFromShX sh) (\_ -> X.slice i n) arr + +msliceU :: Elt a => Int -> Int -> Mixed (Nothing : sh) a -> Mixed (Nothing : sh) a +msliceU i n arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.sliceU i n) arr + +mrev1 :: Elt a => Mixed (n : sh) a -> Mixed (n : sh) a +mrev1 arr = mlift (ssxFromShX (mshape arr)) (\_ -> X.rev1) arr + +mreshape :: forall sh sh' a. Elt a => IShX sh' -> Mixed sh a -> Mixed sh' a +mreshape sh' arr = + mlift (ssxFromShX sh') + (\sshIn -> X.reshapePartial (ssxFromShX (mshape arr)) sshIn sh') + arr + +mflatten :: Elt a => Mixed sh a -> Mixed '[Flatten sh] a +mflatten arr = mreshape (shxFlatten (mshape arr) :$% ZSX) arr + +miota :: (Enum a, PrimElt a) => SNat n -> Mixed '[Just n] a +miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) + +-- | Throws if the array is empty. +mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh +mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = + ixxFromList (ssxFromShX sh) (numEltMinIndex (shxRank sh) (fromO arr)) + +-- | Throws if the array is empty. +mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh +mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = + ixxFromList (ssxFromShX sh) (numEltMaxIndex (shxRank sh) (fromO arr)) + +mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) + => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a +mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primitive sh2 (XArray b)) + | Refl <- lemInitApp (Proxy @sh) (Proxy @n) + , Refl <- lemLastApp (Proxy @sh) (Proxy @n) + = case sh1 of + _ :$% _ + | sh1 == sh2 + , Refl <- lemRankApp (ssxInit (ssxFromShX sh1)) (ssxLast (ssxFromShX sh1) :!% ZKX) -> + fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b)) + | otherwise -> error $ "mdot1Inner: Unequal shapes (" ++ show sh1 ++ " and " ++ show sh2 ++ ")" + ZSX -> error "unreachable" + +-- | This has a temporary, suboptimal implementation in terms of 'mflatten'. +-- Prefer 'mdot1Inner' if applicable. +mdot :: (PrimElt a, NumElt a) => Mixed sh a -> Mixed sh a -> a +mdot a b = + munScalar $ + mdot1Inner Proxy (fromPrimitive (mflatten (toPrimitive a))) + (fromPrimitive (mflatten (toPrimitive b))) + +mtoXArrayPrimP :: Mixed sh (Primitive a) -> (IShX sh, XArray sh a) +mtoXArrayPrimP (M_Primitive sh arr) = (sh, arr) + +mtoXArrayPrim :: PrimElt a => Mixed sh a -> (IShX sh, XArray sh a) +mtoXArrayPrim = mtoXArrayPrimP . toPrimitive + +mfromXArrayPrimP :: StaticShX sh -> XArray sh a -> Mixed sh (Primitive a) +mfromXArrayPrimP ssh arr = M_Primitive (X.shape ssh arr) arr + +mfromXArrayPrim :: PrimElt a => StaticShX sh -> XArray sh a -> Mixed sh a +mfromXArrayPrim = (fromPrimitive .) . mfromXArrayPrimP + +mliftPrim :: (PrimElt a, PrimElt b) + => (a -> b) + -> Mixed sh a -> Mixed sh b +mliftPrim f (toPrimitive -> M_Primitive sh (X.XArray arr)) = fromPrimitive $ M_Primitive sh (X.XArray (S.mapA f arr)) + +mliftPrim2 :: (PrimElt a, PrimElt b, PrimElt c) + => (a -> b -> c) + -> Mixed sh a -> Mixed sh b -> Mixed sh c +mliftPrim2 f (toPrimitive -> M_Primitive sh (X.XArray arr1)) (toPrimitive -> M_Primitive _ (X.XArray arr2)) = + fromPrimitive $ M_Primitive sh (X.XArray (S.zipWithA f arr1 arr2)) diff --git a/src/Data/Array/Nested/Mixed/Shape.hs b/src/Data/Array/Nested/Mixed/Shape.hs new file mode 100644 index 0000000..852dd5e --- /dev/null +++ b/src/Data/Array/Nested/Mixed/Shape.hs @@ -0,0 +1,644 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Mixed.Shape where + +import Control.DeepSeq (NFData(..)) +import Data.Bifunctor (first) +import Data.Coerce +import Data.Foldable qualified as Foldable +import Data.Functor.Const +import Data.Functor.Product +import Data.Kind (Constraint, Type) +import Data.Monoid (Sum(..)) +import Data.Type.Equality +import GHC.Exts (withDict) +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits + +import Data.Array.Nested.Types + + +-- | The length of a type-level list. If the argument is a shape, then the +-- result is the rank of that shape. +type family Rank sh where + Rank '[] = 0 + Rank (_ : sh) = Rank sh + 1 + + +-- * Mixed lists + +type role ListX nominal representational +type ListX :: [Maybe Nat] -> (Maybe Nat -> Type) -> Type +data ListX sh f where + ZX :: ListX '[] f + (::%) :: f n -> ListX sh f -> ListX (n : sh) f +deriving instance (forall n. Eq (f n)) => Eq (ListX sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListX sh f) +infixr 3 ::% + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance (forall n. Show (f n)) => Show (ListX sh f) +#else +instance (forall n. Show (f n)) => Show (ListX sh f) where + showsPrec _ = listxShow shows +#endif + +instance (forall n. NFData (f n)) => NFData (ListX sh f) where + rnf ZX = () + rnf (x ::% l) = rnf x `seq` rnf l + +data UnconsListXRes f sh1 = + forall n sh. (n : sh ~ sh1) => UnconsListXRes (ListX sh f) (f n) +listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1) +listxUncons (i ::% shl') = Just (UnconsListXRes shl' i) +listxUncons ZX = Nothing + +-- | This checks only whether the types are equal; if the elements of the list +-- are not singletons, their values may still differ. This corresponds to +-- 'testEquality', except on the penultimate type parameter. +listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') +listxEqType ZX ZX = Just Refl +listxEqType (n ::% sh) (m ::% sh') + | Just Refl <- testEquality n m + , Just Refl <- listxEqType sh sh' + = Just Refl +listxEqType _ _ = Nothing + +-- | This checks whether the two lists actually contain equal values. This is +-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ +-- in the @some@ package (except on the penultimate type parameter). +listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh') +listxEqual ZX ZX = Just Refl +listxEqual (n ::% sh) (m ::% sh') + | Just Refl <- testEquality n m + , n == m + , Just Refl <- listxEqual sh sh' + = Just Refl +listxEqual _ _ = Nothing + +listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g +listxFmap _ ZX = ZX +listxFmap f (x ::% xs) = f x ::% listxFmap f xs + +listxFold :: Monoid m => (forall n. f n -> m) -> ListX sh f -> m +listxFold _ ZX = mempty +listxFold f (x ::% xs) = f x <> listxFold f xs + +listxLength :: ListX sh f -> Int +listxLength = getSum . listxFold (\_ -> Sum 1) + +listxRank :: ListX sh f -> SNat (Rank sh) +listxRank ZX = SNat +listxRank (_ ::% l) | SNat <- listxRank l = SNat + +listxShow :: forall sh f. (forall n. f n -> ShowS) -> ListX sh f -> ShowS +listxShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListX sh' f -> ShowS + go _ ZX = id + go prefix (x ::% xs) = showString prefix . f x . go "," xs + +listxFromList :: StaticShX sh -> [i] -> ListX sh (Const i) +listxFromList topssh topl = go topssh topl + where + go :: StaticShX sh' -> [i] -> ListX sh' (Const i) + go ZKX [] = ZX + go (_ :!% sh) (i : is) = Const i ::% go sh is + go _ _ = error $ "listxFromList: Mismatched list length (type says " + ++ show (ssxLength topssh) ++ ", list has length " + ++ show (length topl) ++ ")" + +listxToList :: ListX sh' (Const i) -> [i] +listxToList ZX = [] +listxToList (Const i ::% is) = i : listxToList is + +listxHead :: ListX (mn ': sh) f -> f mn +listxHead (i ::% _) = i + +listxTail :: ListX (n : sh) i -> ListX sh i +listxTail (_ ::% sh) = sh + +listxAppend :: ListX sh f -> ListX sh' f -> ListX (sh ++ sh') f +listxAppend ZX idx' = idx' +listxAppend (i ::% idx) idx' = i ::% listxAppend idx idx' + +listxDrop :: forall f g sh sh'. ListX sh g -> ListX (sh ++ sh') f -> ListX sh' f +listxDrop ZX long = long +listxDrop (_ ::% short) long = case long of _ ::% long' -> listxDrop short long' + +listxInit :: forall f n sh. ListX (n : sh) f -> ListX (Init (n : sh)) f +listxInit (i ::% sh@(_ ::% _)) = i ::% listxInit sh +listxInit (_ ::% ZX) = ZX + +listxLast :: forall f n sh. ListX (n : sh) f -> f (Last (n : sh)) +listxLast (_ ::% sh@(_ ::% _)) = listxLast sh +listxLast (x ::% ZX) = x + +listxZip :: ListX sh f -> ListX sh g -> ListX sh (Product f g) +listxZip ZX ZX = ZX +listxZip (i ::% irest) (j ::% jrest) = + Pair i j ::% listxZip irest jrest + +listxZipWith :: (forall a. f a -> g a -> h a) -> ListX sh f -> ListX sh g + -> ListX sh h +listxZipWith _ ZX ZX = ZX +listxZipWith f (i ::% is) (j ::% js) = + f i j ::% listxZipWith f is js + + +-- * Mixed indices + +-- | An index into a mixed-typed array. +type role IxX nominal representational +type IxX :: [Maybe Nat] -> Type -> Type +newtype IxX sh i = IxX (ListX sh (Const i)) + deriving (Eq, Ord, Generic) + +pattern ZIX :: forall sh i. () => sh ~ '[] => IxX sh i +pattern ZIX = IxX ZX + +pattern (:.%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => i -> IxX sh i -> IxX sh1 i +pattern i :.% shl <- IxX (listxUncons -> Just (UnconsListXRes (IxX -> shl) (getConst -> i))) + where i :.% IxX shl = IxX (Const i ::% shl) +infixr 3 :.% + +{-# COMPLETE ZIX, (:.%) #-} + +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). +type IIxX sh = IxX sh Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (IxX sh i) +#else +instance Show i => Show (IxX sh i) where + showsPrec _ (IxX l) = listxShow (shows . getConst) l +#endif + +instance Functor (IxX sh) where + fmap f (IxX l) = IxX (listxFmap (Const . f . getConst) l) + +instance Foldable (IxX sh) where + foldMap f (IxX l) = listxFold (f . getConst) l + +instance NFData i => NFData (IxX sh i) + +ixxLength :: IxX sh i -> Int +ixxLength (IxX l) = listxLength l + +ixxRank :: IxX sh i -> SNat (Rank sh) +ixxRank (IxX l) = listxRank l + +ixxZero :: StaticShX sh -> IIxX sh +ixxZero ZKX = ZIX +ixxZero (_ :!% ssh) = 0 :.% ixxZero ssh + +ixxZero' :: IShX sh -> IIxX sh +ixxZero' ZSX = ZIX +ixxZero' (_ :$% sh) = 0 :.% ixxZero' sh + +ixxFromList :: forall sh i. StaticShX sh -> [i] -> IxX sh i +ixxFromList = coerce (listxFromList @_ @i) + +ixxHead :: IxX (n : sh) i -> i +ixxHead (IxX list) = getConst (listxHead list) + +ixxTail :: IxX (n : sh) i -> IxX sh i +ixxTail (IxX list) = IxX (listxTail list) + +ixxAppend :: forall sh sh' i. IxX sh i -> IxX sh' i -> IxX (sh ++ sh') i +ixxAppend = coerce (listxAppend @_ @(Const i)) + +ixxDrop :: forall sh sh' i. IxX sh i -> IxX (sh ++ sh') i -> IxX sh' i +ixxDrop = coerce (listxDrop @(Const i) @(Const i)) + +ixxInit :: forall n sh i. IxX (n : sh) i -> IxX (Init (n : sh)) i +ixxInit = coerce (listxInit @(Const i)) + +ixxLast :: forall n sh i. IxX (n : sh) i -> i +ixxLast = coerce (listxLast @(Const i)) + +ixxCast :: StaticShX sh' -> IxX sh i -> IxX sh' i +ixxCast ZKX ZIX = ZIX +ixxCast (_ :!% sh) (i :.% idx) = i :.% ixxCast sh idx +ixxCast _ _ = error "ixxCast: ranks don't match" + +ixxZip :: IxX sh i -> IxX sh j -> IxX sh (i, j) +ixxZip ZIX ZIX = ZIX +ixxZip (i :.% is) (j :.% js) = (i, j) :.% ixxZip is js + +ixxZipWith :: (i -> j -> k) -> IxX sh i -> IxX sh j -> IxX sh k +ixxZipWith _ ZIX ZIX = ZIX +ixxZipWith f (i :.% is) (j :.% js) = f i j :.% ixxZipWith f is js + +ixxFromLinear :: IShX sh -> Int -> IIxX sh +ixxFromLinear = \sh i -> case go sh i of + (idx, 0) -> idx + _ -> error $ "ixxFromLinear: out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" + where + -- returns (index in subarray, remaining index in enclosing array) + go :: IShX sh -> Int -> (IIxX sh, Int) + go ZSX i = (ZIX, i) + go (n :$% sh) i = + let (idx, i') = go sh i + (upi, locali) = i' `quotRem` fromSMayNat' n + in (locali :.% idx, upi) + +ixxToLinear :: IShX sh -> IIxX sh -> Int +ixxToLinear = \sh i -> fst (go sh i) + where + -- returns (index in subarray, size of subarray) + go :: IShX sh -> IIxX sh -> (Int, Int) + go ZSX ZIX = (0, 1) + go (n :$% sh) (i :.% ix) = + let (lidx, sz) = go sh ix + in (sz * i + lidx, fromSMayNat' n * sz) + + +-- * Mixed shapes + +data SMayNat i f n where + SUnknown :: i -> SMayNat i f Nothing + SKnown :: f n -> SMayNat i f (Just n) +deriving instance (Show i, forall m. Show (f m)) => Show (SMayNat i f n) +deriving instance (Eq i, forall m. Eq (f m)) => Eq (SMayNat i f n) +deriving instance (Ord i, forall m. Ord (f m)) => Ord (SMayNat i f n) + +instance (NFData i, forall m. NFData (f m)) => NFData (SMayNat i f n) where + rnf (SUnknown i) = rnf i + rnf (SKnown x) = rnf x + +instance TestEquality f => TestEquality (SMayNat i f) where + testEquality SUnknown{} SUnknown{} = Just Refl + testEquality (SKnown n) (SKnown m) | Just Refl <- testEquality n m = Just Refl + testEquality _ _ = Nothing + +fromSMayNat :: (n ~ Nothing => i -> r) + -> (forall m. n ~ Just m => f m -> r) + -> SMayNat i f n -> r +fromSMayNat f _ (SUnknown i) = f i +fromSMayNat _ g (SKnown s) = g s + +fromSMayNat' :: SMayNat Int SNat n -> Int +fromSMayNat' = fromSMayNat id fromSNat' + +type family AddMaybe n m where + AddMaybe Nothing _ = Nothing + AddMaybe (Just _) Nothing = Nothing + AddMaybe (Just n) (Just m) = Just (n + m) + +smnAddMaybe :: SMayNat Int SNat n -> SMayNat Int SNat m -> SMayNat Int SNat (AddMaybe n m) +smnAddMaybe (SUnknown n) m = SUnknown (n + fromSMayNat' m) +smnAddMaybe (SKnown n) (SUnknown m) = SUnknown (fromSNat' n + m) +smnAddMaybe (SKnown n) (SKnown m) = SKnown (snatPlus n m) + + +-- | This is a newtype over 'ListX'. +type role ShX nominal representational +type ShX :: [Maybe Nat] -> Type -> Type +newtype ShX sh i = ShX (ListX sh (SMayNat i SNat)) + deriving (Eq, Ord, Generic) + +pattern ZSX :: forall sh i. () => sh ~ '[] => ShX sh i +pattern ZSX = ShX ZX + +pattern (:$%) + :: forall {sh1} {i}. + forall n sh. (n : sh ~ sh1) + => SMayNat i SNat n -> ShX sh i -> ShX sh1 i +pattern i :$% shl <- ShX (listxUncons -> Just (UnconsListXRes (ShX -> shl) i)) + where i :$% ShX shl = ShX (i ::% shl) +infixr 3 :$% + +{-# COMPLETE ZSX, (:$%) #-} + +type IShX sh = ShX sh Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ShX sh i) +#else +instance Show i => Show (ShX sh i) where + showsPrec _ (ShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l +#endif + +instance Functor (ShX sh) where + fmap f (ShX l) = ShX (listxFmap (fromSMayNat (SUnknown . f) SKnown) l) + +instance NFData i => NFData (ShX sh i) where + rnf (ShX ZX) = () + rnf (ShX (SUnknown i ::% l)) = rnf i `seq` rnf (ShX l) + rnf (ShX (SKnown SNat ::% l)) = rnf (ShX l) + +-- | This checks only whether the types are equal; unknown dimensions might +-- still differ. This corresponds to 'testEquality', except on the penultimate +-- type parameter. +shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') +shxEqType ZSX ZSX = Just Refl +shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') + | Just Refl <- sameNat n m + , Just Refl <- shxEqType sh sh' + = Just Refl +shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh') + | Just Refl <- shxEqType sh sh' + = Just Refl +shxEqType _ _ = Nothing + +-- | This checks whether all dimensions have the same value. This is more than +-- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the +-- @some@ package (except on the penultimate type parameter). +shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh') +shxEqual ZSX ZSX = Just Refl +shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh') + | Just Refl <- sameNat n m + , Just Refl <- shxEqual sh sh' + = Just Refl +shxEqual (SUnknown i :$% sh) (SUnknown j :$% sh') + | i == j + , Just Refl <- shxEqual sh sh' + = Just Refl +shxEqual _ _ = Nothing + +shxLength :: ShX sh i -> Int +shxLength (ShX l) = listxLength l + +shxRank :: ShX sh i -> SNat (Rank sh) +shxRank (ShX l) = listxRank l + +-- | The number of elements in an array described by this shape. +shxSize :: IShX sh -> Int +shxSize ZSX = 1 +shxSize (n :$% sh) = fromSMayNat' n * shxSize sh + +shxFromList :: StaticShX sh -> [Int] -> IShX sh +shxFromList topssh topl = go topssh topl + where + go :: StaticShX sh' -> [Int] -> IShX sh' + go ZKX [] = ZSX + go (SKnown sn :!% sh) (i : is) + | i == fromSNat' sn = SKnown sn :$% go sh is + | otherwise = error $ "shxFromList: Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go (SUnknown () :!% sh) (i : is) = SUnknown i :$% go sh is + go _ _ = error $ "shxFromList: Mismatched list length (type says " + ++ show (ssxLength topssh) ++ ", list has length " + ++ show (length topl) ++ ")" + +shxToList :: IShX sh -> [Int] +shxToList ZSX = [] +shxToList (smn :$% sh) = fromSMayNat' smn : shxToList sh + +shxFromSSX :: StaticShX (MapJust sh) -> ShX (MapJust sh) i +shxFromSSX ZKX = ZSX +shxFromSSX (SKnown n :!% sh :: StaticShX (MapJust sh)) + | Refl <- lemMapJustCons @sh Refl + = SKnown n :$% shxFromSSX sh +shxFromSSX (SUnknown _ :!% _) = error "unreachable" + +-- | This may fail if @sh@ has @Nothing@s in it. +shxFromSSX2 :: StaticShX sh -> Maybe (ShX sh i) +shxFromSSX2 ZKX = Just ZSX +shxFromSSX2 (SKnown n :!% sh) = (SKnown n :$%) <$> shxFromSSX2 sh +shxFromSSX2 (SUnknown _ :!% _) = Nothing + +shxAppend :: forall sh sh' i. ShX sh i -> ShX sh' i -> ShX (sh ++ sh') i +shxAppend = coerce (listxAppend @_ @(SMayNat i SNat)) + +shxHead :: ShX (n : sh) i -> SMayNat i SNat n +shxHead (ShX list) = listxHead list + +shxTail :: ShX (n : sh) i -> ShX sh i +shxTail (ShX list) = ShX (listxTail list) + +shxDropSSX :: forall sh sh' i. StaticShX sh -> ShX (sh ++ sh') i -> ShX sh' i +shxDropSSX = coerce (listxDrop @(SMayNat i SNat) @(SMayNat () SNat)) + +shxDropIx :: forall sh sh' i j. IxX sh j -> ShX (sh ++ sh') i -> ShX sh' i +shxDropIx = coerce (listxDrop @(SMayNat i SNat) @(Const j)) + +shxDropSh :: forall sh sh' i. ShX sh i -> ShX (sh ++ sh') i -> ShX sh' i +shxDropSh = coerce (listxDrop @(SMayNat i SNat) @(SMayNat i SNat)) + +shxInit :: forall n sh i. ShX (n : sh) i -> ShX (Init (n : sh)) i +shxInit = coerce (listxInit @(SMayNat i SNat)) + +shxLast :: forall n sh i. ShX (n : sh) i -> SMayNat i SNat (Last (n : sh)) +shxLast = coerce (listxLast @(SMayNat i SNat)) + +shxTakeSSX :: forall sh sh' i proxy. proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> ShX sh i +shxTakeSSX _ ZKX _ = ZSX +shxTakeSSX p (_ :!% ssh1) (n :$% sh) = n :$% shxTakeSSX p ssh1 sh + +shxZipWith :: (forall n. SMayNat i SNat n -> SMayNat j SNat n -> SMayNat k SNat n) + -> ShX sh i -> ShX sh j -> ShX sh k +shxZipWith _ ZSX ZSX = ZSX +shxZipWith f (i :$% is) (j :$% js) = f i j :$% shxZipWith f is js + +-- This is a weird operation, so it has a long name +shxCompleteZeros :: StaticShX sh -> IShX sh +shxCompleteZeros ZKX = ZSX +shxCompleteZeros (SUnknown () :!% ssh) = SUnknown 0 :$% shxCompleteZeros ssh +shxCompleteZeros (SKnown n :!% ssh) = SKnown n :$% shxCompleteZeros ssh + +shxSplitApp :: proxy sh' -> StaticShX sh -> ShX (sh ++ sh') i -> (ShX sh i, ShX sh' i) +shxSplitApp _ ZKX idx = (ZSX, idx) +shxSplitApp p (_ :!% ssh) (i :$% idx) = first (i :$%) (shxSplitApp p ssh idx) + +shxEnum :: IShX sh -> [IIxX sh] +shxEnum = \sh -> go sh id [] + where + go :: IShX sh -> (IIxX sh -> a) -> [a] -> [a] + go ZSX f = (f ZIX :) + go (n :$% sh) f = foldr (.) id [go sh (f . (i :.%)) | i <- [0 .. fromSMayNat' n - 1]] + +shxCast :: StaticShX sh' -> IShX sh -> Maybe (IShX sh') +shxCast ZKX ZSX = Just ZSX +shxCast (SKnown m :!% ssh) (SKnown n :$% sh) | Just Refl <- testEquality n m = (SKnown n :$%) <$> shxCast ssh sh +shxCast (SKnown m :!% ssh) (SUnknown n :$% sh) | n == fromSNat' m = (SKnown m :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SKnown n :$% sh) = (SUnknown (fromSNat' n) :$%) <$> shxCast ssh sh +shxCast (SUnknown () :!% ssh) (SUnknown n :$% sh) = (SUnknown n :$%) <$> shxCast ssh sh +shxCast _ _ = Nothing + +-- | Partial version of 'shxCast'. +shxCast' :: StaticShX sh' -> IShX sh -> IShX sh' +shxCast' ssh sh = case shxCast ssh sh of + Just sh' -> sh' + Nothing -> error $ "shxCast': Mismatch: (" ++ show sh ++ ") does not match (" ++ show ssh ++ ")" + + +-- * Static mixed shapes + +-- | The part of a shape that is statically known. (A newtype over 'ListX'.) +type StaticShX :: [Maybe Nat] -> Type +newtype StaticShX sh = StaticShX (ListX sh (SMayNat () SNat)) + deriving (Eq, Ord) + +pattern ZKX :: forall sh. () => sh ~ '[] => StaticShX sh +pattern ZKX = StaticShX ZX + +pattern (:!%) + :: forall {sh1}. + forall n sh. (n : sh ~ sh1) + => SMayNat () SNat n -> StaticShX sh -> StaticShX sh1 +pattern i :!% shl <- StaticShX (listxUncons -> Just (UnconsListXRes (StaticShX -> shl) i)) + where i :!% StaticShX shl = StaticShX (i ::% shl) +infixr 3 :!% + +{-# COMPLETE ZKX, (:!%) #-} + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (StaticShX sh) +#else +instance Show (StaticShX sh) where + showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l +#endif + +instance NFData (StaticShX sh) where + rnf (StaticShX ZX) = () + rnf (StaticShX (SUnknown () ::% l)) = rnf (StaticShX l) + rnf (StaticShX (SKnown SNat ::% l)) = rnf (StaticShX l) + +instance TestEquality StaticShX where + testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2 + +ssxLength :: StaticShX sh -> Int +ssxLength (StaticShX l) = listxLength l + +ssxRank :: StaticShX sh -> SNat (Rank sh) +ssxRank (StaticShX l) = listxRank l + +-- | @ssxEqType = 'testEquality'@. Provided for consistency. +ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh') +ssxEqType = testEquality + +ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh') +ssxAppend ZKX sh' = sh' +ssxAppend (n :!% sh) sh' = n :!% ssxAppend sh sh' + +ssxHead :: StaticShX (n : sh) -> SMayNat () SNat n +ssxHead (StaticShX list) = listxHead list + +ssxTail :: StaticShX (n : sh) -> StaticShX sh +ssxTail (_ :!% ssh) = ssh + +ssxDropSSX :: forall sh sh'. StaticShX sh -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSSX = coerce (listxDrop @(SMayNat () SNat) @(SMayNat () SNat)) + +ssxDropIx :: forall sh sh' i. IxX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropIx = coerce (listxDrop @(SMayNat () SNat) @(Const i)) + +ssxDropSh :: forall sh sh' i. ShX sh i -> StaticShX (sh ++ sh') -> StaticShX sh' +ssxDropSh = coerce (listxDrop @(SMayNat () SNat) @(SMayNat i SNat)) + +ssxInit :: forall n sh. StaticShX (n : sh) -> StaticShX (Init (n : sh)) +ssxInit = coerce (listxInit @(SMayNat () SNat)) + +ssxLast :: forall n sh. StaticShX (n : sh) -> SMayNat () SNat (Last (n : sh)) +ssxLast = coerce (listxLast @(SMayNat () SNat)) + +ssxReplicate :: SNat n -> StaticShX (Replicate n Nothing) +ssxReplicate SZ = ZKX +ssxReplicate (SS (n :: SNat n')) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n' + = SUnknown () :!% ssxReplicate n + +ssxIotaFrom :: StaticShX sh -> Int -> [Int] +ssxIotaFrom ZKX _ = [] +ssxIotaFrom (_ :!% ssh) i = i : ssxIotaFrom ssh (i+1) + +ssxFromShX :: ShX sh i -> StaticShX sh +ssxFromShX ZSX = ZKX +ssxFromShX (n :$% sh) = fromSMayNat (\_ -> SUnknown ()) SKnown n :!% ssxFromShX sh + +ssxFromSNat :: SNat n -> StaticShX (Replicate n Nothing) +ssxFromSNat SZ = ZKX +ssxFromSNat (SS (n :: SNat nm1)) | Refl <- lemReplicateSucc @(Nothing @Nat) @nm1 = SUnknown () :!% ssxFromSNat n + + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShX :: [Maybe Nat] -> Constraint +class KnownShX sh where knownShX :: StaticShX sh +instance KnownShX '[] where knownShX = ZKX +instance (KnownNat n, KnownShX sh) => KnownShX (Just n : sh) where knownShX = SKnown natSing :!% knownShX +instance KnownShX sh => KnownShX (Nothing : sh) where knownShX = SUnknown () :!% knownShX + +withKnownShX :: forall sh r. StaticShX sh -> (KnownShX sh => r) -> r +withKnownShX = withDict @(KnownShX sh) + + +-- * Flattening + +type Flatten sh = Flatten' 1 sh + +type family Flatten' acc sh where + Flatten' acc '[] = Just acc + Flatten' acc (Nothing : sh) = Nothing + Flatten' acc (Just n : sh) = Flatten' (acc * n) sh + +-- This function is currently unused +ssxFlatten :: StaticShX sh -> SMayNat () SNat (Flatten sh) +ssxFlatten = go (SNat @1) + where + go :: SNat acc -> StaticShX sh -> SMayNat () SNat (Flatten' acc sh) + go acc ZKX = SKnown acc + go _ (SUnknown () :!% _) = SUnknown () + go acc (SKnown sn :!% sh) = go (snatMul acc sn) sh + +shxFlatten :: IShX sh -> SMayNat Int SNat (Flatten sh) +shxFlatten = go (SNat @1) + where + go :: SNat acc -> IShX sh -> SMayNat Int SNat (Flatten' acc sh) + go acc ZSX = SKnown acc + go acc (SUnknown n :$% sh) = SUnknown (goUnknown (fromSNat' acc * n) sh) + go acc (SKnown sn :$% sh) = go (snatMul acc sn) sh + + goUnknown :: Int -> IShX sh -> Int + goUnknown acc ZSX = acc + goUnknown acc (SUnknown n :$% sh) = goUnknown (acc * n) sh + goUnknown acc (SKnown sn :$% sh) = goUnknown (acc * fromSNat' sn) sh + + +-- | Very untyped: only length is checked (at runtime). +instance KnownShX sh => IsList (ListX sh (Const i)) where + type Item (ListX sh (Const i)) = i + fromList = listxFromList (knownShX @sh) + toList = listxToList + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShX sh => IsList (IxX sh i) where + type Item (IxX sh i) = i + fromList = IxX . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length and known dimensions are checked (at runtime). +instance KnownShX sh => IsList (ShX sh Int) where + type Item (ShX sh Int) = Int + fromList = shxFromList (knownShX @sh) + toList = shxToList diff --git a/src/Data/Array/Nested/Permutation.hs b/src/Data/Array/Nested/Permutation.hs new file mode 100644 index 0000000..03d1640 --- /dev/null +++ b/src/Data/Array/Nested/Permutation.hs @@ -0,0 +1,283 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Permutation where + +import Data.Coerce (coerce) +import Data.Functor.Const +import Data.List (sort) +import Data.Maybe (fromMaybe) +import Data.Proxy +import Data.Type.Bool +import Data.Type.Equality +import Data.Type.Ord +import GHC.Exts (withDict) +import GHC.TypeError +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Types + + +-- * Permutations + +-- | A "backward" permutation of a dimension list. The operation on the +-- dimension list is most similar to @backpermute@ in the @vector@ package; see +-- 'Permute' for code that implements this. +data Perm list where + PNil :: Perm '[] + PCons :: SNat a -> Perm l -> Perm (a : l) +infixr 5 `PCons` +deriving instance Show (Perm list) +deriving instance Eq (Perm list) + +instance TestEquality Perm where + testEquality PNil PNil = Just Refl + testEquality (x `PCons` xs) (y `PCons` ys) + | Just Refl <- testEquality x y + , Just Refl <- testEquality xs ys = Just Refl + testEquality _ _ = Nothing + +permRank :: Perm list -> SNat (Rank list) +permRank PNil = SNat +permRank (_ `PCons` l) | SNat <- permRank l = SNat + +permFromList :: [Int] -> (forall list. Perm list -> r) -> r +permFromList [] k = k PNil +permFromList (x : xs) k = withSomeSNat (fromIntegral x) $ \case + Just sn -> permFromList xs $ \list -> k (sn `PCons` list) + Nothing -> error $ "Data.Array.Mixed.permFromList: negative number in list: " ++ show x + +permToList :: Perm list -> [Natural] +permToList PNil = mempty +permToList (x `PCons` l) = TN.fromSNat x : permToList l + +permToList' :: Perm list -> [Int] +permToList' = map fromIntegral . permToList + +-- | When called as @permCheckPermutation p k@, if @p@ is a permutation of +-- @[0 .. 'length' ('permToList' p) - 1]@, @Just k@ is returned. If it isn't, +-- then @Nothing@ is returned. +permCheckPermutation :: forall r list. Perm list -> (IsPermutation list => r) -> Maybe r +permCheckPermutation = \p k -> + let n = permRank p + in case (provePerm1 (Proxy @list) n p, provePerm2 (SNat @0) n p) of + (Just Refl, Just Refl) -> Just k + _ -> Nothing + where + lemElemCount :: (0 <= n, Compare n m ~ LT) + => proxy n -> proxy m -> Elem n (Count 0 m) :~: True + lemElemCount _ _ = unsafeCoerceRefl + + lemCount :: (OrdCond (Compare i n) True False True ~ True) + => proxy i -> proxy n -> Count i n :~: i : Count (i + 1) n + lemCount _ _ = unsafeCoerceRefl + + lemElem :: Elem x ys ~ True => proxy x -> proxy' (y : ys) -> Elem x (y : ys) :~: True + lemElem _ _ = unsafeCoerceRefl + + provePerm1 :: Proxy isTop -> SNat (Rank isTop) -> Perm is' + -> Maybe (AllElem' is' (Count 0 (Rank isTop)) :~: True) + provePerm1 _ _ PNil = Just Refl + provePerm1 p rtop@SNat (PCons sn@SNat perm) + | Just Refl <- provePerm1 p rtop perm + = case (cmpNat (SNat @0) sn, cmpNat sn rtop) of + (LTI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + (EQI, LTI) | Refl <- lemElemCount sn rtop -> Just Refl + _ -> Nothing + | otherwise + = Nothing + + provePerm2 :: SNat i -> SNat n -> Perm is' + -> Maybe (AllElem' (Count i n) is' :~: True) + provePerm2 = \i@(SNat :: SNat i) n@SNat perm -> + case cmpNat i n of + EQI -> Just Refl + LTI | Refl <- lemCount i n + , Just Refl <- provePerm2 (SNat @(i + 1)) n perm + -> checkElem i perm + | otherwise -> Nothing + GTI -> error "unreachable" + where + checkElem :: SNat i -> Perm is' -> Maybe (Elem i is' :~: True) + checkElem _ PNil = Nothing + checkElem i@SNat (PCons k@SNat perm :: Perm is') = + case sameNat i k of + Just Refl -> Just Refl + Nothing | Just Refl <- checkElem i perm, Refl <- lemElem i (Proxy @is') -> Just Refl + | otherwise -> Nothing + +-- | Utility class for generating permutations from type class information. +class KnownPerm l where makePerm :: Perm l +instance KnownPerm '[] where makePerm = PNil +instance (KnownNat n, KnownPerm l) => KnownPerm (n : l) where makePerm = natSing `PCons` makePerm + +withKnownPerm :: forall l r. Perm l -> (KnownPerm l => r) -> r +withKnownPerm = withDict @(KnownPerm l) + +-- | Untyped permutations for ranked arrays +type PermR = [Int] + + +-- ** Applying permutations + +type family Elem x l where + Elem x '[] = 'False + Elem x (x : _) = 'True + Elem x (_ : ys) = Elem x ys + +type family AllElem' as bs where + AllElem' '[] bs = 'True + AllElem' (a : as) bs = Elem a bs && AllElem' as bs + +type AllElem as bs = Assert (AllElem' as bs) + (TypeError (Text "The elements of " :<>: ShowType as :<>: Text " are not all in " :<>: ShowType bs)) + +type family Count i n where + Count n n = '[] + Count i n = i : Count (i + 1) n + +type IsPermutation as = (AllElem as (Count 0 (Rank as)), AllElem (Count 0 (Rank as)) as) + +type family Index i sh where + Index 0 (n : sh) = n + Index i (_ : sh) = Index (i - 1) sh + +type family Permute is sh where + Permute '[] sh = '[] + Permute (i : is) sh = Index i sh : Permute is sh + +type PermutePrefix is sh = Permute is (TakeLen is sh) ++ DropLen is sh + +type family TakeLen ref l where + TakeLen '[] l = '[] + TakeLen (_ : ref) (x : xs) = x : TakeLen ref xs + +type family DropLen ref l where + DropLen '[] l = l + DropLen (_ : ref) (_ : xs) = DropLen ref xs + +listxTakeLen :: forall f is sh. Perm is -> ListX sh f -> ListX (TakeLen is sh) f +listxTakeLen PNil _ = ZX +listxTakeLen (_ `PCons` is) (n ::% sh) = n ::% listxTakeLen is sh +listxTakeLen (_ `PCons` _) ZX = error "Permutation longer than shape" + +listxDropLen :: forall f is sh. Perm is -> ListX sh f -> ListX (DropLen is sh) f +listxDropLen PNil sh = sh +listxDropLen (_ `PCons` is) (_ ::% sh) = listxDropLen is sh +listxDropLen (_ `PCons` _) ZX = error "Permutation longer than shape" + +listxPermute :: forall f is sh. Perm is -> ListX sh f -> ListX (Permute is sh) f +listxPermute PNil _ = ZX +listxPermute (i `PCons` (is :: Perm is')) (sh :: ListX sh f) = + listxIndex (Proxy @is') (Proxy @sh) i sh ::% listxPermute is sh + +listxIndex :: forall f is shT i sh. Proxy is -> Proxy shT -> SNat i -> ListX sh f -> f (Index i sh) +listxIndex _ _ SZ (n ::% _) = n +listxIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::% (sh :: ListX sh' f)) + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = listxIndex p pT i sh +listxIndex _ _ _ ZX = error "Index into empty shape" + +listxPermutePrefix :: forall f is sh. Perm is -> ListX sh f -> ListX (PermutePrefix is sh) f +listxPermutePrefix perm sh = listxAppend (listxPermute perm (listxTakeLen perm sh)) (listxDropLen perm sh) + +ixxPermutePrefix :: forall i is sh. Perm is -> IxX sh i -> IxX (PermutePrefix is sh) i +ixxPermutePrefix = coerce (listxPermutePrefix @(Const i)) + +ssxTakeLen :: forall is sh. Perm is -> StaticShX sh -> StaticShX (TakeLen is sh) +ssxTakeLen = coerce (listxTakeLen @(SMayNat () SNat)) + +ssxDropLen :: Perm is -> StaticShX sh -> StaticShX (DropLen is sh) +ssxDropLen = coerce (listxDropLen @(SMayNat () SNat)) + +ssxPermute :: Perm is -> StaticShX sh -> StaticShX (Permute is sh) +ssxPermute = coerce (listxPermute @(SMayNat () SNat)) + +ssxIndex :: Proxy is -> Proxy shT -> SNat i -> StaticShX sh -> SMayNat () SNat (Index i sh) +ssxIndex p1 p2 = coerce (listxIndex @(SMayNat () SNat) p1 p2) + +ssxPermutePrefix :: Perm is -> StaticShX sh -> StaticShX (PermutePrefix is sh) +ssxPermutePrefix = coerce (listxPermutePrefix @(SMayNat () SNat)) + +shxPermutePrefix :: Perm is -> IShX sh -> IShX (PermutePrefix is sh) +shxPermutePrefix = coerce (listxPermutePrefix @(SMayNat Int SNat)) + + +-- * Operations on permutations + +permInverse :: Perm is + -> (forall is'. + IsPermutation is' + => Perm is' + -> (forall sh. Rank sh ~ Rank is => StaticShX sh -> Permute is' (Permute is sh) :~: sh) + -> r) + -> r +permInverse = \perm k -> + genPerm perm $ \(invperm :: Perm is') -> + fromMaybe + (error $ "permInverse: did not generate permutation? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm) + (permCheckPermutation invperm + (k invperm + (\ssh -> case permCheckInverse perm invperm ssh of + Just eq -> eq + Nothing -> error $ "permInverse: did not generate inverse? perm = " ++ show perm + ++ " ; invperm = " ++ show invperm))) + where + genPerm :: Perm is -> (forall is'. Perm is' -> r) -> r + genPerm perm = + let permList = permToList' perm + in toHList $ map snd (sort (zip permList [0..])) + where + toHList :: [Natural] -> (forall is'. Perm is' -> r) -> r + toHList [] k = k PNil + toHList (n : ns) k = toHList ns $ \l -> TN.withSomeSNat n $ \sn -> k (PCons sn l) + + permCheckInverse :: Perm is -> Perm is' -> StaticShX sh + -> Maybe (Permute is' (Permute is sh) :~: sh) + permCheckInverse perm perminv ssh = + ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh + +type family MapSucc is where + MapSucc '[] = '[] + MapSucc (i : is) = i + 1 : MapSucc is + +permShift1 :: Perm l -> Perm (0 : MapSucc l) +permShift1 = (SNat @0 `PCons`) . permMapSucc + where + permMapSucc :: Perm l -> Perm (MapSucc l) + permMapSucc PNil = PNil + permMapSucc ((SNat :: SNat i) `PCons` ns) = SNat @(i + 1) `PCons` permMapSucc ns + + +-- * Lemmas + +lemRankPermute :: Proxy sh -> Perm is -> Rank (Permute is sh) :~: Rank is +lemRankPermute _ PNil = Refl +lemRankPermute p (_ `PCons` is) | Refl <- lemRankPermute p is = Refl + +lemRankDropLen :: forall is sh. (Rank is <= Rank sh) + => StaticShX sh -> Perm is -> Rank (DropLen is sh) :~: Rank sh - Rank is +lemRankDropLen ZKX PNil = Refl +lemRankDropLen (_ :!% sh) (_ `PCons` is) | Refl <- lemRankDropLen sh is = Refl +lemRankDropLen (_ :!% _) PNil = Refl +lemRankDropLen ZKX (_ `PCons` _) = error "1 <= 0" + +lemIndexSucc :: Proxy i -> Proxy a -> Proxy l + -> Index (i + 1) (a : l) :~: Index i l +lemIndexSucc _ _ _ = unsafeCoerceRefl diff --git a/src/Data/Array/Nested/Ranked.hs b/src/Data/Array/Nested/Ranked.hs new file mode 100644 index 0000000..9778c54 --- /dev/null +++ b/src/Data/Array/Nested/Ranked.hs @@ -0,0 +1,323 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Ranked ( + Ranked(Ranked), + rquotArray, rremArray, ratan2Array, + rshape, rrank, + module Data.Array.Nested.Ranked, + liftRanked1, liftRanked2, +) where + +import Prelude hiding (mappend, mconcat) + +import Data.Array.RankedS qualified as S +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Nested.Convert +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Ranked.Base +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith +import Data.Array.XArray (XArray(..)) +import Data.Array.XArray qualified as X + + +remptyArray :: KnownElt a => Ranked 1 a +remptyArray = mtoRanked (memptyArray ZSX) + +-- | The total number of elements in the array. +rsize :: Elt a => Ranked n a -> Int +rsize = shrSize . rshape + +rindex :: Elt a => Ranked n a -> IIxR n -> a +rindex (Ranked arr) idx = mindex arr (ixxFromIxR idx) + +rindexPartial :: forall n m a. Elt a => Ranked (n + m) a -> IIxR n -> Ranked m a +rindexPartial (Ranked arr) idx = + Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) + (castWith (subst2 (lemReplicatePlusApp (ixrRank idx) (Proxy @m) (Proxy @Nothing))) arr) + (ixxFromIxR idx)) + +-- | __WARNING__: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details. +rgenerate :: forall n a. KnownElt a => IShR n -> (IIxR n -> a) -> Ranked n a +rgenerate sh f + | sn@SNat <- shrRank sh + , Dict <- lemKnownReplicate sn + , Refl <- lemRankReplicate sn + = Ranked (mgenerate (shxFromShR sh) (f . ixrFromIxX)) + +-- | See the documentation of 'mlift'. +rlift :: forall n1 n2 a. Elt a + => SNat n2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) + -> Ranked n1 a -> Ranked n2 a +rlift sn2 f (Ranked arr) = Ranked (mlift (ssxFromSNat sn2) f arr) + +-- | See the documentation of 'mlift2'. +rlift2 :: forall n1 n2 n3 a. Elt a + => SNat n3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b -> XArray (Replicate n3 Nothing ++ sh') b) + -> Ranked n1 a -> Ranked n2 a -> Ranked n3 a +rlift2 sn3 f (Ranked arr1) (Ranked arr2) = Ranked (mlift2 (ssxFromSNat sn3) f arr1 arr2) + +rsumOuter1P :: forall n a. + (Storable a, NumElt a) + => Ranked (n + 1) (Primitive a) -> Ranked n (Primitive a) +rsumOuter1P (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = Ranked (msumOuter1P arr) + +rsumOuter1 :: forall n a. (NumElt a, PrimElt a) + => Ranked (n + 1) a -> Ranked n a +rsumOuter1 = rfromPrimitive . rsumOuter1P . rtoPrimitive + +rsumAllPrim :: (PrimElt a, NumElt a) => Ranked n a -> a +rsumAllPrim (Ranked arr) = msumAllPrim arr + +rtranspose :: forall n a. Elt a => PermR -> Ranked n a -> Ranked n a +rtranspose perm arr + | sn@SNat <- rrank arr + , Dict <- lemKnownReplicate sn + , length perm <= fromIntegral (natVal (Proxy @n)) + = rlift sn + (\ssh' -> X.transposeUntyped (natSing @n) ssh' perm) + arr + | otherwise + = error "Data.Array.Nested.rtranspose: Permutation longer than rank of array" + +rconcat :: forall n a. Elt a => NonEmpty (Ranked (n + 1) a) -> Ranked (n + 1) a +rconcat + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce mconcat + +rappend :: forall n a. Elt a + => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked (n + 1) a +rappend arr1 arr2 + | sn@SNat <- rrank arr1 + , Dict <- lemKnownReplicate sn + , Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce (mappend @Nothing @Nothing @(Replicate n Nothing)) + arr1 arr2 + +rscalar :: Elt a => a -> Ranked 0 a +rscalar x = Ranked (mscalar x) + +rfromVectorP :: forall n a. Storable a => IShR n -> VS.Vector a -> Ranked n (Primitive a) +rfromVectorP sh v + | Dict <- lemKnownReplicate (shrRank sh) + = Ranked (mfromVectorP (shxFromShR sh) v) + +rfromVector :: forall n a. PrimElt a => IShR n -> VS.Vector a -> Ranked n a +rfromVector sh v = rfromPrimitive (rfromVectorP sh v) + +rtoVectorP :: Storable a => Ranked n (Primitive a) -> VS.Vector a +rtoVectorP = coerce mtoVectorP + +rtoVector :: PrimElt a => Ranked n a -> VS.Vector a +rtoVector = coerce mtoVector + +rfromList1 :: Elt a => NonEmpty a -> Ranked 1 a +rfromList1 l = Ranked (mfromList1 l) + +rfromListOuter :: forall n a. Elt a => NonEmpty (Ranked n a) -> Ranked (n + 1) a +rfromListOuter l + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = Ranked (mfromListOuter (coerce l :: NonEmpty (Mixed (Replicate n Nothing) a))) + +rfromListLinear :: forall n a. Elt a => IShR n -> NonEmpty a -> Ranked n a +rfromListLinear sh l = rreshape sh (rfromList1 l) + +rfromListPrim :: PrimElt a => [a] -> Ranked 1 a +rfromListPrim l = Ranked (mfromListPrim l) + +rfromListPrimLinear :: PrimElt a => IShR n -> [a] -> Ranked n a +rfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in Ranked $ fromPrimitive $ M_Primitive (shxFromShR sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShR sh) xarr) + +rtoList :: Elt a => Ranked 1 a -> [a] +rtoList = map runScalar . rtoListOuter + +rtoListOuter :: forall n a. Elt a => Ranked (n + 1) a -> [Ranked n a] +rtoListOuter (Ranked arr) + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = coerce (mtoListOuter @a @Nothing @(Replicate n Nothing) arr) + +rtoListLinear :: Elt a => Ranked n a -> [a] +rtoListLinear (Ranked arr) = mtoListLinear arr + +rfromOrthotope :: PrimElt a => SNat n -> S.Array n a -> Ranked n a +rfromOrthotope sn arr + | Refl <- lemRankReplicate sn + = let xarr = XArray arr + in Ranked (fromPrimitive (M_Primitive (X.shape (ssxFromSNat sn) xarr) xarr)) + +rtoOrthotope :: PrimElt a => Ranked n a -> S.Array n a +rtoOrthotope (rtoPrimitive -> Ranked (M_Primitive sh (XArray arr))) + | Refl <- lemRankReplicate (shrRank $ shrFromShX2 sh) + = arr + +runScalar :: Elt a => Ranked 0 a -> a +runScalar arr = rindex arr ZIR + +rnest :: forall n m a. Elt a => SNat n -> Ranked (n + m) a -> Ranked n (Ranked m a) +rnest n arr + | Refl <- lemReplicatePlusApp n (Proxy @m) (Proxy @(Nothing @Nat)) + = coerce (mnest (ssxFromSNat n) (coerce arr)) + +runNest :: forall n m a. Elt a => Ranked n (Ranked m a) -> Ranked (n + m) a +runNest rarr@(Ranked (M_Ranked (M_Nest _ arr))) + | Refl <- lemReplicatePlusApp (rrank rarr) (Proxy @m) (Proxy @(Nothing @Nat)) + = Ranked arr + +rzip :: (Elt a, Elt b) => Ranked n a -> Ranked n b -> Ranked n (a, b) +rzip = coerce mzip + +runzip :: Ranked n (a, b) -> (Ranked n a, Ranked n b) +runzip = coerce munzip + +rrerankP :: forall n1 n2 n a b. (Storable a, Storable b) + => SNat n -> IShR n2 + -> (Ranked n1 (Primitive a) -> Ranked n2 (Primitive b)) + -> Ranked (n + n1) (Primitive a) -> Ranked (n + n2) (Primitive b) +rrerankP sn sh2 f (Ranked arr) + | Refl <- lemReplicatePlusApp sn (Proxy @n1) (Proxy @(Nothing @Nat)) + , Refl <- lemReplicatePlusApp sn (Proxy @n2) (Proxy @(Nothing @Nat)) + = Ranked (mrerankP (ssxFromSNat sn) (shxFromShR sh2) + (\a -> let Ranked r = f (Ranked a) in r) + arr) + +-- | If there is a zero-sized dimension in the @n@-prefix of the shape of the +-- input array, then there is no way to deduce the full shape of the output +-- array (more precisely, the @n2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the @n2@ part of the output shape with zeros. +-- +-- For example, if: +-- +-- @ +-- arr :: Ranked 5 Int -- of shape [3, 0, 4, 2, 21] +-- f :: Ranked 2 Int -> Ranked 3 Float +-- @ +-- +-- then: +-- +-- @ +-- rrerank _ _ _ f arr :: Ranked 6 Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 0, 0, 0]@. Note that the +-- "reranked" part (the last 3 entries) are zero; we don't know if @f@ intended +-- to return an array with shape all-0 here (it probably didn't), but there is +-- no better number to put here absent a subarray of the input to pass to @f@. +rrerank :: forall n1 n2 n a b. (PrimElt a, PrimElt b) + => SNat n -> IShR n2 + -> (Ranked n1 a -> Ranked n2 b) + -> Ranked (n + n1) a -> Ranked (n + n2) b +rrerank sn sh2 f (rtoPrimitive -> arr) = + rfromPrimitive $ rrerankP sn sh2 (rtoPrimitive . f . rfromPrimitive) arr + +rreplicate :: forall n m a. Elt a + => IShR n -> Ranked m a -> Ranked (n + m) a +rreplicate sh (Ranked arr) + | Refl <- lemReplicatePlusApp (shrRank sh) (Proxy @m) (Proxy @(Nothing @Nat)) + = Ranked (mreplicate (shxFromShR sh) arr) + +rreplicateScalP :: forall n a. Storable a => IShR n -> a -> Ranked n (Primitive a) +rreplicateScalP sh x + | Dict <- lemKnownReplicate (shrRank sh) + = Ranked (mreplicateScalP (shxFromShR sh) x) + +rreplicateScal :: forall n a. PrimElt a + => IShR n -> a -> Ranked n a +rreplicateScal sh x = rfromPrimitive (rreplicateScalP sh x) + +rslice :: forall n a. Elt a => Int -> Int -> Ranked (n + 1) a -> Ranked (n + 1) a +rslice i n arr + | Refl <- lemReplicateSucc @(Nothing @Nat) @n + = rlift (rrank arr) + (\_ -> X.sliceU i n) + arr + +rrev1 :: forall n a. Elt a => Ranked (n + 1) a -> Ranked (n + 1) a +rrev1 arr = + rlift (rrank arr) + (\(_ :: StaticShX sh') -> + case lemReplicateSucc @(Nothing @Nat) @n of + Refl -> X.rev1 @Nothing @(Replicate n Nothing ++ sh')) + arr + +rreshape :: forall n n' a. Elt a + => IShR n' -> Ranked n a -> Ranked n' a +rreshape sh' rarr@(Ranked arr) + | Dict <- lemKnownReplicate (rrank rarr) + , Dict <- lemKnownReplicate (shrRank sh') + = Ranked (mreshape (shxFromShR sh') arr) + +rflatten :: Elt a => Ranked n a -> Ranked 1 a +rflatten (Ranked arr) = mtoRanked (mflatten arr) + +riota :: (Enum a, PrimElt a) => Int -> Ranked 1 a +riota n = TN.withSomeSNat (fromIntegral n) $ mtoRanked . miota + +-- | Throws if the array is empty. +rminIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n +rminIndexPrim rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) + = ixrFromIxX (mminIndexPrim arr) + +-- | Throws if the array is empty. +rmaxIndexPrim :: (PrimElt a, NumElt a) => Ranked n a -> IIxR n +rmaxIndexPrim rarr@(Ranked arr) + | Refl <- lemRankReplicate (rrank (rtoPrimitive rarr)) + = ixrFromIxX (mmaxIndexPrim arr) + +rdot1Inner :: forall n a. (PrimElt a, NumElt a) => Ranked (n + 1) a -> Ranked (n + 1) a -> Ranked n a +rdot1Inner arr1 arr2 + | SNat <- rrank arr1 + , Refl <- lemReplicatePlusApp (SNat @n) (Proxy @1) (Proxy @(Nothing @Nat)) + = coerce (mdot1Inner (Proxy @(Nothing @Nat))) arr1 arr2 + +-- | This has a temporary, suboptimal implementation in terms of 'mflatten'. +-- Prefer 'rdot1Inner' if applicable. +rdot :: (PrimElt a, NumElt a) => Ranked n a -> Ranked n a -> a +rdot = coerce mdot + +rtoXArrayPrimP :: Ranked n (Primitive a) -> (IShR n, XArray (Replicate n Nothing) a) +rtoXArrayPrimP (Ranked arr) = first shrFromShX2 (mtoXArrayPrimP arr) + +rtoXArrayPrim :: PrimElt a => Ranked n a -> (IShR n, XArray (Replicate n Nothing) a) +rtoXArrayPrim (Ranked arr) = first shrFromShX2 (mtoXArrayPrim arr) + +rfromXArrayPrimP :: SNat n -> XArray (Replicate n Nothing) a -> Ranked n (Primitive a) +rfromXArrayPrimP sn arr = Ranked (mfromXArrayPrimP (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr) + +rfromXArrayPrim :: PrimElt a => SNat n -> XArray (Replicate n Nothing) a -> Ranked n a +rfromXArrayPrim sn arr = Ranked (mfromXArrayPrim (ssxFromShX (X.shape (ssxFromSNat sn) arr)) arr) + +rfromPrimitive :: PrimElt a => Ranked n (Primitive a) -> Ranked n a +rfromPrimitive (Ranked arr) = Ranked (fromPrimitive arr) + +rtoPrimitive :: PrimElt a => Ranked n a -> Ranked n (Primitive a) +rtoPrimitive (Ranked arr) = Ranked (toPrimitive arr) diff --git a/src/Data/Array/Nested/Ranked/Base.hs b/src/Data/Array/Nested/Ranked/Base.hs new file mode 100644 index 0000000..babc809 --- /dev/null +++ b/src/Data/Array/Nested/Ranked/Base.hs @@ -0,0 +1,268 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_HADDOCK not-home #-} +module Data.Array.Nested.Ranked.Base where + +import Prelude hiding (mappend, mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad.ST +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +#ifndef OXAR_DEFAULT_SHOW_INSTANCES +import Data.Foldable (toList) +#endif + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Ranked.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith +import Data.Array.XArray (XArray(..)) + + +-- | A rank-typed array: the number of dimensions of the array (its /rank/) is +-- represented on the type level as a 'Nat'. +-- +-- Valid elements of a ranked arrays are described by the 'Elt' type class. +-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are +-- supported (and are represented as a single, flattened, struct-of-arrays +-- array internally). +-- +-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. +type Ranked :: Nat -> Type -> Type +newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed (Replicate n Nothing) a) => Show (Ranked n a) +#endif +deriving instance Eq (Mixed (Replicate n Nothing) a) => Eq (Ranked n a) +deriving instance Ord (Mixed (Replicate n Nothing) a) => Ord (Ranked n a) + +#ifndef OXAR_DEFAULT_SHOW_INSTANCES +instance (Show a, Elt a) => Show (Ranked n a) where + showsPrec d arr@(Ranked marr) = + let sh = show (toList (rshape arr)) + in showsMixedArray ("rfromListLinear " ++ sh) ("rreplicate " ++ sh) d marr +#endif + +instance Elt a => NFData (Ranked n a) where + rnf (Ranked arr) = rnf arr + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) + deriving (Generic) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed sh (Mixed (Replicate n Nothing) a)) => Show (Mixed sh (Ranked n a)) +#endif + +deriving instance Eq (Mixed sh (Mixed (Replicate n Nothing) a)) => Eq (Mixed sh (Ranked n a)) + +newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) + +-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; +-- these instances allow them to also be used as elements of arrays, thus +-- making them first-class in the API. +instance Elt a => Elt (Ranked n a) where + mshape (M_Ranked arr) = mshape arr + mindex (M_Ranked arr) i = Ranked (mindex arr i) + + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IIxX sh -> Mixed sh' (Ranked n a) + mindexPartial (M_Ranked arr) i = + coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ + mindexPartial arr i + + mscalar (Ranked x) = M_Ranked (M_Nest ZSX x) + + mfromListOuter :: forall sh. NonEmpty (Mixed sh (Ranked n a)) -> Mixed (Nothing : sh) (Ranked n a) + mfromListOuter l = M_Ranked (mfromListOuter (coerce l)) + + mtoListOuter :: forall m sh. Mixed (m : sh) (Ranked n a) -> [Mixed sh (Ranked n a)] + mtoListOuter (M_Ranked arr) = + coerce @[Mixed sh (Mixed (Replicate n 'Nothing) a)] @[Mixed sh (Ranked n a)] (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) + mlift ssh2 f (M_Ranked arr) = + coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ + mlift ssh2 f arr + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) -> Mixed sh3 (Ranked n a) + mlift2 ssh3 f (M_Ranked arr1) (M_Ranked arr2) = + coerce @(Mixed sh3 (Mixed (Replicate n Nothing) a)) @(Mixed sh3 (Ranked n a)) $ + mlift2 ssh3 f arr1 arr2 + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Ranked n a)) -> NonEmpty (Mixed sh2 (Ranked n a)) + mliftL ssh2 f l = + coerce @(NonEmpty (Mixed sh2 (Mixed (Replicate n Nothing) a))) + @(NonEmpty (Mixed sh2 (Ranked n a))) $ + mliftL ssh2 f (coerce l) + + mcastPartial ssh1 ssh2 psh' (M_Ranked arr) = M_Ranked (mcastPartial ssh1 ssh2 psh' arr) + + mtranspose perm (M_Ranked arr) = M_Ranked (mtranspose perm arr) + + mconcat l = M_Ranked (mconcat (coerce l)) + + mrnf (M_Ranked arr) = mrnf arr + + type ShapeTree (Ranked n a) = (IShR n, ShapeTree a) + + mshapeTree (Ranked arr) = first shrFromShX2 (mshapeTree arr) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shrSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Ranked arr) = marrayStrides arr + + mvecsWrite :: forall sh s. IShX sh -> IIxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWrite sh idx (Ranked arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsWritePartial :: forall sh sh' s. + IShX (sh ++ sh') -> IIxX sh -> Mixed sh' (Ranked n a) + -> MixedVecs s (sh ++ sh') (Ranked n a) + -> ST s () + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh' (Ranked n a)) + @(Mixed sh' (Mixed (Replicate n Nothing) a)) + arr) + (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) + @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsFreeze :: forall sh s. IShX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsFreeze sh vecs = + coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) + @(Mixed sh (Ranked n a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh (Ranked n a)) + @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + +instance (KnownNat n, KnownElt a) => KnownElt (Ranked n a) where + memptyArrayUnsafe :: forall sh. IShX sh -> Mixed sh (Ranked n a) + memptyArrayUnsafe i + | Dict <- lemKnownReplicate (SNat @n) + = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ + memptyArrayUnsafe i + + mvecsUnsafeNew idx (Ranked arr) + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsUnsafeNew idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownReplicate (SNat @n) + = MV_Ranked <$> mvecsNewEmpty (Proxy @(Mixed (Replicate n Nothing) a)) + + +liftRanked1 :: forall n a b. + (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b) + -> Ranked n a -> Ranked n b +liftRanked1 = coerce + +liftRanked2 :: forall n a b c. + (Mixed (Replicate n Nothing) a -> Mixed (Replicate n Nothing) b -> Mixed (Replicate n Nothing) c) + -> Ranked n a -> Ranked n b -> Ranked n c +liftRanked2 = coerce + +instance (NumElt a, PrimElt a) => Num (Ranked n a) where + (+) = liftRanked2 (+) + (-) = liftRanked2 (-) + (*) = liftRanked2 (*) + negate = liftRanked1 negate + abs = liftRanked1 abs + signum = liftRanked1 signum + fromInteger = error "Data.Array.Nested(Ranked).fromInteger: No singletons available, use explicit rreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Ranked n a) where + fromRational _ = error "Data.Array.Nested(Ranked).fromRational: No singletons available, use explicit rreplicateScal" + recip = liftRanked1 recip + (/) = liftRanked2 (/) + +instance (FloatElt a, PrimElt a) => Floating (Ranked n a) where + pi = error "Data.Array.Nested(Ranked).pi: No singletons available, use explicit rreplicateScal" + exp = liftRanked1 exp + log = liftRanked1 log + sqrt = liftRanked1 sqrt + (**) = liftRanked2 (**) + logBase = liftRanked2 logBase + sin = liftRanked1 sin + cos = liftRanked1 cos + tan = liftRanked1 tan + asin = liftRanked1 asin + acos = liftRanked1 acos + atan = liftRanked1 atan + sinh = liftRanked1 sinh + cosh = liftRanked1 cosh + tanh = liftRanked1 tanh + asinh = liftRanked1 asinh + acosh = liftRanked1 acosh + atanh = liftRanked1 atanh + log1p = liftRanked1 GHC.Float.log1p + expm1 = liftRanked1 GHC.Float.expm1 + log1pexp = liftRanked1 GHC.Float.log1pexp + log1mexp = liftRanked1 GHC.Float.log1mexp + +rquotArray, rremArray :: (IntElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a +rquotArray = liftRanked2 mquotArray +rremArray = liftRanked2 mremArray + +ratan2Array :: (FloatElt a, PrimElt a) => Ranked n a -> Ranked n a -> Ranked n a +ratan2Array = liftRanked2 matan2Array + + +rshape :: Elt a => Ranked n a -> IShR n +rshape (Ranked arr) = shrFromShX2 (mshape arr) + +rrank :: Elt a => Ranked n a -> SNat n +rrank = shrRank . rshape + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shrFromShX :: forall sh. IShX sh -> IShR (Rank sh) +shrFromShX ZSX = ZSR +shrFromShX (n :$% idx) = fromSMayNat' n :$: shrFromShX idx + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +-- | Convenience wrapper around 'shrFromShX' that applies 'lemRankReplicate'. +shrFromShX2 :: forall n. IShX (Replicate n Nothing) -> IShR n +shrFromShX2 sh + | Refl <- lemRankReplicate (Proxy @n) + = shrFromShX sh diff --git a/src/Data/Array/Nested/Ranked/Shape.hs b/src/Data/Array/Nested/Ranked/Shape.hs new file mode 100644 index 0000000..8b670e5 --- /dev/null +++ b/src/Data/Array/Nested/Ranked/Shape.hs @@ -0,0 +1,369 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Ranked.Shape where + +import Control.DeepSeq (NFData(..)) +import Data.Coerce (coerce) +import Data.Foldable qualified as Foldable +import Data.Kind (Type) +import Data.Proxy +import Data.Type.Equality +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits +import GHC.TypeNats qualified as TN + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Types + + +-- * Ranked lists + +type role ListR nominal representational +type ListR :: Nat -> Type -> Type +data ListR n i where + ZR :: ListR 0 i + (:::) :: forall n {i}. i -> ListR n i -> ListR (n + 1) i +deriving instance Eq i => Eq (ListR n i) +deriving instance Ord i => Ord (ListR n i) +deriving instance Functor (ListR n) +deriving instance Foldable (ListR n) +infixr 3 ::: + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ListR n i) +#else +instance Show i => Show (ListR n i) where + showsPrec _ = listrShow shows +#endif + +instance NFData i => NFData (ListR n i) where + rnf ZR = () + rnf (x ::: l) = rnf x `seq` rnf l + +data UnconsListRRes i n1 = + forall n. (n + 1 ~ n1) => UnconsListRRes (ListR n i) i +listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1) +listrUncons (i ::: sh') = Just (UnconsListRRes sh' i) +listrUncons ZR = Nothing + +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n') +listrEqRank ZR ZR = Just Refl +listrEqRank (_ ::: sh) (_ ::: sh') + | Just Refl <- listrEqRank sh sh' + = Just Refl +listrEqRank _ _ = Nothing + +-- | This compares the lists for value equality. +listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n') +listrEqual ZR ZR = Just Refl +listrEqual (i ::: sh) (j ::: sh') + | Just Refl <- listrEqual sh sh' + , i == j + = Just Refl +listrEqual _ _ = Nothing + +listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS +listrShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListR n' i -> ShowS + go _ ZR = id + go prefix (x ::: xs) = showString prefix . f x . go "," xs + +listrLength :: ListR n i -> Int +listrLength = length + +listrRank :: ListR n i -> SNat n +listrRank ZR = SNat +listrRank (_ ::: sh) = snatSucc (listrRank sh) + +listrAppend :: ListR n i -> ListR m i -> ListR (n + m) i +listrAppend ZR sh = sh +listrAppend (x ::: xs) sh = x ::: listrAppend xs sh + +listrFromList :: [i] -> (forall n. ListR n i -> r) -> r +listrFromList [] k = k ZR +listrFromList (x : xs) k = listrFromList xs $ \l -> k (x ::: l) + +listrHead :: ListR (n + 1) i -> i +listrHead (i ::: _) = i +listrHead ZR = error "unreachable" + +listrTail :: ListR (n + 1) i -> ListR n i +listrTail (_ ::: sh) = sh +listrTail ZR = error "unreachable" + +listrInit :: ListR (n + 1) i -> ListR n i +listrInit (n ::: sh@(_ ::: _)) = n ::: listrInit sh +listrInit (_ ::: ZR) = ZR +listrInit ZR = error "unreachable" + +listrLast :: ListR (n + 1) i -> i +listrLast (_ ::: sh@(_ ::: _)) = listrLast sh +listrLast (n ::: ZR) = n +listrLast ZR = error "unreachable" + +-- | Performs a runtime check that the lengths are identical. +listrCast :: SNat n' -> ListR n i -> ListR n' i +listrCast = listrCastWithName "listrCast" + +listrIndex :: forall k n i. (k + 1 <= n) => SNat k -> ListR n i -> i +listrIndex SZ (x ::: _) = x +listrIndex (SS i) (_ ::: xs) | Refl <- lemLeqSuccSucc (Proxy @k) (Proxy @n) = listrIndex i xs +listrIndex _ ZR = error "k + 1 <= 0" + +listrZip :: ListR n i -> ListR n j -> ListR n (i, j) +listrZip ZR ZR = ZR +listrZip (i ::: irest) (j ::: jrest) = (i, j) ::: listrZip irest jrest +listrZip _ _ = error "listrZip: impossible pattern needlessly required" + +listrZipWith :: (i -> j -> k) -> ListR n i -> ListR n j -> ListR n k +listrZipWith _ ZR ZR = ZR +listrZipWith f (i ::: irest) (j ::: jrest) = + f i j ::: listrZipWith f irest jrest +listrZipWith _ _ _ = + error "listrZipWith: impossible pattern needlessly required" + +listrPermutePrefix :: forall i n. [Int] -> ListR n i -> ListR n i +listrPermutePrefix = \perm sh -> + listrFromList perm $ \sperm -> + case (listrRank sperm, listrRank sh) of + (permlen@SNat, shlen@SNat) -> case cmpNat permlen shlen of + LTI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + EQI -> let (pre, post) = listrSplitAt permlen sh in listrAppend (applyPermRFull permlen sperm pre) post + GTI -> error $ "Length of permutation (" ++ show (fromSNat' permlen) ++ ")" + ++ " > length of shape (" ++ show (fromSNat' shlen) ++ ")" + where + listrSplitAt :: m <= n' => SNat m -> ListR n' i -> (ListR m i, ListR (n' - m) i) + listrSplitAt SZ sh = (ZR, sh) + listrSplitAt (SS m) (n ::: sh) = (\(pre, post) -> (n ::: pre, post)) (listrSplitAt m sh) + listrSplitAt SS{} ZR = error "m' + 1 <= 0" + + applyPermRFull :: SNat m -> ListR k Int -> ListR m i -> ListR k i + applyPermRFull _ ZR _ = ZR + applyPermRFull sm@SNat (i ::: perm) l = + TN.withSomeSNat (fromIntegral i) $ \si@(SNat :: SNat idx) -> + case cmpNat (SNat @(idx + 1)) sm of + LTI -> listrIndex si l ::: applyPermRFull sm perm l + EQI -> listrIndex si l ::: applyPermRFull sm perm l + GTI -> error "listrPermutePrefix: Index in permutation out of range" + + +-- * Ranked indices + +-- | An index into a rank-typed array. +type role IxR nominal representational +type IxR :: Nat -> Type -> Type +newtype IxR n i = IxR (ListR n i) + deriving (Eq, Ord, Generic) + deriving newtype (Functor, Foldable) + +pattern ZIR :: forall n i. () => n ~ 0 => IxR n i +pattern ZIR = IxR ZR + +pattern (:.:) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> IxR n i -> IxR n1 i +pattern i :.: sh <- IxR (listrUncons -> Just (UnconsListRRes (IxR -> sh) i)) + where i :.: IxR sh = IxR (i ::: sh) +infixr 3 :.: + +{-# COMPLETE ZIR, (:.:) #-} + +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). +type IIxR n = IxR n Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (IxR n i) +#else +instance Show i => Show (IxR n i) where + showsPrec _ (IxR l) = listrShow shows l +#endif + +instance NFData i => NFData (IxR sh i) + +ixrLength :: IxR sh i -> Int +ixrLength (IxR l) = listrLength l + +ixrRank :: IxR n i -> SNat n +ixrRank (IxR sh) = listrRank sh + +ixrZero :: SNat n -> IIxR n +ixrZero SZ = ZIR +ixrZero (SS n) = 0 :.: ixrZero n + +ixrHead :: IxR (n + 1) i -> i +ixrHead (IxR list) = listrHead list + +ixrTail :: IxR (n + 1) i -> IxR n i +ixrTail (IxR list) = IxR (listrTail list) + +ixrInit :: IxR (n + 1) i -> IxR n i +ixrInit (IxR list) = IxR (listrInit list) + +ixrLast :: IxR (n + 1) i -> i +ixrLast (IxR list) = listrLast list + +-- | Performs a runtime check that the lengths are identical. +ixrCast :: SNat n' -> IxR n i -> IxR n' i +ixrCast n (IxR idx) = IxR (listrCastWithName "ixrCast" n idx) + +ixrAppend :: forall n m i. IxR n i -> IxR m i -> IxR (n + m) i +ixrAppend = coerce (listrAppend @_ @i) + +ixrZip :: IxR n i -> IxR n j -> IxR n (i, j) +ixrZip (IxR l1) (IxR l2) = IxR $ listrZip l1 l2 + +ixrZipWith :: (i -> j -> k) -> IxR n i -> IxR n j -> IxR n k +ixrZipWith f (IxR l1) (IxR l2) = IxR $ listrZipWith f l1 l2 + +ixrPermutePrefix :: forall n i. [Int] -> IxR n i -> IxR n i +ixrPermutePrefix = coerce (listrPermutePrefix @i) + + +-- * Ranked shapes + +type role ShR nominal representational +type ShR :: Nat -> Type -> Type +newtype ShR n i = ShR (ListR n i) + deriving (Eq, Ord, Generic) + deriving newtype (Functor, Foldable) + +pattern ZSR :: forall n i. () => n ~ 0 => ShR n i +pattern ZSR = ShR ZR + +pattern (:$:) + :: forall {n1} {i}. + forall n. (n + 1 ~ n1) + => i -> ShR n i -> ShR n1 i +pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i)) + where i :$: ShR sh = ShR (i ::: sh) +infixr 3 :$: + +{-# COMPLETE ZSR, (:$:) #-} + +type IShR n = ShR n Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (ShR n i) +#else +instance Show i => Show (ShR n i) where + showsPrec _ (ShR l) = listrShow shows l +#endif + +instance NFData i => NFData (ShR sh i) + +-- | This checks only whether the ranks are equal, not whether the actual +-- values are. +shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n') +shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh' + +-- | This compares the shapes for value equality. +shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n') +shrEqual (ShR sh) (ShR sh') = listrEqual sh sh' + +shrLength :: ShR sh i -> Int +shrLength (ShR l) = listrLength l + +-- | This function can also be used to conjure up a 'KnownNat' dictionary; +-- pattern matching on the returned 'SNat' with the 'pattern SNat' pattern +-- synonym yields 'KnownNat' evidence. +shrRank :: ShR n i -> SNat n +shrRank (ShR sh) = listrRank sh + +-- | The number of elements in an array described by this shape. +shrSize :: IShR n -> Int +shrSize ZSR = 1 +shrSize (n :$: sh) = n * shrSize sh + +shrHead :: ShR (n + 1) i -> i +shrHead (ShR list) = listrHead list + +shrTail :: ShR (n + 1) i -> ShR n i +shrTail (ShR list) = ShR (listrTail list) + +shrInit :: ShR (n + 1) i -> ShR n i +shrInit (ShR list) = ShR (listrInit list) + +shrLast :: ShR (n + 1) i -> i +shrLast (ShR list) = listrLast list + +-- | Performs a runtime check that the lengths are identical. +shrCast :: SNat n' -> ShR n i -> ShR n' i +shrCast n (ShR sh) = ShR (listrCastWithName "shrCast" n sh) + +shrAppend :: forall n m i. ShR n i -> ShR m i -> ShR (n + m) i +shrAppend = coerce (listrAppend @_ @i) + +shrZip :: ShR n i -> ShR n j -> ShR n (i, j) +shrZip (ShR l1) (ShR l2) = ShR $ listrZip l1 l2 + +shrZipWith :: (i -> j -> k) -> ShR n i -> ShR n j -> ShR n k +shrZipWith f (ShR l1) (ShR l2) = ShR $ listrZipWith f l1 l2 + +shrPermutePrefix :: forall n i. [Int] -> ShR n i -> ShR n i +shrPermutePrefix = coerce (listrPermutePrefix @i) + + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ListR n i) where + type Item (ListR n i) = i + fromList topl = go (SNat @n) topl + where + go :: SNat n' -> [i] -> ListR n' i + go SZ [] = ZR + go (SS n) (i : is) = i ::: go n is + go _ _ = error $ "IsList(ListR): Mismatched list length (type says " + ++ show (fromSNat (SNat @n)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = Foldable.toList + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (IxR n i) where + type Item (IxR n i) = i + fromList = IxR . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length is checked at runtime. +instance KnownNat n => IsList (ShR n i) where + type Item (ShR n i) = i + fromList = ShR . IsList.fromList + toList = Foldable.toList + + +-- * Internal helper functions + +listrCastWithName :: String -> SNat n' -> ListR n i -> ListR n' i +listrCastWithName _ SZ ZR = ZR +listrCastWithName name (SS n) (i ::: idx) = i ::: listrCastWithName name n idx +listrCastWithName name _ _ = error $ name ++ ": ranks don't match" diff --git a/src/Data/Array/Nested/Shaped.hs b/src/Data/Array/Nested/Shaped.hs new file mode 100644 index 0000000..198a068 --- /dev/null +++ b/src/Data/Array/Nested/Shaped.hs @@ -0,0 +1,272 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +module Data.Array.Nested.Shaped ( + Shaped(Shaped), + squotArray, sremArray, satan2Array, + sshape, + module Data.Array.Nested.Shaped, + liftShaped1, liftShaped2, +) where + +import Prelude hiding (mappend, mconcat) + +import Data.Array.Internal.RankedG qualified as RG +import Data.Array.Internal.RankedS qualified as RS +import Data.Array.Internal.ShapedG qualified as SG +import Data.Array.Internal.ShapedS qualified as SS +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.TypeLits + +import Data.Array.Nested.Convert +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Shaped.Base +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) +import Data.Array.XArray qualified as X + + +semptyArray :: KnownElt a => ShS sh -> Shaped (0 : sh) a +semptyArray sh = Shaped (memptyArray (shxFromShS sh)) + +srank :: Elt a => Shaped sh a -> SNat (Rank sh) +srank = shsRank . sshape + +-- | The total number of elements in the array. +ssize :: Elt a => Shaped sh a -> Int +ssize = shsSize . sshape + +sindex :: Elt a => Shaped sh a -> IIxS sh -> a +sindex (Shaped arr) idx = mindex arr (ixxFromIxS idx) + +shsTakeIx :: Proxy sh' -> ShS (sh ++ sh') -> IIxS sh -> ShS sh +shsTakeIx _ _ ZIS = ZSS +shsTakeIx p sh (_ :.$ idx) = case sh of n :$$ sh' -> n :$$ shsTakeIx p sh' idx + +sindexPartial :: forall sh1 sh2 a. Elt a => Shaped (sh1 ++ sh2) a -> IIxS sh1 -> Shaped sh2 a +sindexPartial sarr@(Shaped arr) idx = + Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) + (castWith (subst2 (lemMapJustApp (shsTakeIx (Proxy @sh2) (sshape sarr) idx) (Proxy @sh2))) arr) + (ixxFromIxS idx)) + +-- | __WARNING__: All values returned from the function must have equal shape. +-- See the documentation of 'mgenerate' for more details. +sgenerate :: forall sh a. KnownElt a => ShS sh -> (IIxS sh -> a) -> Shaped sh a +sgenerate sh f = Shaped (mgenerate (shxFromShS sh) (f . ixsFromIxX sh)) + +-- | See the documentation of 'mlift'. +slift :: forall sh1 sh2 a. Elt a + => ShS sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) + -> Shaped sh1 a -> Shaped sh2 a +slift sh2 f (Shaped arr) = Shaped (mlift (ssxFromShX (shxFromShS sh2)) f arr) + +-- | See the documentation of 'mlift'. +slift2 :: forall sh1 sh2 sh3 a. Elt a + => ShS sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b -> XArray (MapJust sh3 ++ sh') b) + -> Shaped sh1 a -> Shaped sh2 a -> Shaped sh3 a +slift2 sh3 f (Shaped arr1) (Shaped arr2) = Shaped (mlift2 (ssxFromShX (shxFromShS sh3)) f arr1 arr2) + +ssumOuter1P :: forall sh n a. (Storable a, NumElt a) + => Shaped (n : sh) (Primitive a) -> Shaped sh (Primitive a) +ssumOuter1P (Shaped arr) = Shaped (msumOuter1P arr) + +ssumOuter1 :: forall sh n a. (NumElt a, PrimElt a) + => Shaped (n : sh) a -> Shaped sh a +ssumOuter1 = sfromPrimitive . ssumOuter1P . stoPrimitive + +ssumAllPrim :: (PrimElt a, NumElt a) => Shaped n a -> a +ssumAllPrim (Shaped arr) = msumAllPrim arr + +stranspose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh, Elt a) + => Perm is -> Shaped sh a -> Shaped (PermutePrefix is sh) a +stranspose perm sarr@(Shaped arr) + | Refl <- lemRankMapJust (sshape sarr) + , Refl <- lemTakeLenMapJust perm (sshape sarr) + , Refl <- lemDropLenMapJust perm (sshape sarr) + , Refl <- lemPermuteMapJust perm (shsTakeLen perm (sshape sarr)) + , Refl <- lemMapJustApp (shsPermute perm (shsTakeLen perm (sshape sarr))) (Proxy @(DropLen is sh)) + = Shaped (mtranspose perm arr) + +sappend :: Elt a => Shaped (n : sh) a -> Shaped (m : sh) a -> Shaped (n + m : sh) a +sappend = coerce mappend + +sscalar :: Elt a => a -> Shaped '[] a +sscalar x = Shaped (mscalar x) + +sfromVectorP :: Storable a => ShS sh -> VS.Vector a -> Shaped sh (Primitive a) +sfromVectorP sh v = Shaped (mfromVectorP (shxFromShS sh) v) + +sfromVector :: PrimElt a => ShS sh -> VS.Vector a -> Shaped sh a +sfromVector sh v = sfromPrimitive (sfromVectorP sh v) + +stoVectorP :: Storable a => Shaped sh (Primitive a) -> VS.Vector a +stoVectorP = coerce mtoVectorP + +stoVector :: PrimElt a => Shaped sh a -> VS.Vector a +stoVector = coerce mtoVector + +sfromList1 :: Elt a => SNat n -> NonEmpty a -> Shaped '[n] a +sfromList1 sn = Shaped . mcast (SKnown sn :!% ZKX) . mfromList1 + +sfromListOuter :: Elt a => SNat n -> NonEmpty (Shaped sh a) -> Shaped (n : sh) a +sfromListOuter sn l = Shaped (mcastPartial (SUnknown () :!% ZKX) (SKnown sn :!% ZKX) Proxy $ mfromListOuter (coerce l)) + +sfromListLinear :: forall sh a. Elt a => ShS sh -> NonEmpty a -> Shaped sh a +sfromListLinear sh l = Shaped (mfromListLinear (shxFromShS sh) l) + +sfromListPrim :: forall n a. PrimElt a => SNat n -> [a] -> Shaped '[n] a +sfromListPrim sn l + | Refl <- lemAppNil @'[Just n] + = let ssh = SUnknown () :!% ZKX + xarr = X.cast ssh (SKnown sn :$% ZSX) ZKX (X.fromList1 ssh l) + in Shaped $ fromPrimitive $ M_Primitive (X.shape (SKnown sn :!% ZKX) xarr) xarr + +sfromListPrimLinear :: PrimElt a => ShS sh -> [a] -> Shaped sh a +sfromListPrimLinear sh l = + let M_Primitive _ xarr = toPrimitive (mfromListPrim l) + in Shaped $ fromPrimitive $ M_Primitive (shxFromShS sh) (X.reshape (SUnknown () :!% ZKX) (shxFromShS sh) xarr) + +stoList :: Elt a => Shaped '[n] a -> [a] +stoList = map sunScalar . stoListOuter + +stoListOuter :: Elt a => Shaped (n : sh) a -> [Shaped sh a] +stoListOuter (Shaped arr) = coerce (mtoListOuter arr) + +stoListLinear :: Elt a => Shaped sh a -> [a] +stoListLinear (Shaped arr) = mtoListLinear arr + +sfromOrthotope :: PrimElt a => ShS sh -> SS.Array sh a -> Shaped sh a +sfromOrthotope sh (SS.A (SG.A arr)) = + Shaped (fromPrimitive (M_Primitive (shxFromShS sh) (X.XArray (RS.A (RG.A (shsToList sh) arr))))) + +stoOrthotope :: PrimElt a => Shaped sh a -> SS.Array sh a +stoOrthotope (stoPrimitive -> Shaped (M_Primitive _ (X.XArray (RS.A (RG.A _ arr))))) = SS.A (SG.A arr) + +sunScalar :: Elt a => Shaped '[] a -> a +sunScalar arr = sindex arr ZIS + +snest :: forall sh sh' a. Elt a => ShS sh -> Shaped (sh ++ sh') a -> Shaped sh (Shaped sh' a) +snest sh arr + | Refl <- lemMapJustApp sh (Proxy @sh') + = coerce (mnest (ssxFromShX (shxFromShS sh)) (coerce arr)) + +sunNest :: forall sh sh' a. Elt a => Shaped sh (Shaped sh' a) -> Shaped (sh ++ sh') a +sunNest sarr@(Shaped (M_Shaped (M_Nest _ arr))) + | Refl <- lemMapJustApp (sshape sarr) (Proxy @sh') + = Shaped arr + +szip :: (Elt a, Elt b) => Shaped sh a -> Shaped sh b -> Shaped sh (a, b) +szip = coerce mzip + +sunzip :: Shaped sh (a, b) -> (Shaped sh a, Shaped sh b) +sunzip = coerce munzip + +srerankP :: forall sh1 sh2 sh a b. (Storable a, Storable b) + => ShS sh -> ShS sh2 + -> (Shaped sh1 (Primitive a) -> Shaped sh2 (Primitive b)) + -> Shaped (sh ++ sh1) (Primitive a) -> Shaped (sh ++ sh2) (Primitive b) +srerankP sh sh2 f sarr@(Shaped arr) + | Refl <- lemMapJustApp sh (Proxy @sh1) + , Refl <- lemMapJustApp sh (Proxy @sh2) + = Shaped (mrerankP (ssxFromShX (shxTakeSSX (Proxy @(MapJust sh1)) (ssxFromShX (shxFromShS sh)) (shxFromShS (sshape sarr)))) + (shxFromShS sh2) + (\a -> let Shaped r = f (Shaped a) in r) + arr) + +-- | See the caveats at 'Data.Array.XArray.rerank'. +srerank :: forall sh1 sh2 sh a b. (PrimElt a, PrimElt b) + => ShS sh -> ShS sh2 + -> (Shaped sh1 a -> Shaped sh2 b) + -> Shaped (sh ++ sh1) a -> Shaped (sh ++ sh2) b +srerank sh sh2 f (stoPrimitive -> arr) = + sfromPrimitive $ srerankP sh sh2 (stoPrimitive . f . sfromPrimitive) arr + +sreplicate :: forall sh sh' a. Elt a => ShS sh -> Shaped sh' a -> Shaped (sh ++ sh') a +sreplicate sh (Shaped arr) + | Refl <- lemMapJustApp sh (Proxy @sh') + = Shaped (mreplicate (shxFromShS sh) arr) + +sreplicateScalP :: forall sh a. Storable a => ShS sh -> a -> Shaped sh (Primitive a) +sreplicateScalP sh x = Shaped (mreplicateScalP (shxFromShS sh) x) + +sreplicateScal :: PrimElt a => ShS sh -> a -> Shaped sh a +sreplicateScal sh x = sfromPrimitive (sreplicateScalP sh x) + +sslice :: Elt a => SNat i -> SNat n -> Shaped (i + n + k : sh) a -> Shaped (n : sh) a +sslice i n@SNat arr = + let _ :$$ sh = sshape arr + in slift (n :$$ sh) (\_ -> X.slice i n) arr + +srev1 :: Elt a => Shaped (n : sh) a -> Shaped (n : sh) a +srev1 arr = slift (sshape arr) (\_ -> X.rev1) arr + +sreshape :: (Elt a, Product sh ~ Product sh') => ShS sh' -> Shaped sh a -> Shaped sh' a +sreshape sh' (Shaped arr) = Shaped (mreshape (shxFromShS sh') arr) + +sflatten :: Elt a => Shaped sh a -> Shaped '[Product sh] a +sflatten arr = + case shsProduct (sshape arr) of -- TODO: simplify when removing the KnownNat stuff + n@SNat -> sreshape (n :$$ ZSS) arr + +siota :: (Enum a, PrimElt a) => SNat n -> Shaped '[n] a +siota sn = Shaped (miota sn) + +-- | Throws if the array is empty. +sminIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh +sminIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mminIndexPrim arr) + +-- | Throws if the array is empty. +smaxIndexPrim :: (PrimElt a, NumElt a) => Shaped sh a -> IIxS sh +smaxIndexPrim sarr@(Shaped arr) = ixsFromIxX (sshape (stoPrimitive sarr)) (mmaxIndexPrim arr) + +sdot1Inner :: forall sh n a. (PrimElt a, NumElt a) + => Proxy n -> Shaped (sh ++ '[n]) a -> Shaped (sh ++ '[n]) a -> Shaped sh a +sdot1Inner Proxy sarr1@(Shaped arr1) (Shaped arr2) + | Refl <- lemInitApp (Proxy @sh) (Proxy @n) + , Refl <- lemLastApp (Proxy @sh) (Proxy @n) + = case sshape sarr1 of + _ :$$ _ + | Refl <- lemMapJustApp (shsInit (sshape sarr1)) (Proxy @'[n]) + -> Shaped (mdot1Inner (Proxy @(Just n)) arr1 arr2) + _ -> error "unreachable" + +-- | This has a temporary, suboptimal implementation in terms of 'mflatten'. +-- Prefer 'sdot1Inner' if applicable. +sdot :: (PrimElt a, NumElt a) => Shaped sh a -> Shaped sh a -> a +sdot = coerce mdot + +stoXArrayPrimP :: Shaped sh (Primitive a) -> (ShS sh, XArray (MapJust sh) a) +stoXArrayPrimP (Shaped arr) = first shsFromShX (mtoXArrayPrimP arr) + +stoXArrayPrim :: PrimElt a => Shaped sh a -> (ShS sh, XArray (MapJust sh) a) +stoXArrayPrim (Shaped arr) = first shsFromShX (mtoXArrayPrim arr) + +sfromXArrayPrimP :: ShS sh -> XArray (MapJust sh) a -> Shaped sh (Primitive a) +sfromXArrayPrimP sh arr = Shaped (mfromXArrayPrimP (ssxFromShX (shxFromShS sh)) arr) + +sfromXArrayPrim :: PrimElt a => ShS sh -> XArray (MapJust sh) a -> Shaped sh a +sfromXArrayPrim sh arr = Shaped (mfromXArrayPrim (ssxFromShX (shxFromShS sh)) arr) + +sfromPrimitive :: PrimElt a => Shaped sh (Primitive a) -> Shaped sh a +sfromPrimitive (Shaped arr) = Shaped (fromPrimitive arr) + +stoPrimitive :: PrimElt a => Shaped sh a -> Shaped sh (Primitive a) +stoPrimitive (Shaped arr) = Shaped (toPrimitive arr) diff --git a/src/Data/Array/Nested/Shaped/Base.hs b/src/Data/Array/Nested/Shaped/Base.hs new file mode 100644 index 0000000..879e6b5 --- /dev/null +++ b/src/Data/Array/Nested/Shaped/Base.hs @@ -0,0 +1,255 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_HADDOCK not-home #-} +module Data.Array.Nested.Shaped.Base where + +import Prelude hiding (mappend, mconcat) + +import Control.DeepSeq (NFData(..)) +import Control.Monad.ST +import Data.Bifunctor (first) +import Data.Coerce (coerce) +import Data.Kind (Type) +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Foreign.Storable (Storable) +import GHC.Float qualified (expm1, log1mexp, log1p, log1pexp) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Shaped.Shape +import Data.Array.Nested.Types +import Data.Array.Strided.Arith +import Data.Array.XArray (XArray) + + +-- | A shape-typed array: the full shape of the array (the sizes of its +-- dimensions) is represented on the type level as a list of 'Nat's. Note that +-- these are "GHC.TypeLits" naturals, because we do not need induction over +-- them and we want very large arrays to be possible. +-- +-- Like for 'Ranked', the valid elements are described by the 'Elt' type class, +-- and 'Shaped' itself is again an instance of 'Elt' as well. +-- +-- 'Shaped' is a newtype around a 'Mixed' of 'Just's. +type Shaped :: [Nat] -> Type -> Type +newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed (MapJust sh) a) => Show (Shaped sh a) +#endif +deriving instance Eq (Mixed (MapJust sh) a) => Eq (Shaped sh a) +deriving instance Ord (Mixed (MapJust sh) a) => Ord (Shaped sh a) + +#ifndef OXAR_DEFAULT_SHOW_INSTANCES +instance (Show a, Elt a) => Show (Shaped n a) where + showsPrec d arr@(Shaped marr) = + let sh = show (shsToList (sshape arr)) + in showsMixedArray ("sfromListLinear " ++ sh) ("sreplicate " ++ sh) d marr +#endif + +instance Elt a => NFData (Shaped sh a) where + rnf (Shaped arr) = rnf arr + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh') a)) + deriving (Generic) +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (Mixed sh (Mixed (MapJust sh') a)) => Show (Mixed sh (Shaped sh' a)) +#endif + +deriving instance Eq (Mixed sh (Mixed (MapJust sh') a)) => Eq (Mixed sh (Shaped sh' a)) + +newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh') a)) + +instance Elt a => Elt (Shaped sh a) where + mshape (M_Shaped arr) = mshape arr + mindex (M_Shaped arr) i = Shaped (mindex arr i) + + mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + mindexPartial (M_Shaped arr) i = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mindexPartial arr i + + mscalar (Shaped x) = M_Shaped (M_Nest ZSX x) + + mfromListOuter :: forall sh'. NonEmpty (Mixed sh' (Shaped sh a)) -> Mixed (Nothing : sh') (Shaped sh a) + mfromListOuter l = M_Shaped (mfromListOuter (coerce l)) + + mtoListOuter :: forall n sh'. Mixed (n : sh') (Shaped sh a) -> [Mixed sh' (Shaped sh a)] + mtoListOuter (M_Shaped arr) + = coerce @[Mixed sh' (Mixed (MapJust sh) a)] @[Mixed sh' (Shaped sh a)] (mtoListOuter arr) + + mlift :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) + mlift ssh2 f (M_Shaped arr) = + coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mlift ssh2 f arr + + mlift2 :: forall sh1 sh2 sh3. + StaticShX sh3 + -> (forall sh' b. Storable b => StaticShX sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b -> XArray (sh3 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) -> Mixed sh3 (Shaped sh a) + mlift2 ssh3 f (M_Shaped arr1) (M_Shaped arr2) = + coerce @(Mixed sh3 (Mixed (MapJust sh) a)) @(Mixed sh3 (Shaped sh a)) $ + mlift2 ssh3 f arr1 arr2 + + mliftL :: forall sh1 sh2. + StaticShX sh2 + -> (forall sh' b. Storable b => StaticShX sh' -> NonEmpty (XArray (sh1 ++ sh') b) -> NonEmpty (XArray (sh2 ++ sh') b)) + -> NonEmpty (Mixed sh1 (Shaped sh a)) -> NonEmpty (Mixed sh2 (Shaped sh a)) + mliftL ssh2 f l = + coerce @(NonEmpty (Mixed sh2 (Mixed (MapJust sh) a))) + @(NonEmpty (Mixed sh2 (Shaped sh a))) $ + mliftL ssh2 f (coerce l) + + mcastPartial ssh1 ssh2 psh' (M_Shaped arr) = M_Shaped (mcastPartial ssh1 ssh2 psh' arr) + + mtranspose perm (M_Shaped arr) = M_Shaped (mtranspose perm arr) + + mconcat l = M_Shaped (mconcat (coerce l)) + + mrnf (M_Shaped arr) = mrnf arr + + type ShapeTree (Shaped sh a) = (ShS sh, ShapeTree a) + + mshapeTree (Shaped arr) = first shsFromShX (mshapeTree arr) + + mshapeTreeEq _ (sh1, t1) (sh2, t2) = sh1 == sh2 && mshapeTreeEq (Proxy @a) t1 t2 + + mshapeTreeEmpty _ (sh, t) = shsSize sh == 0 && mshapeTreeEmpty (Proxy @a) t + + mshowShapeTree _ (sh, t) = "(" ++ show sh ++ ", " ++ mshowShapeTree (Proxy @a) t ++ ")" + + marrayStrides (M_Shaped arr) = marrayStrides arr + + mvecsWrite :: forall sh' s. IShX sh' -> IIxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWrite sh idx (Shaped arr) vecs = + mvecsWrite sh idx arr + (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + + mvecsWritePartial :: forall sh1 sh2 s. + IShX (sh1 ++ sh2) -> IIxX sh1 -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartial sh idx arr vecs = + mvecsWritePartial sh idx + (coerce @(Mixed sh2 (Shaped sh a)) + @(Mixed sh2 (Mixed (MapJust sh) a)) + arr) + (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) + @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) + vecs) + + mvecsFreeze :: forall sh' s. IShX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsFreeze sh vecs = + coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + +instance (KnownShS sh, KnownElt a) => KnownElt (Shaped sh a) where + memptyArrayUnsafe :: forall sh'. IShX sh' -> Mixed sh' (Shaped sh a) + memptyArrayUnsafe i + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ + memptyArrayUnsafe i + + mvecsUnsafeNew idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsUnsafeNew idx arr + + mvecsNewEmpty _ + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsNewEmpty (Proxy @(Mixed (MapJust sh) a)) + + +liftShaped1 :: forall sh a b. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b) + -> Shaped sh a -> Shaped sh b +liftShaped1 = coerce + +liftShaped2 :: forall sh a b c. + (Mixed (MapJust sh) a -> Mixed (MapJust sh) b -> Mixed (MapJust sh) c) + -> Shaped sh a -> Shaped sh b -> Shaped sh c +liftShaped2 = coerce + +instance (NumElt a, PrimElt a) => Num (Shaped sh a) where + (+) = liftShaped2 (+) + (-) = liftShaped2 (-) + (*) = liftShaped2 (*) + negate = liftShaped1 negate + abs = liftShaped1 abs + signum = liftShaped1 signum + fromInteger = error "Data.Array.Nested.fromInteger: No singletons available, use explicit sreplicateScal" + +instance (FloatElt a, PrimElt a) => Fractional (Shaped sh a) where + fromRational = error "Data.Array.Nested.fromRational: No singletons available, use explicit sreplicateScal" + recip = liftShaped1 recip + (/) = liftShaped2 (/) + +instance (FloatElt a, PrimElt a) => Floating (Shaped sh a) where + pi = error "Data.Array.Nested.pi: No singletons available, use explicit sreplicateScal" + exp = liftShaped1 exp + log = liftShaped1 log + sqrt = liftShaped1 sqrt + (**) = liftShaped2 (**) + logBase = liftShaped2 logBase + sin = liftShaped1 sin + cos = liftShaped1 cos + tan = liftShaped1 tan + asin = liftShaped1 asin + acos = liftShaped1 acos + atan = liftShaped1 atan + sinh = liftShaped1 sinh + cosh = liftShaped1 cosh + tanh = liftShaped1 tanh + asinh = liftShaped1 asinh + acosh = liftShaped1 acosh + atanh = liftShaped1 atanh + log1p = liftShaped1 GHC.Float.log1p + expm1 = liftShaped1 GHC.Float.expm1 + log1pexp = liftShaped1 GHC.Float.log1pexp + log1mexp = liftShaped1 GHC.Float.log1mexp + +squotArray, sremArray :: (IntElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +squotArray = liftShaped2 mquotArray +sremArray = liftShaped2 mremArray + +satan2Array :: (FloatElt a, PrimElt a) => Shaped sh a -> Shaped sh a -> Shaped sh a +satan2Array = liftShaped2 matan2Array + + +sshape :: forall sh a. Elt a => Shaped sh a -> ShS sh +sshape (Shaped arr) = shsFromShX (mshape arr) + +-- Needed already here, but re-exported in Data.Array.Nested.Convert. +shsFromShX :: forall sh i. ShX (MapJust sh) i -> ShS sh +shsFromShX ZSX = castWith (subst1 (unsafeCoerceRefl :: '[] :~: sh)) ZSS +shsFromShX (SKnown n@SNat :$% (idx :: ShX mjshT i)) = + castWith (subst1 (sym (lemMapJustCons Refl))) $ + n :$$ shsFromShX @(Tail sh) (castWith (subst2 (unsafeCoerceRefl :: mjshT :~: MapJust (Tail sh))) + idx) +shsFromShX (SUnknown _ :$% _) = error "impossible" diff --git a/src/Data/Array/Nested/Shaped/Shape.hs b/src/Data/Array/Nested/Shaped/Shape.hs new file mode 100644 index 0000000..5f9ba79 --- /dev/null +++ b/src/Data/Array/Nested/Shaped/Shape.hs @@ -0,0 +1,425 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Shaped.Shape where + +import Control.DeepSeq (NFData(..)) +import Data.Array.Shape qualified as O +import Data.Coerce (coerce) +import Data.Foldable qualified as Foldable +import Data.Functor.Const +import Data.Functor.Product qualified as Fun +import Data.Kind (Constraint, Type) +import Data.Monoid (Sum(..)) +import Data.Proxy +import Data.Type.Equality +import GHC.Exts (withDict) +import GHC.Generics (Generic) +import GHC.IsList (IsList) +import GHC.IsList qualified as IsList +import GHC.TypeLits + +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Types + + +-- * Shaped lists + +-- | Note: The 'KnownNat' constraint on '(::$)' is deprecated and should be +-- removed in a future release. +type role ListS nominal representational +type ListS :: [Nat] -> (Nat -> Type) -> Type +data ListS sh f where + ZS :: ListS '[] f + -- TODO: when the KnownNat constraint is removed, restore listsIndex to sanity + (::$) :: forall n sh {f}. KnownNat n => f n -> ListS sh f -> ListS (n : sh) f +deriving instance (forall n. Eq (f n)) => Eq (ListS sh f) +deriving instance (forall n. Ord (f n)) => Ord (ListS sh f) +infixr 3 ::$ + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance (forall n. Show (f n)) => Show (ListS sh f) +#else +instance (forall n. Show (f n)) => Show (ListS sh f) where + showsPrec _ = listsShow shows +#endif + +instance (forall m. NFData (f m)) => NFData (ListS n f) where + rnf ZS = () + rnf (x ::$ l) = rnf x `seq` rnf l + +data UnconsListSRes f sh1 = + forall n sh. (KnownNat n, n : sh ~ sh1) => UnconsListSRes (ListS sh f) (f n) +listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1) +listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x) +listsUncons ZS = Nothing + +-- | This checks only whether the types are equal; if the elements of the list +-- are not singletons, their values may still differ. This corresponds to +-- 'testEquality', except on the penultimate type parameter. +listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') +listsEqType ZS ZS = Just Refl +listsEqType (n ::$ sh) (m ::$ sh') + | Just Refl <- testEquality n m + , Just Refl <- listsEqType sh sh' + = Just Refl +listsEqType _ _ = Nothing + +-- | This checks whether the two lists actually contain equal values. This is +-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ +-- in the @some@ package (except on the penultimate type parameter). +listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh') +listsEqual ZS ZS = Just Refl +listsEqual (n ::$ sh) (m ::$ sh') + | Just Refl <- testEquality n m + , n == m + , Just Refl <- listsEqual sh sh' + = Just Refl +listsEqual _ _ = Nothing + +listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g +listsFmap _ ZS = ZS +listsFmap f (x ::$ xs) = f x ::$ listsFmap f xs + +listsFold :: Monoid m => (forall n. f n -> m) -> ListS sh f -> m +listsFold _ ZS = mempty +listsFold f (x ::$ xs) = f x <> listsFold f xs + +listsShow :: forall sh f. (forall n. f n -> ShowS) -> ListS sh f -> ShowS +listsShow f l = showString "[" . go "" l . showString "]" + where + go :: String -> ListS sh' f -> ShowS + go _ ZS = id + go prefix (x ::$ xs) = showString prefix . f x . go "," xs + +listsLength :: ListS sh f -> Int +listsLength = getSum . listsFold (\_ -> Sum 1) + +listsRank :: ListS sh f -> SNat (Rank sh) +listsRank ZS = SNat +listsRank (_ ::$ sh) = snatSucc (listsRank sh) + +listsToList :: ListS sh (Const i) -> [i] +listsToList ZS = [] +listsToList (Const i ::$ is) = i : listsToList is + +listsHead :: ListS (n : sh) f -> f n +listsHead (i ::$ _) = i + +listsTail :: ListS (n : sh) f -> ListS sh f +listsTail (_ ::$ sh) = sh + +listsInit :: ListS (n : sh) f -> ListS (Init (n : sh)) f +listsInit (n ::$ sh@(_ ::$ _)) = n ::$ listsInit sh +listsInit (_ ::$ ZS) = ZS + +listsLast :: ListS (n : sh) f -> f (Last (n : sh)) +listsLast (_ ::$ sh@(_ ::$ _)) = listsLast sh +listsLast (n ::$ ZS) = n + +listsAppend :: ListS sh f -> ListS sh' f -> ListS (sh ++ sh') f +listsAppend ZS idx' = idx' +listsAppend (i ::$ idx) idx' = i ::$ listsAppend idx idx' + +listsZip :: ListS sh f -> ListS sh g -> ListS sh (Fun.Product f g) +listsZip ZS ZS = ZS +listsZip (i ::$ is) (j ::$ js) = + Fun.Pair i j ::$ listsZip is js + +listsZipWith :: (forall a. f a -> g a -> h a) -> ListS sh f -> ListS sh g + -> ListS sh h +listsZipWith _ ZS ZS = ZS +listsZipWith f (i ::$ is) (j ::$ js) = + f i j ::$ listsZipWith f is js + +listsTakeLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (TakeLen is sh) f +listsTakeLenPerm PNil _ = ZS +listsTakeLenPerm (_ `PCons` is) (n ::$ sh) = n ::$ listsTakeLenPerm is sh +listsTakeLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" + +listsDropLenPerm :: forall f is sh. Perm is -> ListS sh f -> ListS (DropLen is sh) f +listsDropLenPerm PNil sh = sh +listsDropLenPerm (_ `PCons` is) (_ ::$ sh) = listsDropLenPerm is sh +listsDropLenPerm (_ `PCons` _) ZS = error "Permutation longer than shape" + +listsPermute :: forall f is sh. Perm is -> ListS sh f -> ListS (Permute is sh) f +listsPermute PNil _ = ZS +listsPermute (i `PCons` (is :: Perm is')) (sh :: ListS sh f) = + case listsIndex (Proxy @is') (Proxy @sh) i sh of + (item, SNat) -> item ::$ listsPermute is sh + +-- TODO: remove this SNat when the KnownNat constaint in ListS is removed +listsIndex :: forall f i is sh shT. Proxy is -> Proxy shT -> SNat i -> ListS sh f -> (f (Index i sh), SNat (Index i sh)) +listsIndex _ _ SZ (n ::$ _) = (n, SNat) +listsIndex p pT (SS (i :: SNat i')) ((_ :: f n) ::$ (sh :: ListS sh' f)) + | Refl <- lemIndexSucc (Proxy @i') (Proxy @n) (Proxy @sh') + = listsIndex p pT i sh +listsIndex _ _ _ ZS = error "Index into empty shape" + +listsPermutePrefix :: forall f is sh. Perm is -> ListS sh f -> ListS (PermutePrefix is sh) f +listsPermutePrefix perm sh = listsAppend (listsPermute perm (listsTakeLenPerm perm sh)) (listsDropLenPerm perm sh) + +-- * Shaped indices + +-- | An index into a shape-typed array. +type role IxS nominal representational +type IxS :: [Nat] -> Type -> Type +newtype IxS sh i = IxS (ListS sh (Const i)) + deriving (Eq, Ord, Generic) + +pattern ZIS :: forall sh i. () => sh ~ '[] => IxS sh i +pattern ZIS = IxS ZS + +-- | Note: The 'KnownNat' constraint on '(:.$)' is deprecated and should be +-- removed in a future release. +pattern (:.$) + :: forall {sh1} {i}. + forall n sh. (KnownNat n, n : sh ~ sh1) + => i -> IxS sh i -> IxS sh1 i +pattern i :.$ shl <- IxS (listsUncons -> Just (UnconsListSRes (IxS -> shl) (getConst -> i))) + where i :.$ IxS shl = IxS (Const i ::$ shl) +infixr 3 :.$ + +{-# COMPLETE ZIS, (:.$) #-} + +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). +type IIxS sh = IxS sh Int + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show i => Show (IxS sh i) +#else +instance Show i => Show (IxS sh i) where + showsPrec _ (IxS l) = listsShow (\(Const i) -> shows i) l +#endif + +instance Functor (IxS sh) where + fmap f (IxS l) = IxS (listsFmap (Const . f . getConst) l) + +instance Foldable (IxS sh) where + foldMap f (IxS l) = listsFold (f . getConst) l + +instance NFData i => NFData (IxS sh i) + +ixsLength :: IxS sh i -> Int +ixsLength (IxS l) = listsLength l + +ixsRank :: IxS sh i -> SNat (Rank sh) +ixsRank (IxS l) = listsRank l + +ixsZero :: ShS sh -> IIxS sh +ixsZero ZSS = ZIS +ixsZero (_ :$$ sh) = 0 :.$ ixsZero sh + +ixsHead :: IxS (n : sh) i -> i +ixsHead (IxS list) = getConst (listsHead list) + +ixsTail :: IxS (n : sh) i -> IxS sh i +ixsTail (IxS list) = IxS (listsTail list) + +ixsInit :: IxS (n : sh) i -> IxS (Init (n : sh)) i +ixsInit (IxS list) = IxS (listsInit list) + +ixsLast :: IxS (n : sh) i -> i +ixsLast (IxS list) = getConst (listsLast list) + +-- TODO: this takes a ShS because there are KnownNats inside IxS. +ixsCast :: ShS sh' -> IxS sh i -> IxS sh' i +ixsCast ZSS ZIS = ZIS +ixsCast (_ :$$ sh) (i :.$ idx) = i :.$ ixsCast sh idx +ixsCast _ _ = error "ixsCast: ranks don't match" + +ixsAppend :: forall sh sh' i. IxS sh i -> IxS sh' i -> IxS (sh ++ sh') i +ixsAppend = coerce (listsAppend @_ @(Const i)) + +ixsZip :: IxS n i -> IxS n j -> IxS n (i, j) +ixsZip ZIS ZIS = ZIS +ixsZip (i :.$ is) (j :.$ js) = (i, j) :.$ ixsZip is js + +ixsZipWith :: (i -> j -> k) -> IxS n i -> IxS n j -> IxS n k +ixsZipWith _ ZIS ZIS = ZIS +ixsZipWith f (i :.$ is) (j :.$ js) = f i j :.$ ixsZipWith f is js + +ixsPermutePrefix :: forall i is sh. Perm is -> IxS sh i -> IxS (PermutePrefix is sh) i +ixsPermutePrefix = coerce (listsPermutePrefix @(Const i)) + + +-- * Shaped shapes + +-- | The shape of a shape-typed array given as a list of 'SNat' values. +-- +-- Note that because the shape of a shape-typed array is known statically, you +-- can also retrieve the array shape from a 'KnownShS' dictionary. +type role ShS nominal +type ShS :: [Nat] -> Type +newtype ShS sh = ShS (ListS sh SNat) + deriving (Eq, Ord, Generic) + +pattern ZSS :: forall sh. () => sh ~ '[] => ShS sh +pattern ZSS = ShS ZS + +pattern (:$$) + :: forall {sh1}. + forall n sh. (KnownNat n, n : sh ~ sh1) + => SNat n -> ShS sh -> ShS sh1 +pattern i :$$ shl <- ShS (listsUncons -> Just (UnconsListSRes (ShS -> shl) i)) + where i :$$ ShS shl = ShS (i ::$ shl) + +infixr 3 :$$ + +{-# COMPLETE ZSS, (:$$) #-} + +#ifdef OXAR_DEFAULT_SHOW_INSTANCES +deriving instance Show (ShS sh) +#else +instance Show (ShS sh) where + showsPrec _ (ShS l) = listsShow (shows . fromSNat) l +#endif + +instance NFData (ShS sh) where + rnf (ShS ZS) = () + rnf (ShS (SNat ::$ l)) = rnf (ShS l) + +instance TestEquality ShS where + testEquality (ShS l1) (ShS l2) = listsEqType l1 l2 + +-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are +-- equal if and only if values are equal.) +shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh') +shsEqual = testEquality + +shsLength :: ShS sh -> Int +shsLength (ShS l) = listsLength l + +shsRank :: ShS sh -> SNat (Rank sh) +shsRank (ShS l) = listsRank l + +shsSize :: ShS sh -> Int +shsSize ZSS = 1 +shsSize (n :$$ sh) = fromSNat' n * shsSize sh + +shsToList :: ShS sh -> [Int] +shsToList ZSS = [] +shsToList (sn :$$ sh) = fromSNat' sn : shsToList sh + +shsHead :: ShS (n : sh) -> SNat n +shsHead (ShS list) = listsHead list + +shsTail :: ShS (n : sh) -> ShS sh +shsTail (ShS list) = ShS (listsTail list) + +shsInit :: ShS (n : sh) -> ShS (Init (n : sh)) +shsInit (ShS list) = ShS (listsInit list) + +shsLast :: ShS (n : sh) -> SNat (Last (n : sh)) +shsLast (ShS list) = listsLast list + +shsAppend :: forall sh sh'. ShS sh -> ShS sh' -> ShS (sh ++ sh') +shsAppend = coerce (listsAppend @_ @SNat) + +shsTakeLen :: Perm is -> ShS sh -> ShS (TakeLen is sh) +shsTakeLen = coerce (listsTakeLenPerm @SNat) + +shsPermute :: Perm is -> ShS sh -> ShS (Permute is sh) +shsPermute = coerce (listsPermute @SNat) + +shsIndex :: Proxy is -> Proxy shT -> SNat i -> ShS sh -> SNat (Index i sh) +shsIndex pis pshT i sh = coerce (fst (listsIndex @SNat pis pshT i (coerce sh))) + +shsPermutePrefix :: forall is sh. Perm is -> ShS sh -> ShS (PermutePrefix is sh) +shsPermutePrefix = coerce (listsPermutePrefix @SNat) + +type family Product sh where + Product '[] = 1 + Product (n : ns) = n * Product ns + +shsProduct :: ShS sh -> SNat (Product sh) +shsProduct ZSS = SNat +shsProduct (n :$$ sh) = n `snatMul` shsProduct sh + +-- | Evidence for the static part of a shape. This pops up only when you are +-- polymorphic in the element type of an array. +type KnownShS :: [Nat] -> Constraint +class KnownShS sh where knownShS :: ShS sh +instance KnownShS '[] where knownShS = ZSS +instance (KnownNat n, KnownShS sh) => KnownShS (n : sh) where knownShS = natSing :$$ knownShS + +withKnownShS :: forall sh r. ShS sh -> (KnownShS sh => r) -> r +withKnownShS = withDict @(KnownShS sh) + +shsKnownShS :: ShS sh -> Dict KnownShS sh +shsKnownShS ZSS = Dict +shsKnownShS (SNat :$$ sh) | Dict <- shsKnownShS sh = Dict + +shsOrthotopeShape :: ShS sh -> Dict O.Shape sh +shsOrthotopeShape ZSS = Dict +shsOrthotopeShape (SNat :$$ sh) | Dict <- shsOrthotopeShape sh = Dict + +-- | This function is a hack made possible by the 'KnownNat' inside 'ListS'. +-- This function may be removed in a future release. +shsFromListS :: ListS sh f -> ShS sh +shsFromListS ZS = ZSS +shsFromListS (_ ::$ l) = SNat :$$ shsFromListS l + +-- | This function is a hack made possible by the 'KnownNat' inside 'IxS'. This +-- function may be removed in a future release. +shsFromIxS :: IxS sh i -> ShS sh +shsFromIxS (IxS l) = shsFromListS l + + +-- | Untyped: length is checked at runtime. +instance KnownShS sh => IsList (ListS sh (Const i)) where + type Item (ListS sh (Const i)) = i + fromList topl = go (knownShS @sh) topl + where + go :: ShS sh' -> [i] -> ListS sh' (Const i) + go ZSS [] = ZS + go (_ :$$ sh) (i : is) = Const i ::$ go sh is + go _ _ = error $ "IsList(ListS): Mismatched list length (type says " + ++ show (shsLength (knownShS @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = listsToList + +-- | Very untyped: only length is checked (at runtime), index bounds are __not checked__. +instance KnownShS sh => IsList (IxS sh i) where + type Item (IxS sh i) = i + fromList = IxS . IsList.fromList + toList = Foldable.toList + +-- | Untyped: length and values are checked at runtime. +instance KnownShS sh => IsList (ShS sh) where + type Item (ShS sh) = Int + fromList topl = ShS (go (knownShS @sh) topl) + where + go :: ShS sh' -> [Int] -> ListS sh' SNat + go ZSS [] = ZS + go (sn :$$ sh) (i : is) + | i == fromSNat' sn = sn ::$ go sh is + | otherwise = error $ "IsList(ShS): Value does not match typing (type says " + ++ show (fromSNat' sn) ++ ", list contains " ++ show i ++ ")" + go _ _ = error $ "IsList(ShS): Mismatched list length (type says " + ++ show (shsLength (knownShS @sh)) ++ ", list has length " + ++ show (length topl) ++ ")" + toList = shsToList diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs new file mode 100644 index 0000000..8a29aa5 --- /dev/null +++ b/src/Data/Array/Nested/Trace.hs @@ -0,0 +1,72 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ExplicitNamespaces #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TemplateHaskell #-} +{-| +This module is API-compatible with "Data.Array.Nested", except that inputs and +outputs of the methods are traced using 'Debug.Trace.trace'. Thus the methods +also have additional 'Show' constraints. + +>>> let res = rtranspose [1, 0] (rreshape (2 :$: 3 :$: ZSR) (riota @Int 6)) * rreshape (3 :$: 2 :$: ZSR) (rreplicate (6 :$: ZSR) (rscalar @Int 7)) +>>> length (show res) `seq` () +oxtrace: riota [Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5]))))] +oxtrace: rreshape [[2,3], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5]))))] +oxtrace: rtranspose [Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,3,1,4,2,5]))))] +oxtrace: rscalar [Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7]))))] +oxtrace: rreplicate [[6], Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7])))), Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7]))))] +oxtrace: rreshape [[3,2], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [7,7,7,7,7,7]))))] +>>> res +Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,21,7,28,14,35])))) +-} +module Data.Array.Nested.Trace ( + -- * Traced variants + module Data.Array.Nested.Trace, + + -- * Re-exports from the plain "Data.Array.Nested" module + Ranked(Ranked), + ListR(ZR, (:::)), + IxR(..), IIxR, + ShR(..), IShR, + + Shaped(Shaped), + ListS(ZS, (::$)), + IxS(..), IIxS, + ShS(..), KnownShS(..), + + Mixed, + ListX(ZX, (::%)), + IxX(..), IIxX, + ShX(..), KnownShX(..), IShX, + StaticShX(..), + SMayNat(..), + Conversion(..), + + Elt, + PrimElt, + Primitive(..), + KnownElt, + + type (++), + Storable, + SNat, pattern SNat, + pattern SZ, pattern SS, + Perm(..), + IsPermutation, + KnownPerm(..), + NumElt, IntElt, FloatElt, + Rank, Product, + Replicate, + MapJust, +) where + +import Prelude hiding (mappend, mconcat) + +import Data.Array.Nested +import Data.Array.Nested.Trace.TH + + +$(concat <$> mapM convertFun + ['rshape, 'rrank, 'rsize, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rsumAllPrim, 'rtranspose, 'rappend, 'rconcat, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'remptyArray, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromList1, 'rfromListOuter, 'rfromListLinear, 'rfromListPrim, 'rfromListPrimLinear, 'rtoList, 'rtoListOuter, 'rtoListLinear, 'rslice, 'rrev1, 'rreshape, 'rflatten, 'riota, 'rminIndexPrim, 'rmaxIndexPrim, 'rdot1Inner, 'rdot, 'rnest, 'runNest, 'rzip, 'runzip, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rtoMixed, 'rcastToMixed, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'rquotArray, 'rremArray, 'ratan2Array, 'sshape, 'srank, 'ssize, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'ssumAllPrim, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'semptyArray, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromList1, 'sfromListOuter, 'sfromListLinear, 'sfromListPrim, 'sfromListPrimLinear, 'stoList, 'stoListOuter, 'stoListLinear, 'sslice, 'srev1, 'sreshape, 'sflatten, 'siota, 'sminIndexPrim, 'smaxIndexPrim, 'sdot1Inner, 'sdot, 'snest, 'sunNest, 'szip, 'sunzip, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoMixed, 'scastToMixed, 'stoRanked, 'sfromOrthotope, 'stoOrthotope, 'squotArray, 'sremArray, 'satan2Array, 'mshape, 'mrank, 'msize, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'msumAllPrim, 'mtranspose, 'mappend, 'mconcat, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'memptyArray, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromList1, 'mfromListOuter, 'mfromListLinear, 'mfromListPrim, 'mfromListPrimLinear, 'mtoList, 'mtoListOuter, 'mtoListLinear, 'mslice, 'mrev1, 'mreshape, 'mflatten, 'miota, 'mminIndexPrim, 'mmaxIndexPrim, 'mdot1Inner, 'mdot, 'mnest, 'munNest, 'mzip, 'munzip, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mcast, 'mcastToShaped, 'mtoRanked, 'convert, 'mquotArray, 'mremArray, 'matan2Array]) diff --git a/src/Data/Array/Nested/Trace/TH.hs b/src/Data/Array/Nested/Trace/TH.hs new file mode 100644 index 0000000..4b388e3 --- /dev/null +++ b/src/Data/Array/Nested/Trace/TH.hs @@ -0,0 +1,98 @@ +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Nested.Trace.TH where + +import Control.Monad (zipWithM) +import Data.List (foldl', intersperse) +import Data.Maybe (isJust) +import Language.Haskell.TH hiding (cxt) + +import Debug.Trace qualified as Debug + +import Data.Array.Nested + + +splitFunTy :: Type -> ([TyVarBndr Specificity], Cxt, [Type], Type) +splitFunTy = \case + ArrowT `AppT` t1 `AppT` t2 -> + let (vars, cx, args, ret) = splitFunTy t2 + in (vars, cx, t1 : args, ret) + ForallT vs cx' t -> + let (vars, cx, args, ret) = splitFunTy t + in (vars ++ vs, cx ++ cx', args, ret) + t -> ([], [], [], t) + +data Arg = RRanked Type Arg + | RShaped Type Arg + | RMixed Type Arg + | RShowable Type + | ROther Type + deriving (Show) + +-- TODO: always returns Just +recognise :: Type -> Maybe Arg +recognise (ConT name `AppT` sht `AppT` ty) + | name == ''Ranked = RRanked sht <$> recognise ty + | name == ''Shaped = RShaped sht <$> recognise ty + | name == ''Mixed = RMixed sht <$> recognise ty +recognise ty@(ConT name `AppT` _) + | name `elem` [''IShR, ''IIxR, ''ShS, ''IIxS, ''SNat] = + Just (RShowable ty) +recognise _ = Nothing + +realise :: Arg -> Type +realise (RRanked sht ty) = ConT ''Ranked `AppT` sht `AppT` realise ty +realise (RShaped sht ty) = ConT ''Shaped `AppT` sht `AppT` realise ty +realise (RMixed sht ty) = ConT ''Mixed `AppT` sht `AppT` realise ty +realise (RShowable ty) = ty +realise (ROther ty) = ty + +mkShow :: Arg -> Cxt +mkShow (RRanked _ ty) = mkShowElt ty +mkShow (RShaped _ ty) = mkShowElt ty +mkShow (RMixed sht ty) = [ConT ''Show `AppT` realise (RMixed sht ty)] +mkShow (RShowable _) = [] +mkShow (ROther ty) = [ConT ''Show `AppT` ty] + +mkShowElt :: Arg -> Cxt +mkShowElt (RRanked _ ty) = mkShowElt ty +mkShowElt (RShaped _ ty) = mkShowElt ty +mkShowElt (RMixed sht ty) = [ConT ''Show `AppT` realise (RMixed sht ty), ConT ''Elt `AppT` realise (RMixed sht ty)] +mkShowElt (RShowable _ty) = [] -- [ConT ''Elt `AppT` ty] +mkShowElt (ROther ty) = [ConT ''Show `AppT` ty, ConT ''Elt `AppT` ty] + +convertType :: Type -> Q (Type, [Bool], Bool) +convertType typ = + let (tybndrs, cxt, args, ret) = splitFunTy typ + argrels = map recognise args + retrel = recognise ret + in return + (ForallT tybndrs + (cxt ++ [constr + | Just rel <- retrel : argrels + , constr <- mkShow rel]) + (foldr (\a b -> ArrowT `AppT` a `AppT` b) ret args) + ,map isJust argrels + ,isJust retrel) + +convertFun :: Name -> Q [Dec] +convertFun funname = do + defname <- newName (nameBase funname) + (convty, argarrs, retarr) <- reifyType funname >>= convertType + names <- zipWithM (\b i -> newName ((if b then "t" else "x") ++ show i)) argarrs [1::Int ..] + resname <- newName "res" + let tracenames = map fst (filter snd (zip (names ++ [resname]) (argarrs ++ [retarr]))) + let ex = LetE [ValD (VarP resname) + (NormalB (foldl' AppE (VarE funname) (map VarE names))) + []] + (VarE 'Debug.trace + `AppE` (VarE 'concat `AppE` ListE + ([LitE (StringL ("oxtrace: " ++ nameBase funname ++ " ["))] ++ + intersperse (LitE (StringL ", ")) + (map (\n -> VarE 'show `AppE` VarE n) tracenames) ++ + [LitE (StringL "]")])) + `AppE` VarE resname) + return + [SigD defname convty + ,FunD defname [Clause (map VarP names) (NormalB ex) []]] diff --git a/src/Data/Array/Nested/Types.hs b/src/Data/Array/Nested/Types.hs new file mode 100644 index 0000000..4444acd --- /dev/null +++ b/src/Data/Array/Nested/Types.hs @@ -0,0 +1,152 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilyDependencies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Nested.Types ( + -- * Reasoning helpers + subst1, subst2, + + -- * Reified evidence of a type class + Dict(..), + + -- * Type-level naturals + pattern SZ, pattern SS, + fromSNat', sameNat', + snatPlus, snatMinus, snatMul, + snatSucc, + + -- * Type-level lists + type (++), + Replicate, + lemReplicateSucc, + MapJust, + lemMapJustEmpty, lemMapJustCons, + Head, + Tail, + Init, + Last, + + -- * Unsafe + unsafeCoerceRefl, +) where + +import Data.Proxy +import Data.Type.Equality +import GHC.TypeLits +import GHC.TypeNats qualified as TN +import Unsafe.Coerce qualified + + +-- Reasoning helpers + +subst1 :: forall f a b. a :~: b -> f a :~: f b +subst1 Refl = Refl + +subst2 :: forall f c a b. a :~: b -> f a c :~: f b c +subst2 Refl = Refl + +-- | Evidence for the constraint @c a@. +data Dict c a where + Dict :: c a => Dict c a + +fromSNat' :: SNat n -> Int +fromSNat' = fromIntegral . fromSNat + +sameNat' :: SNat n -> SNat m -> Maybe (n :~: m) +sameNat' n@SNat m@SNat = sameNat n m + +pattern SZ :: () => (n ~ 0) => SNat n +pattern SZ <- ((\sn -> testEquality sn (SNat @0)) -> Just Refl) + where SZ = SNat + +pattern SS :: forall np1. () => forall n. (n + 1 ~ np1) => SNat n -> SNat np1 +pattern SS sn <- (snatPred -> Just (SNatPredResult sn Refl)) + where SS = snatSucc + +{-# COMPLETE SZ, SS #-} + +snatSucc :: SNat n -> SNat (n + 1) +snatSucc SNat = SNat + +data SNatPredResult np1 = forall n. SNatPredResult (SNat n) (n + 1 :~: np1) +snatPred :: forall np1. SNat np1 -> Maybe (SNatPredResult np1) +snatPred snp1 = + withKnownNat snp1 $ + case cmpNat (Proxy @1) (Proxy @np1) of + LTI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + EQI -> Just (SNatPredResult (SNat @(np1 - 1)) Refl) + GTI -> Nothing + +-- This should be a function in base +snatPlus :: SNat n -> SNat m -> SNat (n + m) +snatPlus n m = TN.withSomeSNat (TN.fromSNat n + TN.fromSNat m) Unsafe.Coerce.unsafeCoerce + +-- This should be a function in base +snatMinus :: SNat n -> SNat m -> SNat (n - m) +snatMinus n m = let res = TN.fromSNat n - TN.fromSNat m in res `seq` TN.withSomeSNat res Unsafe.Coerce.unsafeCoerce + +-- This should be a function in base +snatMul :: SNat n -> SNat m -> SNat (n * m) +snatMul n m = TN.withSomeSNat (TN.fromSNat n * TN.fromSNat m) Unsafe.Coerce.unsafeCoerce + + +-- | Type-level list append. +type family l1 ++ l2 where + '[] ++ l2 = l2 + (x : xs) ++ l2 = x : xs ++ l2 + +type family Replicate n a where + Replicate 0 a = '[] + Replicate n a = a : Replicate (n - 1) a + +lemReplicateSucc :: (a : Replicate n a) :~: Replicate (n + 1) a +lemReplicateSucc = unsafeCoerceRefl + +type family MapJust l = r | r -> l where + MapJust '[] = '[] + MapJust (x : xs) = Just x : MapJust xs + +lemMapJustEmpty :: MapJust sh :~: '[] -> sh :~: '[] +lemMapJustEmpty Refl = unsafeCoerceRefl + +lemMapJustCons :: MapJust sh :~: Just n : sh' -> sh :~: n : Tail sh +lemMapJustCons Refl = unsafeCoerceRefl + +type family Head l where + Head (x : _) = x + +type family Tail l where + Tail (_ : xs) = xs + +type family Init l where + Init (x : y : xs) = x : Init (y : xs) + Init '[x] = '[] + +type family Last l where + Last (x : y : xs) = Last (y : xs) + Last '[x] = x + + +-- | This is just @'Unsafe.Coerce.unsafeCoerce' 'Refl'@, but specialised to +-- only typecheck for actual type equalities. One cannot, e.g. accidentally +-- write this: +-- +-- @ +-- foo :: Proxy a -> Proxy b -> a :~: b +-- foo = unsafeCoerceRefl +-- @ +-- +-- which would have been permitted with normal 'Unsafe.Coerce.unsafeCoerce', +-- but would have resulted in interesting memory errors at runtime. +unsafeCoerceRefl :: a :~: b +unsafeCoerceRefl = Unsafe.Coerce.unsafeCoerce Refl diff --git a/src/Data/Array/Strided/Orthotope.hs b/src/Data/Array/Strided/Orthotope.hs new file mode 100644 index 0000000..5c38d14 --- /dev/null +++ b/src/Data/Array/Strided/Orthotope.hs @@ -0,0 +1,43 @@ +{-# LANGUAGE ImportQualifiedPost #-} +module Data.Array.Strided.Orthotope ( + module Data.Array.Strided.Orthotope, + module Data.Array.Strided.Arith, +) where + +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as RG +import Data.Array.Internal.RankedS qualified as RS + +import Data.Array.Strided qualified as AS +import Data.Array.Strided.Arith + +-- for liftVEltwise1 +import Data.Array.Strided.Arith.Internal (stridesDense) +import Data.Vector.Storable qualified as VS +import Foreign.Storable +import GHC.TypeLits + + +fromO :: RS.Array n a -> AS.Array n a +fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec + +toO :: AS.Array n a -> RS.Array n a +toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec)) + +liftO1 :: (AS.Array n a -> AS.Array n' b) + -> RS.Array n a -> RS.Array n' b +liftO1 f = toO . f . fromO + +liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) + -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c +liftO2 f x y = toO (f (fromO x) (fromO y)) + +liftVEltwise1 :: (Storable a, Storable b) + => SNat n + -> (VS.Vector a -> VS.Vector b) + -> RS.Array n a -> RS.Array n b +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) + | Just (blockOff, blockSz) <- stridesDense sh offset strides = + let vec' = f (VS.slice blockOff blockSz vec) + in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) + | otherwise = RS.fromVector sh (f (RS.toVector arr)) diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs new file mode 100644 index 0000000..bf47622 --- /dev/null +++ b/src/Data/Array/XArray.hs @@ -0,0 +1,348 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StrictData #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.XArray where + +import Control.DeepSeq (NFData) +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as ORG +import Data.Array.Internal.RankedS qualified as ORS +import Data.Array.Ranked qualified as ORB +import Data.Array.RankedS qualified as S +import Data.Coerce +import Data.Foldable (toList) +import Data.Kind +import Data.List.NonEmpty (NonEmpty) +import Data.Proxy +import Data.Type.Equality +import Data.Type.Ord +import Data.Vector.Storable qualified as VS +import Foreign.Storable (Storable) +import GHC.Generics (Generic) +import GHC.TypeLits + +import Data.Array.Nested.Lemmas +import Data.Array.Nested.Mixed.Shape +import Data.Array.Nested.Permutation +import Data.Array.Nested.Types +import Data.Array.Strided.Orthotope + + +type XArray :: [Maybe Nat] -> Type -> Type +newtype XArray sh a = XArray (S.Array (Rank sh) a) + deriving (Show, Eq, Ord, Generic) + +instance NFData (XArray sh a) + + +shape :: forall sh a. StaticShX sh -> XArray sh a -> IShX sh +shape = \ssh (XArray arr) -> go ssh (S.shapeL arr) + where + go :: StaticShX sh' -> [Int] -> IShX sh' + go ZKX [] = ZSX + go (n :!% ssh) (i : l) = fromSMayNat (\_ -> SUnknown i) SKnown n :$% go ssh l + go _ _ = error "Invalid shapeL" + +fromVector :: forall sh a. Storable a => IShX sh -> VS.Vector a -> XArray sh a +fromVector sh v + | Dict <- lemKnownNatRank sh + = XArray (S.fromVector (shxToList sh) v) + +toVector :: Storable a => XArray sh a -> VS.Vector a +toVector (XArray arr) = S.toVector arr + +-- | This allows observing the strides in the underlying orthotope array. This +-- can be useful for optimisation, but should be considered an implementation +-- detail: strides may change in new versions of this library without notice. +arrayStrides :: XArray sh a -> [Int] +arrayStrides (XArray (ORS.A (ORG.A _ (OI.T strides _ _)))) = strides + +scalar :: Storable a => a -> XArray '[] a +scalar = XArray . S.scalar + +-- | Will throw if the array does not have the casted-to shape. +cast :: forall sh1 sh2 sh' a. Rank sh1 ~ Rank sh2 + => StaticShX sh1 -> IShX sh2 -> StaticShX sh' + -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +cast ssh1 sh2 ssh' (XArray arr) + | Refl <- lemRankApp ssh1 ssh' + , Refl <- lemRankApp (ssxFromShX sh2) ssh' + = let arrsh :: IShX sh1 + (arrsh, _) = shxSplitApp (Proxy @sh') ssh1 (shape (ssxAppend ssh1 ssh') (XArray arr)) + in if shxToList arrsh == shxToList sh2 + then XArray arr + else error $ "Data.Array.Mixed.cast: Cannot cast (" ++ show arrsh ++ ") to (" ++ show sh2 ++ ")" + +unScalar :: Storable a => XArray '[] a -> a +unScalar (XArray a) = S.unScalar a + +replicate :: forall sh sh' a. Storable a => IShX sh -> StaticShX sh' -> XArray sh' a -> XArray (sh ++ sh') a +replicate sh ssh' (XArray arr) + | Dict <- lemKnownNatRankSSX ssh' + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh) ssh') + , Refl <- lemRankApp (ssxFromShX sh) ssh' + = XArray (S.stretch (shxToList sh ++ S.shapeL arr) $ + S.reshape (map (const 1) (shxToList sh) ++ S.shapeL arr) + arr) + +replicateScal :: forall sh a. Storable a => IShX sh -> a -> XArray sh a +replicateScal sh x + | Dict <- lemKnownNatRank sh + = XArray (S.constant (shxToList sh) x) + +generate :: Storable a => IShX sh -> (IIxX sh -> a) -> XArray sh a +generate sh f = fromVector sh $ VS.generate (shxSize sh) (f . ixxFromLinear sh) + +-- generateM :: (Monad m, Storable a) => IShX sh -> (IIxX sh -> m a) -> m (XArray sh a) +-- generateM sh f | Dict <- lemKnownNatRank sh = +-- XArray . S.fromVector (shxShapeL sh) +-- <$> VS.generateM (shxSize sh) (f . ixxFromLinear sh) + +indexPartial :: Storable a => XArray (sh ++ sh') a -> IIxX sh -> XArray sh' a +indexPartial (XArray arr) ZIX = XArray arr +indexPartial (XArray arr) (i :.% idx) = indexPartial (XArray (S.index arr i)) idx + +index :: forall sh a. Storable a => XArray sh a -> IIxX sh -> a +index xarr i + | Refl <- lemAppNil @sh + = let XArray arr' = indexPartial xarr i :: XArray '[] a + in S.unScalar arr' + +append :: forall n m sh a. Storable a + => StaticShX sh -> XArray (n : sh) a -> XArray (m : sh) a -> XArray (AddMaybe n m : sh) a +append ssh (XArray a) (XArray b) + | Dict <- lemKnownNatRankSSX ssh + = XArray (S.append a b) + +-- | All arrays must have the same shape, except possibly for the outermost +-- dimension. +concat :: Storable a + => StaticShX sh -> NonEmpty (XArray (Nothing : sh) a) -> XArray (Nothing : sh) a +concat ssh l + | Dict <- lemKnownNatRankSSX ssh + = XArray (S.concatOuter (coerce (toList l))) + +-- | If the prefix of the shape of the input array (@sh@) is empty (i.e. +-- contains a zero), then there is no way to deduce the full shape of the output +-- array (more precisely, the @sh2@ part): that could only come from calling +-- @f@, and there are no subarrays to call @f@ on. @orthotope@ errors out in +-- this case; we choose to fill the shape with zeros wherever we cannot deduce +-- what it should be. +-- +-- For example, if: +-- +-- @ +-- arr :: XArray '[Just 3, Just 0, Just 4, Just 2, Nothing] Int -- of shape [3, 0, 4, 2, 21] +-- f :: XArray '[Just 2, Nothing] Int -> XArray '[Just 5, Nothing, Just 17] Float +-- @ +-- +-- then: +-- +-- @ +-- rerank _ _ _ f arr :: XArray '[Just 3, Just 0, Just 4, Just 5, Nothing, Just 17] Float +-- @ +-- +-- and this result will have shape @[3, 0, 4, 5, 0, 17]@. Note the second @0@ +-- in this shape: we don't know if @f@ intended to return an array with shape 0 +-- here (it probably didn't), but there is no better number to put here absent +-- a subarray of the input to pass to @f@. +-- +-- In this particular case the fact that @sh@ is empty was evident from the +-- type-level information, but the same situation occurs when @sh@ consists of +-- @Nothing@s, and some of those happen to be zero at runtime. +rerank :: forall sh sh1 sh2 a b. + (Storable a, Storable b) + => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b +rerank ssh ssh1 ssh2 f xarr@(XArray arr) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr) + in if 0 `elem` shxToList sh + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) + else case () of + () | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + -> XArray (S.rerank @(Rank sh) @(Rank sh1) @(Rank sh2) + (\a -> let XArray r = f (XArray a) in r) + arr) + +rerankTop :: forall sh1 sh2 sh a b. + (Storable a, Storable b) + => StaticShX sh1 -> StaticShX sh2 -> StaticShX sh + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh1 ++ sh) a -> XArray (sh2 ++ sh) b +rerankTop ssh1 ssh2 ssh f = transpose2 ssh ssh2 . rerank ssh ssh1 ssh2 f . transpose2 ssh1 ssh + +-- | The caveat about empty arrays at @rerank@ applies here too. +rerank2 :: forall sh sh1 sh2 a b c. + (Storable a, Storable b, Storable c) + => StaticShX sh -> StaticShX sh1 -> StaticShX sh2 + -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c +rerank2 ssh ssh1 ssh2 f xarr1@(XArray arr1) (XArray arr2) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) + = let (sh, _) = shxSplitApp (Proxy @sh1) ssh (shape (ssxAppend ssh ssh1) xarr1) + in if 0 `elem` shxToList sh + then XArray (S.fromList (shxToList (shxAppend sh (shxCompleteZeros ssh2))) []) + else case () of + () | Dict <- lemKnownNatRankSSX ssh + , Dict <- lemKnownNatRankSSX ssh2 + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + -> XArray (S.rerank2 @(Rank sh) @(Rank sh1) @(Rank sh2) + (\a b -> let XArray r = f (XArray a) (XArray b) in r) + arr1 arr2) + +-- | The list argument gives indices into the original dimension list. +transpose :: forall is sh a. (IsPermutation is, Rank is <= Rank sh) + => StaticShX sh + -> Perm is + -> XArray sh a + -> XArray (PermutePrefix is sh) a +transpose ssh perm (XArray arr) + | Dict <- lemKnownNatRankSSX ssh + , Refl <- lemRankApp (ssxPermute perm (ssxTakeLen perm ssh)) (ssxDropLen perm ssh) + , Refl <- lemRankPermute (Proxy @(TakeLen is sh)) perm + , Refl <- lemRankDropLen ssh perm + = XArray (S.transpose (permToList' perm) arr) + +-- | The list argument gives indices into the original dimension list. +-- +-- The permutation (the list) must have length <= @n@. If it is longer, this +-- function throws. +transposeUntyped :: forall n sh a. + SNat n -> StaticShX sh -> [Int] + -> XArray (Replicate n Nothing ++ sh) a -> XArray (Replicate n Nothing ++ sh) a +transposeUntyped sn ssh perm (XArray arr) + | length perm <= fromSNat' sn + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxReplicate sn) ssh) + = XArray (S.transpose perm arr) + | otherwise + = error "Data.Array.Mixed.transposeUntyped: Permutation longer than length of unshaped prefix of shape type" + +transpose2 :: forall sh1 sh2 a. + StaticShX sh1 -> StaticShX sh2 + -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a +transpose2 ssh1 ssh2 (XArray arr) + | Refl <- lemRankApp ssh1 ssh2 + , Refl <- lemRankApp ssh2 ssh1 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) + , Refl <- lemRankAppComm ssh1 ssh2 + , let n1 = ssxLength ssh1 + = XArray (S.transpose (ssxIotaFrom ssh2 n1 ++ ssxIotaFrom ssh1 0) arr) + +sumFull :: (Storable a, NumElt a) => StaticShX sh -> XArray sh a -> a +sumFull _ (XArray arr) = + S.unScalar $ + liftO1 (numEltSum1Inner (SNat @0)) $ + S.fromVector [product (S.shapeL arr)] $ + S.toVector arr + +sumInner :: forall sh sh' a. (Storable a, NumElt a) + => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh a +sumInner ssh ssh' arr + | Refl <- lemAppNil @sh + = let (_, sh') = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + sh'F = shxFlatten sh' :$% ZSX + ssh'F = ssxFromShX sh'F + + go :: XArray (sh ++ '[Flatten sh']) a -> XArray sh a + go (XArray arr') + | Refl <- lemRankApp ssh ssh'F + , let sn = listxRank (let StaticShX l = ssh in l) + = XArray (liftO1 (numEltSum1Inner sn) arr') + + in go $ + transpose2 ssh'F ssh $ + reshapePartial ssh' ssh sh'F $ + transpose2 ssh ssh' $ + arr + +sumOuter :: forall sh sh' a. (Storable a, NumElt a) + => StaticShX sh -> StaticShX sh' -> XArray (sh ++ sh') a -> XArray sh' a +sumOuter ssh ssh' arr + | Refl <- lemAppNil @sh + = let (sh, _) = shxSplitApp (Proxy @sh') ssh (shape (ssxAppend ssh ssh') arr) + shF = shxFlatten sh :$% ZSX + in sumInner ssh' (ssxFromShX shF) $ + transpose2 (ssxFromShX shF) ssh' $ + reshapePartial ssh ssh' shF $ + arr + +fromListOuter :: forall n sh a. Storable a + => StaticShX (n : sh) -> [XArray sh a] -> XArray (n : sh) a +fromListOuter ssh l + | Dict <- lemKnownNatRankSSX ssh + = case ssh of + SKnown m :!% _ | fromSNat' m /= length l -> + error $ "Data.Array.Mixed.fromListOuter: length of list (" ++ show (length l) ++ ")" ++ + "does not match the type (" ++ show (fromSNat' m) ++ ")" + _ -> XArray (S.ravel (ORB.fromList [length l] (coerce @[XArray sh a] @[S.Array (Rank sh) a] l))) + +toListOuter :: Storable a => XArray (n : sh) a -> [XArray sh a] +toListOuter (XArray arr) = + case S.shapeL arr of + 0 : _ -> [] + _ -> coerce (ORB.toList (S.unravel arr)) + +fromList1 :: Storable a => StaticShX '[n] -> [a] -> XArray '[n] a +fromList1 ssh l = + let n = length l + in case ssh of + SKnown m :!% _ | fromSNat' m /= n -> + error $ "Data.Array.Mixed.fromList1: length of list (" ++ show n ++ ")" ++ + "does not match the type (" ++ show (fromSNat' m) ++ ")" + _ -> XArray (S.fromVector [n] (VS.fromListN n l)) + +toList1 :: Storable a => XArray '[n] a -> [a] +toList1 (XArray arr) = S.toList arr + +-- | Throws if the given shape is not, in fact, empty. +empty :: forall sh a. Storable a => IShX sh -> XArray sh a +empty sh + | Dict <- lemKnownNatRank sh + , shxSize sh == 0 + = XArray (S.fromVector (shxToList sh) VS.empty) + | otherwise + = error $ "Data.Array.Mixed.empty: shape was not empty: " ++ show sh + +slice :: SNat i -> SNat n -> XArray (Just (i + n + k) : sh) a -> XArray (Just n : sh) a +slice i n (XArray arr) = XArray (S.slice [(fromSNat' i, fromSNat' n)] arr) + +sliceU :: Int -> Int -> XArray (Nothing : sh) a -> XArray (Nothing : sh) a +sliceU i n (XArray arr) = XArray (S.slice [(i, n)] arr) + +rev1 :: XArray (n : sh) a -> XArray (n : sh) a +rev1 (XArray arr) = XArray (S.rev [0] arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshape :: forall sh1 sh2 a. Storable a => StaticShX sh1 -> IShX sh2 -> XArray sh1 a -> XArray sh2 a +reshape ssh1 sh2 (XArray arr) + | Dict <- lemKnownNatRankSSX ssh1 + , Dict <- lemKnownNatRank sh2 + = XArray (S.reshape (shxToList sh2) arr) + +-- | Throws if the given array and the target shape do not have the same number of elements. +reshapePartial :: forall sh1 sh2 sh' a. Storable a => StaticShX sh1 -> StaticShX sh' -> IShX sh2 -> XArray (sh1 ++ sh') a -> XArray (sh2 ++ sh') a +reshapePartial ssh1 ssh' sh2 (XArray arr) + | Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh') + , Dict <- lemKnownNatRankSSX (ssxAppend (ssxFromShX sh2) ssh') + = XArray (S.reshape (shxToList sh2 ++ drop (ssxLength ssh1) (S.shapeL arr)) arr) + +-- this was benchmarked to be (slightly) faster than S.iota, S.generate and S.fromVector(VS.enumFromTo). +iota :: (Enum a, Storable a) => SNat n -> XArray '[Just n] a +iota sn = XArray (S.fromVector [fromSNat' sn] (VS.fromListN (fromSNat' sn) [toEnum 0 .. toEnum (fromSNat' sn - 1)])) diff --git a/src/Data/Bag.hs b/src/Data/Bag.hs new file mode 100644 index 0000000..b424857 --- /dev/null +++ b/src/Data/Bag.hs @@ -0,0 +1,18 @@ +{-# LANGUAGE DeriveTraversable #-} +module Data.Bag where + + +-- | An ordered sequence that can be folded over. +data Bag a = BZero | BOne a | BTwo (Bag a) (Bag a) | BList [Bag a] + deriving (Show, Functor, Foldable, Traversable) + +-- Really only here for 'pure' +instance Applicative Bag where + pure = BOne + BZero <*> _ = BZero + BOne f <*> t = f <$> t + BTwo f1 f2 <*> t = BTwo (f1 <*> t) (f2 <*> t) + BList fs <*> t = BList [f <*> t | f <- fs] + +instance Semigroup (Bag a) where (<>) = BTwo +instance Monoid (Bag a) where mempty = BZero diff --git a/src/Data/INat.hs b/src/Data/INat.hs deleted file mode 100644 index af8f18b..0000000 --- a/src/Data/INat.hs +++ /dev/null @@ -1,121 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Data.INat where - -import Data.Proxy -import Data.Type.Equality ((:~:) (Refl)) -import Numeric.Natural -import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) - --- | Evidence for the constraint @c a@. -data Dict c a where - Dict :: c a => Dict c a - --- | An inductive peano natural number. Intended to be used at the type level. -data INat = Z | S INat - deriving (Show) - --- | Singleton for a 'INat'. -data SINat n where - SZ :: SINat Z - SS :: SINat n -> SINat (S n) -deriving instance Show (SINat n) - --- | A singleton 'SINat' corresponding to @n@. -class KnownINat n where inatSing :: SINat n -instance KnownINat Z where inatSing = SZ -instance KnownINat n => KnownINat (S n) where inatSing = SS inatSing - --- | Explicitly bidirectional pattern synonym that converts between a singleton --- 'SINat' and evidence of a 'KnownINat' constraint. Analogous to 'GHC.SNat'. -pattern SINat' :: () => KnownINat n => SINat n -pattern SINat' <- (snatKnown -> Dict) - where SINat' = inatSing - --- | A 'KnownINat' dictionary is just a singleton natural, so we can create --- evidence of 'KnownINat' given an 'SINat'. -snatKnown :: SINat n -> Dict KnownINat n -snatKnown SZ = Dict -snatKnown (SS n) | Dict <- snatKnown n = Dict - --- | Convert a 'INat' to a normal number. -fromINat :: INat -> Natural -fromINat Z = 0 -fromINat (S n) = 1 + fromINat n - --- | Convert an 'SINat' to a normal number. -fromSINat :: SINat n -> Natural -fromSINat SZ = 0 -fromSINat (SS n) = 1 + fromSINat n - --- | The value of a known inductive natural as a value-level integer. -inatVal :: forall n. KnownINat n => Proxy n -> Natural -inatVal _ = fromSINat (inatSing @n) - --- | Add two 'INat's -type family n +! m where - Z +! m = m - S n +! m = S (n +! m) - --- | Convert a 'INat' to a "GHC.TypeLits" 'G.Nat'. -type family FromINat n where - FromINat Z = 0 - FromINat (S n) = 1 + FromINat n - --- | Convert a "GHC.TypeLits" 'G.Nat' to a 'INat'. -type family ToINat (n :: Nat) where - ToINat 0 = Z - ToINat n = S (ToINat (n - 1)) - -lemInjectiveFromINat :: n :~: ToINat (FromINat n) -lemInjectiveFromINat = unsafeCoerce Refl - -lemSuccFromINat :: Proxy n -> 1 + FromINat n :~: FromINat (S n) -lemSuccFromINat _ = unsafeCoerce Refl - -lemAddFromINat :: Proxy m -> Proxy n - -> FromINat m + FromINat n :~: FromINat (m +! n) -lemAddFromINat _ = unsafeCoerce Refl - -lemInjectiveToINat :: n :~: FromINat (ToINat n) -lemInjectiveToINat = unsafeCoerce Refl - -lemSuccToINat :: Proxy n -> ToINat (1 + n) :~: S (ToINat n) -lemSuccToINat _ = unsafeCoerce Refl - -lemAddToINat :: Proxy m -> Proxy n -> ToINat (m + n) :~: ToINat m +! ToINat n -lemAddToINat _ _ = unsafeCoerce Refl - --- | If an inductive 'INat' is known, then the corresponding "GHC.TypeLits" --- 'G.Nat' is also known. -knownNatFromINat :: KnownINat n => Proxy n -> Dict KnownNat (FromINat n) -knownNatFromINat (Proxy @n) = go (SINat' @n) - where - go :: SINat m -> Dict KnownNat (FromINat m) - go SZ = Dict - go (SS n) | Dict <- go n = Dict - --- * Some type-level inductive naturals - -type I0 = Z -type I1 = S I0 -type I2 = S I1 -type I3 = S I2 -type I4 = S I3 -type I5 = S I4 -type I6 = S I5 -type I7 = S I6 -type I8 = S I7 -type I9 = S I8 |