summaryrefslogtreecommitdiff
path: root/src/ForwardAD
diff options
context:
space:
mode:
Diffstat (limited to 'src/ForwardAD')
-rw-r--r--src/ForwardAD/DualNumbers.hs48
-rw-r--r--src/ForwardAD/DualNumbers/Types.hs46
2 files changed, 55 insertions, 39 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 8e84378..3587378 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
@@ -19,46 +20,9 @@ module ForwardAD.DualNumbers (
import AST
import Data
+import ForwardAD.DualNumbers.Types
--- | Dual-numbers transformation
-type family DN t where
- DN TNil = TNil
- DN (TPair a b) = TPair (DN a) (DN b)
- DN (TEither a b) = TEither (DN a) (DN b)
- DN (TMaybe t) = TMaybe (DN t)
- DN (TArr n t) = TArr n (DN t)
- DN (TScal t) = DNS t
-
-type family DNS t where
- DNS TF32 = TPair (TScal TF32) (TScal TF32)
- DNS TF64 = TPair (TScal TF64) (TScal TF64)
- DNS TI32 = TScal TI32
- DNS TI64 = TScal TI64
- DNS TBool = TScal TBool
-
-type family DNE env where
- DNE '[] = '[]
- DNE (t : ts) = DN t : DNE ts
-
-dn :: STy t -> STy (DN t)
-dn STNil = STNil
-dn (STPair a b) = STPair (dn a) (dn b)
-dn (STEither a b) = STEither (dn a) (dn b)
-dn (STMaybe t) = STMaybe (dn t)
-dn (STArr n t) = STArr n (dn t)
-dn (STScal t) = case t of
- STF32 -> STPair (STScal STF32) (STScal STF32)
- STF64 -> STPair (STScal STF64) (STScal STF64)
- STI32 -> STScal STI32
- STI64 -> STScal STI64
- STBool -> STScal STBool
-dn STAccum{} = error "Accum in source program"
-
-dne :: SList STy env -> SList STy (DNE env)
-dne SNil = SNil
-dne (t `SCons` env) = dn t `SCons` dne env
-
dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx)
dnPreservesTupIx SZ = Refl
dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl
@@ -177,8 +141,14 @@ dfwdDN = \case
, Refl <- dnPreservesTupIx n
-> EIdx ext (dfwdDN a) (dfwdDN b)
EShape _ e
- | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e)
+ | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n)
+ -> EShape ext (dfwdDN e)
EOp _ op e -> dop op (dfwdDN e)
+ ECustom _ s t _ du _ e1 e2 ->
+ -- TODO: we need a bit of codegen here that projects the primals out from the dual number result of e1. Note that a non-differentiating code transformation does not eliminate the need for this, because then the need just shifts to free variable adaptor code.
+ ELet ext (_ e1) $
+ ELet ext (weakenExpr WSink (dfwdDN e2)) $
+ weakenExpr (WCopy (WCopy WClosed)) du
EError t s -> EError (dn t) s
EWith{} -> err_accum
diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs
new file mode 100644
index 0000000..fba92d0
--- /dev/null
+++ b/src/ForwardAD/DualNumbers/Types.hs
@@ -0,0 +1,46 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module ForwardAD.DualNumbers.Types where
+
+import AST.Types
+import Data
+
+
+-- | Dual-numbers transformation
+type family DN t where
+ DN TNil = TNil
+ DN (TPair a b) = TPair (DN a) (DN b)
+ DN (TEither a b) = TEither (DN a) (DN b)
+ DN (TMaybe t) = TMaybe (DN t)
+ DN (TArr n t) = TArr n (DN t)
+ DN (TScal t) = DNS t
+
+type family DNS t where
+ DNS TF32 = TPair (TScal TF32) (TScal TF32)
+ DNS TF64 = TPair (TScal TF64) (TScal TF64)
+ DNS TI32 = TScal TI32
+ DNS TI64 = TScal TI64
+ DNS TBool = TScal TBool
+
+type family DNE env where
+ DNE '[] = '[]
+ DNE (t : ts) = DN t : DNE ts
+
+dn :: STy t -> STy (DN t)
+dn STNil = STNil
+dn (STPair a b) = STPair (dn a) (dn b)
+dn (STEither a b) = STEither (dn a) (dn b)
+dn (STMaybe t) = STMaybe (dn t)
+dn (STArr n t) = STArr n (dn t)
+dn (STScal t) = case t of
+ STF32 -> STPair (STScal STF32) (STScal STF32)
+ STF64 -> STPair (STScal STF64) (STScal STF64)
+ STI32 -> STScal STI32
+ STI64 -> STScal STI64
+ STBool -> STScal STBool
+dn STAccum{} = error "Accum in source program"
+
+dne :: SList STy env -> SList STy (DNE env)
+dne SNil = SNil
+dne (t `SCons` env) = dn t `SCons` dne env