blob: fba92d0075cf755e276dc81a886a9a5ec30496a4 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module ForwardAD.DualNumbers.Types where
import AST.Types
import Data
-- | Dual-numbers transformation
type family DN t where
DN TNil = TNil
DN (TPair a b) = TPair (DN a) (DN b)
DN (TEither a b) = TEither (DN a) (DN b)
DN (TMaybe t) = TMaybe (DN t)
DN (TArr n t) = TArr n (DN t)
DN (TScal t) = DNS t
type family DNS t where
DNS TF32 = TPair (TScal TF32) (TScal TF32)
DNS TF64 = TPair (TScal TF64) (TScal TF64)
DNS TI32 = TScal TI32
DNS TI64 = TScal TI64
DNS TBool = TScal TBool
type family DNE env where
DNE '[] = '[]
DNE (t : ts) = DN t : DNE ts
dn :: STy t -> STy (DN t)
dn STNil = STNil
dn (STPair a b) = STPair (dn a) (dn b)
dn (STEither a b) = STEither (dn a) (dn b)
dn (STMaybe t) = STMaybe (dn t)
dn (STArr n t) = STArr n (dn t)
dn (STScal t) = case t of
STF32 -> STPair (STScal STF32) (STScal STF32)
STF64 -> STPair (STScal STF64) (STScal STF64)
STI32 -> STScal STI32
STI64 -> STScal STI64
STBool -> STScal STBool
dn STAccum{} = error "Accum in source program"
dne :: SList STy env -> SList STy (DNE env)
dne SNil = SNil
dne (t `SCons` env) = dn t `SCons` dne env
|