aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Internal/Mixed.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Data/Array/Nested/Internal/Mixed.hs')
-rw-r--r--src/Data/Array/Nested/Internal/Mixed.hs75
1 files changed, 38 insertions, 37 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs
index 80d581e..eb452dd 100644
--- a/src/Data/Array/Nested/Internal/Mixed.hs
+++ b/src/Data/Array/Nested/Internal/Mixed.hs
@@ -49,6 +49,7 @@ import Data.Array.Mixed.Shape
import Data.Array.Mixed.Types
import Data.Array.Mixed.Permutation
import Data.Array.Mixed.Lemmas
+import Data.Array.Strided.Arith
-- TODO:
-- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a
@@ -225,52 +226,52 @@ mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_
| otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2
instance (NumElt a, PrimElt a) => Num (Mixed sh a) where
- (+) = mliftNumElt2 numEltAdd
- (-) = mliftNumElt2 numEltSub
- (*) = mliftNumElt2 numEltMul
- negate = mliftNumElt1 numEltNeg
- abs = mliftNumElt1 numEltAbs
- signum = mliftNumElt1 numEltSignum
+ (+) = mliftNumElt2 (liftO2 . numEltAdd)
+ (-) = mliftNumElt2 (liftO2 . numEltSub)
+ (*) = mliftNumElt2 (liftO2 . numEltMul)
+ negate = mliftNumElt1 (liftO1 . numEltNeg)
+ abs = mliftNumElt1 (liftO1 . numEltAbs)
+ signum = mliftNumElt1 (liftO1 . numEltSignum)
-- TODO: THIS IS BAD, WE NEED TO REMOVE THIS
fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal"
instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where
fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate"
- recip = mliftNumElt1 floatEltRecip
- (/) = mliftNumElt2 floatEltDiv
+ recip = mliftNumElt1 (liftO1 . floatEltRecip)
+ (/) = mliftNumElt2 (liftO2 . floatEltDiv)
instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where
pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate"
- exp = mliftNumElt1 floatEltExp
- log = mliftNumElt1 floatEltLog
- sqrt = mliftNumElt1 floatEltSqrt
-
- (**) = mliftNumElt2 floatEltPow
- logBase = mliftNumElt2 floatEltLogbase
-
- sin = mliftNumElt1 floatEltSin
- cos = mliftNumElt1 floatEltCos
- tan = mliftNumElt1 floatEltTan
- asin = mliftNumElt1 floatEltAsin
- acos = mliftNumElt1 floatEltAcos
- atan = mliftNumElt1 floatEltAtan
- sinh = mliftNumElt1 floatEltSinh
- cosh = mliftNumElt1 floatEltCosh
- tanh = mliftNumElt1 floatEltTanh
- asinh = mliftNumElt1 floatEltAsinh
- acosh = mliftNumElt1 floatEltAcosh
- atanh = mliftNumElt1 floatEltAtanh
- log1p = mliftNumElt1 floatEltLog1p
- expm1 = mliftNumElt1 floatEltExpm1
- log1pexp = mliftNumElt1 floatEltLog1pexp
- log1mexp = mliftNumElt1 floatEltLog1mexp
+ exp = mliftNumElt1 (liftO1 . floatEltExp)
+ log = mliftNumElt1 (liftO1 . floatEltLog)
+ sqrt = mliftNumElt1 (liftO1 . floatEltSqrt)
+
+ (**) = mliftNumElt2 (liftO2 . floatEltPow)
+ logBase = mliftNumElt2 (liftO2 . floatEltLogbase)
+
+ sin = mliftNumElt1 (liftO1 . floatEltSin)
+ cos = mliftNumElt1 (liftO1 . floatEltCos)
+ tan = mliftNumElt1 (liftO1 . floatEltTan)
+ asin = mliftNumElt1 (liftO1 . floatEltAsin)
+ acos = mliftNumElt1 (liftO1 . floatEltAcos)
+ atan = mliftNumElt1 (liftO1 . floatEltAtan)
+ sinh = mliftNumElt1 (liftO1 . floatEltSinh)
+ cosh = mliftNumElt1 (liftO1 . floatEltCosh)
+ tanh = mliftNumElt1 (liftO1 . floatEltTanh)
+ asinh = mliftNumElt1 (liftO1 . floatEltAsinh)
+ acosh = mliftNumElt1 (liftO1 . floatEltAcosh)
+ atanh = mliftNumElt1 (liftO1 . floatEltAtanh)
+ log1p = mliftNumElt1 (liftO1 . floatEltLog1p)
+ expm1 = mliftNumElt1 (liftO1 . floatEltExpm1)
+ log1pexp = mliftNumElt1 (liftO1 . floatEltLog1pexp)
+ log1mexp = mliftNumElt1 (liftO1 . floatEltLog1mexp)
mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
-mquotArray = mliftNumElt2 intEltQuot
-mremArray = mliftNumElt2 intEltRem
+mquotArray = mliftNumElt2 (liftO2 . intEltQuot)
+mremArray = mliftNumElt2 (liftO2 . intEltRem)
matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a
-matan2Array = mliftNumElt2 floatEltAtan2
+matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2)
-- | Allowable element types in a mixed array, and by extension in a 'Ranked' or
@@ -867,12 +868,12 @@ miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn)
-- | Throws if the array is empty.
mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
- ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) arr)
+ ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr))
-- | Throws if the array is empty.
mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh
mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) =
- ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) arr)
+ ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr))
mdot1Inner :: forall sh n a. (PrimElt a, NumElt a)
=> Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a
@@ -883,7 +884,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi
_ :$% _
| sh1 == sh2
, Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) ->
- fromPrimitive $ M_Primitive (shxInit sh1) (XArray (numEltDotprodInner (shxRank (shxInit sh1)) a b))
+ fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b))
| otherwise -> error "mdot1Inner: Unequal shapes"
ZSX -> error "unreachable"