diff options
Diffstat (limited to 'src/AST')
| -rw-r--r-- | src/AST/Count.hs | 75 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 29 | ||||
| -rw-r--r-- | src/AST/SplitLets.hs | 6 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 2 | ||||
| -rw-r--r-- | src/AST/Weaken.hs | 2 | ||||
| -rw-r--r-- | src/AST/Weaken/Auto.hs | 44 |
6 files changed, 113 insertions, 45 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 296c021..a53822d 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -523,9 +523,8 @@ occCountX initialS topexpr k = case topexpr of SsNone -> occCountX SsFull a $ \env1 mka -> occCountX SsNone b $ \env2'' mkb -> - withSome (scaleMany (Some env2'')) $ \env2' -> - occEnvPop' env2' $ \env2 s2 -> - withSome (Some env1 <> Some env2) $ \env -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> k env $ \env' -> use (EBuild ext n (mka env') $ use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ @@ -535,31 +534,49 @@ occCountX initialS topexpr k = case topexpr of SsArr' s' -> occCountX SsFull a $ \env1 mka -> occCountX s' b $ \env2'' mkb -> - withSome (scaleMany (Some env2'')) $ \env2' -> - occEnvPop' env2' $ \env2 s2 -> - withSome (Some env1 <> Some env2) $ \env -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> k env $ \env' -> EBuild ext n (mka env') $ elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) + EMap _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $ + ENil ext + SsArr' s' -> + occCountX s' a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + EMap ext (mka (OccPush env' () s1)) (mkb env') + EFold1Inner _ commut a b c -> - occCountX SsFull a $ \env1''' mka -> - withSome (scaleMany (Some env1''')) $ \env1'' -> - occEnvPop' env1'' $ \env1' s2 -> - occEnvPop' env1' $ \env1 s1 -> - let s0 = case s of + occCountX SsFull a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1' -> + let s1 = case s1' of + SsNone -> Some SsNone + SsPair' s1'a s1'b -> Some s1'a <> Some s1'b + s0 = case s of SsNone -> Some SsNone SsArr' s' -> Some s' in - withSome (Some s1 <> Some s2 <> s0) $ \sElt -> + withSome (s1 <> s0) $ \sElt -> occCountX sElt b $ \env2 mkb -> - occCountX (SsArr sElt) c $ \env3 mkc -> - withSome (Some env1 <> Some env2 <> Some env3) $ \env -> + occCountX (SsArr sElt) c $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> projectSmallerSubstruc (SsArr sElt) s $ EFold1Inner ext commut (projectSmallerSubstruc SsFull sElt $ - mka (OccPush (OccPush env' () sElt) () sElt)) + mka (OccPush env' () (SsPair sElt sElt))) (mkb env') (mkc env') ESum1Inner _ e -> handleReduction (ESum1Inner ext) e @@ -608,6 +625,27 @@ occCountX initialS topexpr k = case topexpr of k env $ \env' -> EReshape ext n (mkesh env') (mke env') + EZip _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' SsNone -> + occCountX (SsArr SsNone) a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkb env') $ mka env' + SsArr' (SsPair' s1 s2) -> + occCountX (SsArr s1) a $ \env1 mka -> + occCountX (SsArr s2) b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EZip ext (mka env') (mkb env') + EFold1InnerD1 _ cm e1 e2 e3 -> case s of -- If nothing is necessary, we can execute a fold and then proceed to ignore it @@ -628,7 +666,7 @@ occCountX initialS topexpr k = case topexpr of elet (mapExt (\_ -> ext) e3) $ EPair ext (EShape ext (evar IZ)) - (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1))) + (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1))) (mapExt (\_ -> ext) (weakenExpr WSink e2)) (evar IZ)) in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> @@ -638,15 +676,14 @@ occCountX initialS topexpr k = case topexpr of -- If at least some of the additional stores are required, we need to keep this a mapAccum SsPair' _ (SsArr' sB) -> -- TODO: propagate usage of primals - occCountX (SsPair SsFull sB) e1 $ \env1_2' mka -> - occEnvPop' env1_2' $ \env1_1' _ -> + occCountX (SsPair SsFull sB) e1 $ \env1_1' mka -> occEnvPop' env1_1' $ \env1' _ -> occCountX SsFull e2 $ \env2 mkb -> occCountX SsFull e3 $ \env3 mkc -> withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ - EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull)) + EFold1InnerD1 ext cm (mka (OccPush env' () SsFull)) (mkb env') (mkc env') EFold1InnerD2 _ cm ef ebog ed -> diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 68fc629..ecdaa88 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -204,15 +204,22 @@ ppExpr' d val expr = case expr of <> hardline <> e') (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e']) + EMap _ a b -> do + let STArr _ t1 = typeOf b + name <- genNameIfUsedIn' "i" t1 IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a + b' <- ppExpr' 11 val b + return $ ppParen (d > 0) $ + ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b'] + EFold1Inner _ cm a b c -> do - name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a - name2 <- genNameIfUsedIn (typeOf a) IZ a - a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a + name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c let opname = "fold1i" ++ ppCommut cm return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] ESum1Inner _ e -> do e' <- ppExpr' 11 val e @@ -238,17 +245,21 @@ ppExpr' d val expr = case expr of EReshape _ n esh e -> do esh' <- ppExpr' 11 val esh e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr <+> esh' <+> e' + return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e'] + + EZip _ e1 e2 -> do + e1' <- ppExpr' 11 val e1 + e2' <- ppExpr' 11 val e2 + return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2'] EFold1InnerD1 _ cm a b c -> do - name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a - name2 <- genNameIfUsedIn (typeOf b) IZ a - a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a + name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c let opname = "fold1iD1" ++ ppCommut cm return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] EFold1InnerD2 _ cm ef ebog ed -> do let STArr _ tB = typeOf ebog diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index d276e44..267dd87 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -34,10 +34,10 @@ splitLets' = \sub -> \case in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c) EFold1Inner x cm a b c -> let STArr _ t1 = typeOf c - in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) EFold1InnerD1 x cm a b c -> let STArr _ t1 = typeOf c - in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) EFold1InnerD2 x cm a b c -> let STArr _ tB = typeOf b STArr _ t2 = typeOf c @@ -56,12 +56,14 @@ splitLets' = \sub -> \case ELInr x t e -> ELInr x t (splitLets' sub e) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) + EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b) ESum1Inner x e -> ESum1Inner x (splitLets' sub e) EUnit x e -> EUnit x (splitLets' sub e) EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b) + EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (splitLets' sub e) EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index a22b73f..1712ba5 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -38,6 +38,7 @@ unMonoid = \case ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) EConstArr _ n t x -> EConstArr ext n t x EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) + EMap _ a b -> EMap ext (unMonoid a) (unMonoid b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) ESum1Inner _ e -> ESum1Inner ext (unMonoid e) EUnit _ e -> EUnit ext (unMonoid e) @@ -45,6 +46,7 @@ unMonoid = \case EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) + EZip _ a b -> EZip ext (unMonoid a) (unMonoid b) EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EConst _ t x -> EConst ext t x diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 3a97fd1..f0820b8 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -129,7 +129,7 @@ wCopies bs w = let bs' = slistMap (\_ -> Const ()) bs in WStack bs' bs' WId w -wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env +wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env wRaiseAbove SNil _ = WClosed wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs index c6efe37..7370df1 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/AST/Weaken/Auto.hs @@ -11,6 +11,7 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE AllowAmbiguousTypes #-} @@ -23,6 +24,7 @@ module AST.Weaken.Auto ( ) where import Data.Functor.Const +import Data.Kind (Constraint) import GHC.OverloadedLabels import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) @@ -39,18 +41,21 @@ type family Lookup name list where -- | The @withPre@ type parameter indicates whether there can be 'LPreW' --- occurrences within this layout. -data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where - LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments (Lookup name segments) +-- occurrences within this layout. 'names' is the list of names that this +-- layout /produces/. That is: for LPreW, it contains the target name. The +-- 'names' list of a source layout must be a subset of the names list of the +-- target layout (which cannot contain LPreW); this is checked with SubLayout. +data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (names :: [Symbol]) (env :: [t]) where + LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments '[name] (Lookup name segments) -- | Pre-weaken with a weakening LPreW :: forall name1 name2 segments. SegmentName name1 -> SegmentName name2 -> Lookup name1 segments :> Lookup name2 segments - -> Layout True segments (Lookup name1 segments) - (:++:) :: Layout withPre segments env1 -> Layout withPre segments env2 -> Layout withPre segments (Append env1 env2) + -> Layout True segments '[name2] (Lookup name1 segments) + (:++:) :: Layout withPre segments names1 env1 -> Layout withPre segments names2 env2 -> Layout withPre segments (Append names1 names2) (Append env1 env2) infixr :++: -instance (KnownSymbol name, seg ~ Lookup name segments) => IsLabel name (Layout withPre segments seg) where +instance (KnownSymbol name, seg ~ Lookup name segments, names ~ '[name]) => IsLabel name (Layout withPre segments names seg) where fromLabel = LSeg (symbolSing @name) newtype SegmentName name = SegmentName (SSymbol name) @@ -60,6 +65,18 @@ instance (KnownSymbol name, name ~ name') => IsLabel name (SegmentName name') wh fromLabel = SegmentName symbolSing +type family SubLayout names1 names2 where + SubLayout '[] _ = () :: Constraint + SubLayout (n : names1) names2 = SubLayout' n (Contains n names2) names1 names2 +type family SubLayout' n ok names1 names2 where + SubLayout' n False _ _ = TypeError (Text "The name '" :<>: Text n :<>: Text "' appears in the source layout but not in the target.") + SubLayout' _ True names1 names2 = SubLayout names1 names2 +type family Contains n names where + Contains _ '[] = False + Contains n (n : _) = True + Contains n (_ : names) = Contains n names + + data SSegments (segments :: [(Symbol, [t])]) where SSegNil :: SSegments '[] SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) @@ -74,7 +91,7 @@ auto1 :: SList (Const ()) '[t] auto1 = Const () `SCons` SNil infixr &. -(&.) :: SSegments segs1 -> SSegments segs2 -> SSegments (Append segs1 segs2) +(&.) :: SSegments '[segs1] -> SSegments segs2 -> SSegments (segs1 : segs2) (&.) = ssegmentsAppend where ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b) @@ -118,12 +135,12 @@ linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2 = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2) -lineariseLayout :: Layout withPre segments env -> LinLayout withPre segments env -lineariseLayout (LSeg name :: Layout _ _ seg) +lineariseLayout :: Layout withPre segments names env -> LinLayout withPre segments env +lineariseLayout (LSeg name :: Layout _ _ _ seg) | Refl <- lemAppendNil @seg = LinApp name LinEnd lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2 -lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ seg) +lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ _ seg) | Refl <- lemAppendNil @seg = LinAppPreW name1 name2 w LinEnd @@ -151,8 +168,7 @@ pullDown segs name@SSymbol linlayout kNotFound k = k (LinApp n' lin') (WSwap @env' (segmentLookup segs n') (segmentLookup segs name) .> wCopies (segmentLookup segs n') w) -sortLinLayouts :: forall segments env1 env2. - SSegments segments +sortLinLayouts :: SSegments segments -> LinLayout False segments env1 -> LinLayout False segments env2 -> env1 :> env2 sortLinLayouts _ LinEnd LinEnd = WId sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2) @@ -169,8 +185,8 @@ sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail sortLinLayouts _ LinEnd LinApp{} = WClosed sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target" -autoWeak :: forall segments env1 env2. - SSegments segments -> Layout True segments env1 -> Layout False segments env2 -> env1 :> env2 +autoWeak :: SubLayout names1 names2 + => SSegments segments -> Layout True segments names1 env1 -> Layout False segments names2 env2 -> env1 :> env2 autoWeak segs ly1 ly2 = preWeaken segs (lineariseLayout ly1) $ \wPreweak lin1 -> sortLinLayouts segs lin1 (lineariseLayout ly2) .> wPreweak |
