diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 22 |
1 files changed, 15 insertions, 7 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index e77dbe7..dda434c 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -545,6 +545,8 @@ d1op (OLt t) e = EOp ext (OLt t) e d1op (OLe t) e = EOp ext (OLe t) e d1op (OEq t) e = EOp ext (OEq t) e d1op ONot e = EOp ext ONot e +d1op OAnd e = EOp ext OAnd e +d1op OOr e = EOp ext OOr e d1op OIf e = EOp ext OIf e d1op ORound64 e = EOp ext ORound64 e d1op OToFl64 e = EOp ext OToFl64 e @@ -564,6 +566,8 @@ d2op op = case op of OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) ONot -> Linear $ \_ -> ENil ext + OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) OIf -> Linear $ \_ -> ENil ext ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 OToFl64 -> Linear $ \_ -> ENil ext @@ -1078,15 +1082,19 @@ drev des = \case | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil , STArr n eltty <- typeOf e - , Refl <- indexTupD1Id n -> + , Refl <- indexTupD1Id n + , let tIxN = tTup (sreplicate n tIx) -> Ret (binds `BPush` (STArr n (d1 eltty), e1) - `BPush` (tTup (sreplicate n tIx), EShape ext (EVar ext (typeOf e1) IZ))) - (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS IZ)) - (weakenExpr (WSink .> WSink) ei1)) + `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) + `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) + (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) sub - (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS IZ)) - (EVar ext (d2 eltty) (IS IZ))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS IZ)) $ + ECase ext (EOp ext OIf (eidxEq n (EVar ext tIxN IZ) (EVar ext tIxN (IS (IS IZ))))) + (EVar ext (d2 eltty) (IS (IS IZ))) + (EZero eltty)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) EShape _ e -- Allowed to ignore e2 here because the output of EShape is discrete, |