diff options
Diffstat (limited to 'src/Data/Array/Nested/Internal/Mixed.hs')
-rw-r--r-- | src/Data/Array/Nested/Internal/Mixed.hs | 75 |
1 files changed, 38 insertions, 37 deletions
diff --git a/src/Data/Array/Nested/Internal/Mixed.hs b/src/Data/Array/Nested/Internal/Mixed.hs index 80d581e..eb452dd 100644 --- a/src/Data/Array/Nested/Internal/Mixed.hs +++ b/src/Data/Array/Nested/Internal/Mixed.hs @@ -49,6 +49,7 @@ import Data.Array.Mixed.Shape import Data.Array.Mixed.Types import Data.Array.Mixed.Permutation import Data.Array.Mixed.Lemmas +import Data.Array.Strided.Arith -- TODO: -- sumAllPrim :: (PrimElt a, NumElt a) => Mixed sh a -> a @@ -225,52 +226,52 @@ mliftNumElt2 f (toPrimitive -> M_Primitive sh1 (XArray arr1)) (toPrimitive -> M_ | otherwise = error $ "Data.Array.Nested: Shapes unequal in elementwise Num operation: " ++ show sh1 ++ " vs " ++ show sh2 instance (NumElt a, PrimElt a) => Num (Mixed sh a) where - (+) = mliftNumElt2 numEltAdd - (-) = mliftNumElt2 numEltSub - (*) = mliftNumElt2 numEltMul - negate = mliftNumElt1 numEltNeg - abs = mliftNumElt1 numEltAbs - signum = mliftNumElt1 numEltSignum + (+) = mliftNumElt2 (liftO2 . numEltAdd) + (-) = mliftNumElt2 (liftO2 . numEltSub) + (*) = mliftNumElt2 (liftO2 . numEltMul) + negate = mliftNumElt1 (liftO1 . numEltNeg) + abs = mliftNumElt1 (liftO1 . numEltAbs) + signum = mliftNumElt1 (liftO1 . numEltSignum) -- TODO: THIS IS BAD, WE NEED TO REMOVE THIS fromInteger = error "Data.Array.Nested.fromInteger: Cannot implement fromInteger, use mreplicateScal" instance (FloatElt a, PrimElt a) => Fractional (Mixed sh a) where fromRational _ = error "Data.Array.Nested.fromRational: No singletons available, use explicit mreplicate" - recip = mliftNumElt1 floatEltRecip - (/) = mliftNumElt2 floatEltDiv + recip = mliftNumElt1 (liftO1 . floatEltRecip) + (/) = mliftNumElt2 (liftO2 . floatEltDiv) instance (FloatElt a, PrimElt a) => Floating (Mixed sh a) where pi = error "Data.Array.Nested.pi: No singletons available, use explicit mreplicate" - exp = mliftNumElt1 floatEltExp - log = mliftNumElt1 floatEltLog - sqrt = mliftNumElt1 floatEltSqrt - - (**) = mliftNumElt2 floatEltPow - logBase = mliftNumElt2 floatEltLogbase - - sin = mliftNumElt1 floatEltSin - cos = mliftNumElt1 floatEltCos - tan = mliftNumElt1 floatEltTan - asin = mliftNumElt1 floatEltAsin - acos = mliftNumElt1 floatEltAcos - atan = mliftNumElt1 floatEltAtan - sinh = mliftNumElt1 floatEltSinh - cosh = mliftNumElt1 floatEltCosh - tanh = mliftNumElt1 floatEltTanh - asinh = mliftNumElt1 floatEltAsinh - acosh = mliftNumElt1 floatEltAcosh - atanh = mliftNumElt1 floatEltAtanh - log1p = mliftNumElt1 floatEltLog1p - expm1 = mliftNumElt1 floatEltExpm1 - log1pexp = mliftNumElt1 floatEltLog1pexp - log1mexp = mliftNumElt1 floatEltLog1mexp + exp = mliftNumElt1 (liftO1 . floatEltExp) + log = mliftNumElt1 (liftO1 . floatEltLog) + sqrt = mliftNumElt1 (liftO1 . floatEltSqrt) + + (**) = mliftNumElt2 (liftO2 . floatEltPow) + logBase = mliftNumElt2 (liftO2 . floatEltLogbase) + + sin = mliftNumElt1 (liftO1 . floatEltSin) + cos = mliftNumElt1 (liftO1 . floatEltCos) + tan = mliftNumElt1 (liftO1 . floatEltTan) + asin = mliftNumElt1 (liftO1 . floatEltAsin) + acos = mliftNumElt1 (liftO1 . floatEltAcos) + atan = mliftNumElt1 (liftO1 . floatEltAtan) + sinh = mliftNumElt1 (liftO1 . floatEltSinh) + cosh = mliftNumElt1 (liftO1 . floatEltCosh) + tanh = mliftNumElt1 (liftO1 . floatEltTanh) + asinh = mliftNumElt1 (liftO1 . floatEltAsinh) + acosh = mliftNumElt1 (liftO1 . floatEltAcosh) + atanh = mliftNumElt1 (liftO1 . floatEltAtanh) + log1p = mliftNumElt1 (liftO1 . floatEltLog1p) + expm1 = mliftNumElt1 (liftO1 . floatEltExpm1) + log1pexp = mliftNumElt1 (liftO1 . floatEltLog1pexp) + log1mexp = mliftNumElt1 (liftO1 . floatEltLog1mexp) mquotArray, mremArray :: (IntElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a -mquotArray = mliftNumElt2 intEltQuot -mremArray = mliftNumElt2 intEltRem +mquotArray = mliftNumElt2 (liftO2 . intEltQuot) +mremArray = mliftNumElt2 (liftO2 . intEltRem) matan2Array :: (FloatElt a, PrimElt a) => Mixed sh a -> Mixed sh a -> Mixed sh a -matan2Array = mliftNumElt2 floatEltAtan2 +matan2Array = mliftNumElt2 (liftO2 . floatEltAtan2) -- | Allowable element types in a mixed array, and by extension in a 'Ranked' or @@ -867,12 +868,12 @@ miota sn = fromPrimitive $ M_Primitive (SKnown sn :$% ZSX) (X.iota sn) -- | Throws if the array is empty. mminIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mminIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) arr) + ixxFromList (ssxFromShape sh) (numEltMinIndex (shxRank sh) (fromO arr)) -- | Throws if the array is empty. mmaxIndexPrim :: (PrimElt a, NumElt a) => Mixed sh a -> IIxX sh mmaxIndexPrim (toPrimitive -> M_Primitive sh (XArray arr)) = - ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) arr) + ixxFromList (ssxFromShape sh) (numEltMaxIndex (shxRank sh) (fromO arr)) mdot1Inner :: forall sh n a. (PrimElt a, NumElt a) => Proxy n -> Mixed (sh ++ '[n]) a -> Mixed (sh ++ '[n]) a -> Mixed sh a @@ -883,7 +884,7 @@ mdot1Inner _ (toPrimitive -> M_Primitive sh1 (XArray a)) (toPrimitive -> M_Primi _ :$% _ | sh1 == sh2 , Refl <- lemRankApp (ssxInit (ssxFromShape sh1)) (ssxLast (ssxFromShape sh1) :!% ZKX) -> - fromPrimitive $ M_Primitive (shxInit sh1) (XArray (numEltDotprodInner (shxRank (shxInit sh1)) a b)) + fromPrimitive $ M_Primitive (shxInit sh1) (XArray (liftO2 (numEltDotprodInner (shxRank (shxInit sh1))) a b)) | otherwise -> error "mdot1Inner: Unequal shapes" ZSX -> error "unreachable" |