summaryrefslogtreecommitdiff
path: root/src/CHAD
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD')
-rw-r--r--src/CHAD/Accum.hs4
-rw-r--r--src/CHAD/Top.hs3
-rw-r--r--src/CHAD/Types.hs44
3 files changed, 31 insertions, 20 deletions
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs
index b61b5ff..d8a71b5 100644
--- a/src/CHAD/Accum.hs
+++ b/src/CHAD/Accum.hs
@@ -10,9 +10,9 @@ import Data
makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
makeAccumulators SNil e = e
-makeAccumulators (t `SCons` envpro) e =
+makeAccumulators (t `SCons` envpro) e | Refl <- lemZeroInfoD2 t =
makeAccumulators envpro $
- EWith ext t (EZero ext t) e
+ EWith ext (d2M t) (EZero ext (d2M t) (ENil ext)) e
uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
uninvertTup SNil _ e = EPair ext e (ENil ext)
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index 2c01178..9e7e7f5 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -53,6 +53,7 @@ d1Identity = \case
STArr _ t | Refl <- d1Identity t -> Refl
STScal _ -> Refl
STAccum{} -> error "Accumulators not allowed in input program"
+ STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
d1eIdentity :: SList STy env -> D1E env :~: env
d1eIdentity SNil = Refl
@@ -72,7 +73,7 @@ reassembleD2E (des `DPush` (_, _, SMerge)) e =
EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ))
(EFst ext (ESnd ext (EVar ext (typeOf e) IZ)))))
(ESnd ext (ESnd ext (EVar ext (typeOf e) IZ)))
-reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero ext t)
+reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (ezeroD2 t)
chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
chad config env (term :: Ex env t)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 7f49cef..74e7dbd 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -14,14 +14,16 @@ type family D1 t where
D1 (TMaybe a) = TMaybe (D1 a)
D1 (TArr n t) = TArr n (D1 t)
D1 (TScal t) = TScal t
+ D1 (TLEither a b) = TLEither (D1 a) (D1 b)
type family D2 t where
D2 TNil = TNil
D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b))
- D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b))
+ D2 (TEither a b) = TLEither (D2 a) (D2 b)
D2 (TMaybe t) = TMaybe (D2 t)
D2 (TArr n t) = TMaybe (TArr n (D2 t))
D2 (TScal t) = D2s t
+ D2 (TLEither a b) = TLEither (D2 a) (D2 b)
type family D2s t where
D2s TI32 = TNil
@@ -40,7 +42,7 @@ type family D2E env where
type family D2AcE env where
D2AcE '[] = '[]
- D2AcE (t : env) = TAccum t : D2AcE env
+ D2AcE (t : env) = TAccum (D2 t) : D2AcE env
d1 :: STy t -> STy (D1 t)
d1 STNil = STNil
@@ -50,32 +52,40 @@ d1 (STMaybe t) = STMaybe (d1 t)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
d1 STAccum{} = error "Accumulators not allowed in input program"
+d1 (STLEither a b) = STLEither (d1 a) (d1 b)
d1e :: SList STy env -> SList STy (D1E env)
d1e SNil = SNil
d1e (t `SCons` env) = d1 t `SCons` d1e env
+d2M :: STy t -> SMTy (D2 t)
+d2M STNil = SMTNil
+d2M (STPair a b) = SMTMaybe (SMTPair (d2M a) (d2M b))
+d2M (STEither a b) = SMTLEither (d2M a) (d2M b)
+d2M (STMaybe t) = SMTMaybe (d2M t)
+d2M (STArr n t) = SMTMaybe (SMTArr n (d2M t))
+d2M (STScal t) = case t of
+ STI32 -> SMTNil
+ STI64 -> SMTNil
+ STF32 -> SMTScal STF32
+ STF64 -> SMTScal STF64
+ STBool -> SMTNil
+d2M STAccum{} = error "Accumulators not allowed in input program"
+d2M (STLEither a b) = SMTLEither (d2M a) (d2M b)
+
d2 :: STy t -> STy (D2 t)
-d2 STNil = STNil
-d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b))
-d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b))
-d2 (STMaybe t) = STMaybe (d2 t)
-d2 (STArr n t) = STMaybe (STArr n (d2 t))
-d2 (STScal t) = case t of
- STI32 -> STNil
- STI64 -> STNil
- STF32 -> STScal STF32
- STF64 -> STScal STF64
- STBool -> STNil
-d2 STAccum{} = error "Accumulators not allowed in input program"
+d2 = fromSMTy . d2M
+
+d2eM :: SList STy env -> SList SMTy (D2E env)
+d2eM SNil = SNil
+d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts
d2e :: SList STy env -> SList STy (D2E env)
-d2e SNil = SNil
-d2e (t `SCons` ts) = d2 t `SCons` d2e ts
+d2e = slistMap fromSMTy . d2eM
d2ace :: SList STy env -> SList STy (D2AcE env)
d2ace SNil = SNil
-d2ace (t `SCons` ts) = STAccum t `SCons` d2ace ts
+d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts
data CHADConfig = CHADConfig