From 6f82510ddde60b76e7c81eae2da7d947312179e5 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 15 May 2025 21:40:50 +0200 Subject: Move Data.Array.Arith to Data.Array.Strided.Orthotope --- ox-arrays.cabal | 2 +- src/Data/Array/Arith.hs | 43 ------------------------------------- src/Data/Array/Nested/Mixed.hs | 2 +- src/Data/Array/Strided/Orthotope.hs | 43 +++++++++++++++++++++++++++++++++++++ src/Data/Array/XArray.hs | 2 +- 5 files changed, 46 insertions(+), 46 deletions(-) delete mode 100644 src/Data/Array/Arith.hs create mode 100644 src/Data/Array/Strided/Orthotope.hs diff --git a/ox-arrays.cabal b/ox-arrays.cabal index 9cfc7dd..0cc2953 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -47,7 +47,6 @@ library -- put this module on top so ghci considers it the "main" module Data.Array.Nested - Data.Array.Arith Data.Array.Mixed.Lemmas Data.Array.Mixed.Permutation Data.Array.Mixed.Types @@ -59,6 +58,7 @@ library Data.Array.Nested.Mixed.Shape Data.Array.Nested.Ranked.Shape Data.Array.Nested.Shaped.Shape + Data.Array.Strided.Orthotope Data.Array.XArray Data.Bag diff --git a/src/Data/Array/Arith.hs b/src/Data/Array/Arith.hs deleted file mode 100644 index 1eae737..0000000 --- a/src/Data/Array/Arith.hs +++ /dev/null @@ -1,43 +0,0 @@ -{-# LANGUAGE ImportQualifiedPost #-} -module Data.Array.Arith ( - module Data.Array.Arith, - module Data.Array.Strided.Arith, -) where - -import Data.Array.Internal qualified as OI -import Data.Array.Internal.RankedG qualified as RG -import Data.Array.Internal.RankedS qualified as RS - -import Data.Array.Strided qualified as AS -import Data.Array.Strided.Arith - --- for liftVEltwise1 -import Data.Array.Strided.Arith.Internal (stridesDense) -import Data.Vector.Storable qualified as VS -import Foreign.Storable -import GHC.TypeLits - - -fromO :: RS.Array n a -> AS.Array n a -fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec - -toO :: AS.Array n a -> RS.Array n a -toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec)) - -liftO1 :: (AS.Array n a -> AS.Array n' b) - -> RS.Array n a -> RS.Array n' b -liftO1 f = toO . f . fromO - -liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) - -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c -liftO2 f x y = toO (f (fromO x) (fromO y)) - -liftVEltwise1 :: (Storable a, Storable b) - => SNat n - -> (VS.Vector a -> VS.Vector b) - -> RS.Array n a -> RS.Array n b -liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) - | Just (blockOff, blockSz) <- stridesDense sh offset strides = - let vec' = f (VS.slice blockOff blockSz vec) - in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) - | otherwise = RS.fromVector sh (f (RS.toVector arr)) diff --git a/src/Data/Array/Nested/Mixed.hs b/src/Data/Array/Nested/Mixed.hs index ec19c21..0a7eaba 100644 --- a/src/Data/Array/Nested/Mixed.hs +++ b/src/Data/Array/Nested/Mixed.hs @@ -42,13 +42,13 @@ import GHC.Generics (Generic) import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) -import Data.Array.Arith import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Mixed.Types import Data.Array.XArray (XArray(..)) import Data.Array.XArray qualified as X import Data.Array.Nested.Mixed.Shape +import Data.Array.Strided.Orthotope import Data.Bag diff --git a/src/Data/Array/Strided/Orthotope.hs b/src/Data/Array/Strided/Orthotope.hs new file mode 100644 index 0000000..5c38d14 --- /dev/null +++ b/src/Data/Array/Strided/Orthotope.hs @@ -0,0 +1,43 @@ +{-# LANGUAGE ImportQualifiedPost #-} +module Data.Array.Strided.Orthotope ( + module Data.Array.Strided.Orthotope, + module Data.Array.Strided.Arith, +) where + +import Data.Array.Internal qualified as OI +import Data.Array.Internal.RankedG qualified as RG +import Data.Array.Internal.RankedS qualified as RS + +import Data.Array.Strided qualified as AS +import Data.Array.Strided.Arith + +-- for liftVEltwise1 +import Data.Array.Strided.Arith.Internal (stridesDense) +import Data.Vector.Storable qualified as VS +import Foreign.Storable +import GHC.TypeLits + + +fromO :: RS.Array n a -> AS.Array n a +fromO (RS.A (RG.A sh (OI.T strides offset vec))) = AS.Array sh strides offset vec + +toO :: AS.Array n a -> RS.Array n a +toO (AS.Array sh strides offset vec) = RS.A (RG.A sh (OI.T strides offset vec)) + +liftO1 :: (AS.Array n a -> AS.Array n' b) + -> RS.Array n a -> RS.Array n' b +liftO1 f = toO . f . fromO + +liftO2 :: (AS.Array n a -> AS.Array n1 b -> AS.Array n2 c) + -> RS.Array n a -> RS.Array n1 b -> RS.Array n2 c +liftO2 f x y = toO (f (fromO x) (fromO y)) + +liftVEltwise1 :: (Storable a, Storable b) + => SNat n + -> (VS.Vector a -> VS.Vector b) + -> RS.Array n a -> RS.Array n b +liftVEltwise1 SNat f arr@(RS.A (RG.A sh (OI.T strides offset vec))) + | Just (blockOff, blockSz) <- stridesDense sh offset strides = + let vec' = f (VS.slice blockOff blockSz vec) + in RS.A (RG.A sh (OI.T strides (offset - blockOff) vec')) + | otherwise = RS.fromVector sh (f (RS.toVector arr)) diff --git a/src/Data/Array/XArray.hs b/src/Data/Array/XArray.hs index d8d564e..7f78420 100644 --- a/src/Data/Array/XArray.hs +++ b/src/Data/Array/XArray.hs @@ -31,11 +31,11 @@ import Foreign.Storable (Storable) import GHC.Generics (Generic) import GHC.TypeLits -import Data.Array.Arith import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation import Data.Array.Mixed.Types import Data.Array.Nested.Mixed.Shape +import Data.Array.Strided.Orthotope type XArray :: [Maybe Nat] -> Type -> Type -- cgit v1.2.3-70-g09d2