aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Data/Array/Mixed/Internal/Arith.hs158
-rw-r--r--src/Data/Array/Mixed/Internal/Arith/Foreign.hs30
-rw-r--r--src/Data/Array/Mixed/Permutation.hs2
-rw-r--r--src/Data/Array/Mixed/Shape.hs63
-rw-r--r--src/Data/Array/Nested/Internal/Shape.hs61
5 files changed, 226 insertions, 88 deletions
diff --git a/src/Data/Array/Mixed/Internal/Arith.hs b/src/Data/Array/Mixed/Internal/Arith.hs
index 11ee3fe..6253ae0 100644
--- a/src/Data/Array/Mixed/Internal/Arith.hs
+++ b/src/Data/Array/Mixed/Internal/Arith.hs
@@ -67,7 +67,7 @@ liftOpEltwise1 sn@SNat ptrconv1 ptrconv2 cf_strided (RS.A (RG.A sh (OI.T strides
VS.unsafeWith (VS.singleton (fromIntegral blockSz)) $ \psh ->
VS.unsafeWith (VS.singleton 1) $ \pstrides ->
VS.unsafeWith (VS.slice blockOff blockSz vec) $ \pv ->
- cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv)
+ cf_strided 1 (ptrconv2 poutv) psh pstrides (ptrconv1 pv)
RS.A . RG.A sh . OI.T strides (offset - blockOff) <$> VS.unsafeFreeze outv
| otherwise = unsafePerformIO $ do
outv <- VSM.unsafeNew (product sh)
@@ -79,33 +79,52 @@ liftOpEltwise1 sn@SNat ptrconv1 ptrconv2 cf_strided (RS.A (RG.A sh (OI.T strides
RS.fromVector sh <$> VS.unsafeFreeze outv
-- TODO: test all the cases of this thing with various input strides
-liftVEltwise2 :: (Storable a, Storable b, Storable c)
+liftVEltwise2 :: Storable a
=> SNat n
- -> (Either a (VS.Vector a) -> Either b (VS.Vector b) -> VS.Vector c)
- -> RS.Array n a -> RS.Array n b -> RS.Array n c
-liftVEltwise2 SNat f
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (a -> a -> a)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ sv
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ()) -- ^ vs
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ()) -- ^ vv
+ -> RS.Array n a -> RS.Array n a -> RS.Array n a
+liftVEltwise2 sn@SNat valconv ptrconv f_ss f_sv f_vs f_vv
arr1@(RS.A (RG.A sh1 (OI.T strides1 offset1 vec1)))
arr2@(RS.A (RG.A sh2 (OI.T strides2 offset2 vec2)))
| sh1 /= sh2 = error $ "liftVEltwise2: shapes unequal: " ++ show sh1 ++ " vs " ++ show sh2
| product sh1 == 0 = RS.A (RG.A sh1 (OI.T (0 <$ strides1) 0 VS.empty))
| otherwise = case (stridesDense sh1 offset1 strides1, stridesDense sh2 offset2 strides2) of
(Just (_, 1), Just (_, 1)) -> -- both are a (potentially replicated) scalar; just apply f to the scalars
- let vec' = f (Left (vec1 VS.! offset1)) (Left (vec2 VS.! offset2))
+ let vec' = VS.singleton (f_ss (vec1 VS.! offset1) (vec2 VS.! offset2))
in RS.A (RG.A sh1 (OI.T strides1 0 vec'))
+
(Just (_, 1), Just (blockOff, blockSz)) -> -- scalar * dense
- RS.A (RG.A sh1 (OI.T strides2 (offset2 - blockOff)
- (f (Left (vec1 VS.! offset1)) (Right (VS.slice blockOff blockSz vec2)))))
+ let arr2' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec2)
+ RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2'
+ in RS.A (RG.A sh1 (OI.T strides2 (offset2 - blockOff) resvec))
+
+ (Just (_, 1), Nothing) -> -- scalar * array
+ wrapBinarySV sn valconv ptrconv f_sv (vec1 VS.! offset1) arr2
+
(Just (blockOff, blockSz), Just (_, 1)) -> -- dense * scalar
- RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff)
- (f (Right (VS.slice blockOff blockSz vec1)) (Left (vec2 VS.! offset2)))))
+ let arr1' = RS.fromVector [blockSz] (VS.slice blockOff blockSz vec1)
+ RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVS sn valconv ptrconv f_vs arr1' (vec2 VS.! offset2)
+ in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff) resvec))
+
+ (Nothing, Just (_, 1)) -> -- array * scalar
+ wrapBinaryVS sn valconv ptrconv f_vs arr1 (vec2 VS.! offset2)
+
(Just (blockOff1, blockSz1), Just (blockOff2, blockSz2))
- | blockSz1 == blockSz2 -- not sure if this check is necessary, might be implied by the below
+ | blockSz1 == blockSz2 -- not sure if this check is necessary, might be implied by the strides check
, strides1 == strides2
-> -- dense * dense but the strides match
- RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff1)
- (f (Right (VS.slice blockOff1 blockSz1 vec1)) (Right (VS.slice blockOff2 blockSz2 vec2)))))
+ let arr1' = RS.fromVector [blockSz1] (VS.slice blockOff1 blockSz1 vec1)
+ arr2' = RS.fromVector [blockSz1] (VS.slice blockOff2 blockSz2 vec2)
+ RS.A (RG.A _ (OI.T _ _ resvec)) = wrapBinaryVV sn ptrconv f_vv arr1' arr2'
+ in RS.A (RG.A sh1 (OI.T strides1 (offset1 - blockOff1) resvec))
+
(_, _) -> -- fallback case
- RS.fromVector sh1 (f (Right (RS.toVector arr1)) (Right (RS.toVector arr2)))
+ wrapBinaryVV sn ptrconv f_vv arr1 arr2
-- | Given shape vector, offset and stride vector, check whether this virtual
-- vector uses a dense subarray of its backing array. If so, the first index
@@ -141,6 +160,57 @@ stridesDense sh offsetNeg stridesNeg =
in second ((-s) :) (flipReverseds sh' off' str')
flipReverseds _ _ _ = error "flipReverseds: invalid arguments"
+{-# NOINLINE wrapBinarySV #-}
+wrapBinarySV :: Storable a
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> b -> Ptr Int64 -> Ptr b -> IO ())
+ -> a -> RS.Array n a
+ -> RS.Array n a
+wrapBinarySV sn@SNat valconv ptrconv cf_strided x (RS.A (RG.A sh (OI.T strides offset vec))) =
+ unsafePerformIO $ do
+ outv <- VSM.unsafeNew (product sh)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides)) $ \pstrides ->
+ VS.unsafeWith (VS.slice offset (VS.length vec - offset) vec) $ \pv ->
+ cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) (valconv x) pstrides (ptrconv pv)
+ RS.fromVector sh <$> VS.unsafeFreeze outv
+
+wrapBinaryVS :: Storable a
+ => SNat n
+ -> (a -> b)
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> b -> IO ())
+ -> RS.Array n a -> a
+ -> RS.Array n a
+wrapBinaryVS sn valconv ptrconv cf_strided arr y =
+ wrapBinarySV sn valconv ptrconv
+ (\rank psh poutv y' pstrides pv -> cf_strided rank psh poutv pstrides pv y') y arr
+
+-- | This function assumes that the two shapes are equal.
+{-# NOINLINE wrapBinaryVV #-}
+wrapBinaryVV :: Storable a
+ => SNat n
+ -> (Ptr a -> Ptr b)
+ -> (Int64 -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> Ptr Int64 -> Ptr b -> IO ())
+ -> RS.Array n a -> RS.Array n a
+ -> RS.Array n a
+wrapBinaryVV sn@SNat ptrconv cf_strided
+ (RS.A (RG.A sh (OI.T strides1 offset1 vec1)))
+ (RS.A (RG.A _ (OI.T strides2 offset2 vec2))) =
+ unsafePerformIO $ do
+ outv <- VSM.unsafeNew (product sh)
+ VSM.unsafeWith outv $ \poutv ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral sh)) $ \psh ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides1)) $ \pstrides1 ->
+ VS.unsafeWith (VS.fromListN (fromSNat' sn) (map fromIntegral strides2)) $ \pstrides2 ->
+ VS.unsafeWith (VS.slice offset1 (VS.length vec1 - offset1) vec1) $ \pv1 ->
+ VS.unsafeWith (VS.slice offset2 (VS.length vec2 - offset2) vec2) $ \pv2 ->
+ cf_strided (fromIntegral (fromSNat' sn)) psh (ptrconv poutv) pstrides1 (ptrconv pv1) pstrides2 (ptrconv pv2)
+ RS.fromVector sh <$> VS.unsafeFreeze outv
+
{-# NOINLINE vectorOp1 #-}
vectorOp1 :: forall a b. Storable a
=> (Ptr a -> Ptr b)
@@ -423,13 +493,13 @@ $(fmap concat . forM typesList $ \arithtype -> do
fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (aboName arithop ++ "Vector" ++ nameBase (atType arithtype))
cnamebase = "c_binary_" ++ atCName arithtype
- c_ss = varE (aboNumOp arithop)
- c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
- c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
- c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_ss_str = varE (aboNumOp arithop)
+ c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
+ c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (aboEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
+ ,do body <- [| \sn -> liftVEltwise2 sn $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
return $ FunD name [Clause [] (NormalB body) []]])
$(fmap concat . forM floatTypesList $ \arithtype -> do
@@ -437,13 +507,13 @@ $(fmap concat . forM floatTypesList $ \arithtype -> do
fmap concat . forM [minBound..maxBound] $ \arithop -> do
let name = mkName (afboName arithop ++ "Vector" ++ nameBase (atType arithtype))
cnamebase = "c_fbinary_" ++ atCName arithtype
- c_ss = varE (afboNumOp arithop)
- c_sv = varE (mkName (cnamebase ++ "_sv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
- c_vs = varE (mkName (cnamebase ++ "_vs")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
- c_vv = varE (mkName (cnamebase ++ "_vv")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_ss_str = varE (afboNumOp arithop)
+ c_sv_str = varE (mkName (cnamebase ++ "_sv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vs_str = varE (mkName (cnamebase ++ "_vs_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
+ c_vv_str = varE (mkName (cnamebase ++ "_vv_strided")) `appE` litE (integerL (fromIntegral (afboEnum arithop)))
sequence [SigD name <$>
[t| forall n. SNat n -> RS.Array n $ttyp -> RS.Array n $ttyp -> RS.Array n $ttyp |]
- ,do body <- [| \sn -> liftVEltwise2 sn (vectorOp2 id id $c_ss $c_sv $c_vs $c_vv) |]
+ ,do body <- [| \sn -> liftVEltwise2 sn $c_ss_str $c_sv_str $c_vs_str $c_vv_str |]
return $ FunD name [Clause [] (NormalB body) []]])
$(fmap concat . forM typesList $ \arithtype -> do
@@ -526,17 +596,17 @@ intWidBranch1 f32 f64 sn
intWidBranch2 :: forall i n. (FiniteBits i, Storable i, Integral i)
=> (i -> i -> i) -- ss
-- int32
- -> (Int64 -> Ptr Int32 -> Int32 -> Ptr Int32 -> IO ()) -- sv
- -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Int32 -> IO ()) -- vs
- -> (Int64 -> Ptr Int32 -> Ptr Int32 -> Ptr Int32 -> IO ()) -- vv
+ -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Int32 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- sv
+ -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Int32 -> IO ()) -- vs
+ -> (Int64 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> Ptr Int64 -> Ptr Int32 -> IO ()) -- vv
-- int64
- -> (Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> IO ()) -- sv
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs
- -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- sv
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Int64 -> IO ()) -- vs
+ -> (Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr Int64 -> IO ()) -- vv
-> (SNat n -> RS.Array n i -> RS.Array n i -> RS.Array n i)
intWidBranch2 ss sv32 vs32 vv32 sv64 vs64 vv64 sn
- | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn (vectorOp2 @i @Int32 fromIntegral castPtr ss sv32 vs32 vv32)
- | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn (vectorOp2 @i @Int64 fromIntegral castPtr ss sv64 vs64 vv64)
+ | finiteBitSize (undefined :: i) == 32 = liftVEltwise2 sn fromIntegral castPtr ss sv32 vs32 vv32
+ | finiteBitSize (undefined :: i) == 64 = liftVEltwise2 sn fromIntegral castPtr ss sv64 vs64 vv64
| otherwise = error "Unsupported Int width"
intWidBranchRed1 :: forall i n. (FiniteBits i, Storable i, Integral i)
@@ -667,14 +737,14 @@ instance NumElt Double where
instance NumElt Int where
numEltAdd = intWidBranch2 @Int (+)
- (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD))
- (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD))
+ (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD))
+ (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD))
numEltSub = intWidBranch2 @Int (-)
- (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB))
- (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB))
+ (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB))
+ (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB))
numEltMul = intWidBranch2 @Int (*)
- (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
- (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL))
numEltNeg = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))
numEltAbs = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))
numEltSignum = intWidBranch1 @Int (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))
@@ -693,14 +763,14 @@ instance NumElt Int where
instance NumElt CInt where
numEltAdd = intWidBranch2 @CInt (+)
- (c_binary_i32_sv (aboEnum BO_ADD)) (flipOp (c_binary_i32_sv (aboEnum BO_ADD))) (c_binary_i32_vv (aboEnum BO_ADD))
- (c_binary_i64_sv (aboEnum BO_ADD)) (flipOp (c_binary_i64_sv (aboEnum BO_ADD))) (c_binary_i64_vv (aboEnum BO_ADD))
+ (c_binary_i32_sv_strided (aboEnum BO_ADD)) (c_binary_i32_vs_strided (aboEnum BO_ADD)) (c_binary_i32_vv_strided (aboEnum BO_ADD))
+ (c_binary_i64_sv_strided (aboEnum BO_ADD)) (c_binary_i64_vs_strided (aboEnum BO_ADD)) (c_binary_i64_vv_strided (aboEnum BO_ADD))
numEltSub = intWidBranch2 @CInt (-)
- (c_binary_i32_sv (aboEnum BO_SUB)) (flipOp (c_binary_i32_sv (aboEnum BO_SUB))) (c_binary_i32_vv (aboEnum BO_SUB))
- (c_binary_i64_sv (aboEnum BO_SUB)) (flipOp (c_binary_i64_sv (aboEnum BO_SUB))) (c_binary_i64_vv (aboEnum BO_SUB))
+ (c_binary_i32_sv_strided (aboEnum BO_SUB)) (c_binary_i32_vs_strided (aboEnum BO_SUB)) (c_binary_i32_vv_strided (aboEnum BO_SUB))
+ (c_binary_i64_sv_strided (aboEnum BO_SUB)) (c_binary_i64_vs_strided (aboEnum BO_SUB)) (c_binary_i64_vv_strided (aboEnum BO_SUB))
numEltMul = intWidBranch2 @CInt (*)
- (c_binary_i32_sv (aboEnum BO_MUL)) (flipOp (c_binary_i32_sv (aboEnum BO_MUL))) (c_binary_i32_vv (aboEnum BO_MUL))
- (c_binary_i64_sv (aboEnum BO_MUL)) (flipOp (c_binary_i64_sv (aboEnum BO_MUL))) (c_binary_i64_vv (aboEnum BO_MUL))
+ (c_binary_i32_sv_strided (aboEnum BO_MUL)) (c_binary_i32_vs_strided (aboEnum BO_MUL)) (c_binary_i32_vv_strided (aboEnum BO_MUL))
+ (c_binary_i64_sv_strided (aboEnum BO_MUL)) (c_binary_i64_vs_strided (aboEnum BO_MUL)) (c_binary_i64_vv_strided (aboEnum BO_MUL))
numEltNeg = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_NEG)) (c_unary_i64_strided (auoEnum UO_NEG))
numEltAbs = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_ABS)) (c_unary_i64_strided (auoEnum UO_ABS))
numEltSignum = intWidBranch1 @CInt (c_unary_i32_strided (auoEnum UO_SIGNUM)) (c_unary_i64_strided (auoEnum UO_SIGNUM))
diff --git a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
index a60b717..fa89766 100644
--- a/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
+++ b/src/Data/Array/Mixed/Internal/Arith/Foreign.hs
@@ -12,24 +12,24 @@ import Data.Array.Mixed.Internal.Arith.Lists
$(do
let importsScal ttyp tyn =
- [("binary_" ++ tyn ++ "_vv", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
- ,("binary_" ++ tyn ++ "_sv", [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
- ,("binary_" ++ tyn ++ "_vs", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
- ,("unary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
- ,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
- ,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |])
- ,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
- ,("extremum_max_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
- ,("dotprod_" ++ tyn, [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |])
- ,("dotprod_" ++ tyn ++ "_strided", [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |])
- ,("dotprodinner_" ++ tyn, [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ [("binary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("binary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("binary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |])
+ ,("unary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("reduce1_" ++ tyn, [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("reducefull_" ++ tyn, [t| CInt -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO $ttyp |])
+ ,("extremum_min_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("extremum_max_" ++ tyn, [t| Ptr Int64 -> Int64 -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("dotprod_" ++ tyn, [t| Int64 -> Ptr $ttyp -> Ptr $ttyp -> IO $ttyp |])
+ ,("dotprod_" ++ tyn ++ "_strided", [t| Int64 -> Int64 -> Int64 -> Ptr $ttyp -> Int64 -> Int64 -> Ptr $ttyp -> IO $ttyp |])
+ ,("dotprodinner_" ++ tyn, [t| Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
]
let importsFloat ttyp tyn =
- [("fbinary_" ++ tyn ++ "_vv", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> Ptr $ttyp -> IO () |])
- ,("fbinary_" ++ tyn ++ "_sv", [t| CInt -> Int64 -> Ptr $ttyp -> $ttyp -> Ptr $ttyp -> IO () |])
- ,("fbinary_" ++ tyn ++ "_vs", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr $ttyp -> $ttyp -> IO () |])
- ,("funary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ [("fbinary_" ++ tyn ++ "_vv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("fbinary_" ++ tyn ++ "_sv_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> Ptr Int64 -> Ptr $ttyp -> IO () |])
+ ,("fbinary_" ++ tyn ++ "_vs_strided", [t| CInt -> Int64 -> Ptr Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr $ttyp -> $ttyp -> IO () |])
+ ,("funary_" ++ tyn ++ "_strided", [t| CInt -> Int64 -> Ptr $ttyp -> Ptr Int64 -> Ptr Int64 -> Ptr $ttyp -> IO () |])
]
let generate types imports =
diff --git a/src/Data/Array/Mixed/Permutation.hs b/src/Data/Array/Mixed/Permutation.hs
index 015a828..8efcbe8 100644
--- a/src/Data/Array/Mixed/Permutation.hs
+++ b/src/Data/Array/Mixed/Permutation.hs
@@ -241,7 +241,7 @@ permInverse = \perm k ->
provePermInverse :: Perm is -> Perm is' -> StaticShX sh
-> Maybe (Permute is' (Permute is sh) :~: sh)
provePermInverse perm perminv ssh =
- ssxGeq (ssxPermute perminv (ssxPermute perm ssh)) ssh
+ ssxEqType (ssxPermute perminv (ssxPermute perm ssh)) ssh
type family MapSucc is where
MapSucc '[] = '[]
diff --git a/src/Data/Array/Mixed/Shape.hs b/src/Data/Array/Mixed/Shape.hs
index e5f8b67..16f62fe 100644
--- a/src/Data/Array/Mixed/Shape.hs
+++ b/src/Data/Array/Mixed/Shape.hs
@@ -71,13 +71,28 @@ listxUncons :: ListX sh1 f -> Maybe (UnconsListXRes f sh1)
listxUncons (i ::% shl') = Just (UnconsListXRes shl' i)
listxUncons ZX = Nothing
-listxEq :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
-listxEq ZX ZX = Just Refl
-listxEq (n ::% sh) (m ::% sh')
+-- | This checks only whether the types are equal; if the elements of the list
+-- are not singletons, their values may still differ. This corresponds to
+-- 'testEquality', except on the penultimate type parameter.
+listxEqType :: TestEquality f => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
+listxEqType ZX ZX = Just Refl
+listxEqType (n ::% sh) (m ::% sh')
| Just Refl <- testEquality n m
- , Just Refl <- listxEq sh sh'
+ , Just Refl <- listxEqType sh sh'
= Just Refl
-listxEq _ _ = Nothing
+listxEqType _ _ = Nothing
+
+-- | This checks whether the two lists actually contain equal values. This is
+-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
+-- in the @some@ package (except on the penultimate type parameter).
+listxEqual :: (TestEquality f, forall n. Eq (f n)) => ListX sh f -> ListX sh' f -> Maybe (sh :~: sh')
+listxEqual ZX ZX = Just Refl
+listxEqual (n ::% sh) (m ::% sh')
+ | Just Refl <- testEquality n m
+ , n == m
+ , Just Refl <- listxEqual sh sh'
+ = Just Refl
+listxEqual _ _ = Nothing
listxFmap :: (forall n. f n -> g n) -> ListX sh f -> ListX sh g
listxFmap _ ZX = ZX
@@ -293,11 +308,26 @@ instance NFData i => NFData (ShX sh i) where
shxLength :: ShX sh i -> Int
shxLength (ShX l) = listxLength l
-shxRank :: ShX sh f -> SNat (Rank sh)
+shxRank :: ShX sh i -> SNat (Rank sh)
shxRank (ShX list) = listxRank list
--- | This is more than @geq@: it also checks that the integers (the unknown
--- dimensions) are the same.
+-- | This checks only whether the types are equal; unknown dimensions might
+-- still differ. This corresponds to 'testEquality', except on the penultimate
+-- type parameter.
+shxEqType :: ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
+shxEqType ZSX ZSX = Just Refl
+shxEqType (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
+ | Just Refl <- sameNat n m
+ , Just Refl <- shxEqType sh sh'
+ = Just Refl
+shxEqType (SUnknown _ :$% sh) (SUnknown _ :$% sh')
+ | Just Refl <- shxEqType sh sh'
+ = Just Refl
+shxEqType _ _ = Nothing
+
+-- | This checks whether all dimensions have the same value. This is more than
+-- 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@ in the
+-- @some@ package (except on the penultimate type parameter).
shxEqual :: Eq i => ShX sh i -> ShX sh' i -> Maybe (sh :~: sh')
shxEqual ZSX ZSX = Just Refl
shxEqual (SKnown n@SNat :$% sh) (SKnown m@SNat :$% sh')
@@ -417,23 +447,14 @@ instance Show (StaticShX sh) where
showsPrec _ (StaticShX l) = listxShow (fromSMayNat shows (shows . fromSNat)) l
instance TestEquality StaticShX where
- testEquality (StaticShX l1) (StaticShX l2) = listxEq l1 l2
+ testEquality (StaticShX l1) (StaticShX l2) = listxEqType l1 l2
ssxLength :: StaticShX sh -> Int
ssxLength (StaticShX l) = listxLength l
--- | This suffices as an implementation of @geq@ in the @Data.GADT.Compare@
--- class of the @some@ package.
-ssxGeq :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
-ssxGeq ZKX ZKX = Just Refl
-ssxGeq (SKnown n@SNat :!% sh) (SKnown m@SNat :!% sh')
- | Just Refl <- sameNat n m
- , Just Refl <- ssxGeq sh sh'
- = Just Refl
-ssxGeq (SUnknown () :!% sh) (SUnknown () :!% sh')
- | Just Refl <- ssxGeq sh sh'
- = Just Refl
-ssxGeq _ _ = Nothing
+-- | @ssxEqType = 'testEquality'@. Provided for consistency.
+ssxEqType :: StaticShX sh -> StaticShX sh' -> Maybe (sh :~: sh')
+ssxEqType = testEquality
ssxAppend :: StaticShX sh -> StaticShX sh' -> StaticShX (sh ++ sh')
ssxAppend ZKX sh' = sh'
diff --git a/src/Data/Array/Nested/Internal/Shape.hs b/src/Data/Array/Nested/Internal/Shape.hs
index 59f2c9a..5fb2e7f 100644
--- a/src/Data/Array/Nested/Internal/Shape.hs
+++ b/src/Data/Array/Nested/Internal/Shape.hs
@@ -65,6 +65,24 @@ listrUncons :: ListR n1 i -> Maybe (UnconsListRRes i n1)
listrUncons (i ::: sh') = Just (UnconsListRRes sh' i)
listrUncons ZR = Nothing
+-- | This checks only whether the ranks are equal, not whether the actual
+-- values are.
+listrEqRank :: ListR n i -> ListR n' i -> Maybe (n :~: n')
+listrEqRank ZR ZR = Just Refl
+listrEqRank (_ ::: sh) (_ ::: sh')
+ | Just Refl <- listrEqRank sh sh'
+ = Just Refl
+listrEqRank _ _ = Nothing
+
+-- | This compares the lists for value equality.
+listrEqual :: Eq i => ListR n i -> ListR n' i -> Maybe (n :~: n')
+listrEqual ZR ZR = Just Refl
+listrEqual (i ::: sh) (j ::: sh')
+ | Just Refl <- listrEqual sh sh'
+ , i == j
+ = Just Refl
+listrEqual _ _ = Nothing
+
listrShow :: forall n i. (i -> ShowS) -> ListR n i -> ShowS
listrShow f l = showString "[" . go "" l . showString "]"
where
@@ -207,7 +225,7 @@ pattern (:$:)
forall n. (n + 1 ~ n1)
=> i -> ShR n i -> ShR n1 i
pattern i :$: sh <- ShR (listrUncons -> Just (UnconsListRRes (ShR -> sh) i))
- where i :$: (ShR sh) = ShR (i ::: sh)
+ where i :$: ShR sh = ShR (i ::: sh)
infixr 3 :$:
{-# COMPLETE ZSR, (:$:) #-}
@@ -241,6 +259,15 @@ shCvtRX (n :$: (idx :: ShR m Int)) =
castWith (subst2 @ShX @Int (lemReplicateSucc @(Nothing @Nat) @m))
(SUnknown n :$% shCvtRX idx)
+-- | This checks only whether the ranks are equal, not whether the actual
+-- values are.
+shrEqRank :: ShR n i -> ShR n' i -> Maybe (n :~: n')
+shrEqRank (ShR sh) (ShR sh') = listrEqRank sh sh'
+
+-- | This compares the shapes for value equality.
+shrEqual :: Eq i => ShR n i -> ShR n' i -> Maybe (n :~: n')
+shrEqual (ShR sh) (ShR sh') = listrEqual sh sh'
+
-- | The number of elements in an array described by this shape.
shrSize :: IShR n -> Int
shrSize ZSR = 1
@@ -316,13 +343,28 @@ listsUncons :: ListS sh1 f -> Maybe (UnconsListSRes f sh1)
listsUncons (x ::$ sh') = Just (UnconsListSRes sh' x)
listsUncons ZS = Nothing
-listsEq :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
-listsEq ZS ZS = Just Refl
-listsEq (n ::$ sh) (m ::$ sh')
+-- | This checks only whether the types are equal; if the elements of the list
+-- are not singletons, their values may still differ. This corresponds to
+-- 'testEquality', except on the penultimate type parameter.
+listsEqType :: TestEquality f => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
+listsEqType ZS ZS = Just Refl
+listsEqType (n ::$ sh) (m ::$ sh')
+ | Just Refl <- testEquality n m
+ , Just Refl <- listsEqType sh sh'
+ = Just Refl
+listsEqType _ _ = Nothing
+
+-- | This checks whether the two lists actually contain equal values. This is
+-- more than 'testEquality', and corresponds to @geq@ from @Data.GADT.Compare@
+-- in the @some@ package (except on the penultimate type parameter).
+listsEqual :: (TestEquality f, forall n. Eq (f n)) => ListS sh f -> ListS sh' f -> Maybe (sh :~: sh')
+listsEqual ZS ZS = Just Refl
+listsEqual (n ::$ sh) (m ::$ sh')
| Just Refl <- testEquality n m
- , Just Refl <- listsEq sh sh'
+ , n == m
+ , Just Refl <- listsEqual sh sh'
= Just Refl
-listsEq _ _ = Nothing
+listsEqual _ _ = Nothing
listsFmap :: (forall n. f n -> g n) -> ListS sh f -> ListS sh g
listsFmap _ ZS = ZS
@@ -484,7 +526,12 @@ instance Show (ShS sh) where
showsPrec _ (ShS l) = listsShow (shows . fromSNat) l
instance TestEquality ShS where
- testEquality (ShS l1) (ShS l2) = listsEq l1 l2
+ testEquality (ShS l1) (ShS l2) = listsEqType l1 l2
+
+-- | @'shsEqual' = 'testEquality'@. (Because 'ShS' is a singleton, types are
+-- equal if and only if values are equal.)
+shsEqual :: ShS sh -> ShS sh' -> Maybe (sh :~: sh')
+shsEqual = testEquality
shsLength :: ShS sh -> Int
shsLength (ShS l) = getSum (listsFold (\_ -> Sum 1) l)