summaryrefslogtreecommitdiff
path: root/src/ForwardAD/DualNumbers
diff options
context:
space:
mode:
Diffstat (limited to 'src/ForwardAD/DualNumbers')
-rw-r--r--src/ForwardAD/DualNumbers/Types.hs46
1 files changed, 46 insertions, 0 deletions
diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs
new file mode 100644
index 0000000..fba92d0
--- /dev/null
+++ b/src/ForwardAD/DualNumbers/Types.hs
@@ -0,0 +1,46 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module ForwardAD.DualNumbers.Types where
+
+import AST.Types
+import Data
+
+
+-- | Dual-numbers transformation
+type family DN t where
+ DN TNil = TNil
+ DN (TPair a b) = TPair (DN a) (DN b)
+ DN (TEither a b) = TEither (DN a) (DN b)
+ DN (TMaybe t) = TMaybe (DN t)
+ DN (TArr n t) = TArr n (DN t)
+ DN (TScal t) = DNS t
+
+type family DNS t where
+ DNS TF32 = TPair (TScal TF32) (TScal TF32)
+ DNS TF64 = TPair (TScal TF64) (TScal TF64)
+ DNS TI32 = TScal TI32
+ DNS TI64 = TScal TI64
+ DNS TBool = TScal TBool
+
+type family DNE env where
+ DNE '[] = '[]
+ DNE (t : ts) = DN t : DNE ts
+
+dn :: STy t -> STy (DN t)
+dn STNil = STNil
+dn (STPair a b) = STPair (dn a) (dn b)
+dn (STEither a b) = STEither (dn a) (dn b)
+dn (STMaybe t) = STMaybe (dn t)
+dn (STArr n t) = STArr n (dn t)
+dn (STScal t) = case t of
+ STF32 -> STPair (STScal STF32) (STScal STF32)
+ STF64 -> STPair (STScal STF64) (STScal STF64)
+ STI32 -> STScal STI32
+ STI64 -> STScal STI64
+ STBool -> STScal STBool
+dn STAccum{} = error "Accum in source program"
+
+dne :: SList STy env -> SList STy (DNE env)
+dne SNil = SNil
+dne (t `SCons` env) = dn t `SCons` dne env