diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-03-20 13:01:24 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-20 13:01:24 +0100 | 
| commit | 55036a5ea4a6e590d0404638b2823c6a4aec3fba (patch) | |
| tree | 484bc377229d3edff36bd9a2a80f999bbcd2e889 /src/Data/Array/Nested/Internal | |
| parent | 5414434df62b2b196354b9748b265093c168601b (diff) | |
Separate arith routines into a library
The point is that this separate library does not depend on orthotope.
Diffstat (limited to 'src/Data/Array/Nested/Internal')
| -rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 71 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Ranked.hs | 2 | ||||
| -rw-r--r-- | src/Data/Array/Nested/Internal/Shaped.hs | 2 | 
3 files changed, 38 insertions, 37 deletions
| diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 80d581e..eb452dd 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -49,6 +49,7 @@ import Data.Array.Mixed.Shape  import Data.Array.Mixed.Types  import Data.Array.Mixed.Permutation  import Data.Array.Mixed.Lemmas +import Data.Array.Strided.Arith  -- TODO:  --   sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a @@ -225,52 +226,52 @@ mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_    | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2  instance (NumElt a, PrimElt a) => Num (Mixed sh a) where -  (+) = mliftNumElt2 numEltAdd -  (-) = mliftNumElt2 numEltSub -  (*) = mliftNumElt2 numEltMul -  negate = mliftNumElt1 numEltNeg -  abs = mliftNumElt1 numEltAbs -  signum = mliftNumElt1 numEltSignum +  (+) = mliftNumElt2 (liftO2 . numEltAdd) +  (-) = mliftNumElt2 (liftO2 . numEltSub) +  (*) = mliftNumElt2 (liftO2 . numEltMul) +  negate = mliftNumElt1 (liftO1 . numEltNeg) +  abs = mliftNumElt1 (liftO1 . numEltAbs) +  signum = mliftNumElt1 (liftO1 . numEltSignum)    -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS    fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal"  instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where    fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" -  recip = mliftNumElt1 floatEltRecip -  (/) = mliftNumElt2 floatEltDiv +  recip = mliftNumElt1 (liftO1 . floatEltRecip) +  (/) = mliftNumElt2 (liftO2 . floatEltDiv)  instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where    pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" -  exp = mliftNumElt1 floatEltExp -  log = mliftNumElt1 floatEltLog -  sqrt = mliftNumElt1 floatEltSqrt +  exp = mliftNumElt1 (liftO1 . floatEltExp) +  log = mliftNumElt1 (liftO1 . floatEltLog) +  sqrt = mliftNumElt1 (liftO1 . floatEltSqrt) -  (**) = mliftNumElt2 floatEltPow -  logBase = mliftNumElt2 floatEltLogbase +  (**) = mliftNumElt2 (liftO2 . floatEltPow) +  logBase = mliftNumElt2 (liftO2 . floatEltLogbase) -  sin = mliftNumElt1 floatEltSin -  cos = mliftNumElt1 floatEltCos -  tan = mliftNumElt1 floatEltTan -  asin = mliftNumElt1 floatEltAsin -  acos = mliftNumElt1 floatEltAcos -  atan = mliftNumElt1 floatEltAtan -  sinh = mliftNumElt1 floatEltSinh -  cosh = mliftNumElt1 floatEltCosh -  tanh = mliftNumElt1 floatEltTanh -  asinh = mliftNumElt1 floatEltAsinh -  acosh = mliftNumElt1 floatEltAcosh -  atanh = mliftNumElt1 floatEltAtanh -  log1p = mliftNumElt1 floatEltLog1p -  expm1 = mliftNumElt1 floatEltExpm1 -  log1pexp = mliftNumElt1 floatEltLog1pexp -  log1mexp = mliftNumElt1 floatEltLog1mexp +  sin = mliftNumElt1 (liftO1 . floatEltSin) +  cos = mliftNumElt1 (liftO1 . floatEltCos) +  tan = mliftNumElt1 (liftO1 . floatEltTan) +  asin = mliftNumElt1 (liftO1 . floatEltAsin) +  acos = mliftNumElt1 (liftO1 . floatEltAcos) +  atan = mliftNumElt1 (liftO1 . floatEltAtan) +  sinh = mliftNumElt1 (liftO1 . floatEltSinh) +  cosh = mliftNumElt1 (liftO1 . floatEltCosh) +  tanh = mliftNumElt1 (liftO1 . floatEltTanh) +  asinh = mliftNumElt1 (liftO1 . floatEltAsinh) +  acosh = mliftNumElt1 (liftO1 . floatEltAcosh) +  atanh = mliftNumElt1 (liftO1 . floatEltAtanh) +  log1p = mliftNumElt1 (liftO1 . floatEltLog1p) +  expm1 = mliftNumElt1 (liftO1 . floatEltExpm1) +  log1pexp = mliftNumElt1 (liftO1 . floatEltLog1pexp) +  log1mexp = mliftNumElt1 (liftO1 . floatEltLog1mexp)  mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a -mquotArray = mliftNumElt2 intEltQuot -mremArray = mliftNumElt2 intEltRem +mquotArray = mliftNumElt2 (liftO2 . intEltQuot) +mremArray = mliftNumElt2 (liftO2 . intEltRem)  matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a -matan2Array = mliftNumElt2 floatEltAtan2 +matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2)  -- | Allowable element types in a mixed array, and by extension in a 'Ranked' or @@ -867,12 +868,12 @@ miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)  -- | Throws if the array is empty.  mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh  mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = -  ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) arr) +  ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr))  -- | Throws if the array is empty.  mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh  mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = -  ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) arr) +  ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr))  mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)             => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a @@ -883,7 +884,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi        _ :$% _          | sh1 == sh2          , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) -> -            fromPrimitive $ M_Primitive (shxInit sh1) (XArray (numEltDotprodInner (shxRank (shxInit sh1)) a b)) +            fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b))          | otherwise -> error "mdot1Inner: Unequal shapes"        ZSX -> error "unreachable" diff --git a/src/Data/Array/Nested/Internal/Ranked.hs b/src/Data/Array/Nested/Internal/Ranked.hs index 1c6b789..0a165bc 100644 --- a/src/Data/Array/Nested/Internal/Ranked.hs +++ b/src/Data/Array/Nested/Internal/Ranked.hs @@ -41,13 +41,13 @@ import GHC.TypeNats qualified as TN  import Data.Array.Mixed.XArray (XArray(..))  import Data.Array.Mixed.XArray qualified as X -import Data.Array.Mixed.Internal.Arith  import Data.Array.Mixed.Lemmas  import Data.Array.Mixed.Permutation  import Data.Array.Mixed.Shape  import Data.Array.Mixed.Types  import Data.Array.Nested.Internal.Mixed  import Data.Array.Nested.Internal.Shape +import Data.Array.Strided.Arith  -- | A rank-typed array: the number of dimensions of the array (its /rank/) is diff --git a/src/Data/Array/Nested/Internal/Shaped.hs b/src/Data/Array/Nested/Internal/Shaped.hs index 35628db..d7a8ece 100644 --- a/src/Data/Array/Nested/Internal/Shaped.hs +++ b/src/Data/Array/Nested/Internal/Shaped.hs @@ -41,7 +41,6 @@ import GHC.TypeLits  import Data.Array.Mixed.XArray (XArray)  import Data.Array.Mixed.XArray qualified as X -import Data.Array.Mixed.Internal.Arith  import Data.Array.Mixed.Lemmas  import Data.Array.Mixed.Permutation  import Data.Array.Mixed.Shape @@ -49,6 +48,7 @@ import Data.Array.Mixed.Types  import Data.Array.Nested.Internal.Lemmas  import Data.Array.Nested.Internal.Mixed  import Data.Array.Nested.Internal.Shape +import Data.Array.Strided.Arith  -- | A shape-typed array: the full shape of the array (the sizes of its | 
