summaryrefslogtreecommitdiff
path: root/src/ForwardAD.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-09-28 13:01:08 +0200
committerTom Smeding <tom@tomsmeding.com>2024-09-28 13:01:08 +0200
commit1f13bc80915a26473e0622c4afa65c8276b396ff (patch)
treee4c3b0b266cb73db2b17eb602baa40855bbe347d /src/ForwardAD.hs
parenta87b1a8eb1f659fc15060b6215eb9a6706dfdccd (diff)
Dual-numbers forward AD
Diffstat (limited to 'src/ForwardAD.hs')
-rw-r--r--src/ForwardAD.hs202
1 files changed, 202 insertions, 0 deletions
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
new file mode 100644
index 0000000..0a9e12c
--- /dev/null
+++ b/src/ForwardAD.hs
@@ -0,0 +1,202 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+
+-- I want to bring various type variables in scope using type annotations in
+-- patterns, but I don't want to have to mention all the other type parameters
+-- of the types in question as well then. Partial type signatures (with '_') are
+-- useful here.
+{-# LANGUAGE PartialTypeSignatures #-}
+{-# OPTIONS -Wno-partial-type-signatures #-}
+module ForwardAD (
+ dfwd,
+ FD, FDS, FDE, fd,
+) where
+
+import AST
+import Data
+
+
+-- | Dual-numbers transformation
+type family FD t where
+ FD TNil = TNil
+ FD (TPair a b) = TPair (FD a) (FD b)
+ FD (TEither a b) = TEither (FD a) (FD b)
+ FD (TMaybe t) = TMaybe (FD t)
+ FD (TArr n t) = TArr n (FD t)
+ FD (TScal t) = FDS t
+
+type family FDS t where
+ FDS TF32 = TPair (TScal TF32) (TScal TF32)
+ FDS TF64 = TPair (TScal TF64) (TScal TF64)
+ FDS TI32 = TScal TI32
+ FDS TI64 = TScal TI64
+ FDS TBool = TScal TBool
+
+type family FDE env where
+ FDE '[] = '[]
+ FDE (t : ts) = FD t : FDE ts
+
+fd :: STy t -> STy (FD t)
+fd STNil = STNil
+fd (STPair a b) = STPair (fd a) (fd b)
+fd (STEither a b) = STEither (fd a) (fd b)
+fd (STMaybe t) = STMaybe (fd t)
+fd (STArr n t) = STArr n (fd t)
+fd (STScal t) = case t of
+ STF32 -> STPair (STScal STF32) (STScal STF32)
+ STF64 -> STPair (STScal STF64) (STScal STF64)
+ STI32 -> STScal STI32
+ STI64 -> STScal STI64
+ STBool -> STScal STBool
+fd STAccum{} = error "Accum in source program"
+
+fdPreservesTupIx :: SNat n -> FD (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx)
+fdPreservesTupIx SZ = Refl
+fdPreservesTupIx (SS n) | Refl <- fdPreservesTupIx n = Refl
+
+convIdx :: Idx env t -> Idx (FDE env) (FD t)
+convIdx IZ = IZ
+convIdx (IS i) = IS (convIdx i)
+
+scalTyCase :: SScalTy t
+ -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), FD (TScal t) ~ TPair (TScal t) (TScal t)) => r)
+ -> (FD (TScal t) ~ TScal t => r)
+ -> r
+scalTyCase STF32 k1 _ = k1
+scalTyCase STF64 k1 _ = k1
+scalTyCase STI32 _ k2 = k2
+scalTyCase STI64 _ k2 = k2
+scalTyCase STBool _ k2 = k2
+
+-- | Argument does not need to be duplicable.
+dop :: forall a b env. SOp a b -> Ex env (FD a) -> Ex env (FD b)
+dop = \case
+ OAdd t -> scalTyCase t
+ (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy)))
+ (EOp ext (OAdd t))
+ OMul t -> scalTyCase t
+ (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x))))
+ (EOp ext (OMul t))
+ ONeg t -> scalTyCase t
+ (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx)))
+ (EOp ext (ONeg t))
+ OLt t -> scalTyCase t
+ (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y)))
+ (EOp ext (OLt t))
+ OLe t -> scalTyCase t
+ (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y)))
+ (EOp ext (OLe t))
+ OEq t -> scalTyCase t
+ (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y)))
+ (EOp ext (OEq t))
+ ONot -> EOp ext ONot
+ OIf -> EOp ext OIf
+ where
+ add :: ScalIsNumeric t ~ True
+ => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
+ add t a b = EOp ext (OAdd t) (EPair ext a b)
+
+ mul :: ScalIsNumeric t ~ True
+ => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
+ mul t a b = EOp ext (OMul t) (EPair ext a b)
+
+ neg :: ScalIsNumeric t ~ True
+ => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
+ neg t = EOp ext (ONeg t)
+
+ unFloat :: FD a ~ TPair a a
+ => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (FD b))
+ -> Ex env (FD a) -> Ex env (FD b)
+ unFloat f e =
+ ELet ext e $
+ let var = EVar ext (typeOf e) IZ
+ in f (EFst ext var, ESnd ext var)
+
+ binFloat :: (a ~ TPair s s, FD s ~ TPair s s)
+ => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (FD b))
+ -> Ex env (FD a) -> Ex env (FD b)
+ binFloat f e =
+ ELet ext e $
+ let var = EVar ext (typeOf e) IZ
+ in f (EFst ext (EFst ext var), ESnd ext (EFst ext var))
+ (EFst ext (ESnd ext var), ESnd ext (ESnd ext var))
+
+dfwd :: Ex env t -> Ex (FDE env) (FD t)
+dfwd = \case
+ EVar _ t i -> EVar ext (fd t) (convIdx i)
+ ELet _ a b -> ELet ext (dfwd a) (dfwd b)
+ EPair _ a b -> EPair ext (dfwd a) (dfwd b)
+ EFst _ e -> EFst ext (dfwd e)
+ ESnd _ e -> ESnd ext (dfwd e)
+ ENil _ -> ENil ext
+ EInl _ t e -> EInl ext (fd t) (dfwd e)
+ EInr _ t e -> EInr ext (fd t) (dfwd e)
+ ECase _ e a b -> ECase ext (dfwd e) (dfwd a) (dfwd b)
+ ENothing _ t -> ENothing ext (fd t)
+ EJust _ e -> EJust ext (dfwd e)
+ EMaybe _ e a b -> EMaybe ext (dfwd e) (dfwd a) (dfwd b)
+ EConstArr _ n t x -> scalTyCase t
+ (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0))
+ (EConstArr ext n t x))
+ (EConstArr ext n t x)
+ EBuild1 _ a b -> EBuild1 ext (dfwd a) (dfwd b)
+ EBuild _ n a b
+ | Refl <- fdPreservesTupIx n -> EBuild ext n (dfwd a) (dfwd b)
+ EFold1Inner _ a b -> EFold1Inner ext (dfwd a) (dfwd b)
+ ESum1Inner _ e ->
+ let STArr n (STScal t) = typeOf e
+ pairty = (STPair (STScal t) (STScal t))
+ in scalTyCase t
+ (ELet ext (dfwd e) $
+ ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ))
+ (EVar ext (STArr n pairty) IZ)))
+ (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ))
+ (EVar ext (STArr n pairty) IZ))))
+ (ESum1Inner ext (dfwd e))
+ EUnit _ e -> EUnit ext (dfwd e)
+ EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwd a) (dfwd b)
+ EConst _ t x -> scalTyCase t
+ (EPair ext (EConst ext t x) (EConst ext t 0.0))
+ (EConst ext t x)
+ EIdx0 _ e -> EIdx0 ext (dfwd e)
+ EIdx1 _ a b -> EIdx1 ext (dfwd a) (dfwd b)
+ EIdx _ n a b
+ | Refl <- fdPreservesTupIx n -> EIdx ext n (dfwd a) (dfwd b)
+ EShape _ e
+ | Refl <- fdPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwd e)
+ EOp _ op e -> dop op (dfwd e)
+ EError t s -> EError (fd t) s
+
+ EWith{} -> err_accum
+ EAccum{} -> err_accum
+ EZero{} -> err_monoid
+ EPlus{} -> err_monoid
+ where
+ err_accum = error "Accumulator operations unsupported in the source program"
+ err_monoid = error "Monoid operations unsupported in the source program"
+
+emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b)
+emap f arr =
+ let STArr n t = typeOf arr
+ in ELet ext arr $
+ EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $
+ ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ))
+ (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) f
+
+ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
+ezip a b =
+ let STArr n t1 = typeOf a
+ STArr _ t2 = typeOf b
+ in ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $
+ EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
+ (EIdx ext n (EVar ext (STArr n t2) (IS IZ))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))