diff options
Diffstat (limited to 'src/AST')
-rw-r--r-- | src/AST/Count.hs | 1 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 17 | ||||
-rw-r--r-- | src/AST/Weaken.hs | 6 | ||||
-rw-r--r-- | src/AST/Weaken/Auto.hs | 7 |
4 files changed, 22 insertions, 9 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 71b38b1..d365218 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -125,6 +125,7 @@ occCountGeneral onehot unpush alter many = go WId EIdx _ a b -> re a <> re b EShape _ e -> re e EOp _ _ e -> re e + ECustom _ _ _ _ _ _ a b -> re a <> re b EWith a b -> re a <> re1 b EAccum _ a b e -> re a <> re b <> re e EZero _ -> mempty diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 76424fe..bac267d 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -14,7 +14,9 @@ import Data.Functor.Const import AST import AST.Count +import CHAD.Types import Data +import ForwardAD.DualNumbers.Types type SVal = SList (Const String) @@ -175,6 +177,21 @@ ppExpr' d val = \case (Prefix, s) -> s return $ showParen (d > 10) $ showString (ops ++ " ") . e' + ECustom _ t1 t2 a b c e1 e2 -> do + pn1 <- genNameIfUsedIn t1 (IS IZ) a + pn2 <- genNameIfUsedIn t2 IZ a + fn1 <- genNameIfUsedIn t1 (IS IZ) b + fn2 <- genNameIfUsedIn (dn t2) IZ b + rn1 <- genNameIfUsedIn (d1 t1) (IS (IS IZ)) c + rn2 <- genNameIfUsedIn (d1 t2) (IS IZ) c + rn3 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c + a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a + b' <- ppExpr' 11 (Const fn2 `SCons` Const fn1 `SCons` SNil) b + c' <- ppExpr' 11 (Const rn3 `SCons` Const rn2 `SCons` Const rn1 `SCons` SNil) c + e1' <- ppExpr' 11 val e1 + e2' <- ppExpr' 11 val e2 + return $ showParen (d > 10) $ showString "custom " . a' . showString " " . b' . showString " " . c' . showString " " . e1' . showString " " . e2' + EWith e1 e2 -> do e1' <- ppExpr' 11 val e1 name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2 diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index ecd7bc9..dbb37f7 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -42,7 +42,7 @@ data env :> env' where WCopy :: forall t env env'. env :> env' -> (t : env) :> (t : env') WPop :: (t : env) :> env' -> env :> env' WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3 - WClosed :: SList (Const ()) env -> '[] :> env + WClosed :: '[] :> env WIdx :: Idx env t -> (t : env) :> env WPick :: forall t pre env env'. SList (Const ()) pre -> env :> env' -> Append pre (t : env) :> t : Append pre env' @@ -62,7 +62,7 @@ WCopy _ @> IZ = IZ WCopy w @> IS i = IS (w @> i) WPop w @> i = w @> IS i WThen w1 w2 @> i = w2 @> w1 @> i -WClosed _ @> i = case i of {} +WClosed @> i = case i of {} WIdx j @> IZ = j WIdx _ @> IS i = i WPick SNil w @> i = WCopy w @> i @@ -115,5 +115,5 @@ wCopies bs w = in WStack bs' bs' WId w wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env -wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env) +wRaiseAbove SNil _ = WClosed wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs index 8555516..6752c24 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/AST/Weaken/Auto.hs @@ -118,11 +118,6 @@ linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2 = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2) -linLayoutEnv :: SSegments segments -> LinLayout withPre segments env -> SList (Const ()) env -linLayoutEnv _ LinEnd = SNil -linLayoutEnv segs (LinApp name lin) = sappend (segmentLookup segs name) (linLayoutEnv segs lin) -linLayoutEnv segs (LinAppPreW name1 _ _ lin) = sappend (segmentLookup segs name1) (linLayoutEnv segs lin) - lineariseLayout :: Layout withPre segments env -> LinLayout withPre segments env lineariseLayout (LSeg name :: Layout _ _ seg) | Refl <- lemAppendNil @seg @@ -171,7 +166,7 @@ sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail -- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and -- wCopies the name2 on top of that. wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w) -sortLinLayouts segs LinEnd lin2@LinApp{} = WClosed (linLayoutEnv segs lin2) +sortLinLayouts _ LinEnd LinApp{} = WClosed sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target" autoWeak :: forall segments env1 env2. |