diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-05-15 21:40:50 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-05-15 21:41:51 +0200 | 
| commit | 6f82510ddde60b76e7c81eae2da7d947312179e5 (patch) | |
| tree | 823bdd206b6826b279146f3acd1a9e851b31592e /src/Data/Array/Strided | |
| parent | 3c8f13c8310de646b15c6f2745cfe190db7610db (diff) | |
Move Data.Array.Arith to Data.Array.Strided.Orthotope
Diffstat (limited to 'src/Data/Array/Strided')
| -rw-r--r-- | src/Data/Array/Strided/Orthotope.hs | 43 | 
1 files changed, 43 insertions, 0 deletions
| diff --git a/src/Data/Array/Strided/Orthotope.hs b/src/Data/Array/Strided/Orthotope.hs new file mode 100644 index 0000000..5c38d14 --- /dev/null +++ b/src/Data/Array/Strided/Orthotope.hs @@ -0,0 +1,43 @@ +{-# LANGUAGE ImportQualifiedPost #-} +module Data.Array.Strided.Orthotope ( +  module Data.Array.Strided.Orthotope, +  module Data.Array.Strided.Arith, +) where + +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as RG +import Data.Array.Internal.RankedS qualified as RS + +import Data.Array.Strided qualified as AS +import Data.Array.Strided.Arith + +-- for liftVEltwise1 +import Data.Array.Strided.Arith.Internal (stridesDense) +import Data.Vector.Storable qualified as VS +import Foreign.Storable +import GHC.TypeLits + + +fromO :: RS.Array n a -> AS.Array n a +fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec + +toO :: AS.Array n a -> RS.Array n a +toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec)) + +liftO1 :: (AS.Array n a -> AS.Array n' b) +       -> RS.Array n a -> RS.Array n' b +liftO1 f = toO . f . fromO + +liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) +       -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c +liftO2 f x y = toO (f (fromO x) (fromO y)) + +liftVEltwise1 :: (Storable a, Storable b) +              => SNat n +              -> (VS.Vector a -> VS.Vector b) +              -> RS.Array n a -> RS.Array n b +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) +  | Just (blockOff, blockSz) <- stridesDense sh offset strides = +      let vec' = f (VS.slice blockOff blockSz vec) +      in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) +  | otherwise = RS.fromVector sh (f (RS.toVector arr)) | 
