diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-09-12 23:08:40 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-09-12 23:08:40 +0200 | 
| commit | 3d8a6cca424fc5279c43a266900160feb28aa715 (patch) | |
| tree | 5fd8b992eb5f2beec156b10a815aaec1cf492d76 | |
| parent | 36732f84cfade5371248806328791d5066673fb7 (diff) | |
Towards neural
| -rw-r--r-- | src/AST/Weaken.hs | 35 | ||||
| -rw-r--r-- | src/AST/Weaken/Auto.hs | 78 | ||||
| -rw-r--r-- | src/CHAD.hs | 143 | ||||
| -rw-r--r-- | src/Example.hs | 12 | 
4 files changed, 225 insertions, 43 deletions
| diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 78276ca..0a1e4ce 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -48,6 +48,9 @@ data env :> env' where                                 -> Append pre (t : env) :> t : Append pre env'    WSwap :: forall env as bs. SList (Const ()) as -> SList (Const ()) bs          -> Append as (Append bs env) :> Append bs (Append as env) +  WStack :: forall as bs env1 env2. SList (Const ()) as -> SList (Const ()) bs +         -> as :> bs -> env1 :> env2 +         -> Append as env1 :> Append bs env2  deriving instance Show (env :> env')  infix 4 :> @@ -67,18 +70,25 @@ WPick (_ `SCons` _) _ @> IZ = IS IZ  WPick @t (_ `SCons` pre) w @> IS i = WCopy WSink .> WPick @t pre w @> i  WSwap @env (as :: SList _ as) (bs :: SList _ bs) @> i =    case splitIdx @(Append bs env) as i of -    Left i' -> skipOver bs (stack @env i' as) +    Left i' -> indexSinks bs (indexRaiseAbove @env as i')      Right i' -> case splitIdx @env bs i' of -                  Left j -> stack @(Append as env) j bs -                  Right j -> skipOver bs (skipOver as j) -  where -    skipOver :: SList (Const ()) as' -> Idx bs' t -> Idx (Append as' bs') t -    skipOver SNil j = j -    skipOver (_ `SCons` bs') j = IS (skipOver bs' j) +                  Left j -> indexRaiseAbove @(Append as env) bs j +                  Right j -> indexSinks bs (indexSinks as j) +WStack @as @bs @env1 @env2 as bs wlo whi @> i = +  case splitIdx @env1 as i of +    Left i' -> indexRaiseAbove @env2 bs (wlo @> i') +    Right i' -> indexSinks bs (whi @> i') + +indexSinks :: SList f as -> Idx bs t -> Idx (Append as bs) t +indexSinks SNil j = j +indexSinks (_ `SCons` bs') j = IS (indexSinks bs' j) -    stack :: forall env' as' t. Idx as' t -> SList (Const ()) as' -> Idx (Append as' env') t -    stack IZ (_ `SCons` _) = IZ -    stack (IS j) (_ `SCons` as') = IS (stack @env' j as') +indexRaiseAbove :: forall env as t f. SList f as -> Idx as t -> Idx (Append as env) t +indexRaiseAbove = flip go +  where +    go :: forall as'. Idx as' t -> SList f as' -> Idx (Append as' env) t +    go IZ (_ `SCons` _) = IZ +    go (IS i) (_ `SCons` as) = IS (go i as)  infixr 3 .>  (.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3 @@ -100,8 +110,9 @@ wSinksAnd SNil w = w  wSinksAnd (SCons _ spine) w = WSink .> wSinksAnd spine w  wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2 -wCopies SNil w = w -wCopies (SCons _ spine) w = WCopy (wCopies spine w) +wCopies bs w = +  let bs' = slistMap (\_ -> Const ()) bs +  in WStack bs' bs' WId w  wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env  wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env) diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs index 444c540..8555516 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/AST/Weaken/Auto.hs @@ -35,15 +35,29 @@ import Lemmas  type family Lookup name list where    Lookup name ('(name, x) : _) = x    Lookup name (_ : list) = Lookup name list +  Lookup name '[] = TypeError (Text "The name '" :<>: Text name :<>: Text "' does not appear in the list.") -data Layout (segments :: [(Symbol, [t])]) (env :: [t]) where -  LSeg :: forall name segments. KnownSymbol name => Layout segments (Lookup name segments) -  (:++:) :: Layout segments env1 -> Layout segments env2 -> Layout segments (Append env1 env2) +-- | The @withPre@ type parameter indicates whether there can be 'LPreW' +-- occurrences within this layout. +data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where +  LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments (Lookup name segments) +  -- | Pre-weaken with a weakening +  LPreW :: forall name1 name2 segments. +           SegmentName name1 -> SegmentName name2 +        -> Lookup name1 segments :> Lookup name2 segments +        -> Layout True segments (Lookup name1 segments) +  (:++:) :: Layout withPre segments env1 -> Layout withPre segments env2 -> Layout withPre segments (Append env1 env2)  infixr :++: -instance (KnownSymbol name, seg ~ Lookup name segments) => IsLabel name (Layout segments seg) where -  fromLabel = LSeg @name @segments +instance (KnownSymbol name, seg ~ Lookup name segments) => IsLabel name (Layout withPre segments seg) where +  fromLabel = LSeg (symbolSing @name) + +newtype SegmentName name = SegmentName (SSymbol name) +  deriving (Show) + +instance (KnownSymbol name, name ~ name') => IsLabel name (SegmentName name') where +  fromLabel = SegmentName symbolSing  data SSegments (segments :: [(Symbol, [t])]) where @@ -86,29 +100,51 @@ segmentLookup = \segs name -> case go segs name of            case unsafeCoerce Refl :: (Lookup name ('(n, ts) : rest) :~: Lookup name rest) of              Refl -> go sseg name -data LinLayout (segments :: [(Symbol, [t])]) (env :: [t]) where -  LinEnd :: LinLayout segments '[] -  LinApp :: SSymbol name -> LinLayout segments env -> LinLayout segments (Append (Lookup name segments) env) +data LinLayout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where +  LinEnd :: LinLayout withPre segments '[] +  LinApp :: SSymbol name -> LinLayout withPre segments env +         -> LinLayout withPre segments (Append (Lookup name segments) env) +  LinAppPreW :: SSymbol name1 -> SSymbol name2 +             -> Lookup name1 segments :> Lookup name2 segments +             -> LinLayout True segments env +             -> LinLayout True segments (Append (Lookup name1 segments) env) -linLayoutAppend :: LinLayout segments env1 -> LinLayout segments env2 -> LinLayout segments (Append env1 env2) +linLayoutAppend :: LinLayout withPre segments env1 -> LinLayout withPre segments env2 -> LinLayout withPre segments (Append env1 env2)  linLayoutAppend LinEnd lin = lin -linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout segments env1')) (lin2 :: LinLayout _ env2) +linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2)    | Refl <- lemAppendAssoc @(Lookup name segments) @env1' @env2    = LinApp name (linLayoutAppend lin1 lin2) +linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2) +  | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2 +  = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2) -linLayoutEnv :: SSegments segments -> LinLayout segments env -> SList (Const ()) env +linLayoutEnv :: SSegments segments -> LinLayout withPre segments env -> SList (Const ()) env  linLayoutEnv _ LinEnd = SNil  linLayoutEnv segs (LinApp name lin) = sappend (segmentLookup segs name) (linLayoutEnv segs lin) +linLayoutEnv segs (LinAppPreW name1 _ _ lin) = sappend (segmentLookup segs name1) (linLayoutEnv segs lin) -lineariseLayout :: Layout segments env -> LinLayout segments env -lineariseLayout (LSeg @name :: Layout _ seg) +lineariseLayout :: Layout withPre segments env -> LinLayout withPre segments env +lineariseLayout (LSeg name :: Layout _ _ seg)    | Refl <- lemAppendNil @seg -  = LinApp (symbolSing @name) LinEnd +  = LinApp name LinEnd  lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2 +lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ seg) +  | Refl <- lemAppendNil @seg +  = LinAppPreW name1 name2 w LinEnd + +preWeaken :: SSegments segments -> LinLayout True segments env +          -> (forall env'. env :> env' -> LinLayout False segments env' -> r) -> r +preWeaken _ LinEnd k = k WId LinEnd +preWeaken segs (LinApp name lin) k = +  preWeaken segs lin $ \w lin' -> +    k (wCopies (segmentLookup segs name) w) (LinApp name lin') +preWeaken segs (LinAppPreW name1 name2 weak lin) k = +  preWeaken segs lin $ \w lin' -> +    k (WStack (segmentLookup segs name1) (segmentLookup segs name2) weak w) (LinApp name2 lin') -pullDown :: SSegments segments -> SSymbol name -> LinLayout segments env +pullDown :: SSegments segments -> SSymbol name -> LinLayout False segments env           -> r  -- Name was not found in source -         -> (forall env'. LinLayout segments env' -> env :> Append (Lookup name segments) env' -> r) +         -> (forall env'. LinLayout False segments env' -> env :> Append (Lookup name segments) env' -> r)           -> r  pullDown segs name@SSymbol linlayout kNotFound k =    case linlayout of @@ -116,13 +152,13 @@ pullDown segs name@SSymbol linlayout kNotFound k =      LinApp n'@SSymbol lin        | Just Refl <- sameSymbol name n' -> k lin WId        | otherwise -> -          pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ env') w -> +          pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ _ env') w ->              k (LinApp n' lin') (WSwap @env' (segmentLookup segs n') (segmentLookup segs name)                                    .> wCopies (segmentLookup segs n') w)  sortLinLayouts :: forall segments env1 env2.                    SSegments segments -               -> LinLayout segments env1 -> LinLayout segments env2 -> env1 :> env2 +               -> LinLayout False segments env1 -> LinLayout False segments env2 -> env1 :> env2  sortLinLayouts _ LinEnd LinEnd = WId  sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2)    | Just Refl <- sameSymbol name1 name2 = wCopies (segmentLookup segs name1) (sortLinLayouts segs tail1 tail2) @@ -139,5 +175,7 @@ sortLinLayouts segs LinEnd lin2@LinApp{} = WClosed (linLayoutEnv segs lin2)  sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target"  autoWeak :: forall segments env1 env2. -            SSegments segments -> Layout segments env1 -> Layout segments env2 -> env1 :> env2 -autoWeak segs ly1 ly2 = sortLinLayouts segs (lineariseLayout ly1) (lineariseLayout ly2) +            SSegments segments -> Layout True segments env1 -> Layout False segments env2 -> env1 :> env2 +autoWeak segs ly1 ly2 = +  preWeaken segs (lineariseLayout ly1) $ \wPreweak lin1 -> +    sortLinLayouts segs lin1 (lineariseLayout ly2) .> wPreweak diff --git a/src/CHAD.hs b/src/CHAD.hs index 943f0a2..7747d46 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -262,7 +262,7 @@ vectoriseExpr :: forall prefix binds env t f.                -> Ex (Append prefix (Append binds env)) t                -> Ex (TIx : Append prefix (Append (Vectorise (S Z) binds) env)) t  vectoriseExpr prefix binds env = -  let wTarget :: Layout ['("ix", '[TIx]), '("pre", prefix), '("vbinds", Vectorise (S Z) binds), '("env", env)] e +  let wTarget :: Layout True ['("ix", '[TIx]), '("pre", prefix), '("vbinds", Vectorise (S Z) binds), '("env", env)] e                -> e :> TIx : Append prefix (Append (Vectorise (S Z) binds) env)        wTarget layout =           autoWeak (#ix (auto1 @TIx) &. #pre prefix &. #vbinds (vectoriseEnv (SS SZ) binds) &. #env env) @@ -422,6 +422,10 @@ zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))  zeroTup SNil = ENil ext  zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t) +indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) +indexTupD1Id SZ = Refl +indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl +  accumPromote :: forall dt env sto proxy r.                  proxy dt               -> Descr env sto @@ -974,6 +978,7 @@ drev des = \case          (subenvNone (select SMerge des))          (ENil ext) +  -- TODO: either remove EBuilds1 entirely or rewrite it to work with an array of tapes instead of a vectorised tape    EBuild1 _ ne (orige :: Ex _ eltty)      | Ret (ne0 :: Bindings _ _ ne_binds) ne1 _ _ <- drev des ne  -- allowed to ignore ne2 here because ne has a discrete result      , let eltty = typeOf orige -> @@ -1050,6 +1055,90 @@ drev des = \case           ESnd ext (EVar ext (STPair (STArr (SS SZ) STNil) (tTup (d2e envPro))) (IS IZ)))      }} +  EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty) +    | Ret (she0 :: Bindings _ _ she_binds) she1 _ _ <- drev des she  -- allowed to ignore she2 here because she has a discrete result +    , let eltty = typeOf orige +    , shty :: STy shty <- tTup (sreplicate ndim tIx) +    , Refl <- indexTupD1Id ndim -> +    deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> +    let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in +    subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> +    accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> +    case drev (prodes `DPush` (shty, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) e1 sub e2 -> +    case assertSubenvEmpty sub of { Refl -> +    let tapety = tapeTy (bindingsBinds e0) in +    let collectexpr = bindingsCollect e0 in +    -- let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv usedDes) IZ e0 in +    Ret (she0 `BPush` (shty, she1) +              `BPush` (STArr ndim tapety +                      ,EBuild ext ndim +                         (EVar ext shty IZ) +                         (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) +                                                                              &. #sh (shty `SCons` SNil) +                                                                              &. #she0 (bindingsBinds she0) +                                                                              &. #d1env (sD1eEnv des) +                                                                              &. #d1env' (sD1eEnv usedDes)) +                                                                             (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) +                                                                             (#ix :++: #sh :++: #she0 :++: #d1env)) +                                                                   e0)) $ +                            collectexpr (autoWeak (#ix (shty `SCons` SNil) +                                                   &. #sh (shty `SCons` SNil) +                                                   &. #she0 (bindingsBinds she0) +                                                   &. #e0 (bindingsBinds e0) +                                                   &. #d1env (sD1eEnv des) +                                                   &. #d1env' (sD1eEnv usedDes)) +                                                  (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) +                                                  (#e0 :++: #ix :++: #sh :++: #she0 :++: #d1env))))) +        (EBuild ext ndim +           (EVar ext shty (IS IZ)) +           (ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (IS IZ)) +                                    (EVar ext shty IZ)) $ +            let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ +            in letBinds rebinds $ +                 weakenExpr (autoWeak (#ix (shty `SCons` SNil) +                                       &. #sh (shty `SCons` SNil) +                                       &. #she0 (bindingsBinds she0) +                                       &. #e0 (bindingsBinds e0) +                                       &. #tape (tapety `SCons` SNil) +                                       &. #tapearr (STArr ndim tapety `SCons` SNil) +                                       &. #prerebinds prerebinds +                                       &. #d1env (sD1eEnv des) +                                       &. #d1env' (sD1eEnv usedDes)) +                                      (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) +                                      ((#e0 :++: #prerebinds) :++: #tape :++: #ix :++: #tapearr :++: #sh :++: #she0 :++: #d1env)) +                            e1)) +        (subenvCompose subMergeUsed proSub) +        (let sinkOverEnvPro = wSinks @(D2 t : TArr ndim (Tape e_binds) : Tup (Replicate ndim TIx) : Append she_binds (D2AcE (Select env sto "accum"))) (d2ace envPro) in +         ESnd ext $ +          uninvertTup (d2e envPro) (STArr ndim STNil) $ +            makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ +              EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ +                -- the cotangent for this element +                ELet ext (EIdx ext ndim (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) +                                        (EVar ext shty IZ)) $ +                -- the tape for this element +                ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) +                                        (EVar ext shty (IS IZ))) $ +                let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ +                in letBinds rebinds $ +                     weakenExpr (autoWeak (#d (auto1 @(D2 eltty)) +                                           &. #pro (d2ace envPro) +                                           &. #ebinds (bindingsBinds e0) +                                           &. #prerebinds prerebinds +                                           &. #tape (tapety `SCons` SNil) +                                           &. #ix (shty `SCons` SNil) +                                           &. #darr (STArr ndim (d2 eltty) `SCons` SNil) +                                           &. #tapearr (STArr ndim tapety `SCons` SNil) +                                           &. #sh (shty `SCons` SNil) +                                           &. #shebinds (bindingsBinds she0) +                                           &. #d2acUsed (d2ace (select SAccum usedDes)) +                                           &. #d2acEnv (d2ace (select SAccum des))) +                                          (#pro :++: #d :++: #ebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) +                                          ((#ebinds :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #shebinds :++: #d2acEnv) +                                 .> wPro (bindingsBinds e0)) +                                e2) +    }} +    EUnit _ e      | Ret e0 e1 sub e2 <- drev des e ->      Ret e0 @@ -1072,23 +1161,55 @@ drev des = \case      | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)          <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil      , STArr (SS n) eltty <- typeOf e -> -    Ret (binds `BPush` (tTup (sreplicate (SS n) tIx), EShape ext e1)) -        (weakenExpr WSink (EIdx1 ext e1 ei1)) +    Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1)) +        (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) IZ) +                   (weakenExpr WSink ei1))          sub -        (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) -                               (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) +        (ELet ext (ebuildUp1 n (EFst ext (EShape ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)))) +                               (ESnd ext (EShape ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ))))                                 (EVar ext (STArr n (d2 eltty)) (IS IZ))) $           weakenExpr (WCopy (WSink .> WSink)) e2) +  EIdx _ n e ei +    -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. +    | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) +        <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil +    , STArr _ eltty <- typeOf e +    , Refl <- indexTupD1Id n -> +    Ret (binds `BPush` (STArr n (d1 eltty), e1)) +        (EIdx ext n (EVar ext (STArr n (d1 eltty)) IZ) +                    (weakenExpr WSink ei1)) +        sub +        (ELet ext (EBuild ext n (EShape ext (EVar ext (STArr n (d1 eltty)) (IS IZ))) +                                (EVar ext (d2 eltty) (IS IZ))) $ +         weakenExpr (WCopy (WSink .> WSink)) e2) + +  EShape _ e +    -- Allowed to ignore e2 here because the output of EShape is discrete, +    -- hence we'd be passing a zero cotangent to e2 anyway. +    | Ret e0 e1 _ _ <- drev des e +    , STArr n _ <- typeOf e +    , Refl <- indexTupD1Id n -> +    Ret e0 +        (EShape ext e1) +        (subenvNone (select SMerge des)) +        (ENil ext) + +  ESum1Inner _ e +    | Ret e0 e1 sub e2 <- drev des e +    , STArr (SS n) t <- typeOf e -> +    Ret (e0 `BPush` (STArr (SS n) t, e1)) +        (ESum1Inner ext (EVar ext (STArr (SS n) t) IZ)) +        sub +        (ELet ext (EReplicate1Inner ext +                     (ESnd ext (EShape ext (EVar ext (STArr (SS n) t) (IS IZ)))) +                     (EVar ext (STArr n (d2 t)) IZ)) $ +         weakenExpr (WCopy (WSink .> WSink)) e2) +    -- These should be the next to be implemented, I think -  ESum1Inner{} -> err_unsupported "ESum" -  EReplicate1Inner{} -> err_unsupported "EReplicate" -  EShape{} -> err_unsupported "EShape" +  EReplicate1Inner{} -> err_unsupported "EReplicate1Inner"    EFold1Inner{} -> err_unsupported "EFold1Inner" -  EIdx{} -> err_unsupported "EIdx" -  EBuild{} -> err_unsupported "EBuild" -    EWith{} -> err_accum    EAccum{} -> err_accum diff --git a/src/Example.hs b/src/Example.hs index d1d04e3..fb4e851 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -1,6 +1,8 @@  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  module Example where @@ -14,6 +16,16 @@ import Simplify  -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) +type family MergeEnv env where +  MergeEnv '[] = '[] +  MergeEnv (t : ts) = "merge" : MergeEnv ts + +mergeDescr :: KnownEnv env => Descr env (MergeEnv env) +mergeDescr = go knownEnv +  where go :: SList STy env -> Descr env (MergeEnv env) +        go SNil = DTop +        go (t `SCons` env) = go env `DPush` (t, SMerge) +  bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c  bin op a b = EOp ext op (EPair ext a b) | 
