aboutsummaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Count.hs75
-rw-r--r--src/AST/Pretty.hs29
-rw-r--r--src/AST/SplitLets.hs6
-rw-r--r--src/AST/UnMonoid.hs2
-rw-r--r--src/AST/Weaken.hs2
-rw-r--r--src/AST/Weaken/Auto.hs44
6 files changed, 113 insertions, 45 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 296c021..a53822d 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -523,9 +523,8 @@ occCountX initialS topexpr k = case topexpr of
SsNone ->
occCountX SsFull a $ \env1 mka ->
occCountX SsNone b $ \env2'' mkb ->
- withSome (scaleMany (Some env2'')) $ \env2' ->
- occEnvPop' env2' $ \env2 s2 ->
- withSome (Some env1 <> Some env2) $ \env ->
+ occEnvPop' env2'' $ \env2' s2 ->
+ withSome (Some env1 <> scaleMany (Some env2')) $ \env ->
k env $ \env' ->
use (EBuild ext n (mka env') $
use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
@@ -535,31 +534,49 @@ occCountX initialS topexpr k = case topexpr of
SsArr' s' ->
occCountX SsFull a $ \env1 mka ->
occCountX s' b $ \env2'' mkb ->
- withSome (scaleMany (Some env2'')) $ \env2' ->
- occEnvPop' env2' $ \env2 s2 ->
- withSome (Some env1 <> Some env2) $ \env ->
+ occEnvPop' env2'' $ \env2' s2 ->
+ withSome (Some env1 <> scaleMany (Some env2')) $ \env ->
k env $ \env' ->
EBuild ext n (mka env') $
elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))
+ EMap _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1 ->
+ occCountX (SsArr s1) b $ \env2 mkb ->
+ withSome (scaleMany (Some env1') <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $
+ ENil ext
+ SsArr' s' ->
+ occCountX s' a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1 ->
+ occCountX (SsArr s1) b $ \env2 mkb ->
+ withSome (scaleMany (Some env1') <> Some env2) $ \env ->
+ k env $ \env' ->
+ EMap ext (mka (OccPush env' () s1)) (mkb env')
+
EFold1Inner _ commut a b c ->
- occCountX SsFull a $ \env1''' mka ->
- withSome (scaleMany (Some env1''')) $ \env1'' ->
- occEnvPop' env1'' $ \env1' s2 ->
- occEnvPop' env1' $ \env1 s1 ->
- let s0 = case s of
+ occCountX SsFull a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1' ->
+ let s1 = case s1' of
+ SsNone -> Some SsNone
+ SsPair' s1'a s1'b -> Some s1'a <> Some s1'b
+ s0 = case s of
SsNone -> Some SsNone
SsArr' s' -> Some s' in
- withSome (Some s1 <> Some s2 <> s0) $ \sElt ->
+ withSome (s1 <> s0) $ \sElt ->
occCountX sElt b $ \env2 mkb ->
- occCountX (SsArr sElt) c $ \env3 mkc ->
- withSome (Some env1 <> Some env2 <> Some env3) $ \env ->
+ occCountX (SsArr sElt) c $ \env3 mkc ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
k env $ \env' ->
projectSmallerSubstruc (SsArr sElt) s $
EFold1Inner ext commut
(projectSmallerSubstruc SsFull sElt $
- mka (OccPush (OccPush env' () sElt) () sElt))
+ mka (OccPush env' () (SsPair sElt sElt)))
(mkb env') (mkc env')
ESum1Inner _ e -> handleReduction (ESum1Inner ext) e
@@ -608,6 +625,27 @@ occCountX initialS topexpr k = case topexpr of
k env $ \env' ->
EReshape ext n (mkesh env') (mke env')
+ EZip _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsArr' SsNone ->
+ occCountX (SsArr SsNone) a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mkb env') $ mka env'
+ SsArr' (SsPair' s1 s2) ->
+ occCountX (SsArr s1) a $ \env1 mka ->
+ occCountX (SsArr s2) b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EZip ext (mka env') (mkb env')
+
EFold1InnerD1 _ cm e1 e2 e3 ->
case s of
-- If nothing is necessary, we can execute a fold and then proceed to ignore it
@@ -628,7 +666,7 @@ occCountX initialS topexpr k = case topexpr of
elet (mapExt (\_ -> ext) e3) $
EPair ext
(EShape ext (evar IZ))
- (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1)))
+ (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1)))
(mapExt (\_ -> ext) (weakenExpr WSink e2))
(evar IZ))
in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex ->
@@ -638,15 +676,14 @@ occCountX initialS topexpr k = case topexpr of
-- If at least some of the additional stores are required, we need to keep this a mapAccum
SsPair' _ (SsArr' sB) ->
-- TODO: propagate usage of primals
- occCountX (SsPair SsFull sB) e1 $ \env1_2' mka ->
- occEnvPop' env1_2' $ \env1_1' _ ->
+ occCountX (SsPair SsFull sB) e1 $ \env1_1' mka ->
occEnvPop' env1_1' $ \env1' _ ->
occCountX SsFull e2 $ \env2 mkb ->
occCountX SsFull e3 $ \env3 mkc ->
withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
k env $ \env' ->
projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $
- EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull))
+ EFold1InnerD1 ext cm (mka (OccPush env' () SsFull))
(mkb env') (mkc env')
EFold1InnerD2 _ cm ef ebog ed ->
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 68fc629..ecdaa88 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -204,15 +204,22 @@ ppExpr' d val expr = case expr of
<> hardline <> e')
(ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e'])
+ EMap _ a b -> do
+ let STArr _ t1 = typeOf b
+ name <- genNameIfUsedIn' "i" t1 IZ a
+ a' <- ppExpr' 0 (Const name `SCons` val) a
+ b' <- ppExpr' 11 val b
+ return $ ppParen (d > 0) $
+ ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b']
+
EFold1Inner _ cm a b c -> do
- name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
- name2 <- genNameIfUsedIn (typeOf a) IZ a
- a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
+ name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a
+ a' <- ppExpr' 0 (Const name `SCons` val) a
b' <- ppExpr' 11 val b
c' <- ppExpr' 11 val c
let opname = "fold1i" ++ ppCommut cm
return $ ppParen (d > 10) $
- ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c']
ESum1Inner _ e -> do
e' <- ppExpr' 11 val e
@@ -238,17 +245,21 @@ ppExpr' d val expr = case expr of
EReshape _ n esh e -> do
esh' <- ppExpr' 11 val esh
e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr <+> esh' <+> e'
+ return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e']
+
+ EZip _ e1 e2 -> do
+ e1' <- ppExpr' 11 val e1
+ e2' <- ppExpr' 11 val e2
+ return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2']
EFold1InnerD1 _ cm a b c -> do
- name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a
- name2 <- genNameIfUsedIn (typeOf b) IZ a
- a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
+ name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a
+ a' <- ppExpr' 0 (Const name `SCons` val) a
b' <- ppExpr' 11 val b
c' <- ppExpr' 11 val c
let opname = "fold1iD1" ++ ppCommut cm
return $ ppParen (d > 10) $
- ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c']
EFold1InnerD2 _ cm ef ebog ed -> do
let STArr _ tB = typeOf ebog
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index d276e44..267dd87 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -34,10 +34,10 @@ splitLets' = \sub -> \case
in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c)
EFold1Inner x cm a b c ->
let STArr _ t1 = typeOf c
- in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+ in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c)
EFold1InnerD1 x cm a b c ->
let STArr _ t1 = typeOf c
- in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+ in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c)
EFold1InnerD2 x cm a b c ->
let STArr _ tB = typeOf b
STArr _ t2 = typeOf c
@@ -56,12 +56,14 @@ splitLets' = \sub -> \case
ELInr x t e -> ELInr x t (splitLets' sub e)
EConstArr x n t a -> EConstArr x n t a
EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b)
+ EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b)
ESum1Inner x e -> ESum1Inner x (splitLets' sub e)
EUnit x e -> EUnit x (splitLets' sub e)
EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b)
EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e)
EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e)
EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b)
+ EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (splitLets' sub e)
EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b)
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index a22b73f..1712ba5 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -38,6 +38,7 @@ unMonoid = \case
ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)
EConstArr _ n t x -> EConstArr ext n t x
EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
+ EMap _ a b -> EMap ext (unMonoid a) (unMonoid b)
EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
EUnit _ e -> EUnit ext (unMonoid e)
@@ -45,6 +46,7 @@ unMonoid = \case
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b)
+ EZip _ a b -> EZip ext (unMonoid a) (unMonoid b)
EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EConst _ t x -> EConst ext t x
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index 3a97fd1..f0820b8 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -129,7 +129,7 @@ wCopies bs w =
let bs' = slistMap (\_ -> Const ()) bs
in WStack bs' bs' WId w
-wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env
+wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env
wRaiseAbove SNil _ = WClosed
wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env)
diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs
index c6efe37..7370df1 100644
--- a/src/AST/Weaken/Auto.hs
+++ b/src/AST/Weaken/Auto.hs
@@ -11,6 +11,7 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
@@ -23,6 +24,7 @@ module AST.Weaken.Auto (
) where
import Data.Functor.Const
+import Data.Kind (Constraint)
import GHC.OverloadedLabels
import GHC.TypeLits
import Unsafe.Coerce (unsafeCoerce)
@@ -39,18 +41,21 @@ type family Lookup name list where
-- | The @withPre@ type parameter indicates whether there can be 'LPreW'
--- occurrences within this layout.
-data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where
- LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments (Lookup name segments)
+-- occurrences within this layout. 'names' is the list of names that this
+-- layout /produces/. That is: for LPreW, it contains the target name. The
+-- 'names' list of a source layout must be a subset of the names list of the
+-- target layout (which cannot contain LPreW); this is checked with SubLayout.
+data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (names :: [Symbol]) (env :: [t]) where
+ LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments '[name] (Lookup name segments)
-- | Pre-weaken with a weakening
LPreW :: forall name1 name2 segments.
SegmentName name1 -> SegmentName name2
-> Lookup name1 segments :> Lookup name2 segments
- -> Layout True segments (Lookup name1 segments)
- (:++:) :: Layout withPre segments env1 -> Layout withPre segments env2 -> Layout withPre segments (Append env1 env2)
+ -> Layout True segments '[name2] (Lookup name1 segments)
+ (:++:) :: Layout withPre segments names1 env1 -> Layout withPre segments names2 env2 -> Layout withPre segments (Append names1 names2) (Append env1 env2)
infixr :++:
-instance (KnownSymbol name, seg ~ Lookup name segments) => IsLabel name (Layout withPre segments seg) where
+instance (KnownSymbol name, seg ~ Lookup name segments, names ~ '[name]) => IsLabel name (Layout withPre segments names seg) where
fromLabel = LSeg (symbolSing @name)
newtype SegmentName name = SegmentName (SSymbol name)
@@ -60,6 +65,18 @@ instance (KnownSymbol name, name ~ name') => IsLabel name (SegmentName name') wh
fromLabel = SegmentName symbolSing
+type family SubLayout names1 names2 where
+ SubLayout '[] _ = () :: Constraint
+ SubLayout (n : names1) names2 = SubLayout' n (Contains n names2) names1 names2
+type family SubLayout' n ok names1 names2 where
+ SubLayout' n False _ _ = TypeError (Text "The name '" :<>: Text n :<>: Text "' appears in the source layout but not in the target.")
+ SubLayout' _ True names1 names2 = SubLayout names1 names2
+type family Contains n names where
+ Contains _ '[] = False
+ Contains n (n : _) = True
+ Contains n (_ : names) = Contains n names
+
+
data SSegments (segments :: [(Symbol, [t])]) where
SSegNil :: SSegments '[]
SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list)
@@ -74,7 +91,7 @@ auto1 :: SList (Const ()) '[t]
auto1 = Const () `SCons` SNil
infixr &.
-(&.) :: SSegments segs1 -> SSegments segs2 -> SSegments (Append segs1 segs2)
+(&.) :: SSegments '[segs1] -> SSegments segs2 -> SSegments (segs1 : segs2)
(&.) = ssegmentsAppend
where
ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b)
@@ -118,12 +135,12 @@ linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout
| Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2
= LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2)
-lineariseLayout :: Layout withPre segments env -> LinLayout withPre segments env
-lineariseLayout (LSeg name :: Layout _ _ seg)
+lineariseLayout :: Layout withPre segments names env -> LinLayout withPre segments env
+lineariseLayout (LSeg name :: Layout _ _ _ seg)
| Refl <- lemAppendNil @seg
= LinApp name LinEnd
lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2
-lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ seg)
+lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ _ seg)
| Refl <- lemAppendNil @seg
= LinAppPreW name1 name2 w LinEnd
@@ -151,8 +168,7 @@ pullDown segs name@SSymbol linlayout kNotFound k =
k (LinApp n' lin') (WSwap @env' (segmentLookup segs n') (segmentLookup segs name)
.> wCopies (segmentLookup segs n') w)
-sortLinLayouts :: forall segments env1 env2.
- SSegments segments
+sortLinLayouts :: SSegments segments
-> LinLayout False segments env1 -> LinLayout False segments env2 -> env1 :> env2
sortLinLayouts _ LinEnd LinEnd = WId
sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2)
@@ -169,8 +185,8 @@ sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail
sortLinLayouts _ LinEnd LinApp{} = WClosed
sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target"
-autoWeak :: forall segments env1 env2.
- SSegments segments -> Layout True segments env1 -> Layout False segments env2 -> env1 :> env2
+autoWeak :: SubLayout names1 names2
+ => SSegments segments -> Layout True segments names1 env1 -> Layout False segments names2 env2 -> env1 :> env2
autoWeak segs ly1 ly2 =
preWeaken segs (lineariseLayout ly1) $ \wPreweak lin1 ->
sortLinLayouts segs lin1 (lineariseLayout ly2) .> wPreweak