diff options
Diffstat (limited to 'ops/Data/Array/Strided/Arith/Internal.hs')
| -rw-r--r-- | ops/Data/Array/Strided/Arith/Internal.hs | 30 |
1 files changed, 15 insertions, 15 deletions
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs index 7578dd8..2eb0666 100644 --- a/ops/Data/Array/Strided/Arith/Internal.hs +++ b/ops/Data/Array/Strided/Arith/Internal.hs @@ -27,7 +27,7 @@ import Foreign.C.Types import Foreign.Ptr import Foreign.Storable import GHC.TypeLits -import GHC.TypeNats qualified as TypeNats +import GHC.TypeNats qualified as TN import Language.Haskell.TH import System.IO (hFlush, stdout) import System.IO.Unsafe @@ -42,7 +42,7 @@ import Data.Array.Strided.Array -- TODO: move this to a utilities module fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat +fromSNat' = fromEnum . TN.fromSNat data Dict c where Dict :: c => Dict c @@ -179,7 +179,7 @@ unreplicateStrides (Array sh strides offset vec) = unrepSize = product [n | (n, True) <- zip sh replDims] - in TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> + in TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> Unreplicated (Array @lenshF shF stridesF offset vec) unrepSize (reinsertZeros replDims) simplifyArray :: Array n a @@ -200,7 +200,7 @@ simplifyArray :: Array n a -> r) -> r simplifyArray array k - | let revDims = map (<0) (arrStrides array) + | let revDims = map (< 0) (arrStrides array) , Unreplicated array' unrepSize rereplicate <- unreplicateStrides (arrayRevDims revDims array) = k array' unrepSize @@ -258,7 +258,7 @@ simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k , let unrepSize = product [n | (n, True) <- zip sh replDims] - = TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> + = TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> k @lenshF (Array shF strides1F offset1 vec1) (Array shF strides2F offset2 vec2) @@ -386,7 +386,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides off VS.unsafeWith vec' $ \pv -> let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a)) in fred (fromIntegral ndims') (ptrconv poutv) psh pstrides (ptrconv pv') - TypeNats.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do + TN.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of LTI -> pure Dict EQI -> pure Dict @@ -491,7 +491,7 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1')) pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2')) - TypeNats.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do + TN.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of LTI -> pure Dict EQI -> pure Dict @@ -879,19 +879,19 @@ instance IntElt Int64 where instance IntElt Int where intEltQuot = intWidBranch2 @Int quot - (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) - (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) + (c_ibinary_i32_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vv_strided (aiboEnum IB_QUOT)) + (c_ibinary_i64_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vv_strided (aiboEnum IB_QUOT)) intEltRem = intWidBranch2 @Int rem - (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) - (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) + (c_ibinary_i32_sv_strided (aiboEnum IB_REM)) (c_ibinary_i32_vs_strided (aiboEnum IB_REM)) (c_ibinary_i32_vv_strided (aiboEnum IB_REM)) + (c_ibinary_i64_sv_strided (aiboEnum IB_REM)) (c_ibinary_i64_vs_strided (aiboEnum IB_REM)) (c_ibinary_i64_vv_strided (aiboEnum IB_REM)) instance IntElt CInt where intEltQuot = intWidBranch2 @CInt quot - (c_binary_i32_sv_strided (aiboEnum IB_QUOT)) (c_binary_i32_vs_strided (aiboEnum IB_QUOT)) (c_binary_i32_vv_strided (aiboEnum IB_QUOT)) - (c_binary_i64_sv_strided (aiboEnum IB_QUOT)) (c_binary_i64_vs_strided (aiboEnum IB_QUOT)) (c_binary_i64_vv_strided (aiboEnum IB_QUOT)) + (c_ibinary_i32_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i32_vv_strided (aiboEnum IB_QUOT)) + (c_ibinary_i64_sv_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vs_strided (aiboEnum IB_QUOT)) (c_ibinary_i64_vv_strided (aiboEnum IB_QUOT)) intEltRem = intWidBranch2 @CInt rem - (c_binary_i32_sv_strided (aiboEnum IB_REM)) (c_binary_i32_vs_strided (aiboEnum IB_REM)) (c_binary_i32_vv_strided (aiboEnum IB_REM)) - (c_binary_i64_sv_strided (aiboEnum IB_REM)) (c_binary_i64_vs_strided (aiboEnum IB_REM)) (c_binary_i64_vv_strided (aiboEnum IB_REM)) + (c_ibinary_i32_sv_strided (aiboEnum IB_REM)) (c_ibinary_i32_vs_strided (aiboEnum IB_REM)) (c_ibinary_i32_vv_strided (aiboEnum IB_REM)) + (c_ibinary_i64_sv_strided (aiboEnum IB_REM)) (c_ibinary_i64_vs_strided (aiboEnum IB_REM)) (c_ibinary_i64_vv_strided (aiboEnum IB_REM)) class NumElt a => FloatElt a where floatEltDiv :: SNat n -> Array n a -> Array n a -> Array n a |
