diff options
Diffstat (limited to 'src/Data/Array')
-rw-r--r-- | src/Data/Array/Mixed.hs | 42 |
1 files changed, 21 insertions, 21 deletions
diff --git a/src/Data/Array/Mixed.hs b/src/Data/Array/Mixed.hs index 21293dc..7f25d84 100644 --- a/src/Data/Array/Mixed.hs +++ b/src/Data/Array/Mixed.hs @@ -173,15 +173,15 @@ lemRankAppComm :: StaticShapeX sh1 -> StaticShapeX sh2 -> FromINat (Rank (sh1 ++ sh2)) :~: FromINat (Rank (sh2 ++ sh1)) lemRankAppComm _ _ = unsafeCoerce Refl -- TODO improve this -lemKnownNatRank :: IxX sh -> Dict KnownINat (Rank sh) -lemKnownNatRank IZX = Dict -lemKnownNatRank (_ ::@ sh) | Dict <- lemKnownNatRank sh = Dict -lemKnownNatRank (_ ::? sh) | Dict <- lemKnownNatRank sh = Dict +lemKnownINatRank :: IxX sh -> Dict KnownINat (Rank sh) +lemKnownINatRank IZX = Dict +lemKnownINatRank (_ ::@ sh) | Dict <- lemKnownINatRank sh = Dict +lemKnownINatRank (_ ::? sh) | Dict <- lemKnownINatRank sh = Dict -lemKnownNatRankSSX :: StaticShapeX sh -> Dict KnownINat (Rank sh) -lemKnownNatRankSSX SZX = Dict -lemKnownNatRankSSX (_ :$@ ssh) | Dict <- lemKnownNatRankSSX ssh = Dict -lemKnownNatRankSSX (_ :$? ssh) | Dict <- lemKnownNatRankSSX ssh = Dict +lemKnownINatRankSSX :: StaticShapeX sh -> Dict KnownINat (Rank sh) +lemKnownINatRankSSX SZX = Dict +lemKnownINatRankSSX (_ :$@ ssh) | Dict <- lemKnownINatRankSSX ssh = Dict +lemKnownINatRankSSX (_ :$? ssh) | Dict <- lemKnownINatRankSSX ssh = Dict lemKnownShapeX :: StaticShapeX sh -> Dict KnownShapeX sh lemKnownShapeX SZX = Dict @@ -208,7 +208,7 @@ shape (XArray arr) = go (knownShapeX @sh) (S.shapeL arr) fromVector :: forall sh a. Storable a => IxX sh -> VS.Vector a -> XArray sh a fromVector sh v - | Dict <- lemKnownNatRank sh + | Dict <- lemKnownINatRank sh , Dict <- knownNatFromINat (Proxy @(Rank sh)) = XArray (S.fromVector (shapeLshape sh) v) @@ -225,7 +225,7 @@ generate :: Storable a => IxX sh -> (IxX sh -> a) -> XArray sh a generate sh f = fromVector sh $ VS.generate (shapeSize sh) (f . fromLinearIdx sh) -- generateM :: (Monad m, Storable a) => IxX sh -> (IxX sh -> m a) -> m (XArray sh a) --- generateM sh f | Dict <- lemKnownNatRank sh = +-- generateM sh f | Dict <- lemKnownINatRank sh = -- XArray . S.fromVector (shapeLshape sh) -- <$> VS.generateM (shapeSize sh) (f . fromLinearIdx sh) @@ -242,7 +242,7 @@ index xarr i append :: forall sh a. (KnownShapeX sh, Storable a) => XArray sh a -> XArray sh a -> XArray sh a append (XArray a) (XArray b) - | Dict <- lemKnownNatRankSSX (knownShapeX @sh) + | Dict <- lemKnownINatRankSSX (knownShapeX @sh) , Dict <- knownNatFromINat (Proxy @(Rank sh)) = XArray (S.append a b) @@ -252,14 +252,14 @@ rerank :: forall sh sh1 sh2 a b. -> (XArray sh1 a -> XArray sh2 b) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh2) b rerank ssh ssh1 ssh2 f (XArray arr) - | Dict <- lemKnownNatRankSSX ssh + | Dict <- lemKnownINatRankSSX ssh , Dict <- knownNatFromINat (Proxy @(Rank sh)) - , Dict <- lemKnownNatRankSSX ssh2 + , Dict <- lemKnownINatRankSSX ssh2 , Dict <- knownNatFromINat (Proxy @(Rank sh2)) , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the - , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough + , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the + , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough = XArray (S.rerank @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2)) (\a -> unXArray (f (XArray a))) arr) @@ -279,13 +279,13 @@ rerank2 :: forall sh sh1 sh2 a b c. -> (XArray sh1 a -> XArray sh1 b -> XArray sh2 c) -> XArray (sh ++ sh1) a -> XArray (sh ++ sh1) b -> XArray (sh ++ sh2) c rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) - | Dict <- lemKnownNatRankSSX ssh + | Dict <- lemKnownINatRankSSX ssh , Dict <- knownNatFromINat (Proxy @(Rank sh)) - , Dict <- lemKnownNatRankSSX ssh2 + , Dict <- lemKnownINatRankSSX ssh2 , Dict <- knownNatFromINat (Proxy @(Rank sh2)) , Refl <- lemRankApp ssh ssh1 , Refl <- lemRankApp ssh ssh2 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the + , Dict <- lemKnownINatRankSSX (ssxAppend ssh ssh2) -- these two should be redundant but the , Dict <- knownNatFromINat (Proxy @(Rank (sh ++ sh2))) -- solver is not clever enough = XArray (S.rerank2 @(FromINat (Rank sh)) @(FromINat (Rank sh1)) @(FromINat (Rank sh2)) (\a b -> unXArray (f (XArray a) (XArray b))) @@ -296,7 +296,7 @@ rerank2 ssh ssh1 ssh2 f (XArray arr1) (XArray arr2) -- | The list argument gives indices into the original dimension list. transpose :: forall sh a. KnownShapeX sh => [Int] -> XArray sh a -> XArray sh a transpose perm (XArray arr) - | Dict <- lemKnownNatRankSSX (knownShapeX @sh) + | Dict <- lemKnownINatRankSSX (knownShapeX @sh) , Dict <- knownNatFromINat (Proxy @(Rank sh)) = XArray (S.transpose perm arr) @@ -306,9 +306,9 @@ transpose2 :: forall sh1 sh2 a. transpose2 ssh1 ssh2 (XArray arr) | Refl <- lemRankApp ssh1 ssh2 , Refl <- lemRankApp ssh2 ssh1 - , Dict <- lemKnownNatRankSSX (ssxAppend ssh1 ssh2) + , Dict <- lemKnownINatRankSSX (ssxAppend ssh1 ssh2) , Dict <- knownNatFromINat (Proxy @(Rank (sh1 ++ sh2))) - , Dict <- lemKnownNatRankSSX (ssxAppend ssh2 ssh1) + , Dict <- lemKnownINatRankSSX (ssxAppend ssh2 ssh1) , Dict <- knownNatFromINat (Proxy @(Rank (sh2 ++ sh1))) , Refl <- lemRankAppComm ssh1 ssh2 , let n1 = ssxLength ssh1 |