diff options
| author | Tom Smeding <t.j.smeding@uu.nl> | 2025-04-05 18:26:31 +0200 | 
|---|---|---|
| committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-04-05 18:26:31 +0200 | 
| commit | b6c1d3a9d0651aa25ea5f03d514a214a3347f7a4 (patch) | |
| tree | 49764a3f3b78bb2848cdc871a1217f7ae1a04120 /src | |
| parent | ebe8d8219e12fc9ac7ca58b367bc91e640ed0556 (diff) | |
Split product lets before chad
Diffstat (limited to 'src')
| -rw-r--r-- | src/AST/Bindings.hs | 2 | ||||
| -rw-r--r-- | src/AST/SplitLets.hs | 125 | ||||
| -rw-r--r-- | src/AST/Weaken.hs | 4 | ||||
| -rw-r--r-- | src/CHAD/Top.hs | 7 | 
4 files changed, 135 insertions, 3 deletions
| diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs index 2e63b42..3d99afe 100644 --- a/src/AST/Bindings.hs +++ b/src/AST/Bindings.hs @@ -45,7 +45,7 @@ weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env'  weakenOver SNil w = w  weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) -sinkWithBindings :: Bindings f env binds -> env' :> Append binds env' +sinkWithBindings :: forall env' env binds f. Bindings f env binds -> env' :> Append binds env'  sinkWithBindings BTop = WId  sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs new file mode 100644 index 0000000..1de417c --- /dev/null +++ b/src/AST/SplitLets.hs @@ -0,0 +1,125 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +module AST.SplitLets (splitLets) where + +import Data.Type.Equality + +import AST +import AST.Bindings +import Lemmas + + +splitLets :: Ex env t -> Ex env t +splitLets = splitLets' (\t i w -> EVar ext t (w @> i)) + +splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t +splitLets' = \sub -> \case +  EVar _ t i -> sub t i WId +  ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) +  ECase x e a b -> +    let STEither t1 t2 = typeOf e +    in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b) +  EMaybe x a b e -> +    let STMaybe t1 = typeOf e +    in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e) +  EFold1Inner x cm a b c -> +    let STArr _ t1 = typeOf c +    in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + +  EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) +  EFst x e -> EFst x (splitLets' sub e) +  ESnd x e -> ESnd x (splitLets' sub e) +  ENil x -> ENil x +  EInl x t e -> EInl x t (splitLets' sub e) +  EInr x t e -> EInr x t (splitLets' sub e) +  ENothing x t -> ENothing x t +  EJust x e -> EJust x (splitLets' sub e) +  EConstArr x n t a -> EConstArr x n t a +  EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) +  ESum1Inner x e -> ESum1Inner x (splitLets' sub e) +  EUnit x e -> EUnit x (splitLets' sub e) +  EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) +  EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) +  EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) +  EConst x t v -> EConst x t v +  EIdx0 x e -> EIdx0 x (splitLets' sub e) +  EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) +  EIdx x e es -> EIdx x (splitLets' sub e) (splitLets' sub es) +  EShape x e -> EShape x (splitLets' sub e) +  EOp x op e -> EOp x op (splitLets' sub e) +  ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) +  EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) +  EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3) +  EZero x t -> EZero x t +  EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) +  EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) +  EError x t s -> EError x t s +  where +    sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) +          -> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t +    sinkF _ t IZ w = EVar ext t (w @> IZ) +    sinkF f t (IS i) w = f t i (w .> WSink) + +    split1 :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) +           -> STy bind -> Ex (bind : env) t -> Ex (bind : env') t +    split1 sub (tbind :: STy bind) body = +      let (ptrs, bs) = split (EVar ext tbind IZ) tbind +      in letBinds bs $ +           splitLets' (\cases _ IZ w -> subPointers ptrs w +                              t (IS i) w -> sub t i (WPop @bind (wPops (bindingsBinds bs) w))) +                      body + +    split2 :: forall bind1 bind2 env' env t. +              (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) +           -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t +    split2 sub tbind1 tbind2 body = +      let (ptrs1, bs1) = split (EVar ext tbind1 (IS IZ)) tbind1 +          (ptrs2, bs2) = split (EVar ext tbind2 IZ) tbind2 +      in letBinds bs1 $ +         letBinds (fst (weakenBindings weakenExpr (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ +           splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1))) +                              _ (IS IZ) w -> subPointers ptrs1 (w .> wSinks (bindingsBinds bs2)) +                              t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) +                      body + +type family Split t where +  Split TNil = '[] +  Split (TPair a b) = Append (Split b) (Split a) +  Split t = '[t] + +data Pointers env t where +  Point :: STy t -> Idx env t -> Pointers env t +  PNil :: Pointers env TNil +  PPair :: Pointers env a -> Pointers env b -> Pointers env (TPair a b) +  PWeak :: env' :> env -> Pointers env' t -> Pointers env t + +subPointers :: Pointers env t -> env :> env' -> Ex env' t +subPointers (Point t i) w = EVar ext t (w @> i) +subPointers PNil _ = ENil ext +subPointers (PPair a b) w = EPair ext (subPointers a w) (subPointers b w) +subPointers (PWeak w' p) w = subPointers p (w .> w') + +split :: forall env t. Ex env t -> STy t +      -> (Pointers (Append (Split t) env) t, Bindings Ex env (Split t)) +split i = \case +  STNil -> (PNil, BTop) +  STPair (a :: STy a) (b :: STy b) +    | Refl <- lemAppendAssoc @(Split b) @(Split a) @env -> +        let (p1, bs1) = split (EFst ext i) a +            (p2, bs2) = split (ESnd ext (sinkWithBindings bs1 `weakenExpr` i)) b +        in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2) +  t@STEither{} -> other t +  t@STMaybe{} -> other t +  t@STArr{} -> other t +  t@STScal{} -> other t +  t@STAccum{} -> other t +  where +    other :: STy t -> (Pointers (t : env) t, Bindings Ex env '[t]) +    other t = (Point t IZ, BPush BTop (t, i)) diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index bd2c244..d882e28 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -126,3 +126,7 @@ wCopies bs w =  wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env  wRaiseAbove SNil _ = WClosed  wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) + +wPops :: SList f bs -> Append bs env1 :> env2 -> env1 :> env2 +wPops SNil w = w +wPops (_ `SCons` bs) w = wPops bs (WPop w) diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index ced7550..ea7449d 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -12,6 +12,7 @@ module CHAD.Top where  import Analysis.Identity  import AST +import AST.SplitLets  import AST.Weaken.Auto  import CHAD  import CHAD.Accum @@ -87,7 +88,7 @@ chad config env (term :: Ex env t)                                                  &. #tl (d1e env))                                                 (#d :++: #acenv :++: #tl)                                                 (#acenv :++: #d :++: #tl)) $ -                            freezeRet descr (drev descr VarMap.empty (identityAnalysis env term))) $ +                            freezeRet descr (drev descr VarMap.empty term')) $                EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))                          (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ))                                                          (ESnd ext (EFst ext (EVar ext tvar IZ))))) @@ -95,7 +96,9 @@ chad config env (term :: Ex env t)    | False <- chcArgArrayAccum config    , Refl <- mergeEnvNoAccum env    , Refl <- mergeEnvOnlyMerge env -  = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (identityAnalysis env term)) +  = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty term') +  where +    term' = identityAnalysis env (splitLets term)  chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env)))  chad' config env term | 
