diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 10:04:27 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 10:04:27 +0100 |
commit | 42d59947356ab51e5a4070b930f02f4909208d35 (patch) | |
tree | 3c8afab888e61c4e3157a257f0a40ae2fd4eb9c1 | |
parent | 33e0ed21603cbd85d6aba6548811db27480647db (diff) |
Complete GMM implementation
-rw-r--r-- | bench/Bench/GMM.hs | 14 | ||||
-rw-r--r-- | src/AST.hs | 2 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 1 | ||||
-rw-r--r-- | src/AST/Types.hs | 7 | ||||
-rw-r--r-- | src/CHAD.hs | 7 | ||||
-rw-r--r-- | src/Example.hs | 1 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 6 | ||||
-rw-r--r-- | src/Interpreter.hs | 7 | ||||
-rw-r--r-- | src/Language.hs | 7 | ||||
-rw-r--r-- | test/Main.hs | 1 |
10 files changed, 45 insertions, 8 deletions
diff --git a/bench/Bench/GMM.hs b/bench/Bench/GMM.hs index ebbbe1e..9b84d23 100644 --- a/bench/Bench/GMM.hs +++ b/bench/Bench/GMM.hs @@ -3,8 +3,6 @@ {-# LANGUAGE TypeApplications #-} module Bench.GMM where -import AST -import Data import Language @@ -31,7 +29,7 @@ type TMat = TArr (S (S Z)) -- <https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435651> -- <https://github.com/microsoft/ADBench> -- - 2021 Tom Smeding: “Reverse Automatic Differentiation for Accelerate”. --- Master thesis at Utrecht University. +-- Master thesis at Utrecht University. (Appendix B.1) -- <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y> -- <https://tomsmeding.com/f/master.pdf> objective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R @@ -93,7 +91,15 @@ objective = fromNamed $ normsq v = inline normsq' (SNil .$ v) qmat' = lambda @(TVec R) #q $ lambda @(TVec R) #l $ body $ - _ + let_ #n (snd_ (shape #q)) $ + build (SS (SS SZ)) (pair (pair nil #n) #n) $ #idx :-> + let_ #i (snd_ (fst_ #idx)) $ + let_ #j (snd_ #idx) $ + if_ (#i .== #j) + (exp (#q ! pair nil #i)) + (if_ (#i .> #j) + (toFloat_ $ #i * (#i - 1) `idiv` 2 + 1 + #j) + 0.0) qmat q l = inline qmat' (SNil .$ q .$ l) in - #k1 + idx0 (sum1i (build1 #N $ #i :-> @@ -172,6 +172,7 @@ data SOp a t where ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) deriving instance Show (SOp a t) opt2 :: SOp a t -> STy t @@ -191,6 +192,7 @@ opt2 = \case ORecip t -> STScal t OExp t -> STScal t OLog t -> STScal t + OIDiv t -> STScal t typeOf :: Expr x env t -> STy t typeOf = \case diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 63742ad..51d89dc 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -278,6 +278,7 @@ operator OToFl64 = (Prefix, "toFl64") operator ORecip{} = (Prefix, "recip") operator OExp{} = (Prefix, "exp") operator OLog{} = (Prefix, "log") +operator OIDiv{} = (Infix, "`div`") ppTy :: Int -> STy t -> String ppTy d ty = ppTys d ty "" diff --git a/src/AST/Types.hs b/src/AST/Types.hs index 5688277..adcc760 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -100,3 +100,10 @@ type family ScalIsFloating t where ScalIsFloating TF32 = True ScalIsFloating TF64 = True ScalIsFloating TBool = False + +type family ScalIsIntegral t where + ScalIsIntegral TI32 = True + ScalIsIntegral TI64 = True + ScalIsIntegral TF32 = False + ScalIsIntegral TF64 = False + ScalIsIntegral TBool = False diff --git a/src/CHAD.hs b/src/CHAD.hs index fb6f5e3..59d61a7 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -262,6 +262,7 @@ d1op OToFl64 e = EOp ext OToFl64 e d1op (ORecip t) e = EOp ext (ORecip t) e d1op (OExp t) e = EOp ext (OExp t) e d1op (OLog t) e = EOp ext (OLog t) e +d1op (OIDiv t) e = EOp ext (OIDiv t) e -- | Both primal and dual must be duplicable expressions data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) @@ -286,6 +287,7 @@ d2op op = case op of ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) + OIDiv t -> integralD2 t $ Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) where d2opUnArrangeInt :: SScalTy a -> (D2s a ~ TScal a => D2Op (TScal a) t) @@ -312,6 +314,11 @@ d2op op = case op of floatingD2 STF32 k = k floatingD2 STF64 k = k + integralD2 :: ScalIsIntegral a ~ True + => SScalTy a -> ((D2s a ~ TNil, ScalIsNumeric a ~ True) => r) -> r + integralD2 STI32 k = k + integralD2 STI64 k = k + sD1eEnv :: Descr env sto -> SList STy (D1E env) sD1eEnv DTop = SNil sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d) diff --git a/src/Example.hs b/src/Example.hs index 697c4d9..b208963 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -12,7 +12,6 @@ import AST import AST.Pretty import CHAD import CHAD.Top -import Data import ForwardAD import Interpreter import Language diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 9ed04bb..0a08926 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} @@ -32,7 +33,7 @@ convIdx IZ = IZ convIdx (IS i) = IS (convIdx i) scalTyCase :: SScalTy t - -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r) + -> ((ScalIsNumeric t ~ True, ScalIsFloating t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r) -> (DN (TScal t) ~ TScal t => r) -> r scalTyCase STF32 k1 _ = k1 @@ -82,6 +83,9 @@ dop = \case OLog t -> floatingDual t $ unFloat (\(x, dx) -> EPair ext (EOp ext (OLog t) x) (mul t (recip' t x) dx)) + OIDiv t -> scalTyCase t + (case t of {}) + (EOp ext (OIDiv t)) where add :: ScalIsNumeric t ~ True => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 576b0d9..37d4a83 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -173,6 +173,7 @@ interpretOp op arg = case op of ORecip st -> floatingIsFractional st $ recip arg OExp st -> floatingIsFractional st $ exp arg OLog st -> floatingIsFractional st $ log arg + OIDiv st -> integralIsIntegral st $ uncurry div arg where styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r styIsEq STI32 = id @@ -526,10 +527,14 @@ numericIsNum STI64 = id numericIsNum STF32 = id numericIsNum STF64 = id -floatingIsFractional :: ScalIsFloating st ~ True => SScalTy st -> ((Floating (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True) => r) -> r +floatingIsFractional :: ScalIsFloating st ~ True => SScalTy st -> ((Floating (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsFloating st ~ True) => r) -> r floatingIsFractional STF32 = id floatingIsFractional STF64 = id +integralIsIntegral :: ScalIsIntegral st ~ True => SScalTy st -> ((Integral (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsIntegral st ~ True) => r) -> r +integralIsIntegral STI32 = id +integralIsIntegral STI64 = id + unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n unTupRepIdx nil _ SZ _ = nil diff --git a/src/Language.hs b/src/Language.hs index 7aceee7..a7737e0 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -8,12 +8,16 @@ module Language ( fromNamed, NExpr, + Ex, module Language, + module AST.Types, + module Data, Lookup, ) where import Array import AST +import AST.Types import CHAD.Types import Data import Language.AST @@ -191,3 +195,6 @@ round_ = oper ORound64 toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64) toFloat_ = oper OToFl64 + +idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t) +idiv = oper2 (OIDiv knownScalTy) diff --git a/test/Main.hs b/test/Main.hs index 3a598c0..2573a32 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -21,7 +21,6 @@ import AST import AST.Pretty import CHAD.Top import CHAD.Types -import Data import qualified Example import ForwardAD import Interpreter |