{-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module ForwardAD.DualNumbers.Types where import AST.Types import Data -- | Dual-numbers transformation type family DN t where DN TNil = TNil DN (TPair a b) = TPair (DN a) (DN b) DN (TEither a b) = TEither (DN a) (DN b) DN (TMaybe t) = TMaybe (DN t) DN (TArr n t) = TArr n (DN t) DN (TScal t) = DNS t type family DNS t where DNS TF32 = TPair (TScal TF32) (TScal TF32) DNS TF64 = TPair (TScal TF64) (TScal TF64) DNS TI32 = TScal TI32 DNS TI64 = TScal TI64 DNS TBool = TScal TBool type family DNE env where DNE '[] = '[] DNE (t : ts) = DN t : DNE ts dn :: STy t -> STy (DN t) dn STNil = STNil dn (STPair a b) = STPair (dn a) (dn b) dn (STEither a b) = STEither (dn a) (dn b) dn (STMaybe t) = STMaybe (dn t) dn (STArr n t) = STArr n (dn t) dn (STScal t) = case t of STF32 -> STPair (STScal STF32) (STScal STF32) STF64 -> STPair (STScal STF64) (STScal STF64) STI32 -> STScal STI32 STI64 -> STScal STI64 STBool -> STScal STBool dn STAccum{} = error "Accum in source program" dne :: SList STy env -> SList STy (DNE env) dne SNil = SNil dne (t `SCons` env) = dn t `SCons` dne env