diff options
Diffstat (limited to 'ops/Data/Array/Strided/Arith')
| -rw-r--r-- | ops/Data/Array/Strided/Arith/Internal.hs | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs index 6aa111a..a099326 100644 --- a/ops/Data/Array/Strided/Arith/Internal.hs +++ b/ops/Data/Array/Strided/Arith/Internal.hs @@ -27,7 +27,7 @@ import Foreign.C.Types import Foreign.Ptr import Foreign.Storable import GHC.TypeLits -import GHC.TypeNats qualified as TypeNats +import GHC.TypeNats qualified as TN import Language.Haskell.TH import System.IO (hFlush, stdout) import System.IO.Unsafe @@ -42,7 +42,7 @@ import Data.Array.Strided.Array -- TODO: move this to a utilities module fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat +fromSNat' = fromEnum . TN.fromSNat data Dict c where Dict :: c => Dict c @@ -179,7 +179,7 @@ unreplicateStrides (Array sh strides offset vec) = unrepSize = product [n | (n, True) <- zip sh replDims] - in TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> + in TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> Unreplicated (Array @lenshF shF stridesF offset vec) unrepSize (reinsertZeros replDims) simplifyArray :: Array n a @@ -258,7 +258,7 @@ simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k , let unrepSize = product [n | (n, True) <- zip sh replDims] - = TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> + = TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> k @lenshF (Array shF strides1F offset1 vec1) (Array shF strides2F offset2 vec2) @@ -386,7 +386,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides off VS.unsafeWith vec' $ \pv -> let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a)) in fred (fromIntegral ndims') (ptrconv poutv) psh pstrides (ptrconv pv') - TypeNats.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do + TN.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of LTI -> pure Dict EQI -> pure Dict @@ -490,7 +490,7 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1')) pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2')) - TypeNats.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do + TN.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of LTI -> pure Dict EQI -> pure Dict |
