summaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Count.hs1
-rw-r--r--src/AST/Pretty.hs17
-rw-r--r--src/AST/Weaken.hs6
-rw-r--r--src/AST/Weaken/Auto.hs7
4 files changed, 22 insertions, 9 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 71b38b1..d365218 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -125,6 +125,7 @@ occCountGeneral onehot unpush alter many = go WId
EIdx _ a b -> re a <> re b
EShape _ e -> re e
EOp _ _ e -> re e
+ ECustom _ _ _ _ _ _ a b -> re a <> re b
EWith a b -> re a <> re1 b
EAccum _ a b e -> re a <> re b <> re e
EZero _ -> mempty
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 76424fe..bac267d 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -14,7 +14,9 @@ import Data.Functor.Const
import AST
import AST.Count
+import CHAD.Types
import Data
+import ForwardAD.DualNumbers.Types
type SVal = SList (Const String)
@@ -175,6 +177,21 @@ ppExpr' d val = \case
(Prefix, s) -> s
return $ showParen (d > 10) $ showString (ops ++ " ") . e'
+ ECustom _ t1 t2 a b c e1 e2 -> do
+ pn1 <- genNameIfUsedIn t1 (IS IZ) a
+ pn2 <- genNameIfUsedIn t2 IZ a
+ fn1 <- genNameIfUsedIn t1 (IS IZ) b
+ fn2 <- genNameIfUsedIn (dn t2) IZ b
+ rn1 <- genNameIfUsedIn (d1 t1) (IS (IS IZ)) c
+ rn2 <- genNameIfUsedIn (d1 t2) (IS IZ) c
+ rn3 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c
+ a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a
+ b' <- ppExpr' 11 (Const fn2 `SCons` Const fn1 `SCons` SNil) b
+ c' <- ppExpr' 11 (Const rn3 `SCons` Const rn2 `SCons` Const rn1 `SCons` SNil) c
+ e1' <- ppExpr' 11 val e1
+ e2' <- ppExpr' 11 val e2
+ return $ showParen (d > 10) $ showString "custom " . a' . showString " " . b' . showString " " . c' . showString " " . e1' . showString " " . e2'
+
EWith e1 e2 -> do
e1' <- ppExpr' 11 val e1
name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index ecd7bc9..dbb37f7 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -42,7 +42,7 @@ data env :> env' where
WCopy :: forall t env env'. env :> env' -> (t : env) :> (t : env')
WPop :: (t : env) :> env' -> env :> env'
WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3
- WClosed :: SList (Const ()) env -> '[] :> env
+ WClosed :: '[] :> env
WIdx :: Idx env t -> (t : env) :> env
WPick :: forall t pre env env'. SList (Const ()) pre -> env :> env'
-> Append pre (t : env) :> t : Append pre env'
@@ -62,7 +62,7 @@ WCopy _ @> IZ = IZ
WCopy w @> IS i = IS (w @> i)
WPop w @> i = w @> IS i
WThen w1 w2 @> i = w2 @> w1 @> i
-WClosed _ @> i = case i of {}
+WClosed @> i = case i of {}
WIdx j @> IZ = j
WIdx _ @> IS i = i
WPick SNil w @> i = WCopy w @> i
@@ -115,5 +115,5 @@ wCopies bs w =
in WStack bs' bs' WId w
wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env
-wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env)
+wRaiseAbove SNil _ = WClosed
wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env)
diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs
index 8555516..6752c24 100644
--- a/src/AST/Weaken/Auto.hs
+++ b/src/AST/Weaken/Auto.hs
@@ -118,11 +118,6 @@ linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout
| Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2
= LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2)
-linLayoutEnv :: SSegments segments -> LinLayout withPre segments env -> SList (Const ()) env
-linLayoutEnv _ LinEnd = SNil
-linLayoutEnv segs (LinApp name lin) = sappend (segmentLookup segs name) (linLayoutEnv segs lin)
-linLayoutEnv segs (LinAppPreW name1 _ _ lin) = sappend (segmentLookup segs name1) (linLayoutEnv segs lin)
-
lineariseLayout :: Layout withPre segments env -> LinLayout withPre segments env
lineariseLayout (LSeg name :: Layout _ _ seg)
| Refl <- lemAppendNil @seg
@@ -171,7 +166,7 @@ sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail
-- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and
-- wCopies the name2 on top of that.
wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w)
-sortLinLayouts segs LinEnd lin2@LinApp{} = WClosed (linLayoutEnv segs lin2)
+sortLinLayouts _ LinEnd LinApp{} = WClosed
sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target"
autoWeak :: forall segments env1 env2.