summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-10 10:04:27 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-10 10:04:27 +0100
commit42d59947356ab51e5a4070b930f02f4909208d35 (patch)
tree3c8afab888e61c4e3157a257f0a40ae2fd4eb9c1
parent33e0ed21603cbd85d6aba6548811db27480647db (diff)
Complete GMM implementation
-rw-r--r--bench/Bench/GMM.hs14
-rw-r--r--src/AST.hs2
-rw-r--r--src/AST/Pretty.hs1
-rw-r--r--src/AST/Types.hs7
-rw-r--r--src/CHAD.hs7
-rw-r--r--src/Example.hs1
-rw-r--r--src/ForwardAD/DualNumbers.hs6
-rw-r--r--src/Interpreter.hs7
-rw-r--r--src/Language.hs7
-rw-r--r--test/Main.hs1
10 files changed, 45 insertions, 8 deletions
diff --git a/bench/Bench/GMM.hs b/bench/Bench/GMM.hs
index ebbbe1e..9b84d23 100644
--- a/bench/Bench/GMM.hs
+++ b/bench/Bench/GMM.hs
@@ -3,8 +3,6 @@
{-# LANGUAGE TypeApplications #-}
module Bench.GMM where
-import AST
-import Data
import Language
@@ -31,7 +29,7 @@ type TMat = TArr (S (S Z))
-- <https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435651>
-- <https://github.com/microsoft/ADBench>
-- - 2021 Tom Smeding: “Reverse Automatic Differentiation for Accelerate”.
--- Master thesis at Utrecht University.
+-- Master thesis at Utrecht University. (Appendix B.1)
-- <https://studenttheses.uu.nl/bitstream/handle/20.500.12932/38958/report.pdf?sequence=1&isAllowed=y>
-- <https://tomsmeding.com/f/master.pdf>
objective :: Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R
@@ -93,7 +91,15 @@ objective = fromNamed $
normsq v = inline normsq' (SNil .$ v)
qmat' = lambda @(TVec R) #q $ lambda @(TVec R) #l $ body $
- _
+ let_ #n (snd_ (shape #q)) $
+ build (SS (SS SZ)) (pair (pair nil #n) #n) $ #idx :->
+ let_ #i (snd_ (fst_ #idx)) $
+ let_ #j (snd_ #idx) $
+ if_ (#i .== #j)
+ (exp (#q ! pair nil #i))
+ (if_ (#i .> #j)
+ (toFloat_ $ #i * (#i - 1) `idiv` 2 + 1 + #j)
+ 0.0)
qmat q l = inline qmat' (SNil .$ q .$ l)
in - #k1
+ idx0 (sum1i (build1 #N $ #i :->
diff --git a/src/AST.hs b/src/AST.hs
index 263b806..e7dde90 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -172,6 +172,7 @@ data SOp a t where
ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
+ OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
deriving instance Show (SOp a t)
opt2 :: SOp a t -> STy t
@@ -191,6 +192,7 @@ opt2 = \case
ORecip t -> STScal t
OExp t -> STScal t
OLog t -> STScal t
+ OIDiv t -> STScal t
typeOf :: Expr x env t -> STy t
typeOf = \case
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 63742ad..51d89dc 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -278,6 +278,7 @@ operator OToFl64 = (Prefix, "toFl64")
operator ORecip{} = (Prefix, "recip")
operator OExp{} = (Prefix, "exp")
operator OLog{} = (Prefix, "log")
+operator OIDiv{} = (Infix, "`div`")
ppTy :: Int -> STy t -> String
ppTy d ty = ppTys d ty ""
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index 5688277..adcc760 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -100,3 +100,10 @@ type family ScalIsFloating t where
ScalIsFloating TF32 = True
ScalIsFloating TF64 = True
ScalIsFloating TBool = False
+
+type family ScalIsIntegral t where
+ ScalIsIntegral TI32 = True
+ ScalIsIntegral TI64 = True
+ ScalIsIntegral TF32 = False
+ ScalIsIntegral TF64 = False
+ ScalIsIntegral TBool = False
diff --git a/src/CHAD.hs b/src/CHAD.hs
index fb6f5e3..59d61a7 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -262,6 +262,7 @@ d1op OToFl64 e = EOp ext OToFl64 e
d1op (ORecip t) e = EOp ext (ORecip t) e
d1op (OExp t) e = EOp ext (OExp t) e
d1op (OLog t) e = EOp ext (OLog t) e
+d1op (OIDiv t) e = EOp ext (OIDiv t) e
-- | Both primal and dual must be duplicable expressions
data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
@@ -286,6 +287,7 @@ d2op op = case op of
ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)
OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)
OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d)
+ OIDiv t -> integralD2 t $ Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
where
d2opUnArrangeInt :: SScalTy a
-> (D2s a ~ TScal a => D2Op (TScal a) t)
@@ -312,6 +314,11 @@ d2op op = case op of
floatingD2 STF32 k = k
floatingD2 STF64 k = k
+ integralD2 :: ScalIsIntegral a ~ True
+ => SScalTy a -> ((D2s a ~ TNil, ScalIsNumeric a ~ True) => r) -> r
+ integralD2 STI32 k = k
+ integralD2 STI64 k = k
+
sD1eEnv :: Descr env sto -> SList STy (D1E env)
sD1eEnv DTop = SNil
sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)
diff --git a/src/Example.hs b/src/Example.hs
index 697c4d9..b208963 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -12,7 +12,6 @@ import AST
import AST.Pretty
import CHAD
import CHAD.Top
-import Data
import ForwardAD
import Interpreter
import Language
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 9ed04bb..0a08926 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
@@ -32,7 +33,7 @@ convIdx IZ = IZ
convIdx (IS i) = IS (convIdx i)
scalTyCase :: SScalTy t
- -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r)
+ -> ((ScalIsNumeric t ~ True, ScalIsFloating t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r)
-> (DN (TScal t) ~ TScal t => r)
-> r
scalTyCase STF32 k1 _ = k1
@@ -82,6 +83,9 @@ dop = \case
OLog t -> floatingDual t $ unFloat (\(x, dx) ->
EPair ext (EOp ext (OLog t) x)
(mul t (recip' t x) dx))
+ OIDiv t -> scalTyCase t
+ (case t of {})
+ (EOp ext (OIDiv t))
where
add :: ScalIsNumeric t ~ True
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 576b0d9..37d4a83 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -173,6 +173,7 @@ interpretOp op arg = case op of
ORecip st -> floatingIsFractional st $ recip arg
OExp st -> floatingIsFractional st $ exp arg
OLog st -> floatingIsFractional st $ log arg
+ OIDiv st -> integralIsIntegral st $ uncurry div arg
where
styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r
styIsEq STI32 = id
@@ -526,10 +527,14 @@ numericIsNum STI64 = id
numericIsNum STF32 = id
numericIsNum STF64 = id
-floatingIsFractional :: ScalIsFloating st ~ True => SScalTy st -> ((Floating (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True) => r) -> r
+floatingIsFractional :: ScalIsFloating st ~ True => SScalTy st -> ((Floating (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsFloating st ~ True) => r) -> r
floatingIsFractional STF32 = id
floatingIsFractional STF64 = id
+integralIsIntegral :: ScalIsIntegral st ~ True => SScalTy st -> ((Integral (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsIntegral st ~ True) => r) -> r
+integralIsIntegral STI32 = id
+integralIsIntegral STI64 = id
+
unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m))
-> SNat n -> Rep (Tup (Replicate n TIx)) -> f n
unTupRepIdx nil _ SZ _ = nil
diff --git a/src/Language.hs b/src/Language.hs
index 7aceee7..a7737e0 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -8,12 +8,16 @@
module Language (
fromNamed,
NExpr,
+ Ex,
module Language,
+ module AST.Types,
+ module Data,
Lookup,
) where
import Array
import AST
+import AST.Types
import CHAD.Types
import Data
import Language.AST
@@ -191,3 +195,6 @@ round_ = oper ORound64
toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64)
toFloat_ = oper OToFl64
+
+idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t)
+idiv = oper2 (OIDiv knownScalTy)
diff --git a/test/Main.hs b/test/Main.hs
index 3a598c0..2573a32 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -21,7 +21,6 @@ import AST
import AST.Pretty
import CHAD.Top
import CHAD.Types
-import Data
import qualified Example
import ForwardAD
import Interpreter