diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 20:29:29 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 20:29:29 +0100 | 
| commit | 4fcdb7118e0084f192753ea6c70394352a27d5ed (patch) | |
| tree | c5e91ae438b6f4c3e5075bf591e5fbe28aa5d96b /src/AST | |
| parent | 83692cf41f76272423445c9cbbad65561ee3b50c (diff) | |
Custom derivatives
Diffstat (limited to 'src/AST')
| -rw-r--r-- | src/AST/Count.hs | 2 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 27 | 
2 files changed, 16 insertions, 13 deletions
| diff --git a/src/AST/Count.hs b/src/AST/Count.hs index d365218..364773a 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -125,7 +125,7 @@ occCountGeneral onehot unpush alter many = go WId        EIdx _ a b -> re a <> re b        EShape _ e -> re e        EOp _ _ e -> re e -      ECustom _ _ _ _ _ _ a b -> re a <> re b +      ECustom _ _ _ _ _ _ _ a b -> re a <> re b        EWith a b -> re a <> re1 b        EAccum _ a b e -> re a <> re b <> re e        EZero _ -> mempty diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index bac267d..a2232ee 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -16,7 +16,6 @@ import AST  import AST.Count  import CHAD.Types  import Data -import ForwardAD.DualNumbers.Types  type SVal = SList (Const String) @@ -177,20 +176,24 @@ ppExpr' d val = \case                  (Prefix, s) -> s      return $ showParen (d > 10) $ showString (ops ++ " ") . e' -  ECustom _ t1 t2 a b c e1 e2 -> do -    pn1 <- genNameIfUsedIn t1 (IS IZ) a -    pn2 <- genNameIfUsedIn t2 IZ a -    fn1 <- genNameIfUsedIn t1 (IS IZ) b -    fn2 <- genNameIfUsedIn (dn t2) IZ b -    rn1 <- genNameIfUsedIn (d1 t1) (IS (IS IZ)) c -    rn2 <- genNameIfUsedIn (d1 t2) (IS IZ) c -    rn3 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c +  ECustom _ t1 t2 t3 a b c e1 e2 -> do +    en1 <- genNameIfUsedIn t1 (IS IZ) a +    en2 <- genNameIfUsedIn t2 IZ a +    pn1 <- genNameIfUsedIn (d1 t1) (IS IZ) b +    pn2 <- genNameIfUsedIn (d1 t2) IZ b +    dn1 <- genNameIfUsedIn' "tape" t3 (IS IZ) c +    dn2 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c      a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a -    b' <- ppExpr' 11 (Const fn2 `SCons` Const fn1 `SCons` SNil) b -    c' <- ppExpr' 11 (Const rn3 `SCons` Const rn2 `SCons` Const rn1 `SCons` SNil) c +    b' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) b +    c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c      e1' <- ppExpr' 11 val e1      e2' <- ppExpr' 11 val e2 -    return $ showParen (d > 10) $ showString "custom " . a' . showString " " . b' . showString " " . c' . showString " " . e1' . showString " " . e2' +    return $ showParen (d > 10) $ showString "custom " +      . showString ("(" ++ en1 ++ " " ++ en2 ++ ". ") . a' . showString ") " +      . showString ("(" ++ pn1 ++ " " ++ pn2 ++ ". ") . b' . showString ") " +      . showString ("(" ++ dn1 ++ " " ++ dn2 ++ ". ") . c' . showString ") " +      . e1' . showString " " +      . e2'    EWith e1 e2 -> do      e1' <- ppExpr' 11 val e1 | 
