diff options
Diffstat (limited to 'src/Data/Array/Arith.hs')
-rw-r--r-- | src/Data/Array/Arith.hs | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/src/Data/Array/Arith.hs b/src/Data/Array/Arith.hs new file mode 100644 index 0000000..1eae737 --- /dev/null +++ b/src/Data/Array/Arith.hs @@ -0,0 +1,43 @@ +{-# LANGUAGE ImportQualifiedPost #-} +module Data.Array.Arith ( + module Data.Array.Arith, + module Data.Array.Strided.Arith, +) where + +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as RG +import Data.Array.Internal.RankedS qualified as RS + +import Data.Array.Strided qualified as AS +import Data.Array.Strided.Arith + +-- for liftVEltwise1 +import Data.Array.Strided.Arith.Internal (stridesDense) +import Data.Vector.Storable qualified as VS +import Foreign.Storable +import GHC.TypeLits + + +fromO :: RS.Array n a -> AS.Array n a +fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec + +toO :: AS.Array n a -> RS.Array n a +toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec)) + +liftO1 :: (AS.Array n a -> AS.Array n' b) + -> RS.Array n a -> RS.Array n' b +liftO1 f = toO . f . fromO + +liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) + -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c +liftO2 f x y = toO (f (fromO x) (fromO y)) + +liftVEltwise1 :: (Storable a, Storable b) + => SNat n + -> (VS.Vector a -> VS.Vector b) + -> RS.Array n a -> RS.Array n b +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) + | Just (blockOff, blockSz) <- stridesDense sh offset strides = + let vec' = f (VS.slice blockOff blockSz vec) + in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) + | otherwise = RS.fromVector sh (f (RS.toVector arr)) |