summaryrefslogtreecommitdiff
path: root/src/ForwardAD/DualNumbers.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-08 12:37:51 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-08 12:37:51 +0100
commit83692cf41f76272423445c9cbbad65561ee3b50c (patch)
tree49f56f498a68722a7302b4ce0b41402a9b9da9ef /src/ForwardAD/DualNumbers.hs
parent58d4d0b47f5e609e21132f48b727de37d06b6777 (diff)
WIP custom derivatives
Diffstat (limited to 'src/ForwardAD/DualNumbers.hs')
-rw-r--r--src/ForwardAD/DualNumbers.hs48
1 files changed, 9 insertions, 39 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 8e84378..3587378 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
@@ -19,46 +20,9 @@ module ForwardAD.DualNumbers (
import AST
import Data
+import ForwardAD.DualNumbers.Types
--- | Dual-numbers transformation
-type family DN t where
- DN TNil = TNil
- DN (TPair a b) = TPair (DN a) (DN b)
- DN (TEither a b) = TEither (DN a) (DN b)
- DN (TMaybe t) = TMaybe (DN t)
- DN (TArr n t) = TArr n (DN t)
- DN (TScal t) = DNS t
-
-type family DNS t where
- DNS TF32 = TPair (TScal TF32) (TScal TF32)
- DNS TF64 = TPair (TScal TF64) (TScal TF64)
- DNS TI32 = TScal TI32
- DNS TI64 = TScal TI64
- DNS TBool = TScal TBool
-
-type family DNE env where
- DNE '[] = '[]
- DNE (t : ts) = DN t : DNE ts
-
-dn :: STy t -> STy (DN t)
-dn STNil = STNil
-dn (STPair a b) = STPair (dn a) (dn b)
-dn (STEither a b) = STEither (dn a) (dn b)
-dn (STMaybe t) = STMaybe (dn t)
-dn (STArr n t) = STArr n (dn t)
-dn (STScal t) = case t of
- STF32 -> STPair (STScal STF32) (STScal STF32)
- STF64 -> STPair (STScal STF64) (STScal STF64)
- STI32 -> STScal STI32
- STI64 -> STScal STI64
- STBool -> STScal STBool
-dn STAccum{} = error "Accum in source program"
-
-dne :: SList STy env -> SList STy (DNE env)
-dne SNil = SNil
-dne (t `SCons` env) = dn t `SCons` dne env
-
dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx)
dnPreservesTupIx SZ = Refl
dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl
@@ -177,8 +141,14 @@ dfwdDN = \case
, Refl <- dnPreservesTupIx n
-> EIdx ext (dfwdDN a) (dfwdDN b)
EShape _ e
- | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e)
+ | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n)
+ -> EShape ext (dfwdDN e)
EOp _ op e -> dop op (dfwdDN e)
+ ECustom _ s t _ du _ e1 e2 ->
+ -- TODO: we need a bit of codegen here that projects the primals out from the dual number result of e1. Note that a non-differentiating code transformation does not eliminate the need for this, because then the need just shifts to free variable adaptor code.
+ ELet ext (_ e1) $
+ ELet ext (weakenExpr WSink (dfwdDN e2)) $
+ weakenExpr (WCopy (WCopy WClosed)) du
EError t s -> EError (dn t) s
EWith{} -> err_accum