From 92902c4f66db111b439f3b7eba9de50ad7c73f7b Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Wed, 3 Apr 2024 12:37:35 +0200 Subject: Reorganise, documentation --- ox-arrays.cabal | 9 +- src/Array.hs | 312 ------------------- src/Data/Array/Mixed.hs | 314 +++++++++++++++++++ src/Data/Array/Nested.hs | 40 +++ src/Data/Array/Nested/Internal.hs | 623 ++++++++++++++++++++++++++++++++++++++ src/Data/Nat.hs | 70 +++++ src/Fancy.hs | 598 ------------------------------------ src/Nats.hs | 58 ---- 8 files changed, 1052 insertions(+), 972 deletions(-) delete mode 100644 src/Array.hs create mode 100644 src/Data/Array/Mixed.hs create mode 100644 src/Data/Array/Nested.hs create mode 100644 src/Data/Array/Nested/Internal.hs create mode 100644 src/Data/Nat.hs delete mode 100644 src/Fancy.hs delete mode 100644 src/Nats.hs diff --git a/ox-arrays.cabal b/ox-arrays.cabal index 2930ba0..5bdff7d 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -7,13 +7,14 @@ build-type: Simple library exposed-modules: - Array - Fancy - Nats + Data.Array.Mixed + Data.Array.Nested + Data.Array.Nested.Internal + Data.Nat build-depends: base >=4.18, ghc-typelits-knownnat, - ghc-typelits-natnormalise, + -- ghc-typelits-natnormalise, orthotope, vector hs-source-dirs: src diff --git a/src/Array.hs b/src/Array.hs deleted file mode 100644 index cbf04fc..0000000 --- a/src/Array.hs +++ /dev/null @@ -1,312 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -module Array where - -import qualified Data.Array.RankedU as U -import Data.Kind -import Data.Proxy -import Data.Type.Equality -import qualified Data.Vector.Unboxed as VU -import qualified GHC.TypeLits as GHC -import Unsafe.Coerce (unsafeCoerce) - -import Nats - - -type family l1 ++ l2 where - '[] ++ l2 = l2 - (x : xs) ++ l2 = x : xs ++ l2 - -lemAppNil :: l ++ '[] :~: l -lemAppNil = unsafeCoerce Refl - -lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) -lemAppAssoc _ _ _ = unsafeCoerce Refl - - -type IxX :: [Maybe Nat] -> Type -data IxX sh where - IZX :: IxX '[] - (::@) :: Int -> IxX sh -> IxX (Just n : sh) - (::?) :: Int -> IxX sh -> IxX (Nothing : sh) -deriving instance Show (IxX sh) - -type StaticShapeX :: [Maybe Nat] -> Type -data StaticShapeX sh where - SZX :: StaticShapeX '[] - (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) - (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) -deriving instance Show (StaticShapeX sh) - -type KnownShapeX :: [Maybe Nat] -> Constraint -class KnownShapeX sh where - knownShapeX :: StaticShapeX sh -instance KnownShapeX '[] where - knownShapeX = SZX -instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where - knownShapeX = knownNat :$@ knownShapeX -instance KnownShapeX sh => KnownShapeX (Nothing : sh) where - knownShapeX = () :$? knownShapeX - -type family Rank sh where - Rank '[] = Z - Rank (_ : sh) = S (Rank sh) - -type XArray :: [Maybe Nat] -> Type -> Type -data XArray sh a = XArray (U.Array (GNat (Rank sh)) a) - -zeroIdx :: StaticShapeX sh -> IxX sh -zeroIdx SZX = IZX -zeroIdx (_ :$@ ssh) = 0 ::@ zeroIdx ssh -zeroIdx (_ :$? ssh) = 0 ::? zeroIdx ssh - -zeroIdx' :: IxX sh -> IxX sh -zeroIdx' IZX = IZX -zeroIdx' (_ ::@ sh) = 0 ::@ zeroIdx' sh -zeroIdx' (_ ::? sh) = 0 ::? zeroIdx' sh - -ixAppend :: IxX sh -> IxX sh' -> IxX (sh ++ sh') -ixAppend IZX idx' = idx' -ixAppend (i ::@ idx) idx' = i ::@ ixAppend idx idx' -ixAppend (i ::? idx) idx' = i ::? ixAppend idx idx' - -ixDrop :: IxX (sh ++ sh') -> IxX sh -> IxX sh' -ixDrop sh IZX = sh -ixDrop (_ ::@ sh) (_ ::@ idx) = ixDrop sh idx -ixDrop (_ ::? sh) (_ ::? idx) = ixDrop sh idx - -ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh') -ssxAppend SZX idx' = idx' -ssxAppend (n :$@ idx) idx' = n :$@ ssxAppend idx idx' -ssxAppend (() :$? idx) idx' = () :$? ssxAppend idx idx' - -shapeSize :: IxX sh -> Int -shapeSize IZX = 1 -shapeSize (n ::@ sh) = n * shapeSize sh -shapeSize (n ::? sh) = n * shapeSize sh - -fromLinearIdx :: IxX sh -> Int -> IxX sh -fromLinearIdx = \sh i -> case go sh i of - (idx, 0) -> idx - _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ - " in array of shape " ++ show sh ++ ")" - where - -- returns (index in subarray, remaining index in enclosing array) - go :: IxX sh -> Int -> (IxX sh, Int) - go IZX i = (IZX, i) - go (n ::@ sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` n - in (locali ::@ idx, upi) - go (n ::? sh) i = - let (idx, i') = go sh i - (upi, locali) = i' `quotRem` n - in (locali ::? idx, upi) - -toLinearIdx :: IxX sh -> IxX sh -> Int -toLinearIdx = \sh i -> fst (go sh i) - where - -- returns (index in subarray, size of subarray) - go :: IxX sh -> IxX sh -> (Int, Int) - go IZX IZX = (0, 1) - go (n ::@ sh) (i ::@ ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, n * sz) - go (n ::? sh) (i ::? ix) = - let (lidx, sz) = go sh ix - in (sz * i + lidx, n * sz) - -enumShape :: IxX sh -> [IxX sh] -enumShape = \sh -> go 0 sh id [] - where - go :: Int -> IxX sh -> (IxX sh -> a) -> [a] -> [a] - go _ IZX _ = id - go i (n ::@ sh) f - | i < n = go (i + 1) (n ::@ sh) f . go 0 sh (f . (i ::@)) - | otherwise = id - go i (n ::? sh) f - | i < n = go (i + 1) (n ::? sh) f . go 0 sh (f . (i ::?)) - | otherwise = id - -shapeLshape :: IxX sh -> U.ShapeL -shapeLshape IZX = [] -shapeLshape (n ::@ sh) = n : shapeLshape sh -shapeLshape (n ::? sh) = n : shapeLshape sh - -ssxLength :: StaticShapeX sh -> Int -ssxLength SZX = 0 -ssxLength (_ :$@ ssh) = 1 + ssxLength ssh -ssxLength (_ :$? ssh) = 1 + ssxLength ssh - -ssxIotaFrom :: Int -> StaticShapeX sh -> [Int] -ssxIotaFrom _ SZX = [] -ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh -ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh - -lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2 - -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank sh1) GHC.+ GNat (Rank sh2) -lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this - -lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2 - -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank (sh2 ++ sh1)) -lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this - -lemKnownNatRank :: IxX sh -> Dict KnownNat (Rank sh) -lemKnownNatRank IZX = Dict -lemKnownNatRank (_ ::@ sh) | Dict <- lemKnownNatRank sh = Dict -lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict - -lemKnownNatRankSSX :: StaticShapeX sh -> Dict KnownNat (Rank sh) -lemKnownNatRankSSX SZX = Dict -lemKnownNatRankSSX (_ :$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict -lemKnownNatRankSSX (_ :$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict - -lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh -lemKnownShapeX SZX = Dict -lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict -lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict - -lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2) -lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh' -lemAppKnownShapeX (n :$@ ssh) ssh' - | Dict <- lemAppKnownShapeX ssh ssh' - , Dict <- snatKnown n - = Dict -lemAppKnownShapeX (() :$? ssh) ssh' - | Dict <- lemAppKnownShapeX ssh ssh' - = Dict - -shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh -shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr) - where - go :: StaticShapeX sh' -> [Int] -> IxX sh' - go SZX [] = IZX - go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l - go (() :$? ssh) (n : l) = n ::? go ssh l - go _ _ = error "Invalid shapeL" - -fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a -fromVector sh v - | Dict <- lemKnownNatRank sh - , Dict <- gknownNat (Proxy @(Rank sh)) - = XArray (U.fromVector (shapeLshape sh) v) - -toVector :: U.Unbox a => XArray sh a -> VU.Vector a -toVector (XArray arr) = U.toVector arr - -scalar :: U.Unbox a => a -> XArray '[] a -scalar = XArray . U.scalar - -unScalar :: U.Unbox a => XArray '[] a -> a -unScalar (XArray a) = U.unScalar a - -generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a -generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh) - --- generateM :: (Monad m, U.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a) --- generateM sh f | Dict <- lemKnownNatRank sh = --- XArray . U.fromVector (shapeLshape sh) --- <$> VU.generateM (shapeSize sh) (f . fromLinearIdx sh) - -indexPartial :: U.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a -indexPartial (XArray arr) IZX = XArray arr -indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (U.index arr i)) idx -indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (U.index arr i)) idx - -index :: forall sh a. U.Unbox a => XArray sh a -> IxX sh -> a -index xarr i - | Refl <- lemAppNil @sh - = let XArray arr' = indexPartial xarr i :: XArray '[] a - in U.unScalar arr' - -append :: forall sh a. (KnownShapeX sh, U.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a -append (XArray a) (XArray b) - | Dict <- lemKnownNatRankSSX (knownShapeX @sh) - , Dict <- gknownNat (Proxy @(Rank sh)) - = XArray (U.append a b) - -rerank :: forall sh sh1 sh2 a b. - (U.Unbox a, U.Unbox b) - => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 - -> (XArray sh1 a -> XArray sh2 b) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b -rerank ssh ssh1 ssh2 f (XArray arr) - | Dict <- lemKnownNatRankSSX ssh - , Dict <- gknownNat (Proxy @(Rank sh)) - , Dict <- lemKnownNatRankSSX ssh2 - , Dict <- gknownNat (Proxy @(Rank sh2)) - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) - (\a -> unXArray (f (XArray a))) - arr) - where - unXArray (XArray a) = a - -rerank2 :: forall sh sh1 sh2 a b c. - (U.Unbox a, U.Unbox b, U.Unbox c) - => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 - -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) - -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c -rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) - | Dict <- lemKnownNatRankSSX ssh - , Dict <- gknownNat (Proxy @(Rank sh)) - , Dict <- lemKnownNatRankSSX ssh2 - , Dict <- gknownNat (Proxy @(Rank sh2)) - , Refl <- lemRankApp ssh ssh1 - , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough - = XArray (U.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) - (\a b -> unXArray (f (XArray a) (XArray b))) - arr1 arr2) - where - unXArray (XArray a) = a - --- | The list argument gives indices into the original dimension list. -transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a -transpose perm (XArray arr) - | Dict <- lemKnownNatRankSSX (knownShapeX @sh) - , Dict <- gknownNat (Proxy @(Rank sh)) - = XArray (U.transpose perm arr) - -transpose2 :: forall sh1 sh2 a. - StaticShapeX sh1 -> StaticShapeX sh2 - -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a -transpose2 ssh1 ssh2 (XArray arr) - | Refl <- lemRankApp ssh1 ssh2 - , Refl <- lemRankApp ssh2 ssh1 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) - , Dict <- gknownNat (Proxy @(Rank (sh1 ++ sh2))) - , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) - , Dict <- gknownNat (Proxy @(Rank (sh2 ++ sh1))) - , Refl <- lemRankAppComm ssh1 ssh2 - , let n1 = ssxLength ssh1 - = XArray (U.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) - -sumFull :: (U.Unbox a, Num a) => XArray sh a -> a -sumFull (XArray arr) = U.sumA arr - -sumInner :: forall sh sh' a. (U.Unbox a, Num a) - => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a -sumInner ssh ssh' - | Refl <- lemAppNil @sh - = rerank ssh ssh' SZX (scalar . sumFull) - -sumOuter :: forall sh sh' a. (U.Unbox a, Num a) - => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a -sumOuter ssh ssh' - | Refl <- lemAppNil @sh - = sumInner ssh' ssh . transpose2 ssh ssh' diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs new file mode 100644 index 0000000..e1e2d5a --- /dev/null +++ b/src/Data/Array/Mixed.hs @@ -0,0 +1,314 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Array.Mixed where + +import qualified Data.Array.RankedU as U +import Data.Kind +import Data.Proxy +import Data.Type.Equality +import qualified Data.Vector.Unboxed as VU +import qualified GHC.TypeLits as GHC +import Unsafe.Coerce (unsafeCoerce) + +import Data.Nat + + +-- | Type-level list append. +type family l1 ++ l2 where + '[] ++ l2 = l2 + (x : xs) ++ l2 = x : xs ++ l2 + +lemAppNil :: l ++ '[] :~: l +lemAppNil = unsafeCoerce Refl + +lemAppAssoc :: Proxy a -> Proxy b -> Proxy c -> (a ++ b) ++ c :~: a ++ (b ++ c) +lemAppAssoc _ _ _ = unsafeCoerce Refl + + +type IxX :: [Maybe Nat] -> Type +data IxX sh where + IZX :: IxX '[] + (::@) :: Int -> IxX sh -> IxX (Just n : sh) + (::?) :: Int -> IxX sh -> IxX (Nothing : sh) +deriving instance Show (IxX sh) + +-- | The part of a shape that is statically known. +type StaticShapeX :: [Maybe Nat] -> Type +data StaticShapeX sh where + SZX :: StaticShapeX '[] + (:$@) :: SNat n -> StaticShapeX sh -> StaticShapeX (Just n : sh) + (:$?) :: () -> StaticShapeX sh -> StaticShapeX (Nothing : sh) +deriving instance Show (StaticShapeX sh) + +-- | Evidence for the static part of a shape. +type KnownShapeX :: [Maybe Nat] -> Constraint +class KnownShapeX sh where + knownShapeX :: StaticShapeX sh +instance KnownShapeX '[] where + knownShapeX = SZX +instance (KnownNat n, KnownShapeX sh) => KnownShapeX (Just n : sh) where + knownShapeX = knownNat :$@ knownShapeX +instance KnownShapeX sh => KnownShapeX (Nothing : sh) where + knownShapeX = () :$? knownShapeX + +type family Rank sh where + Rank '[] = Z + Rank (_ : sh) = S (Rank sh) + +type XArray :: [Maybe Nat] -> Type -> Type +data XArray sh a = XArray (U.Array (GNat (Rank sh)) a) + +zeroIdx :: StaticShapeX sh -> IxX sh +zeroIdx SZX = IZX +zeroIdx (_ :$@ ssh) = 0 ::@ zeroIdx ssh +zeroIdx (_ :$? ssh) = 0 ::? zeroIdx ssh + +zeroIdx' :: IxX sh -> IxX sh +zeroIdx' IZX = IZX +zeroIdx' (_ ::@ sh) = 0 ::@ zeroIdx' sh +zeroIdx' (_ ::? sh) = 0 ::? zeroIdx' sh + +ixAppend :: IxX sh -> IxX sh' -> IxX (sh ++ sh') +ixAppend IZX idx' = idx' +ixAppend (i ::@ idx) idx' = i ::@ ixAppend idx idx' +ixAppend (i ::? idx) idx' = i ::? ixAppend idx idx' + +ixDrop :: IxX (sh ++ sh') -> IxX sh -> IxX sh' +ixDrop sh IZX = sh +ixDrop (_ ::@ sh) (_ ::@ idx) = ixDrop sh idx +ixDrop (_ ::? sh) (_ ::? idx) = ixDrop sh idx + +ssxAppend :: StaticShapeX sh -> StaticShapeX sh' -> StaticShapeX (sh ++ sh') +ssxAppend SZX idx' = idx' +ssxAppend (n :$@ idx) idx' = n :$@ ssxAppend idx idx' +ssxAppend (() :$? idx) idx' = () :$? ssxAppend idx idx' + +shapeSize :: IxX sh -> Int +shapeSize IZX = 1 +shapeSize (n ::@ sh) = n * shapeSize sh +shapeSize (n ::? sh) = n * shapeSize sh + +fromLinearIdx :: IxX sh -> Int -> IxX sh +fromLinearIdx = \sh i -> case go sh i of + (idx, 0) -> idx + _ -> error $ "fromLinearIdx: out of range (" ++ show i ++ + " in array of shape " ++ show sh ++ ")" + where + -- returns (index in subarray, remaining index in enclosing array) + go :: IxX sh -> Int -> (IxX sh, Int) + go IZX i = (IZX, i) + go (n ::@ sh) i = + let (idx, i') = go sh i + (upi, locali) = i' `quotRem` n + in (locali ::@ idx, upi) + go (n ::? sh) i = + let (idx, i') = go sh i + (upi, locali) = i' `quotRem` n + in (locali ::? idx, upi) + +toLinearIdx :: IxX sh -> IxX sh -> Int +toLinearIdx = \sh i -> fst (go sh i) + where + -- returns (index in subarray, size of subarray) + go :: IxX sh -> IxX sh -> (Int, Int) + go IZX IZX = (0, 1) + go (n ::@ sh) (i ::@ ix) = + let (lidx, sz) = go sh ix + in (sz * i + lidx, n * sz) + go (n ::? sh) (i ::? ix) = + let (lidx, sz) = go sh ix + in (sz * i + lidx, n * sz) + +enumShape :: IxX sh -> [IxX sh] +enumShape = \sh -> go 0 sh id [] + where + go :: Int -> IxX sh -> (IxX sh -> a) -> [a] -> [a] + go _ IZX _ = id + go i (n ::@ sh) f + | i < n = go (i + 1) (n ::@ sh) f . go 0 sh (f . (i ::@)) + | otherwise = id + go i (n ::? sh) f + | i < n = go (i + 1) (n ::? sh) f . go 0 sh (f . (i ::?)) + | otherwise = id + +shapeLshape :: IxX sh -> U.ShapeL +shapeLshape IZX = [] +shapeLshape (n ::@ sh) = n : shapeLshape sh +shapeLshape (n ::? sh) = n : shapeLshape sh + +ssxLength :: StaticShapeX sh -> Int +ssxLength SZX = 0 +ssxLength (_ :$@ ssh) = 1 + ssxLength ssh +ssxLength (_ :$? ssh) = 1 + ssxLength ssh + +ssxIotaFrom :: Int -> StaticShapeX sh -> [Int] +ssxIotaFrom _ SZX = [] +ssxIotaFrom i (_ :$@ ssh) = i : ssxIotaFrom (i+1) ssh +ssxIotaFrom i (_ :$? ssh) = i : ssxIotaFrom (i+1) ssh + +lemRankApp :: StaticShapeX sh1 -> StaticShapeX sh2 + -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank sh1) GHC.+ GNat (Rank sh2) +lemRankApp _ _ = unsafeCoerce Refl -- TODO improve this + +lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2 + -> GNat (Rank (sh1 ++ sh2)) :~: GNat (Rank (sh2 ++ sh1)) +lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this + +lemKnownNatRank :: IxX sh -> Dict KnownNat (Rank sh) +lemKnownNatRank IZX = Dict +lemKnownNatRank (_ ::@ sh) | Dict <- lemKnownNatRank sh = Dict +lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict + +lemKnownNatRankSSX :: StaticShapeX sh -> Dict KnownNat (Rank sh) +lemKnownNatRankSSX SZX = Dict +lemKnownNatRankSSX (_ :$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict +lemKnownNatRankSSX (_ :$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict + +lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh +lemKnownShapeX SZX = Dict +lemKnownShapeX (n :$@ ssh) | Dict <- lemKnownShapeX ssh, Dict <- snatKnown n = Dict +lemKnownShapeX (() :$? ssh) | Dict <- lemKnownShapeX ssh = Dict + +lemAppKnownShapeX :: StaticShapeX sh1 -> StaticShapeX sh2 -> Dict KnownShapeX (sh1 ++ sh2) +lemAppKnownShapeX SZX ssh' = lemKnownShapeX ssh' +lemAppKnownShapeX (n :$@ ssh) ssh' + | Dict <- lemAppKnownShapeX ssh ssh' + , Dict <- snatKnown n + = Dict +lemAppKnownShapeX (() :$? ssh) ssh' + | Dict <- lemAppKnownShapeX ssh ssh' + = Dict + +shape :: forall sh a. KnownShapeX sh => XArray sh a -> IxX sh +shape (XArray arr) = go (knownShapeX @sh) (U.shapeL arr) + where + go :: StaticShapeX sh' -> [Int] -> IxX sh' + go SZX [] = IZX + go (n :$@ ssh) (_ : l) = fromIntegral (unSNat n) ::@ go ssh l + go (() :$? ssh) (n : l) = n ::? go ssh l + go _ _ = error "Invalid shapeL" + +fromVector :: forall sh a. U.Unbox a => IxX sh -> VU.Vector a -> XArray sh a +fromVector sh v + | Dict <- lemKnownNatRank sh + , Dict <- gknownNat (Proxy @(Rank sh)) + = XArray (U.fromVector (shapeLshape sh) v) + +toVector :: U.Unbox a => XArray sh a -> VU.Vector a +toVector (XArray arr) = U.toVector arr + +scalar :: U.Unbox a => a -> XArray '[] a +scalar = XArray . U.scalar + +unScalar :: U.Unbox a => XArray '[] a -> a +unScalar (XArray a) = U.unScalar a + +generate :: U.Unbox a => IxX sh -> (IxX sh -> a) -> XArray sh a +generate sh f = fromVector sh $ VU.generate (shapeSize sh) (f . fromLinearIdx sh) + +-- generateM :: (Monad m, U.Unbox a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a) +-- generateM sh f | Dict <- lemKnownNatRank sh = +-- XArray . U.fromVector (shapeLshape sh) +-- <$> VU.generateM (shapeSize sh) (f . fromLinearIdx sh) + +indexPartial :: U.Unbox a => XArray (sh ++ sh') a -> IxX sh -> XArray sh' a +indexPartial (XArray arr) IZX = XArray arr +indexPartial (XArray arr) (i ::@ idx) = indexPartial (XArray (U.index arr i)) idx +indexPartial (XArray arr) (i ::? idx) = indexPartial (XArray (U.index arr i)) idx + +index :: forall sh a. U.Unbox a => XArray sh a -> IxX sh -> a +index xarr i + | Refl <- lemAppNil @sh + = let XArray arr' = indexPartial xarr i :: XArray '[] a + in U.unScalar arr' + +append :: forall sh a. (KnownShapeX sh, U.Unbox a) => XArray sh a -> XArray sh a -> XArray sh a +append (XArray a) (XArray b) + | Dict <- lemKnownNatRankSSX (knownShapeX @sh) + , Dict <- gknownNat (Proxy @(Rank sh)) + = XArray (U.append a b) + +rerank :: forall sh sh1 sh2 a b. + (U.Unbox a, U.Unbox b) + => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 + -> (XArray sh1 a -> XArray sh2 b) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b +rerank ssh ssh1 ssh2 f (XArray arr) + | Dict <- lemKnownNatRankSSX ssh + , Dict <- gknownNat (Proxy @(Rank sh)) + , Dict <- lemKnownNatRankSSX ssh2 + , Dict <- gknownNat (Proxy @(Rank sh2)) + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the + , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough + = XArray (U.rerank @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) + (\a -> unXArray (f (XArray a))) + arr) + where + unXArray (XArray a) = a + +rerank2 :: forall sh sh1 sh2 a b c. + (U.Unbox a, U.Unbox b, U.Unbox c) + => StaticShapeX sh -> StaticShapeX sh1 -> StaticShapeX sh2 + -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) + -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c +rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) + | Dict <- lemKnownNatRankSSX ssh + , Dict <- gknownNat (Proxy @(Rank sh)) + , Dict <- lemKnownNatRankSSX ssh2 + , Dict <- gknownNat (Proxy @(Rank sh2)) + , Refl <- lemRankApp ssh ssh1 + , Refl <- lemRankApp ssh ssh2 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the + , Dict <- gknownNat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough + = XArray (U.rerank2 @(GNat (Rank sh)) @(GNat (Rank sh1)) @(GNat (Rank sh2)) + (\a b -> unXArray (f (XArray a) (XArray b))) + arr1 arr2) + where + unXArray (XArray a) = a + +-- | The list argument gives indices into the original dimension list. +transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a +transpose perm (XArray arr) + | Dict <- lemKnownNatRankSSX (knownShapeX @sh) + , Dict <- gknownNat (Proxy @(Rank sh)) + = XArray (U.transpose perm arr) + +transpose2 :: forall sh1 sh2 a. + StaticShapeX sh1 -> StaticShapeX sh2 + -> XArray (sh1 ++ sh2) a -> XArray (sh2 ++ sh1) a +transpose2 ssh1 ssh2 (XArray arr) + | Refl <- lemRankApp ssh1 ssh2 + , Refl <- lemRankApp ssh2 ssh1 + , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) + , Dict <- gknownNat (Proxy @(Rank (sh1 ++ sh2))) + , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) + , Dict <- gknownNat (Proxy @(Rank (sh2 ++ sh1))) + , Refl <- lemRankAppComm ssh1 ssh2 + , let n1 = ssxLength ssh1 + = XArray (U.transpose (ssxIotaFrom n1 ssh2 ++ ssxIotaFrom 0 ssh1) arr) + +sumFull :: (U.Unbox a, Num a) => XArray sh a -> a +sumFull (XArray arr) = U.sumA arr + +sumInner :: forall sh sh' a. (U.Unbox a, Num a) + => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh a +sumInner ssh ssh' + | Refl <- lemAppNil @sh + = rerank ssh ssh' SZX (scalar . sumFull) + +sumOuter :: forall sh sh' a. (U.Unbox a, Num a) + => StaticShapeX sh -> StaticShapeX sh' -> XArray (sh ++ sh') a -> XArray sh' a +sumOuter ssh ssh' + | Refl <- lemAppNil @sh + = sumInner ssh' ssh . transpose2 ssh ssh' diff --git a/src/Data/Array/Nested.hs b/src/Data/Array/Nested.hs new file mode 100644 index 0000000..983a636 --- /dev/null +++ b/src/Data/Array/Nested.hs @@ -0,0 +1,40 @@ +{-# LANGUAGE ExplicitNamespaces #-} +module Data.Array.Nested ( + -- * Ranked arrays + Ranked, + IxR(..), + rshape, rindex, rindexPartial, rgenerate, rsumOuter1, + -- ** Lifting orthotope operations to 'Ranked' arrays + rlift, + + -- * Shaped arrays + Shaped, + IxS(..), + KnownShape(..), SShape(..), + sshape, sindex, sindexPartial, sgenerate, ssumOuter1, + -- ** Lifting orthotope operations to 'Shaped' arrays + slift, + + -- * Mixed arrays + Mixed, + IxX(..), + KnownShapeX(..), StaticShapeX(..), + mgenerate, + + -- * Array elements + Elt(mshape, mindex, mindexPartial, mlift), + Primitive(..), + + -- * Natural numbers + module Data.Nat, + + -- * Further utilities / re-exports + type (++), + VU.Unbox, +) where + +import qualified Data.Vector.Unboxed as VU + +import Data.Array.Mixed +import Data.Array.Nested.Internal +import Data.Nat diff --git a/src/Data/Array/Nested/Internal.hs b/src/Data/Array/Nested/Internal.hs new file mode 100644 index 0000000..1139c57 --- /dev/null +++ b/src/Data/Array/Nested/Internal.hs @@ -0,0 +1,623 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +{-| +TODO: +* This module needs better structure with an Internal module and less public + exports etc. + +* We should be more consistent in whether functions take a 'StaticShapeX' + argument or a 'KnownShapeX' constraint. + +-} + +module Data.Array.Nested.Internal where + +import Control.Monad (forM_) +import Control.Monad.ST +import Data.Coerce (coerce, Coercible) +import Data.Kind +import Data.Proxy +import Data.Type.Equality +import qualified Data.Vector.Unboxed as VU +import qualified Data.Vector.Unboxed.Mutable as VUM + +import Data.Array.Mixed (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++)) +import qualified Data.Array.Mixed as X +import Data.Nat + + +type family Replicate n a where + Replicate Z a = '[] + Replicate (S n) a = a : Replicate n a + +type family MapJust l where + MapJust '[] = '[] + MapJust (x : xs) = Just x : MapJust xs + +lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) +lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n)) + where + go :: SNat m -> StaticShapeX (Replicate m Nothing) + go SZ = SZX + go (SS n) = () :$? go n + +lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n +lemRankReplicate _ = go (knownNat @n) + where + go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m + go SZ = Refl + go (SS n) | Refl <- go n = Refl + +lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a + -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a +lemReplicatePlusApp _ _ _ = go (knownNat @n) + where + go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a + go SZ = Refl + go (SS n) | Refl <- go n = Refl + + +-- | Wrapper type used as a tag to attach instances on. The instances on arrays +-- of @'Primitive' a@ are more polymorphic than the direct instances for arrays +-- of scalars; this means that if @orthotope@ supports an element type @T@ that +-- this library does not (directly), it may just work if you use an array of +-- @'Primitive' T@ instead. +newtype Primitive a = Primitive a + + +-- | Mixed arrays: some dimensions are size-typed, some are not. Distributes +-- over product-typed elements using a data family so that the full array is +-- always in struct-of-arrays format. +-- +-- Built on top of 'XArray' which is built on top of @orthotope@, meaning that +-- dimension permutations (e.g. 'transpose') are typically free. +-- +-- Many of the methods for working on 'Mixed' arrays come from the 'Elt' type +-- class. +type Mixed :: [Maybe Nat] -> Type -> Type +data family Mixed sh a + +newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a) + +newtype instance Mixed sh Int = M_Int (XArray sh Int) +newtype instance Mixed sh Double = M_Double (XArray sh Double) +newtype instance Mixed sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector) +-- etc. + +data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b) +-- etc. + +newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a) + + +-- | Internal helper data family mirrorring 'Mixed' that consists of mutable +-- vectors instead of 'XArray's. +type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type +data family MixedVecs s sh a + +newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a) + +newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int) +newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double) +newtype instance MixedVecs s sh () = MV_Nil (VU.MVector s ()) -- no content, MVector optimises this +-- etc. + +data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b) +-- etc. + +data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a) + + +-- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or +-- 'Shaped' array. Note the polymorphic instance for 'Elt' of @'Primitive' +-- a@; see the documentation for 'Primitive' for more details. +class Elt a where + -- ====== PUBLIC METHODS ====== -- + + mshape :: KnownShapeX sh => Mixed sh a -> IxX sh + mindex :: Mixed sh a -> IxX sh -> a + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a + + mlift :: forall sh1 sh2. KnownShapeX sh2 + => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 a -> Mixed sh2 a + + -- ====== PRIVATE METHODS ====== -- + -- Remember I said that this module needed better management of exports? + + -- | Create an empty array. The given shape must have size zero; this may or may not be checked. + memptyArray :: IxX sh -> Mixed sh a + + -- | Return the size of the individual (SoA) arrays in this value. If @a@ + -- does not contain tuples, this coincides with the total number of scalars + -- in the given value; if @a@ contains tuples, then it is some multiple of + -- this number of scalars. + mvecsNumElts :: a -> Int + + -- | Create uninitialised vectors for this array type, given the shape of + -- this vector and an example for the contents. The shape must not have size + -- zero; an error may be thrown otherwise. + mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a) + + -- | Given the shape of this array, an index and a value, write the value at + -- that index in the vectors. + mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s () + + -- | Given the shape of this array, an index and a value, write the value at + -- that index in the vectors. + mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () + + -- | Given the shape of this array, finalise the vectors into 'XArray's. + mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) + + +-- Arrays of scalars are basically just arrays of scalars. +instance VU.Unbox a => Elt (Primitive a) where + mshape (M_Primitive a) = X.shape a + mindex (M_Primitive a) i = Primitive (X.index a i) + mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) + + mlift :: forall sh1 sh2. + (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) + -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) + mlift f (M_Primitive a) + | Refl <- X.lemAppNil @sh1 + , Refl <- X.lemAppNil @sh2 + = M_Primitive (f Proxy a) + + memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty")) + mvecsNumElts _ = 1 + mvecsUnsafeNew sh _ = MV_Primitive <$> VUM.unsafeNew (X.shapeSize sh) + mvecsWrite sh i (Primitive x) (MV_Primitive v) = VUM.write v (X.toLinearIdx sh i) x + + -- TODO: this use of toVector is suboptimal + mvecsWritePartial + :: forall sh' sh s. (KnownShapeX sh', VU.Unbox a) + => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () + mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do + let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr))) + VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) + + mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v + +-- What a blessing that orthotope's Array has "representational" role on the value type! +deriving via Primitive Int instance Elt Int +deriving via Primitive Double instance Elt Double +deriving via Primitive () instance Elt () + +-- Arrays of pairs are pairs of arrays. +instance (Elt a, Elt b) => Elt (a, b) where + mshape (M_Tup2 a _) = mshape a + mindex (M_Tup2 a b) i = (mindex a i, mindex b i) + mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) + mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b) + + memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) + mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y + mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y + mvecsWrite sh i (x, y) (MV_Tup2 a b) = do + mvecsWrite sh i x a + mvecsWrite sh i y b + mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do + mvecsWritePartial sh i x a + mvecsWritePartial sh i y b + mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b + +-- Arrays of arrays are just arrays, but with more dimensions. +instance (Elt a, KnownShapeX sh') => Elt (Mixed sh' a) where + mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh + mshape (M_Nest arr) + | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') + = ixAppPrefix (knownShapeX @sh) (mshape arr) + where + ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1 + ixAppPrefix SZX _ = IZX + ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx + ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx + + mindex (M_Nest arr) i = mindexPartial arr i + + mindexPartial :: forall sh1 sh2. + Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) + mindexPartial (M_Nest arr) i + | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) + + mlift :: forall sh1 sh2. KnownShapeX sh2 + => (forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b) + -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) + mlift f (M_Nest arr) + | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) + = M_Nest (mlift f' arr) + where + f' :: forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b + f' _ + | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3) + , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3) + , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @sh3)) + = f (Proxy @(sh' ++ sh3)) + + memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh')))) + + mvecsNumElts arr = + let n = X.shapeSize (mshape arr) + in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh'))) + + mvecsUnsafeNew sh example + | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example" + | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) + (mindex example (X.zeroIdx (knownShapeX @sh'))) + where + sh' = mshape example + + mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs + + mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 + => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) + -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) + -> ST s () + mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) + | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) + , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') + = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs + + mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs + + +-- Public method. Turns out this doesn't have to be in the type class! +-- | Create an array given a size and a function that computes the element at a +-- given index. +mgenerate :: forall sh a. (KnownShapeX sh, Elt a) => IxX sh -> (IxX sh -> a) -> Mixed sh a +mgenerate sh f + -- TODO: Do we need this checkBounds check elsewhere as well? + | not (checkBounds sh (knownShapeX @sh)) = + error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) + -- We need to be very careful here to ensure that neither 'sh' nor + -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty. + | X.shapeSize sh == 0 = memptyArray sh + | otherwise = + let firstidx = X.zeroIdx' sh + firstelem = f (X.zeroIdx' sh) + in if mvecsNumElts firstelem == 0 + then memptyArray sh + else runST $ do + vecs <- mvecsUnsafeNew sh firstelem + mvecsWrite sh firstidx firstelem vecs + -- TODO: This is likely fine if @a@ is big, but if @a@ is a + -- scalar this feels inefficient. Should improve this. + forM_ (tail (X.enumShape sh)) $ \idx -> + mvecsWrite sh idx (f idx) vecs + mvecsFreeze sh vecs + where + checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool + checkBounds IZX SZX = True + checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh' + checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' + + +-- | A rank-typed array: the number of dimensions of the array (its /rank/) is +-- represented on the type level as a 'Nat'. +-- +-- Valid elements of a ranked arrays are described by the 'Elt' type class. +-- Because 'Ranked' itself is also an instance of 'Elt', nested arrays are +-- supported (and are represented as a single, flattened, struct-of-arrays +-- array internally). +-- +-- Note that this 'Nat' is not a "GHC.TypeLits" natural, because we want a +-- type-level natural that supports induction. +-- +-- 'Ranked' is a newtype around a 'Mixed' of 'Nothing's. +type Ranked :: Nat -> Type -> Type +newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) + +-- | A shape-typed array: the full shape of the array (the sizes of its +-- dimensions) is represented on the type level as a list of 'Nat's. +-- +-- Like for 'Ranked', the valid elements are described by the 'Elt' type class, +-- and 'Shaped' itself is again an instance of 'Elt' as well. +-- +-- 'Shaped' is a newtype around a 'Mixed' of 'Just's. +type Shaped :: [Nat] -> Type -> Type +newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) + +-- just unwrap the newtype and defer to the general instance for nested arrays +newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) +newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a)) + +newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) +newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a)) + + +-- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; +-- these instances allow them to also be used as elements of arrays, thus +-- making them first-class in the API. +instance (KnownNat n, Elt a) => Elt (Ranked n a) where + mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr + mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) + + mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a) + mindexPartial (M_Ranked arr) i + | Dict <- lemKnownReplicate (Proxy @n) + = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ + mindexPartial arr i + + mlift :: forall sh1 sh2. KnownShapeX sh2 + => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) + mlift f (M_Ranked arr) + | Dict <- lemKnownReplicate (Proxy @n) + = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ + mlift f arr + + memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a) + memptyArray i + | Dict <- lemKnownReplicate (Proxy @n) + = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ + memptyArray i + + mvecsNumElts (Ranked arr) + | Dict <- lemKnownReplicate (Proxy @n) + = mvecsNumElts arr + + mvecsUnsafeNew idx (Ranked arr) + | Dict <- lemKnownReplicate (Proxy @n) + = MV_Ranked <$> mvecsUnsafeNew idx arr + + mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () + mvecsWrite sh idx (Ranked arr) vecs + | Dict <- lemKnownReplicate (Proxy @n) + = mvecsWrite sh idx arr + (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsWritePartial :: forall sh sh' s. KnownShapeX sh' + => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a) + -> MixedVecs s (sh ++ sh') (Ranked n a) + -> ST s () + mvecsWritePartial sh idx arr vecs + | Dict <- lemKnownReplicate (Proxy @n) + = mvecsWritePartial sh idx + (coerce @(Mixed sh' (Ranked n a)) + @(Mixed sh' (Mixed (Replicate n Nothing) a)) + arr) + (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) + @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) + vecs) + + mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) + mvecsFreeze sh vecs + | Dict <- lemKnownReplicate (Proxy @n) + = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) + @(Mixed sh (Ranked n a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh (Ranked n a)) + @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) + vecs) + + +-- | The shape of a shape-typed array given as a list of 'SNat' values. +data SShape sh where + ShNil :: SShape '[] + ShCons :: SNat n -> SShape sh -> SShape (n : sh) +deriving instance Show (SShape sh) + +-- | A statically-known shape of a shape-typed array. +class KnownShape sh where knownShape :: SShape sh +instance KnownShape '[] where knownShape = ShNil +instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape + +lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh) +lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) + where + go :: SShape sh' -> StaticShapeX (MapJust sh') + go ShNil = SZX + go (ShCons n sh) = n :$@ go sh + +lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2 + -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 +lemMapJustPlusApp _ _ = go (knownShape @sh1) + where + go :: SShape sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2 + go ShNil = Refl + go (ShCons _ sh) | Refl <- go sh = Refl + +instance (KnownShape sh, Elt a) => Elt (Shaped sh a) where + mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr + mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i) + + mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IxX sh1 -> Mixed sh2 (Shaped sh a) + mindexPartial (M_Shaped arr) i + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mindexPartial arr i + + mlift :: forall sh1 sh2. KnownShapeX sh2 + => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) + -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) + mlift f (M_Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ + mlift f arr + + memptyArray :: forall sh'. IxX sh' -> Mixed sh' (Shaped sh a) + memptyArray i + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ + memptyArray i + + mvecsNumElts (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = mvecsNumElts arr + + mvecsUnsafeNew idx (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = MV_Shaped <$> mvecsUnsafeNew idx arr + + mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () + mvecsWrite sh idx (Shaped arr) vecs + | Dict <- lemKnownMapJust (Proxy @sh) + = mvecsWrite sh idx arr + (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + + mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 + => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Shaped sh a) + -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) + -> ST s () + mvecsWritePartial sh idx arr vecs + | Dict <- lemKnownMapJust (Proxy @sh) + = mvecsWritePartial sh idx + (coerce @(Mixed sh2 (Shaped sh a)) + @(Mixed sh2 (Mixed (MapJust sh) a)) + arr) + (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) + @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) + vecs) + + mvecsFreeze :: forall sh' s. IxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) + mvecsFreeze sh vecs + | Dict <- lemKnownMapJust (Proxy @sh) + = coerce @(Mixed sh' (Mixed (MapJust sh) a)) + @(Mixed sh' (Shaped sh a)) + <$> mvecsFreeze sh + (coerce @(MixedVecs s sh' (Shaped sh a)) + @(MixedVecs s sh' (Mixed (MapJust sh) a)) + vecs) + + +-- Utility function to satisfy the type checker sometimes +rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a +rewriteMixed Refl x = x + + +-- ====== API OF RANKED ARRAYS ====== -- + +-- | An index into a rank-typed array. +type IxR :: Nat -> Type +data IxR n where + IZR :: IxR Z + (:::) :: Int -> IxR n -> IxR (S n) + +ixCvtXR :: IxX sh -> IxR (X.Rank sh) +ixCvtXR IZX = IZR +ixCvtXR (n ::@ idx) = n ::: ixCvtXR idx +ixCvtXR (n ::? idx) = n ::: ixCvtXR idx + +ixCvtRX :: IxR n -> IxX (Replicate n Nothing) +ixCvtRX IZR = IZX +ixCvtRX (n ::: idx) = n ::? ixCvtRX idx + + +rshape :: forall n a. (KnownNat n, Elt a) => Ranked n a -> IxR n +rshape (Ranked arr) + | Dict <- lemKnownReplicate (Proxy @n) + , Refl <- lemRankReplicate (Proxy @n) + = ixCvtXR (mshape arr) + +rindex :: Elt a => Ranked n a -> IxR n -> a +rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) + +rindexPartial :: forall n m a. (KnownNat n, Elt a) => Ranked (n + m) a -> IxR n -> Ranked m a +rindexPartial (Ranked arr) idx = + Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) + (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) + (ixCvtRX idx)) + +rgenerate :: forall n a. (KnownNat n, Elt a) => IxR n -> (IxR n -> a) -> Ranked n a +rgenerate sh f + | Dict <- lemKnownReplicate (Proxy @n) + , Refl <- lemRankReplicate (Proxy @n) + = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) + +rlift :: forall n1 n2 a. (KnownNat n2, Elt a) + => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) + -> Ranked n1 a -> Ranked n2 a +rlift f (Ranked arr) + | Dict <- lemKnownReplicate (Proxy @n2) + = Ranked (mlift f arr) + +rsumOuter1 :: forall n a. + (VU.Unbox a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a)) + => Ranked (S n) a -> Ranked n a +rsumOuter1 (Ranked arr) + | Dict <- lemKnownReplicate (Proxy @n) + = Ranked + . coerce @(XArray (Replicate n Nothing) a) @(Mixed (Replicate n Nothing) a) + . X.sumOuter (() :$? SZX) (knownShapeX @(Replicate n Nothing)) + . coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a) + $ arr + + +-- ====== API OF SHAPED ARRAYS ====== -- + +-- | An index into a shape-typed array. +-- +-- For convenience, this contains regular 'Int's instead of bounded integers +-- (traditionally called \"@Fin@\"). Note that because the shape of a +-- shape-typed array is known statically, you can also retrieve the array shape +-- from a 'KnownShape' dictionary. +type IxS :: [Nat] -> Type +data IxS sh where + IZS :: IxS '[] + (::$) :: Int -> IxS sh -> IxS (n : sh) + +cvtSShapeIxS :: SShape sh -> IxS sh +cvtSShapeIxS ShNil = IZS +cvtSShapeIxS (ShCons n sh) = fromIntegral (unSNat n) ::$ cvtSShapeIxS sh + +ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh +ixCvtXS ShNil IZX = IZS +ixCvtXS (ShCons _ sh) (n ::@ idx) = n ::$ ixCvtXS sh idx + +ixCvtSX :: IxS sh -> IxX (MapJust sh) +ixCvtSX IZS = IZX +ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh + + +sshape :: forall sh a. (KnownShape sh, Elt a) => Shaped sh a -> IxS sh +sshape _ = cvtSShapeIxS (knownShape @sh) + +sindex :: Elt a => Shaped sh a -> IxS sh -> a +sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) + +sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, Elt a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a +sindexPartial (Shaped arr) idx = + Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) + (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) + (ixCvtSX idx)) + +sgenerate :: forall sh a. (KnownShape sh, Elt a) => IxS sh -> (IxS sh -> a) -> Shaped sh a +sgenerate sh f + | Dict <- lemKnownMapJust (Proxy @sh) + = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh))) + +slift :: forall sh1 sh2 a. (KnownShape sh2, Elt a) + => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) + -> Shaped sh1 a -> Shaped sh2 a +slift f (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh2) + = Shaped (mlift f arr) + +ssumOuter1 :: forall sh n a. + (VU.Unbox a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) + => Shaped (n : sh) a -> Shaped sh a +ssumOuter1 (Shaped arr) + | Dict <- lemKnownMapJust (Proxy @sh) + = Shaped + . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a) + . X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh)) + . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a) + $ arr diff --git a/src/Data/Nat.hs b/src/Data/Nat.hs new file mode 100644 index 0000000..5dacc8a --- /dev/null +++ b/src/Data/Nat.hs @@ -0,0 +1,70 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} +module Data.Nat where + +import Data.Proxy +import Numeric.Natural +import qualified GHC.TypeLits as G + + +-- | Evidence for the constraint @c a@. +data Dict c a where + Dict :: c a => Dict c a + +-- | A peano natural number. Intended to be used at the type level. +data Nat = Z | S Nat + deriving (Show) + +-- | Singleton for a 'Nat'. +data SNat n where + SZ :: SNat Z + SS :: SNat n -> SNat (S n) +deriving instance Show (SNat n) + +-- | A singleton 'SNat' corresponding to @n@. +class KnownNat n where knownNat :: SNat n +instance KnownNat Z where knownNat = SZ +instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat + +-- | Convert a 'Nat' to a normal number. +unNat :: Nat -> Natural +unNat Z = 0 +unNat (S n) = 1 + unNat n + +-- | Convert an 'SNat' to a normal number. +unSNat :: SNat n -> Natural +unSNat SZ = 0 +unSNat (SS n) = 1 + unSNat n + +-- | A 'KnownNat' dictionary is just a singleton natural, so we can create +-- evidence of 'KnownNat' given an 'SNat'. +snatKnown :: SNat n -> Dict KnownNat n +snatKnown SZ = Dict +snatKnown (SS n) | Dict <- snatKnown n = Dict + +-- | Add two 'Nat's +type family n + m where + Z + m = m + S n + m = S (n + m) + +-- | Convert a 'Nat' to a "GHC.TypeLits" 'G.Nat'. +type family GNat n where + GNat Z = 0 + GNat (S n) = 1 G.+ GNat n + +-- | If an inductive 'Nat' is known, then the corresponding "GHC.TypeLits" +-- 'G.Nat' is also known. +gknownNat :: KnownNat n => Proxy n -> Dict G.KnownNat (GNat n) +gknownNat (Proxy @n) = go (knownNat @n) + where + go :: SNat m -> Dict G.KnownNat (GNat m) + go SZ = Dict + go (SS n) | Dict <- go n = Dict diff --git a/src/Fancy.hs b/src/Fancy.hs deleted file mode 100644 index 7461c1f..0000000 --- a/src/Fancy.hs +++ /dev/null @@ -1,598 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE InstanceSigs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} - -{-| -TODO: -* This module needs better structure with an Internal module and less public - exports etc. - -* We should be more consistent in whether functions take a 'StaticShapeX' - argument or a 'KnownShapeX' constraint. - --} - -module Fancy where - -import Control.Monad (forM_) -import Control.Monad.ST -import Data.Coerce (coerce, Coercible) -import Data.Kind -import Data.Proxy -import Data.Type.Equality -import qualified Data.Vector.Unboxed as VU -import qualified Data.Vector.Unboxed.Mutable as VUM - -import Array (XArray, IxX(..), KnownShapeX(..), StaticShapeX(..), type (++)) -import qualified Array as X -import Nats - - -type family Replicate n a where - Replicate Z a = '[] - Replicate (S n) a = a : Replicate n a - -type family MapJust l where - MapJust '[] = '[] - MapJust (x : xs) = Just x : MapJust xs - -lemKnownReplicate :: forall n. KnownNat n => Proxy n -> Dict KnownShapeX (Replicate n Nothing) -lemKnownReplicate _ = X.lemKnownShapeX (go (knownNat @n)) - where - go :: SNat m -> StaticShapeX (Replicate m Nothing) - go SZ = SZX - go (SS n) = () :$? go n - -lemRankReplicate :: forall n. KnownNat n => Proxy n -> X.Rank (Replicate n (Nothing @Nat)) :~: n -lemRankReplicate _ = go (knownNat @n) - where - go :: SNat m -> X.Rank (Replicate m (Nothing @Nat)) :~: m - go SZ = Refl - go (SS n) | Refl <- go n = Refl - -lemReplicatePlusApp :: forall n m a. KnownNat n => Proxy n -> Proxy m -> Proxy a - -> Replicate (n + m) a :~: Replicate n a ++ Replicate m a -lemReplicatePlusApp _ _ _ = go (knownNat @n) - where - go :: SNat n' -> Replicate (n' + m) a :~: Replicate n' a ++ Replicate m a - go SZ = Refl - go (SS n) | Refl <- go n = Refl - - --- | Wrapper type used as a tag to attach instances on. The instances on arrays --- of @'Primitive' a@ are more polymorphic than the direct instances for arrays --- of scalars; this means that if @orthotope@ supports an element type @T@ that --- this library does not (directly), it may just work if you use an array of --- @'Primitive' T@ instead. -newtype Primitive a = Primitive a - - --- | Mixed arrays: some dimensions are size-typed, some are not. Distributes --- over product-typed elements using a dat afamily so that the full array is --- always in struct-of-arrays format. --- --- Built on top of 'XArray' which is built on top of @orthotope@, meaning that --- dimension permutations (e.g. 'transpose') are typically free. -type Mixed :: [Maybe Nat] -> Type -> Type -data family Mixed sh a - -newtype instance Mixed sh (Primitive a) = M_Primitive (XArray sh a) - -newtype instance Mixed sh Int = M_Int (XArray sh Int) -newtype instance Mixed sh Double = M_Double (XArray sh Double) -newtype instance Mixed sh () = M_Nil (XArray sh ()) -- no content, orthotope optimises this (via Vector) --- etc. - -data instance Mixed sh (a, b) = M_Tup2 (Mixed sh a) (Mixed sh b) --- etc. - -newtype instance Mixed sh1 (Mixed sh2 a) = M_Nest (Mixed (sh1 ++ sh2) a) - - --- | Internal helper data family mirrorring 'Mixed' that consists of mutable --- vectors instead of 'XArray's. -type MixedVecs :: Type -> [Maybe Nat] -> Type -> Type -data family MixedVecs s sh a - -newtype instance MixedVecs s sh (Primitive a) = MV_Primitive (VU.MVector s a) - -newtype instance MixedVecs s sh Int = MV_Int (VU.MVector s Int) -newtype instance MixedVecs s sh Double = MV_Double (VU.MVector s Double) -newtype instance MixedVecs s sh () = MV_Nil (VU.MVector s ()) -- no content, MVector optimises this --- etc. - -data instance MixedVecs s sh (a, b) = MV_Tup2 (MixedVecs s sh a) (MixedVecs s sh b) --- etc. - -data instance MixedVecs s sh1 (Mixed sh2 a) = MV_Nest (IxX sh2) (MixedVecs s (sh1 ++ sh2) a) - - --- | Allowable scalar types in a mixed array, and by extension in a 'Ranked' or --- 'Shaped' array. Note the polymorphic instance for 'GMixed' of @'Primitive' --- a@; see the documentation for 'Primitive' for more details. -class GMixed a where - -- ====== PUBLIC METHODS ====== -- - - mshape :: KnownShapeX sh => Mixed sh a -> IxX sh - mindex :: Mixed sh a -> IxX sh -> a - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') a -> IxX sh -> Mixed sh' a - - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 a -> Mixed sh2 a - - -- ====== PRIVATE METHODS ====== -- - -- Remember I said that this module needed better management of exports? - - -- | Create an empty array. The given shape must have size zero; this may or may not be checked. - memptyArray :: IxX sh -> Mixed sh a - - -- | Return the size of the individual (SoA) arrays in this value. If @a@ - -- does not contain tuples, this coincides with the total number of scalars - -- in the given value; if @a@ contains tuples, then it is some multiple of - -- this number of scalars. - mvecsNumElts :: a -> Int - - -- | Create uninitialised vectors for this array type, given the shape of - -- this vector and an example for the contents. The shape must not have size - -- zero; an error may be thrown otherwise. - mvecsUnsafeNew :: IxX sh -> a -> ST s (MixedVecs s sh a) - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWrite :: IxX sh -> IxX sh -> a -> MixedVecs s sh a -> ST s () - - -- | Given the shape of this array, an index and a value, write the value at - -- that index in the vectors. - mvecsWritePartial :: KnownShapeX sh' => IxX (sh ++ sh') -> IxX sh -> Mixed sh' a -> MixedVecs s (sh ++ sh') a -> ST s () - - -- | Given the shape of this array, finalise the vectors into 'XArray's. - mvecsFreeze :: IxX sh -> MixedVecs s sh a -> ST s (Mixed sh a) - - --- Arrays of scalars are basically just arrays of scalars. -instance VU.Unbox a => GMixed (Primitive a) where - mshape (M_Primitive a) = X.shape a - mindex (M_Primitive a) i = Primitive (X.index a i) - mindexPartial (M_Primitive a) i = M_Primitive (X.indexPartial a i) - - mlift :: forall sh1 sh2. - (Proxy '[] -> XArray (sh1 ++ '[]) a -> XArray (sh2 ++ '[]) a) - -> Mixed sh1 (Primitive a) -> Mixed sh2 (Primitive a) - mlift f (M_Primitive a) - | Refl <- X.lemAppNil @sh1 - , Refl <- X.lemAppNil @sh2 - = M_Primitive (f Proxy a) - - memptyArray sh = M_Primitive (X.generate sh (error "memptyArray Int: shape was not empty")) - mvecsNumElts _ = 1 - mvecsUnsafeNew sh _ = MV_Primitive <$> VUM.unsafeNew (X.shapeSize sh) - mvecsWrite sh i (Primitive x) (MV_Primitive v) = VUM.write v (X.toLinearIdx sh i) x - - -- TODO: this use of toVector is suboptimal - mvecsWritePartial - :: forall sh' sh s. (KnownShapeX sh', VU.Unbox a) - => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Primitive a) -> MixedVecs s (sh ++ sh') (Primitive a) -> ST s () - mvecsWritePartial sh i (M_Primitive arr) (MV_Primitive v) = do - let offset = X.toLinearIdx sh (X.ixAppend i (X.zeroIdx' (X.shape arr))) - VU.copy (VUM.slice offset (X.shapeSize (X.shape arr)) v) (X.toVector arr) - - mvecsFreeze sh (MV_Primitive v) = M_Primitive . X.fromVector sh <$> VU.freeze v - --- What a blessing that orthotope's Array has "representational" role on the value type! -deriving via Primitive Int instance GMixed Int -deriving via Primitive Double instance GMixed Double -deriving via Primitive () instance GMixed () - --- Arrays of pairs are pairs of arrays. -instance (GMixed a, GMixed b) => GMixed (a, b) where - mshape (M_Tup2 a _) = mshape a - mindex (M_Tup2 a b) i = (mindex a i, mindex b i) - mindexPartial (M_Tup2 a b) i = M_Tup2 (mindexPartial a i) (mindexPartial b i) - mlift f (M_Tup2 a b) = M_Tup2 (mlift f a) (mlift f b) - - memptyArray sh = M_Tup2 (memptyArray sh) (memptyArray sh) - mvecsNumElts (x, y) = mvecsNumElts x * mvecsNumElts y - mvecsUnsafeNew sh (x, y) = MV_Tup2 <$> mvecsUnsafeNew sh x <*> mvecsUnsafeNew sh y - mvecsWrite sh i (x, y) (MV_Tup2 a b) = do - mvecsWrite sh i x a - mvecsWrite sh i y b - mvecsWritePartial sh i (M_Tup2 x y) (MV_Tup2 a b) = do - mvecsWritePartial sh i x a - mvecsWritePartial sh i y b - mvecsFreeze sh (MV_Tup2 a b) = M_Tup2 <$> mvecsFreeze sh a <*> mvecsFreeze sh b - --- Arrays of arrays are just arrays, but with more dimensions. -instance (GMixed a, KnownShapeX sh') => GMixed (Mixed sh' a) where - mshape :: forall sh. KnownShapeX sh => Mixed sh (Mixed sh' a) -> IxX sh - mshape (M_Nest arr) - | Dict <- X.lemAppKnownShapeX (knownShapeX @sh) (knownShapeX @sh') - = ixAppPrefix (knownShapeX @sh) (mshape arr) - where - ixAppPrefix :: StaticShapeX sh1 -> IxX (sh1 ++ sh') -> IxX sh1 - ixAppPrefix SZX _ = IZX - ixAppPrefix (_ :$@ ssh) (i ::@ idx) = i ::@ ixAppPrefix ssh idx - ixAppPrefix (_ :$? ssh) (i ::? idx) = i ::? ixAppPrefix ssh idx - - mindex (M_Nest arr) i = mindexPartial arr i - - mindexPartial :: forall sh1 sh2. - Mixed (sh1 ++ sh2) (Mixed sh' a) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) - mindexPartial (M_Nest arr) i - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = M_Nest (mindexPartial @a @sh1 @(sh2 ++ sh') arr i) - - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray (sh1 ++ sh3) b -> XArray (sh2 ++ sh3) b) - -> Mixed sh1 (Mixed sh' a) -> Mixed sh2 (Mixed sh' a) - mlift f (M_Nest arr) - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) - = M_Nest (mlift f' arr) - where - f' :: forall sh3 b. (KnownShapeX sh3, VU.Unbox b) => Proxy sh3 -> XArray ((sh1 ++ sh') ++ sh3) b -> XArray ((sh2 ++ sh') ++ sh3) b - f' _ - | Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh') (Proxy @sh3) - , Refl <- X.lemAppAssoc (Proxy @sh2) (Proxy @sh') (Proxy @sh3) - , Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh') (knownShapeX @sh3)) - = f (Proxy @(sh' ++ sh3)) - - memptyArray sh = M_Nest (memptyArray (X.ixAppend sh (X.zeroIdx (knownShapeX @sh')))) - - mvecsNumElts arr = - let n = X.shapeSize (mshape arr) - in if n == 0 then 0 else n * mvecsNumElts (mindex arr (X.zeroIdx (knownShapeX @sh'))) - - mvecsUnsafeNew sh example - | X.shapeSize sh' == 0 = error "mvecsUnsafeNew: empty example" - | otherwise = MV_Nest sh' <$> mvecsUnsafeNew (X.ixAppend sh (mshape example)) - (mindex example (X.zeroIdx (knownShapeX @sh'))) - where - sh' = mshape example - - mvecsWrite sh idx val (MV_Nest sh' vecs) = mvecsWritePartial (X.ixAppend sh sh') idx val vecs - - mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Mixed sh' a) - -> MixedVecs s (sh1 ++ sh2) (Mixed sh' a) - -> ST s () - mvecsWritePartial sh12 idx (M_Nest arr) (MV_Nest sh' vecs) - | Dict <- X.lemKnownShapeX (X.ssxAppend (knownShapeX @sh2) (knownShapeX @sh')) - , Refl <- X.lemAppAssoc (Proxy @sh1) (Proxy @sh2) (Proxy @sh') - = mvecsWritePartial @a @(sh2 ++ sh') @sh1 (X.ixAppend sh12 sh') idx arr vecs - - mvecsFreeze sh (MV_Nest sh' vecs) = M_Nest <$> mvecsFreeze (X.ixAppend sh sh') vecs - - --- Public method. Turns out this doesn't have to be in the type class! --- | Create an array given a size and a function that computes the element at a --- given index. -mgenerate :: forall sh a. (KnownShapeX sh, GMixed a) => IxX sh -> (IxX sh -> a) -> Mixed sh a -mgenerate sh f - -- TODO: Do we need this checkBounds check elsewhere as well? - | not (checkBounds sh (knownShapeX @sh)) = - error $ "mgenerate: Shape " ++ show sh ++ " not valid for shape type " ++ show (knownShapeX @sh) - -- We need to be very careful here to ensure that neither 'sh' nor - -- 'firstelem' that we pass to 'mvecsUnsafeNew' are empty. - | X.shapeSize sh == 0 = memptyArray sh - | otherwise = - let firstidx = X.zeroIdx' sh - firstelem = f (X.zeroIdx' sh) - in if mvecsNumElts firstelem == 0 - then memptyArray sh - else runST $ do - vecs <- mvecsUnsafeNew sh firstelem - mvecsWrite sh firstidx firstelem vecs - -- TODO: This is likely fine if @a@ is big, but if @a@ is a - -- scalar this feels inefficient. Should improve this. - forM_ (tail (X.enumShape sh)) $ \idx -> - mvecsWrite sh idx (f idx) vecs - mvecsFreeze sh vecs - where - checkBounds :: IxX sh' -> StaticShapeX sh' -> Bool - checkBounds IZX SZX = True - checkBounds (n ::@ sh') (n' :$@ ssh') = n == fromIntegral (unSNat n') && checkBounds sh' ssh' - checkBounds (_ ::? sh') (() :$? ssh') = checkBounds sh' ssh' - - --- | Newtype around a 'Mixed' of 'Nothing's. This works like a rank-typed array --- as in @orthotope@. -type Ranked :: Nat -> Type -> Type -newtype Ranked n a = Ranked (Mixed (Replicate n Nothing) a) - --- | Newtype around a 'Mixed' of 'Just's. This works like a shape-typed array --- as in @orthotope@. -type Shaped :: [Nat] -> Type -> Type -newtype Shaped sh a = Shaped (Mixed (MapJust sh) a) - --- just unwrap the newtype and defer to the general instance for nested arrays -newtype instance Mixed sh (Ranked n a) = M_Ranked (Mixed sh (Mixed (Replicate n Nothing) a)) -newtype instance Mixed sh (Shaped sh' a) = M_Shaped (Mixed sh (Mixed (MapJust sh' ) a)) - -newtype instance MixedVecs s sh (Ranked n a) = MV_Ranked (MixedVecs s sh (Mixed (Replicate n Nothing) a)) -newtype instance MixedVecs s sh (Shaped sh' a) = MV_Shaped (MixedVecs s sh (Mixed (MapJust sh' ) a)) - - --- 'Ranked' and 'Shaped' can already be used at the top level of an array nest; --- these instances allow them to also be used as elements of arrays, thus --- making them first-class in the API. -instance (KnownNat n, GMixed a) => GMixed (Ranked n a) where - mshape (M_Ranked arr) | Dict <- lemKnownReplicate (Proxy @n) = mshape arr - mindex (M_Ranked arr) i | Dict <- lemKnownReplicate (Proxy @n) = Ranked (mindex arr i) - - mindexPartial :: forall sh sh'. Mixed (sh ++ sh') (Ranked n a) -> IxX sh -> Mixed sh' (Ranked n a) - mindexPartial (M_Ranked arr) i - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh' (Mixed (Replicate n Nothing) a)) @(Mixed sh' (Ranked n a)) $ - mindexPartial arr i - - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Ranked n a) -> Mixed sh2 (Ranked n a) - mlift f (M_Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh2 (Mixed (Replicate n Nothing) a)) @(Mixed sh2 (Ranked n a)) $ - mlift f arr - - memptyArray :: forall sh. IxX sh -> Mixed sh (Ranked n a) - memptyArray i - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) @(Mixed sh (Ranked n a)) $ - memptyArray i - - mvecsNumElts (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = mvecsNumElts arr - - mvecsUnsafeNew idx (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = MV_Ranked <$> mvecsUnsafeNew idx arr - - mvecsWrite :: forall sh s. IxX sh -> IxX sh -> Ranked n a -> MixedVecs s sh (Ranked n a) -> ST s () - mvecsWrite sh idx (Ranked arr) vecs - | Dict <- lemKnownReplicate (Proxy @n) - = mvecsWrite sh idx arr - (coerce @(MixedVecs s sh (Ranked n a)) @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsWritePartial :: forall sh sh' s. KnownShapeX sh' - => IxX (sh ++ sh') -> IxX sh -> Mixed sh' (Ranked n a) - -> MixedVecs s (sh ++ sh') (Ranked n a) - -> ST s () - mvecsWritePartial sh idx arr vecs - | Dict <- lemKnownReplicate (Proxy @n) - = mvecsWritePartial sh idx - (coerce @(Mixed sh' (Ranked n a)) - @(Mixed sh' (Mixed (Replicate n Nothing) a)) - arr) - (coerce @(MixedVecs s (sh ++ sh') (Ranked n a)) - @(MixedVecs s (sh ++ sh') (Mixed (Replicate n Nothing) a)) - vecs) - - mvecsFreeze :: forall sh s. IxX sh -> MixedVecs s sh (Ranked n a) -> ST s (Mixed sh (Ranked n a)) - mvecsFreeze sh vecs - | Dict <- lemKnownReplicate (Proxy @n) - = coerce @(Mixed sh (Mixed (Replicate n Nothing) a)) - @(Mixed sh (Ranked n a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh (Ranked n a)) - @(MixedVecs s sh (Mixed (Replicate n Nothing) a)) - vecs) - - -data SShape sh where - ShNil :: SShape '[] - ShCons :: SNat n -> SShape sh -> SShape (n : sh) -deriving instance Show (SShape sh) - -class KnownShape sh where knownShape :: SShape sh -instance KnownShape '[] where knownShape = ShNil -instance (KnownNat n, KnownShape sh) => KnownShape (n : sh) where knownShape = ShCons knownNat knownShape - -lemKnownMapJust :: forall sh. KnownShape sh => Proxy sh -> Dict KnownShapeX (MapJust sh) -lemKnownMapJust _ = X.lemKnownShapeX (go (knownShape @sh)) - where - go :: SShape sh' -> StaticShapeX (MapJust sh') - go ShNil = SZX - go (ShCons n sh) = n :$@ go sh - -lemMapJustPlusApp :: forall sh1 sh2. KnownShape sh1 => Proxy sh1 -> Proxy sh2 - -> MapJust (sh1 ++ sh2) :~: MapJust sh1 ++ MapJust sh2 -lemMapJustPlusApp _ _ = go (knownShape @sh1) - where - go :: SShape sh1' -> MapJust (sh1' ++ sh2) :~: MapJust sh1' ++ MapJust sh2 - go ShNil = Refl - go (ShCons _ sh) | Refl <- go sh = Refl - -instance (KnownShape sh, GMixed a) => GMixed (Shaped sh a) where - mshape (M_Shaped arr) | Dict <- lemKnownMapJust (Proxy @sh) = mshape arr - mindex (M_Shaped arr) i | Dict <- lemKnownMapJust (Proxy @sh) = Shaped (mindex arr i) - - mindexPartial :: forall sh1 sh2. Mixed (sh1 ++ sh2) (Shaped sh a) -> IxX sh1 -> Mixed sh2 (Shaped sh a) - mindexPartial (M_Shaped arr) i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mindexPartial arr i - - mlift :: forall sh1 sh2. KnownShapeX sh2 - => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (sh1 ++ sh') b -> XArray (sh2 ++ sh') b) - -> Mixed sh1 (Shaped sh a) -> Mixed sh2 (Shaped sh a) - mlift f (M_Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh2 (Mixed (MapJust sh) a)) @(Mixed sh2 (Shaped sh a)) $ - mlift f arr - - memptyArray :: forall sh'. IxX sh' -> Mixed sh' (Shaped sh a) - memptyArray i - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) @(Mixed sh' (Shaped sh a)) $ - memptyArray i - - mvecsNumElts (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = mvecsNumElts arr - - mvecsUnsafeNew idx (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = MV_Shaped <$> mvecsUnsafeNew idx arr - - mvecsWrite :: forall sh' s. IxX sh' -> IxX sh' -> Shaped sh a -> MixedVecs s sh' (Shaped sh a) -> ST s () - mvecsWrite sh idx (Shaped arr) vecs - | Dict <- lemKnownMapJust (Proxy @sh) - = mvecsWrite sh idx arr - (coerce @(MixedVecs s sh' (Shaped sh a)) @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - - mvecsWritePartial :: forall sh1 sh2 s. KnownShapeX sh2 - => IxX (sh1 ++ sh2) -> IxX sh1 -> Mixed sh2 (Shaped sh a) - -> MixedVecs s (sh1 ++ sh2) (Shaped sh a) - -> ST s () - mvecsWritePartial sh idx arr vecs - | Dict <- lemKnownMapJust (Proxy @sh) - = mvecsWritePartial sh idx - (coerce @(Mixed sh2 (Shaped sh a)) - @(Mixed sh2 (Mixed (MapJust sh) a)) - arr) - (coerce @(MixedVecs s (sh1 ++ sh2) (Shaped sh a)) - @(MixedVecs s (sh1 ++ sh2) (Mixed (MapJust sh) a)) - vecs) - - mvecsFreeze :: forall sh' s. IxX sh' -> MixedVecs s sh' (Shaped sh a) -> ST s (Mixed sh' (Shaped sh a)) - mvecsFreeze sh vecs - | Dict <- lemKnownMapJust (Proxy @sh) - = coerce @(Mixed sh' (Mixed (MapJust sh) a)) - @(Mixed sh' (Shaped sh a)) - <$> mvecsFreeze sh - (coerce @(MixedVecs s sh' (Shaped sh a)) - @(MixedVecs s sh' (Mixed (MapJust sh) a)) - vecs) - - --- Utility function to satisfy the type checker sometimes -rewriteMixed :: sh1 :~: sh2 -> Mixed sh1 a -> Mixed sh2 a -rewriteMixed Refl x = x - - --- ====== API OF RANKED ARRAYS ====== -- - --- | An index into a rank-typed array. -type IxR :: Nat -> Type -data IxR n where - IZR :: IxR Z - (:::) :: Int -> IxR n -> IxR (S n) - -ixCvtXR :: IxX sh -> IxR (X.Rank sh) -ixCvtXR IZX = IZR -ixCvtXR (n ::@ idx) = n ::: ixCvtXR idx -ixCvtXR (n ::? idx) = n ::: ixCvtXR idx - -ixCvtRX :: IxR n -> IxX (Replicate n Nothing) -ixCvtRX IZR = IZX -ixCvtRX (n ::: idx) = n ::? ixCvtRX idx - - -rshape :: forall n a. (KnownNat n, GMixed a) => Ranked n a -> IxR n -rshape (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemRankReplicate (Proxy @n) - = ixCvtXR (mshape arr) - -rindex :: GMixed a => Ranked n a -> IxR n -> a -rindex (Ranked arr) idx = mindex arr (ixCvtRX idx) - -rindexPartial :: forall n m a. (KnownNat n, GMixed a) => Ranked (n + m) a -> IxR n -> Ranked m a -rindexPartial (Ranked arr) idx = - Ranked (mindexPartial @a @(Replicate n Nothing) @(Replicate m Nothing) - (rewriteMixed (lemReplicatePlusApp (Proxy @n) (Proxy @m) (Proxy @Nothing)) arr) - (ixCvtRX idx)) - -rgenerate :: forall n a. (KnownNat n, GMixed a) => IxR n -> (IxR n -> a) -> Ranked n a -rgenerate sh f - | Dict <- lemKnownReplicate (Proxy @n) - , Refl <- lemRankReplicate (Proxy @n) - = Ranked (mgenerate (ixCvtRX sh) (f . ixCvtXR)) - -rlift :: forall n1 n2 a. (KnownNat n2, GMixed a) - => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (Replicate n1 Nothing ++ sh') b -> XArray (Replicate n2 Nothing ++ sh') b) - -> Ranked n1 a -> Ranked n2 a -rlift f (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n2) - = Ranked (mlift f arr) - -rsumOuter1 :: forall n a. - (VU.Unbox a, Num a, KnownNat n, forall sh. Coercible (Mixed sh a) (XArray sh a)) - => Ranked (S n) a -> Ranked n a -rsumOuter1 (Ranked arr) - | Dict <- lemKnownReplicate (Proxy @n) - = Ranked - . coerce @(XArray (Replicate n Nothing) a) @(Mixed (Replicate n Nothing) a) - . X.sumOuter (() :$? SZX) (knownShapeX @(Replicate n Nothing)) - . coerce @(Mixed (Replicate (S n) Nothing) a) @(XArray (Replicate (S n) Nothing) a) - $ arr - - --- ====== API OF SHAPED ARRAYS ====== -- - --- | An index into a shape-typed array. -type IxS :: [Nat] -> Type -data IxS sh where - IZS :: IxS '[] - (::$) :: Int -> IxS sh -> IxS (n : sh) - -cvtSShapeIxS :: SShape sh -> IxS sh -cvtSShapeIxS ShNil = IZS -cvtSShapeIxS (ShCons n sh) = fromIntegral (unSNat n) ::$ cvtSShapeIxS sh - -ixCvtXS :: SShape sh -> IxX (MapJust sh) -> IxS sh -ixCvtXS ShNil IZX = IZS -ixCvtXS (ShCons _ sh) (n ::@ idx) = n ::$ ixCvtXS sh idx - -ixCvtSX :: IxS sh -> IxX (MapJust sh) -ixCvtSX IZS = IZX -ixCvtSX (n ::$ sh) = n ::@ ixCvtSX sh - - -sshape :: forall sh a. (KnownShape sh, GMixed a) => Shaped sh a -> IxS sh -sshape _ = cvtSShapeIxS (knownShape @sh) - -sindex :: GMixed a => Shaped sh a -> IxS sh -> a -sindex (Shaped arr) idx = mindex arr (ixCvtSX idx) - -sindexPartial :: forall sh1 sh2 a. (KnownShape sh1, GMixed a) => Shaped (sh1 ++ sh2) a -> IxS sh1 -> Shaped sh2 a -sindexPartial (Shaped arr) idx = - Shaped (mindexPartial @a @(MapJust sh1) @(MapJust sh2) - (rewriteMixed (lemMapJustPlusApp (Proxy @sh1) (Proxy @sh2)) arr) - (ixCvtSX idx)) - -sgenerate :: forall sh a. (KnownShape sh, GMixed a) => IxS sh -> (IxS sh -> a) -> Shaped sh a -sgenerate sh f - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped (mgenerate (ixCvtSX sh) (f . ixCvtXS (knownShape @sh))) - -slift :: forall sh1 sh2 a. (KnownShape sh2, GMixed a) - => (forall sh' b. (KnownShapeX sh', VU.Unbox b) => Proxy sh' -> XArray (MapJust sh1 ++ sh') b -> XArray (MapJust sh2 ++ sh') b) - -> Shaped sh1 a -> Shaped sh2 a -slift f (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh2) - = Shaped (mlift f arr) - -ssumOuter1 :: forall sh n a. - (VU.Unbox a, Num a, KnownNat n, KnownShape sh, forall sh'. Coercible (Mixed sh' a) (XArray sh' a)) - => Shaped (n : sh) a -> Shaped sh a -ssumOuter1 (Shaped arr) - | Dict <- lemKnownMapJust (Proxy @sh) - = Shaped - . coerce @(XArray (MapJust sh) a) @(Mixed (MapJust sh) a) - . X.sumOuter (knownNat @n :$@ SZX) (knownShapeX @(MapJust sh)) - . coerce @(Mixed (Just n : MapJust sh) a) @(XArray (Just n : MapJust sh) a) - $ arr diff --git a/src/Nats.hs b/src/Nats.hs deleted file mode 100644 index fdc090e..0000000 --- a/src/Nats.hs +++ /dev/null @@ -1,58 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} -module Nats where - -import Data.Proxy -import Numeric.Natural -import qualified GHC.TypeLits as G - - -data Dict c a where - Dict :: c a => Dict c a - -data Nat = Z | S Nat - deriving (Show) - -data SNat n where - SZ :: SNat Z - SS :: SNat n -> SNat (S n) -deriving instance Show (SNat n) - -class KnownNat n where knownNat :: SNat n -instance KnownNat Z where knownNat = SZ -instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat - -unSNat :: SNat n -> Natural -unSNat SZ = 0 -unSNat (SS n) = 1 + unSNat n - -unNat :: Nat -> Natural -unNat Z = 0 -unNat (S n) = 1 + unNat n - -snatKnown :: SNat n -> Dict KnownNat n -snatKnown SZ = Dict -snatKnown (SS n) | Dict <- snatKnown n = Dict - -type family n + m where - Z + m = m - S n + m = S (n + m) - -type family GNat n where - GNat Z = 0 - GNat (S n) = 1 G.+ GNat n - -gknownNat :: KnownNat n => Proxy n -> Dict G.KnownNat (GNat n) -gknownNat (Proxy @n) = go (knownNat @n) - where - go :: SNat m -> Dict G.KnownNat (GNat m) - go SZ = Dict - go (SS n) | Dict <- go n = Dict -- cgit v1.2.3-70-g09d2