diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 12:37:51 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 12:37:51 +0100 |
commit | 83692cf41f76272423445c9cbbad65561ee3b50c (patch) | |
tree | 49f56f498a68722a7302b4ce0b41402a9b9da9ef /src/ForwardAD/DualNumbers | |
parent | 58d4d0b47f5e609e21132f48b727de37d06b6777 (diff) |
WIP custom derivatives
Diffstat (limited to 'src/ForwardAD/DualNumbers')
-rw-r--r-- | src/ForwardAD/DualNumbers/Types.hs | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs new file mode 100644 index 0000000..fba92d0 --- /dev/null +++ b/src/ForwardAD/DualNumbers/Types.hs @@ -0,0 +1,46 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module ForwardAD.DualNumbers.Types where + +import AST.Types +import Data + + +-- | Dual-numbers transformation +type family DN t where + DN TNil = TNil + DN (TPair a b) = TPair (DN a) (DN b) + DN (TEither a b) = TEither (DN a) (DN b) + DN (TMaybe t) = TMaybe (DN t) + DN (TArr n t) = TArr n (DN t) + DN (TScal t) = DNS t + +type family DNS t where + DNS TF32 = TPair (TScal TF32) (TScal TF32) + DNS TF64 = TPair (TScal TF64) (TScal TF64) + DNS TI32 = TScal TI32 + DNS TI64 = TScal TI64 + DNS TBool = TScal TBool + +type family DNE env where + DNE '[] = '[] + DNE (t : ts) = DN t : DNE ts + +dn :: STy t -> STy (DN t) +dn STNil = STNil +dn (STPair a b) = STPair (dn a) (dn b) +dn (STEither a b) = STEither (dn a) (dn b) +dn (STMaybe t) = STMaybe (dn t) +dn (STArr n t) = STArr n (dn t) +dn (STScal t) = case t of + STF32 -> STPair (STScal STF32) (STScal STF32) + STF64 -> STPair (STScal STF64) (STScal STF64) + STI32 -> STScal STI32 + STI64 -> STScal STI64 + STBool -> STScal STBool +dn STAccum{} = error "Accum in source program" + +dne :: SList STy env -> SList STy (DNE env) +dne SNil = SNil +dne (t `SCons` env) = dn t `SCons` dne env |