aboutsummaryrefslogtreecommitdiff
path: root/ops/Data/Array/Strided/Arith
diff options
context:
space:
mode:
authorMikolaj Konarski <mikolaj.konarski@funktory.com>2025-12-28 22:04:19 +0100
committerMikolaj Konarski <mikolaj.konarski@funktory.com>2026-01-12 18:43:13 +0100
commita321a334a09467e4563e0b432f7cedd7839647ff (patch)
tree7983d9a2376276e52f0ed348db3d834137e6d943 /ops/Data/Array/Strided/Arith
parent490af8b23ee335b4acc81c50335adbbab03e402a (diff)
Improve the implementation of the other fromSNat'
Diffstat (limited to 'ops/Data/Array/Strided/Arith')
-rw-r--r--ops/Data/Array/Strided/Arith/Internal.hs12
1 files changed, 6 insertions, 6 deletions
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs
index 6aa111a..a099326 100644
--- a/ops/Data/Array/Strided/Arith/Internal.hs
+++ b/ops/Data/Array/Strided/Arith/Internal.hs
@@ -27,7 +27,7 @@ import Foreign.C.Types
import Foreign.Ptr
import Foreign.Storable
import GHC.TypeLits
-import GHC.TypeNats qualified as TypeNats
+import GHC.TypeNats qualified as TN
import Language.Haskell.TH
import System.IO (hFlush, stdout)
import System.IO.Unsafe
@@ -42,7 +42,7 @@ import Data.Array.Strided.Array
-- TODO: move this to a utilities module
fromSNat' :: SNat n -> Int
-fromSNat' = fromIntegral . fromSNat
+fromSNat' = fromEnum . TN.fromSNat
data Dict c where
Dict :: c => Dict c
@@ -179,7 +179,7 @@ unreplicateStrides (Array sh strides offset vec) =
unrepSize = product [n | (n, True) <- zip sh replDims]
- in TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+ in TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
Unreplicated (Array @lenshF shF stridesF offset vec) unrepSize (reinsertZeros replDims)
simplifyArray :: Array n a
@@ -258,7 +258,7 @@ simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k
, let unrepSize = product [n | (n, True) <- zip sh replDims]
- = TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
+ = TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) ->
k @lenshF
(Array shF strides1F offset1 vec1)
(Array shF strides2F offset2 vec2)
@@ -386,7 +386,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides off
VS.unsafeWith vec' $ \pv ->
let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a))
in fred (fromIntegral ndims') (ptrconv poutv) psh pstrides (ptrconv pv')
- TypeNats.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do
+ TN.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do
(Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
LTI -> pure Dict
EQI -> pure Dict
@@ -490,7 +490,7 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner
fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv)
pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1'))
pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2'))
- TypeNats.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do
+ TN.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do
(Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of
LTI -> pure Dict
EQI -> pure Dict