From 83692cf41f76272423445c9cbbad65561ee3b50c Mon Sep 17 00:00:00 2001
From: Tom Smeding <tom@tomsmeding.com>
Date: Fri, 8 Nov 2024 12:37:51 +0100
Subject: WIP custom derivatives

---
 src/ForwardAD/DualNumbers/Types.hs | 46 ++++++++++++++++++++++++++++++++++++++
 1 file changed, 46 insertions(+)
 create mode 100644 src/ForwardAD/DualNumbers/Types.hs

(limited to 'src/ForwardAD/DualNumbers')

diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs
new file mode 100644
index 0000000..fba92d0
--- /dev/null
+++ b/src/ForwardAD/DualNumbers/Types.hs
@@ -0,0 +1,46 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module ForwardAD.DualNumbers.Types where
+
+import AST.Types
+import Data
+
+
+-- | Dual-numbers transformation
+type family DN t where
+  DN TNil = TNil
+  DN (TPair a b) = TPair (DN a) (DN b)
+  DN (TEither a b) = TEither (DN a) (DN b)
+  DN (TMaybe t) = TMaybe (DN t)
+  DN (TArr n t) = TArr n (DN t)
+  DN (TScal t) = DNS t
+
+type family DNS t where
+  DNS TF32 = TPair (TScal TF32) (TScal TF32)
+  DNS TF64 = TPair (TScal TF64) (TScal TF64)
+  DNS TI32 = TScal TI32
+  DNS TI64 = TScal TI64
+  DNS TBool = TScal TBool
+
+type family DNE env where
+  DNE '[] = '[]
+  DNE (t : ts) = DN t : DNE ts
+
+dn :: STy t -> STy (DN t)
+dn STNil = STNil
+dn (STPair a b) = STPair (dn a) (dn b)
+dn (STEither a b) = STEither (dn a) (dn b)
+dn (STMaybe t) = STMaybe (dn t)
+dn (STArr n t) = STArr n (dn t)
+dn (STScal t) = case t of
+  STF32 -> STPair (STScal STF32) (STScal STF32)
+  STF64 -> STPair (STScal STF64) (STScal STF64)
+  STI32 -> STScal STI32
+  STI64 -> STScal STI64
+  STBool -> STScal STBool
+dn STAccum{} = error "Accum in source program"
+
+dne :: SList STy env -> SList STy (DNE env)
+dne SNil = SNil
+dne (t `SCons` env) = dn t `SCons` dne env
-- 
cgit v1.2.3-70-g09d2