diff options
| author | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2025-12-28 22:04:19 +0100 |
|---|---|---|
| committer | Mikolaj Konarski <mikolaj.konarski@funktory.com> | 2026-01-12 18:43:13 +0100 |
| commit | a321a334a09467e4563e0b432f7cedd7839647ff (patch) | |
| tree | 7983d9a2376276e52f0ed348db3d834137e6d943 /ops/Data/Array/Strided/Arith | |
| parent | 490af8b23ee335b4acc81c50335adbbab03e402a (diff) | |
Improve the implementation of the other fromSNat'
Diffstat (limited to 'ops/Data/Array/Strided/Arith')
| -rw-r--r-- | ops/Data/Array/Strided/Arith/Internal.hs | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/ops/Data/Array/Strided/Arith/Internal.hs b/ops/Data/Array/Strided/Arith/Internal.hs index 6aa111a..a099326 100644 --- a/ops/Data/Array/Strided/Arith/Internal.hs +++ b/ops/Data/Array/Strided/Arith/Internal.hs @@ -27,7 +27,7 @@ import Foreign.C.Types import Foreign.Ptr import Foreign.Storable import GHC.TypeLits -import GHC.TypeNats qualified as TypeNats +import GHC.TypeNats qualified as TN import Language.Haskell.TH import System.IO (hFlush, stdout) import System.IO.Unsafe @@ -42,7 +42,7 @@ import Data.Array.Strided.Array -- TODO: move this to a utilities module fromSNat' :: SNat n -> Int -fromSNat' = fromIntegral . fromSNat +fromSNat' = fromEnum . TN.fromSNat data Dict c where Dict :: c => Dict c @@ -179,7 +179,7 @@ unreplicateStrides (Array sh strides offset vec) = unrepSize = product [n | (n, True) <- zip sh replDims] - in TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> + in TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> Unreplicated (Array @lenshF shF stridesF offset vec) unrepSize (reinsertZeros replDims) simplifyArray :: Array n a @@ -258,7 +258,7 @@ simplifyArray2 arr1@(Array sh _ _ _) arr2@(Array sh2 _ _ _) k , let unrepSize = product [n | (n, True) <- zip sh replDims] - = TypeNats.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> + = TN.withSomeSNat (fromIntegral (length shF)) $ \(SNat :: SNat lenshF) -> k @lenshF (Array shF strides1F offset1 vec1) (Array shF strides2F offset2 vec2) @@ -386,7 +386,7 @@ vectorRedInnerOp sn@SNat valconv ptrconv fscale fred array@(Array sh strides off VS.unsafeWith vec' $ \pv -> let pv' = pv `plusPtr` (offset' * sizeOf (undefined :: a)) in fred (fromIntegral ndims') (ptrconv poutv) psh pstrides (ptrconv pv') - TypeNats.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do + TN.withSomeSNat (fromIntegral (ndims' - 1)) $ \(SNat :: SNat n'm1) -> do (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of LTI -> pure Dict EQI -> pure Dict @@ -490,7 +490,7 @@ vectorDotprodInnerOp sn@SNat valconv ptrconv fmul fscale fred fdotinner fdotinner (fromIntegral @Int @Int64 inrank) psh (ptrconv poutv) pstrides1 (ptrconv pvec1 `plusPtr` (sizeOf (undefined :: a) * offset1')) pstrides2 (ptrconv pvec2 `plusPtr` (sizeOf (undefined :: a) * offset2')) - TypeNats.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do + TN.withSomeSNat (fromIntegral (inrank - 1)) $ \(SNat :: SNat n'm1) -> do (Dict :: Dict (1 <= n')) <- case cmpNat (natSing @1) (natSing @n') of LTI -> pure Dict EQI -> pure Dict |
