summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs22
1 files changed, 15 insertions, 7 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index e77dbe7..dda434c 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -545,6 +545,8 @@ d1op (OLt t) e = EOp ext (OLt t) e
d1op (OLe t) e = EOp ext (OLe t) e
d1op (OEq t) e = EOp ext (OEq t) e
d1op ONot e = EOp ext ONot e
+d1op OAnd e = EOp ext OAnd e
+d1op OOr e = EOp ext OOr e
d1op OIf e = EOp ext OIf e
d1op ORound64 e = EOp ext ORound64 e
d1op OToFl64 e = EOp ext OToFl64 e
@@ -564,6 +566,8 @@ d2op op = case op of
OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
ONot -> Linear $ \_ -> ENil ext
+ OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
OIf -> Linear $ \_ -> ENil ext
ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
OToFl64 -> Linear $ \_ -> ENil ext
@@ -1078,15 +1082,19 @@ drev des = \case
| Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
<- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
, STArr n eltty <- typeOf e
- , Refl <- indexTupD1Id n ->
+ , Refl <- indexTupD1Id n
+ , let tIxN = tTup (sreplicate n tIx) ->
Ret (binds `BPush` (STArr n (d1 eltty), e1)
- `BPush` (tTup (sreplicate n tIx), EShape ext (EVar ext (typeOf e1) IZ)))
- (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS IZ))
- (weakenExpr (WSink .> WSink) ei1))
+ `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ))
+ `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1))
+ (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
sub
- (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS IZ))
- (EVar ext (d2 eltty) (IS IZ))) $
- weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
+ (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS IZ)) $
+ ECase ext (EOp ext OIf (eidxEq n (EVar ext tIxN IZ) (EVar ext tIxN (IS (IS IZ)))))
+ (EVar ext (d2 eltty) (IS (IS IZ)))
+ (EZero eltty)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2)
EShape _ e
-- Allowed to ignore e2 here because the output of EShape is discrete,