diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 12:37:51 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 12:37:51 +0100 |
commit | 83692cf41f76272423445c9cbbad65561ee3b50c (patch) | |
tree | 49f56f498a68722a7302b4ce0b41402a9b9da9ef /src/ForwardAD/DualNumbers.hs | |
parent | 58d4d0b47f5e609e21132f48b727de37d06b6777 (diff) |
WIP custom derivatives
Diffstat (limited to 'src/ForwardAD/DualNumbers.hs')
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 48 |
1 files changed, 9 insertions, 39 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 8e84378..3587378 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -3,6 +3,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} @@ -19,46 +20,9 @@ module ForwardAD.DualNumbers ( import AST import Data +import ForwardAD.DualNumbers.Types --- | Dual-numbers transformation -type family DN t where - DN TNil = TNil - DN (TPair a b) = TPair (DN a) (DN b) - DN (TEither a b) = TEither (DN a) (DN b) - DN (TMaybe t) = TMaybe (DN t) - DN (TArr n t) = TArr n (DN t) - DN (TScal t) = DNS t - -type family DNS t where - DNS TF32 = TPair (TScal TF32) (TScal TF32) - DNS TF64 = TPair (TScal TF64) (TScal TF64) - DNS TI32 = TScal TI32 - DNS TI64 = TScal TI64 - DNS TBool = TScal TBool - -type family DNE env where - DNE '[] = '[] - DNE (t : ts) = DN t : DNE ts - -dn :: STy t -> STy (DN t) -dn STNil = STNil -dn (STPair a b) = STPair (dn a) (dn b) -dn (STEither a b) = STEither (dn a) (dn b) -dn (STMaybe t) = STMaybe (dn t) -dn (STArr n t) = STArr n (dn t) -dn (STScal t) = case t of - STF32 -> STPair (STScal STF32) (STScal STF32) - STF64 -> STPair (STScal STF64) (STScal STF64) - STI32 -> STScal STI32 - STI64 -> STScal STI64 - STBool -> STScal STBool -dn STAccum{} = error "Accum in source program" - -dne :: SList STy env -> SList STy (DNE env) -dne SNil = SNil -dne (t `SCons` env) = dn t `SCons` dne env - dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) dnPreservesTupIx SZ = Refl dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl @@ -177,8 +141,14 @@ dfwdDN = \case , Refl <- dnPreservesTupIx n -> EIdx ext (dfwdDN a) (dfwdDN b) EShape _ e - | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e) + | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) + -> EShape ext (dfwdDN e) EOp _ op e -> dop op (dfwdDN e) + ECustom _ s t _ du _ e1 e2 -> + -- TODO: we need a bit of codegen here that projects the primals out from the dual number result of e1. Note that a non-differentiating code transformation does not eliminate the need for this, because then the need just shifts to free variable adaptor code. + ELet ext (_ e1) $ + ELet ext (weakenExpr WSink (dfwdDN e2)) $ + weakenExpr (WCopy (WCopy WClosed)) du EError t s -> EError (dn t) s EWith{} -> err_accum |