diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/AST.hs | 2 | ||||
| -rw-r--r-- | src/AST/Count.hs | 55 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 38 | ||||
| -rw-r--r-- | src/CHAD.hs | 60 | ||||
| -rw-r--r-- | src/Example.hs | 36 | ||||
| -rw-r--r-- | src/Simplify.hs | 58 | 
6 files changed, 163 insertions, 86 deletions
@@ -147,6 +147,7 @@ data SOp a t where    OLe :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)    OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)    ONot :: SOp (TScal TBool) (TScal TBool) +  OIf :: SOp (TScal TBool) (TEither TNil TNil)  deriving instance Show (SOp a t)  opt2 :: SOp a t -> STy t @@ -158,6 +159,7 @@ opt2 = \case    OLe _ -> STScal STBool    OEq _ -> STScal STBool    ONot -> STScal STBool +  OIf -> STEither STNil STNil  typeOf :: Expr x env t -> STy t  typeOf = \case diff --git a/src/AST/Count.hs b/src/AST/Count.hs new file mode 100644 index 0000000..baf132e --- /dev/null +++ b/src/AST/Count.hs @@ -0,0 +1,55 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE GADTs #-} +module AST.Count where + +import AST + + +data Count = Zero | One | Many +  deriving (Show, Eq, Ord) + +instance Semigroup Count where +  Zero <> n = n +  n <> Zero = n +  _ <> _ = Many +instance Monoid Count where +  mempty = Zero + +data Occ = Occ { _occLexical :: Count +               , _occRuntime :: Count } +  deriving (Eq) +instance Semigroup Occ where Occ a b <> Occ c d = Occ (a <> c) (b <> d) +instance Monoid Occ    where mempty = Occ mempty mempty + +-- | One of the two branches is taken +(<||>) :: Occ -> Occ -> Occ +Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) + +-- | This code is executed many times +scaleMany :: Occ -> Occ +scaleMany (Occ l _) = Occ l Many + +occCount :: Idx env a -> Expr x env t -> Occ +occCount idx = \case +  EVar _ _ i | idx2int i == idx2int idx -> Occ One One +             | otherwise -> mempty +  ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body +  EPair _ a b -> occCount idx a <> occCount idx b +  EFst _ e -> occCount idx e +  ESnd _ e -> occCount idx e +  ENil _ -> mempty +  EInl _ _ e -> occCount idx e +  EInr _ _ e -> occCount idx e +  ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b) +  EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b) +  EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e) +  EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b +  EConst{} -> mempty +  EIdx1 _ a b -> occCount idx a <> occCount idx b +  EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es +  EOp _ _ e -> occCount idx e +  EMOne _ _ e -> occCount idx e +  EMScope e -> occCount idx e +  EMReturn _ e -> occCount idx e +  EMBind a b -> occCount idx a <> occCount (IS idx) b +  EError{} -> mempty diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index e793ce1..289294d 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -14,6 +14,7 @@ import Data.Foldable (toList)  import Data.Functor.Const  import AST +import AST.Count  data Val f env where @@ -42,6 +43,12 @@ genId = M (\i -> (i, i + 1))  genName :: M String  genName = ('x' :) . show <$> genId +genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String +genNameIfUsedIn ty idx ex +  | occCount idx ex == mempty = case ty of STNil -> return "()" +                                           _ -> return "_" +  | otherwise                 = genName +  ppExpr :: SList STy env -> Expr x env t -> String  ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""    where @@ -59,7 +66,7 @@ ppExpr' d val = \case    etop@ELet{} -> do      let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS)          collect val' (ELet _ rhs body) = do -          name <- genName +          name <- genNameIfUsedIn (typeOf rhs) IZ body            (binds, core) <- collect (VPush (Const name) val') body            rhs' <- ppExpr' 0 val' rhs            return ((name, rhs') : binds, core) @@ -101,9 +108,10 @@ ppExpr' d val = \case    ECase _ e a b -> do      e' <- ppExpr' 0 val e -    name1 <- genName +    let STEither t1 t2 = typeOf e +    name1 <- genNameIfUsedIn t1 IZ a      a' <- ppExpr' 0 (VPush (Const name1) val) a -    name2 <- genName +    name2 <- genNameIfUsedIn t2 IZ b      b' <- ppExpr' 0 (VPush (Const name2) val) b      return $ showParen (d > 0) $        showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a' @@ -111,14 +119,14 @@ ppExpr' d val = \case    EBuild1 _ a b -> do      a' <- ppExpr' 11 val a -    name <- genName +    name <- genNameIfUsedIn (STScal STI64) IZ b      b' <- ppExpr' 0 (VPush (Const name) val) b      return $ showParen (d > 10) $        showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")"    EBuild _ es e -> do      es' <- mapM (ppExpr' 0 val) es -    names <- mapM (const genName) es +    names <- mapM (const genName) es  -- TODO generate underscores      e' <- ppExpr' 0 (vpushN names val) e      return $ showParen (d > 10) $        showString "build [" @@ -128,8 +136,8 @@ ppExpr' d val = \case        . showString ("] -> ") . e' . showString ")"    EFold1 _ a b -> do -    name1 <- genName -    name2 <- genName +    name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a +    name2 <- genNameIfUsedIn (typeOf a) IZ a      a' <- ppExpr' 0 (VPush (Const name2) (VPush (Const name1) val)) a      b' <- ppExpr' 11 val b      return $ showParen (d > 10) $ @@ -185,10 +193,11 @@ ppExpr' d val = \case      e' <- ppExpr' 11 val e      return $ showParen (d > 10) $ showString ("return ") . e' -  etop@(EMBind _ EMBind{}) -> do +  etop@(EMBind _ _) -> do      let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS)          collect val' (EMBind lhs cont) = do -          name <- genName +          let STEVM _ t = typeOf lhs +          name <- genNameIfUsedIn t IZ cont            (binds, core) <- collect (VPush (Const name) val') cont            lhs' <- ppExpr' 0 val' lhs            return ((name, lhs') : binds, core) @@ -201,11 +210,11 @@ ppExpr' d val = \case                          (map (\(name, rhs) -> showString (name ++ " <- ") . rhs) binds))        . showString " ; " . core . showString " }" -  EMBind a b -> do -    a' <- ppExpr' 0 val a -    name <- genName -    b' <- ppExpr' 0 (VPush (Const name) val) b -    return $ showParen (d > 10) $ a' . showString (" >>= \\" ++ name ++ " -> ") . b' +  -- EMBind a b -> do +  --   a' <- ppExpr' 0 val a +  --   name <- genNameIfUsedIn IZ b +  --   b' <- ppExpr' 0 (VPush (Const name) val) b +  --   return $ showParen (d > 10) $ a' . showString (" >>= \\" ++ name ++ " -> ") . b'    EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s) @@ -220,3 +229,4 @@ operator OLt{} = (Infix, "<")  operator OLe{} = (Infix, "<=")  operator OEq{} = (Infix, "==")  operator ONot = (Prefix, "not") +operator OIf = (Prefix, "ifB") diff --git a/src/CHAD.hs b/src/CHAD.hs index d0358b8..9a1c7d2 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -204,39 +204,44 @@ d1op (OLt t) e = EOp ext (OLt t) e  d1op (OLe t) e = EOp ext (OLe t) e  d1op (OEq t) e = EOp ext (OEq t) e  d1op ONot e = EOp ext ONot e +d1op OIf e = EOp ext OIf e + +data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) +              | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a))  -- both primal and dual must be duplicable expressions -d2op :: SOp a t -> Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a) -d2op op e d = case op of -  OAdd _ -> EInr ext STNil (EPair ext d d) -  OMul t -> d2opBinArrangeInt t $ +d2op :: SOp a t -> D2Op a t +d2op op = case op of +  OAdd _ -> Linear $ \d -> EInr ext STNil (EPair ext d d) +  OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->      EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))                                (EOp ext (OMul t) (EPair ext (EFst ext e) d))) -  ONeg t -> d2opUnArrangeInt t $ EOp ext (ONeg t) d -  OLt t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) -  OLe t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) -  OEq t -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) -  ONot -> ENil ext +  ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d +  OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) +  OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) +  OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) +  ONot -> Linear $ \_ -> ENil ext +  OIf -> Linear $ \_ -> ENil ext    where      d2opUnArrangeInt :: SScalTy a -                     -> (D2s a ~ TScal a => Ex env (D2 (TScal a))) -                     -> Ex env (D2 (TScal a)) +                     -> (D2s a ~ TScal a => D2Op (TScal a) t) +                     -> D2Op (TScal a) t      d2opUnArrangeInt ty float = case ty of -      STI32 -> ENil ext -      STI64 -> ENil ext +      STI32 -> Linear $ \_ -> ENil ext +      STI64 -> Linear $ \_ -> ENil ext        STF32 -> float        STF64 -> float -      STBool -> ENil ext +      STBool -> Linear $ \_ -> ENil ext      d2opBinArrangeInt :: SScalTy a -                      -> (D2s a ~ TScal a => Ex env (D2 (TPair (TScal a) (TScal a)))) -                      -> Ex env (D2 (TPair (TScal a) (TScal a))) +                      -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) +                      -> D2Op (TPair (TScal a) (TScal a)) t      d2opBinArrangeInt ty float = case ty of -      STI32 -> EInl ext (STPair STNil STNil) (ENil ext) -      STI64 -> EInl ext (STPair STNil STNil) (ENil ext) +      STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) +      STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)        STF32 -> float        STF64 -> float -      STBool -> EInl ext (STPair STNil STNil) (ENil ext) +      STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)  freezeRet :: Ret env t            -> (forall env'. Ex env' (D2 t))  -- the incoming cotangent value @@ -359,10 +364,17 @@ drev senv = \case    EOp _ op e      | Ret e0 e1 e2 <- drev senv e -> -    Ret (e0 `BPush` (d1 (typeOf e), e1)) -        (d1op op $ EVar ext (d1 (typeOf e)) IZ) -        (ELet ext (d2op op (EVar ext (d1 (typeOf e)) (IS IZ)) -                           (EVar ext (d2 (opt2 op)) IZ)) -           (weakenExpr (WCopy (wSinks @[_,_])) e2)) +    case d2op op of +      Linear d2opfun -> +        Ret e0 +            (d1op op e1) +            (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ)) +               (weakenExpr (WCopy WSink) e2)) +      Nonlinear d2opfun -> +        Ret (e0 `BPush` (d1 (typeOf e), e1)) +            (d1op op $ EVar ext (d1 (typeOf e)) IZ) +            (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) +                               (EVar ext (d2 (opt2 op)) IZ)) +               (weakenExpr (WCopy (wSinks @[_,_])) e2))    e -> error $ "CHAD: unsupported " ++ takeWhile (/= ' ') (show e) diff --git a/src/Example.hs b/src/Example.hs index c8f12ba..f2e5966 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -58,3 +58,39 @@ ex2 =        (bin (OAdd STF32)          (EVar ext (STScal STF32) IZ)          (EVar ext (STScal STF32) (IS (IS IZ)))) + +-- x y |- if x < y then 2 * x else 3 + x +ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32) +ex3 = +  ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ)) +                                          (EVar ext (STScal STF32) IZ))) +    (bin (OMul STF32) (EConst ext STF32 2.0) +                      (EVar ext (STScal STF32) (IS (IS IZ)))) +    (bin (OAdd STF32) (EConst ext STF32 3.0) +                      (EVar ext (STScal STF32) (IS (IS IZ)))) + +-- x y |- if x < y then 2 * x + y * y else 3 + x +ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) +ex4 = +  ECase ext (EOp ext OIf (bin (OLt STF32) (EVar ext (STScal STF32) (IS IZ)) +                                          (EVar ext (STScal STF32) IZ))) +    (bin (OAdd STF32) +       (bin (OMul STF32) (EConst ext STF32 2.0) +                         (EVar ext (STScal STF32) (IS (IS IZ)))) +       (bin (OMul STF32) (EVar ext (STScal STF32) (IS IZ)) +                         (EVar ext (STScal STF32) (IS IZ)))) +    (bin (OAdd STF32) (EConst ext STF32 3.0) +                      (EVar ext (STScal STF32) (IS (IS IZ)))) + +senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)] +senv5 = STScal STF32 `SCons` STEither (STScal STF32) (STScal STF32) `SCons` SNil + +-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} +ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) +ex5 = +  ECase ext (EVar ext (STEither (STScal STF32) (STScal STF32)) (IS IZ)) +    (bin (OMul STF32) (EVar ext (STScal STF32) IZ) +                      (EVar ext (STScal STF32) (IS IZ))) +    (bin (OMul STF32) (EVar ext (STScal STF32) IZ) +                      (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ)) +                                        (EConst ext STF32 1.0))) diff --git a/src/Simplify.hs b/src/Simplify.hs index cb649d5..acc2392 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -6,8 +6,13 @@  module Simplify where  import AST +import AST.Count +simplifyN :: Int -> Ex env t -> Ex env t +simplifyN 0 = id +simplifyN n = simplifyN (n - 1) . simplify +  simplify :: Ex env t -> Ex env t  simplify = \case    -- inlining @@ -28,9 +33,11 @@ simplify = \case                               IS i -> EVar ext t (IS (IS i)))                body +  -- beta rules for products    EFst _ (EPair _ e _) -> simplify e    ESnd _ (EPair _ _ e) -> simplify e +  -- beta rules for coproducts    ECase _ (EInl _ _ e) rhs _ -> simplify (ELet ext e rhs)    ECase _ (EInr _ _ e) _ rhs -> simplify (ELet ext e rhs) @@ -38,6 +45,9 @@ simplify = \case    -- TODO: constant folding for operations +  -- eta rule for return+bind +  EMBind (EMReturn _ a) b -> simplify (ELet ext a b) +    EVar _ t i -> EVar ext t i    ELet _ a b -> ELet ext (simplify a) (simplify b)    EPair _ a b -> EPair ext (simplify a) (simplify b) @@ -67,54 +77,6 @@ cheapExpr = \case    EConst{} -> True    _ -> False -data Count = Zero | One | Many -  deriving (Show, Eq, Ord) - -instance Semigroup Count where -  Zero <> n = n -  n <> Zero = n -  _ <> _ = Many -instance Monoid Count where -  mempty = Zero - -data Occ = Occ { _occLexical :: Count -               , _occRuntime :: Count } -instance Semigroup Occ where Occ a b <> Occ c d = Occ (a <> c) (b <> d) -instance Monoid Occ    where mempty = Occ mempty mempty - --- | One of the two branches is taken -(<||>) :: Occ -> Occ -> Occ -Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) - --- | This code is executed many times -scaleMany :: Occ -> Occ -scaleMany (Occ l _) = Occ l Many - -occCount :: Idx env a -> Expr x env t -> Occ -occCount idx = \case -  EVar _ _ i | idx2int i == idx2int idx -> Occ One One -             | otherwise -> mempty -  ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body -  EPair _ a b -> occCount idx a <> occCount idx b -  EFst _ e -> occCount idx e -  ESnd _ e -> occCount idx e -  ENil _ -> mempty -  EInl _ _ e -> occCount idx e -  EInr _ _ e -> occCount idx e -  ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b) -  EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b) -  EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e) -  EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b -  EConst{} -> mempty -  EIdx1 _ a b -> occCount idx a <> occCount idx b -  EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es -  EOp _ _ e -> occCount idx e -  EMOne _ _ e -> occCount idx e -  EMScope e -> occCount idx e -  EMReturn _ e -> occCount idx e -  EMBind a b -> occCount idx a <> occCount (IS idx) b -  EError{} -> mempty -  subst1 :: Expr x env a -> Expr x (a : env) t -> Expr x env t  subst1 repl = subst $ \x t -> \case IZ -> repl                                      IS i -> EVar x t i  | 
