diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/AST/Weaken.hs | 1 | ||||
| -rw-r--r-- | src/CHAD.hs | 100 | ||||
| -rw-r--r-- | src/Example.hs | 3 | ||||
| -rw-r--r-- | src/Lemmas.hs | 15 | 
4 files changed, 72 insertions, 47 deletions
| diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 56b7a74..dd121fa 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -18,6 +18,7 @@ data env :> env' where    WCopy :: env :> env' -> (t : env) :> (t : env')    WPop :: (t : env) :> env' -> env :> env'    WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3 +  WClosed :: '[] :> env  deriving instance Show (env :> env')  (.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3 diff --git a/src/CHAD.hs b/src/CHAD.hs index b1251aa..a5f9bb3 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -20,65 +20,74 @@ import Data.Some  import GHC.TypeLits (Symbol)  import AST +import Lemmas -data Bindings f env env' where -  BTop :: Bindings f env env -  BPush :: Bindings f env env' -> (STy t, f env' t) -> Bindings f env (t : env') +data Bindings f env binds where +  BTop :: Bindings f env '[] +  BPush :: Bindings f env binds -> (STy t, f (Append binds env) t) -> Bindings f env (t : binds)  deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env')  infixl `BPush`  weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) -               -> env1 :> env2 -> Bindings f env1 env' -               -> (forall env2'. Bindings f env2 env2' -> env' :> env2' -> r) -> r -weakenBindings _ w BTop k = k BTop w -weakenBindings wf w (BPush b (t, x)) k = -  weakenBindings wf w b $ \b' w' -> k (BPush b' (t, wf w' x)) (WCopy w') +               -> env1 :> env2 +               -> Bindings f env1 binds +               -> (Bindings f env2 binds, Append binds env1 :> Append binds env2) +weakenBindings _ w BTop = (BTop, w) +weakenBindings wf w (BPush b (t, x)) = +  let (b', w') = weakenBindings wf w b +  in (BPush b' (t, wf w' x), WCopy w') -sinkWithBindings :: Bindings f env env' -> env :> env' +sinkWithBindings :: Bindings f env binds -> env :> Append binds env  sinkWithBindings BTop = WId  sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b -bconcat :: Bindings f env1 env2 -> Bindings f env2 env3 -> Bindings f env1 env3 +bconcat :: forall f env binds1 binds2. Bindings f env binds1 -> Bindings f (Append binds1 env) binds2 -> Bindings f env (Append binds2 binds1)  bconcat b1 BTop = b1 -bconcat b1 (BPush b2 (t, x)) = BPush (bconcat b1 b2) (t, x) +bconcat b1 (BPush (b2 :: Bindings f (Append binds1 env) binds2C) (t, x)) +  | Refl <- lemAppendAssoc @binds2C @binds1 @env +  = BPush (bconcat b1 b2) (t, x) -bconcat' :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) -         -> Bindings f env env1 -> Bindings f env env2 -         -> (forall env12. Bindings f env env12 -> r) -> r -bconcat' wf b1 b2 k = weakenBindings wf (sinkWithBindings b1) b2 $ \b2' _ -> k (bconcat b1 b2') +-- bconcat' :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) +--          -> Bindings f env env1 -> Bindings f env env2 +--          -> (forall env12. Bindings f env env12 -> r) -> r +-- bconcat' wf b1 b2 k = weakenBindings wf (sinkWithBindings b1) b2 $ \b2' _ -> k (bconcat b1 b2') -bsnoc :: STy t -> f env t -> Bindings f (t : env) env' -> Bindings f env env' -bsnoc t x b = bconcat (BTop `BPush` (t, x)) b +-- bsnoc :: STy t -> f env t -> Bindings f (t : env) env' -> Bindings f env env' +-- bsnoc t x b = bconcat (BTop `BPush` (t, x)) b -data TupBindsReconstruct f env1 env2 env3 = -  forall env4. -    TupBindsReconstruct (Bindings f env3 env4) -                        (env2 :> env4) +type family Tape binds where +  Tape '[] = TNil +  Tape (t : ts) = TPair t (Tape ts) -data TupBinds f env1 env2 = -  forall tape. -    TupBinds (STy tape) -             (forall env2'. env2 :> env2' -> Ex env2' tape) -             (forall env3. env1 :> env3 -> Idx env3 tape -> TupBindsReconstruct f env1 env2 env3) +-- TODO: The problem is that in the 3rd field of TupBinds, we should reconstruct a stack of let bindings from the tape, but we can't directly without having quadratic code size (due to nested projections). Instead we should produce _two_ Bindings there: one to an existential intermediate 'tempbinds', and another that picks up from there and creates 'binds'. +data TupBinds f env binds = +  TupBinds (SList STy binds) +           (forall env2. Append binds env :> env2 -> Ex env2 (Tape binds)) +           (forall env2. Idx env2 (Tape binds) -> Bindings f env2 binds) -tupBinds :: Bindings Ex env1 env2 -> TupBinds Ex env1 env2 -tupBinds BTop = TupBinds STNil (\_ -> ENil ext) (\w _ -> TupBindsReconstruct BTop w) +tupBinds :: Bindings Ex env binds -> TupBinds Ex env binds +tupBinds BTop = TupBinds SNil (\_ -> ENil ext) (\_ -> BTop)  tupBinds (BPush binds (t, _)) -  | TupBinds tape collect recon <- tupBinds binds -  = TupBinds (STPair tape t) -             (\w -> EPair ext (collect (w .> WSink)) -                              (EVar ext t (w @> IZ))) -             (\w tapeidx -> -               case recon (WSink .> w) IZ of -                 TupBindsReconstruct rebinds wunder -> -                   let rebinds1 = bsnoc tape (EFst ext (EVar ext (STPair tape t) tapeidx)) rebinds -                   in TupBindsReconstruct -                        (rebinds1 `BPush` -                          (t, ESnd ext (EVar ext (STPair tape t) -                                             (sinkWithBindings rebinds1 @> tapeidx)))) -                        (WCopy wunder)) +  | TupBinds tapelist collect recon <- tupBinds binds +  = TupBinds (SCons t tapelist) +             (\w -> EPair ext (EVar ext t (w @> IZ)) +                              (collect (w .> WSink))) +             (\tapeidx -> +                let b = recon tapeidx +                in BPush _ +                         (t, _ (sinkWithBindings b) tapeidx)) +             -- (\w tapeidx -> +             --   case recon (WSink .> w) IZ of +             --     TupBindsReconstruct rebinds wunder -> +             --       let rebinds1 = bsnoc tape (EFst ext (EVar ext (STPair tape t) tapeidx)) rebinds +             --       in TupBindsReconstruct +             --            (rebinds1 `BPush` +             --              (t, ESnd ext (EVar ext (STPair tape t) +             --                                 (sinkWithBindings rebinds1 @> tapeidx)))) +             --            (WCopy wunder)) +{-  letBinds :: Bindings Ex env env' -> Ex env' t -> Ex env t  letBinds BTop = id  letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs @@ -239,11 +248,11 @@ data Subenv env env' where  deriving instance Show (Subenv env env')  data Ret env0 sto t = -  forall env env0F. -    Ret (Bindings Ex (D1E env0) env) -        (Ex env (D1 t)) +  forall justenv env0F. +    Ret (Bindings Ex (D1E env0) justenv) +        (Ex (Append justenv (D1E env0)) (D1 t))          (Subenv (Select env0 sto "merge") env0F) -        (Ex (D2 t : env) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E env0F)))) +        (Ex (D2 t : justenv) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E env0F))))  deriving instance Show (Ret env0 sto t)  data RetPair env0 sto env t = @@ -633,3 +642,4 @@ drev des policy = \case    where      d2acc = d2e (select SAccum des) +-} diff --git a/src/Example.hs b/src/Example.hs index 389248a..643b82f 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -103,8 +103,7 @@ descr5 a b = DTop `DPush` (STEither (STScal STF32) (STScal STF32), a) `DPush` (S  ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)  ex5 =    ECase ext (EVar ext (STEither (STScal STF32) (STScal STF32)) (IS IZ)) -    (bin (OMul STF32) (EVar ext (STScal STF32) IZ) -                      (EVar ext (STScal STF32) (IS IZ))) +    (EVar ext (STScal STF32) IZ)      (bin (OMul STF32) (EVar ext (STScal STF32) IZ)                        (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ))                                          (EConst ext STF32 1.0))) diff --git a/src/Lemmas.hs b/src/Lemmas.hs new file mode 100644 index 0000000..7dbf680 --- /dev/null +++ b/src/Lemmas.hs @@ -0,0 +1,15 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE PolyKinds #-} + +{-# LANGUAGE AllowAmbiguousTypes #-} +module Lemmas (module Lemmas, (:~:)(Refl)) where + +import Data.Type.Equality +import Unsafe.Coerce (unsafeCoerce) + +import AST.Weaken + + +lemAppendAssoc :: Append a (Append b c) :~: Append (Append a b) c +lemAppendAssoc = unsafeCoerce Refl | 
