diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/AST.hs | 12 | ||||
| -rw-r--r-- | src/AST/Count.hs | 106 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 4 | ||||
| -rw-r--r-- | src/AST/Weaken.hs | 2 | ||||
| -rw-r--r-- | src/CHAD.hs | 59 | ||||
| -rw-r--r-- | src/Simplify.hs | 2 | 
6 files changed, 157 insertions, 28 deletions
| @@ -15,6 +15,7 @@  {-# LANGUAGE EmptyCase #-}  module AST (module AST, module AST.Weaken) where +import Data.Bifunctor (first)  import Data.Functor.Const  import Data.Kind (Type)  import Data.Int @@ -55,6 +56,9 @@ deriving instance Show (SScalTy t)  type TIx = TScal TI64 +tIx :: STy TIx +tIx = STScal STI64 +  type family ScalRep t where    ScalRep TI32 = Int32    ScalRep TI64 = Int64 @@ -92,6 +96,7 @@ data Expr x env t where    -- expression operations    EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) +  EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t    EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)    EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t    EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t @@ -150,6 +155,7 @@ typeOf = \case    EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t    EConst _ t _ -> STScal t +  EIdx0 _ e | STArr _ t <- typeOf e -> t    EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t    EIdx _ e _ | STArr _ t <- typeOf e -> t    EOp _ op _ -> opt2 op @@ -210,6 +216,7 @@ subst' f w = \case    EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e)    EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)    EConst x t v -> EConst x t v +  EIdx0 x e -> EIdx0 x (subst' f w e)    EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)    EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es)    EOp x op e -> EOp x op (subst' f w e) @@ -254,6 +261,11 @@ idx2int :: Idx env t -> Int  idx2int IZ = 0  idx2int (IS n) = 1 + idx2int n +splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t) +splitIdx SNil i = Right i +splitIdx (SCons _ _) IZ = Left IZ +splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i) +  class KnownScalTy t where knownScalTy :: SScalTy t  instance KnownScalTy TI32 where knownScalTy = STI32  instance KnownScalTy TI64 where knownScalTy = STI64 diff --git a/src/AST/Count.hs b/src/AST/Count.hs index f66b809..7e70a7d 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -1,10 +1,17 @@ +{-# LANGUAGE DataKinds #-}  {-# LANGUAGE DeriveGeneric #-}  {-# LANGUAGE DerivingStrategies #-}  {-# LANGUAGE DerivingVia #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-}  module AST.Count where +import Data.Functor.Const  import GHC.Generics (Generic, Generically(..))  import AST @@ -35,24 +42,81 @@ scaleMany :: Occ -> Occ  scaleMany (Occ l _) = Occ l Many  occCount :: Idx env a -> Expr x env t -> Occ -occCount idx = \case -  EVar _ _ i | idx2int i == idx2int idx -> Occ One One -             | otherwise -> mempty -  ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body -  EPair _ a b -> occCount idx a <> occCount idx b -  EFst _ e -> occCount idx e -  ESnd _ e -> occCount idx e -  ENil _ -> mempty -  EInl _ _ e -> occCount idx e -  EInr _ _ e -> occCount idx e -  ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b) -  EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b) -  EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e) -  EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b -  EConst{} -> mempty -  EIdx1 _ a b -> occCount idx a <> occCount idx b -  EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es -  EOp _ _ e -> occCount idx e -  EWith a b -> occCount idx a <> occCount (IS idx) b -  EAccum a b e -> occCount idx a <> occCount idx b <> occCount idx e -  EError{} -> mempty +occCount idx = +  getConst . occCountGeneral +    (\i o -> if idx2int i == idx2int idx then Const o else mempty) +    (\(Const o) -> Const o) +    (\_ (Const o) -> Const o) +    (\(Const o1) (Const o2) -> Const (o1 <||> o2)) +    (\(Const o) -> Const (scaleMany o)) + + +data OccEnv env where +  OccEnd :: OccEnv env  -- not necessarily top! +  OccPush :: OccEnv env -> Occ -> OccEnv (t : env) + +instance Semigroup (OccEnv env) where +  OccEnd <> e = e +  e <> OccEnd = e +  OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o') + +instance Monoid (OccEnv env) where +  mempty = OccEnd + +onehotOccEnv :: Idx env t -> Occ -> OccEnv env +onehotOccEnv IZ v = OccPush OccEnd v +onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty + +(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env +OccEnd <||>! e = e +e <||>! OccEnd = e +OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o') + +scaleManyOccEnv :: OccEnv env -> OccEnv env +scaleManyOccEnv OccEnd = OccEnd +scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o) + +occCountAll :: Expr x env t -> OccEnv env +occCountAll = occCountGeneral onehotOccEnv unpush unpushN (<||>!) scaleManyOccEnv +  where +    unpush :: OccEnv (t : env) -> OccEnv env +    unpush (OccPush o _) = o +    unpush OccEnd = OccEnd + +    unpushN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env +    unpushN _ OccEnd = OccEnd +    unpushN SZ e = e +    unpushN (SS n) (OccPush e _) = unpushN n e + +occCountGeneral :: forall r env t x. +                   (forall env'. Monoid (r env')) +                => (forall env' a. Idx env' a -> Occ -> r env')  -- ^ one-hot +                -> (forall env' a. r (a : env') -> r env')  -- ^ unpush +                -> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env')  -- ^ unpushN +                -> (forall env'. r env' -> r env' -> r env')  -- ^ alternation +                -> (forall env'. r env' -> r env')  -- ^ scale-many +                -> Expr x env t -> r env +occCountGeneral onehot unpush unpushN alter many = go +  where +    go :: Monoid (r env') => Expr x env' t' -> r env' +    go = \case +      EVar _ _ i -> onehot i (Occ One One) +      ELet _ rhs body -> go rhs <> unpush (go body) +      EPair _ a b -> go a <> go b +      EFst _ e -> go e +      ESnd _ e -> go e +      ENil _ -> mempty +      EInl _ _ e -> go e +      EInr _ _ e -> go e +      ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b)) +      EBuild1 _ a b -> go a <> many (unpush (go b)) +      EBuild _ es e -> foldMap go es <> many (unpushN (vecLength es) (go e)) +      EFold1 _ a b -> many (unpush (unpush (go a))) <> go b +      EConst{} -> mempty +      EIdx0 _ e -> go e +      EIdx1 _ a b -> go a <> go b +      EIdx _ e es -> go e <> foldMap go es +      EOp _ _ e -> go e +      EWith a b -> go a <> unpush (go b) +      EAccum a b e -> go a <> go b <> go e +      EError{} -> mempty diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 3473131..ba1b756 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -136,6 +136,10 @@ ppExpr' d val = \case    EConst _ ty v -> return $ showString $ case ty of      STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v +  EIdx0 _ e -> do +    e' <- ppExpr' 11 val e +    return $ showParen (d > 10) $ showString "idx0 " . e' +    EIdx1 _ a b -> do      a' <- ppExpr' 9 val a      b' <- ppExpr' 9 val b diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 432b687..78577ee 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -28,7 +28,7 @@ deriving instance Show (Idx env t)  data env :> env' where    WId :: env :> env -  WSink :: env :> (t : env) +  WSink :: forall t env. env :> (t : env)    WCopy :: env :> env' -> (t : env) :> (t : env')    WPop :: (t : env) :> env' -> env :> env'    WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3 diff --git a/src/CHAD.hs b/src/CHAD.hs index e209b67..a6dd9ff 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -27,7 +27,6 @@ module CHAD (  ) where  import Data.Bifunctor (first, second) -import Data.Functor.Const  import Data.Kind (Type)  import GHC.TypeLits (Symbol) @@ -242,6 +241,28 @@ letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t  letBinds BTop = id  letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs +type family Vectorise n list where +  Vectorise _ '[] = '[] +  Vectorise n (t : ts) = TArr n t : Vectorise n ts + +vectoriseIdx :: Idx binds t -> Idx (Vectorise n binds) (TArr n t) +vectoriseIdx IZ = IZ +vectoriseIdx (IS i) = IS (vectoriseIdx i) + +vectorise1Binds :: forall env binds. SList STy env -> Idx env TIx -> Bindings Ex env binds -> Bindings Ex env (Vectorise (S Z) binds) +vectorise1Binds _ _ BTop = BTop +vectorise1Binds env n (bs `BPush` (t, e)) = +  let bs' = vectorise1Binds env n bs +      e' = EBuild1 ext (EVar ext tIx (sinkWithBindings bs' @> n)) +             (subst (\_ t' i -> case splitIdx @env (bindingsBinds bs) i of +                                  Left i1 -> +                                    let i1' = IS (wRaiseAbove (bindingsBinds bs') env @> vectoriseIdx i1) +                                    in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t') i1') +                                                            (EVar ext tIx (WSink .> sinkWithBindings bs' @> n))) +                                  Right i2 -> EVar ext t' (IS (sinkWithBindings bs' @> i2))) +                    e) +  in bs' `BPush` (STArr (SS SZ) t, e') +  type family D1 t where    D1 TNil = TNil    D1 (TPair a b) = TPair (D1 a) (D1 b) @@ -588,9 +609,9 @@ select s@SMerge (DPush des (_, SAccum)) = select s des  select s@SAccum (DPush des (_, SMerge)) = select s des  select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) -sD1eEnv :: Descr env sto -> SList (Const ()) (D1E env) +sD1eEnv :: Descr env sto -> SList STy (D1E env)  sD1eEnv DTop = SNil -sD1eEnv (DPush d _) = SCons (Const ()) (sD1eEnv d) +sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)  d2e :: SList STy env -> SList STy (D2E env)  d2e SNil = SNil @@ -806,13 +827,39 @@ drev des = \case          (subenvNone (select SMerge des))          (ENil ext) +  EBuild1 _ ne e +    -- TODO: use occCountAll to determine which variables from @env are used in +    -- 'e', and promote those to SAccum storage in 'des' +    | Ret (ne0 :: Bindings _ _ ne_binds) ne1 nsub ne2 <- drev des ne +    , Ret e0 e1 sub e2 <- drev (des `DPush` (tIx, SMerge)) e +    , let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv des) IZ e0 -> +    Ret (bconcat (ne0 `BPush` (tIx, ne1)) +                 (fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0))) +        (EBuild1 ext +           (weakenExpr (wStack @(D1E env) (wSinks (bindingsBinds ve0) .> WSink @TIx @ne_binds)) +                       ne1) +           (subst (\_ t i -> case splitIdx @(TIx : D1E env) (bindingsBinds e0) i of +                               Left ibind -> +                                 let ibind' = WSink +                                              .> wRaiseAbove (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) +                                                             (sD1eEnv des) +                                              .> wRaiseAbove (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0) +                                              @> vectoriseIdx ibind +                                 in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t) ibind') +                                                         (EVar ext tIx IZ)) +                               Right IZ -> EVar ext tIx IZ  -- build lambda index argument +                               Right (IS ienv) -> EVar ext t (IS (wSinks (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) @> ienv))) +                  e1)) +        (subenvNone (select SMerge des)) +        _ +    -- These should be the next to be implemented, I think -  EBuild1{} -> err_unsupported "EBuild1" -  EFold1{} -> err_unsupported "EFold1" +  EIdx0{} -> err_unsupported "EIdx0"    EIdx1{} -> err_unsupported "EIdx1" +  EFold1{} -> err_unsupported "EFold1" -  EBuild{} -> err_unsupported "EBuild"    EIdx{} -> err_unsupported "EIdx" +  EBuild{} -> err_unsupported "EBuild"    EWith{} -> err_accum    EAccum{} -> err_accum diff --git a/src/Simplify.hs b/src/Simplify.hs index 39b3afd..af0ca4c 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -74,6 +74,7 @@ simplify' = \case    EBuild _ es e -> EBuild ext (fmap simplify' es) (simplify' e)    EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b)    EConst _ t v -> EConst ext t v +  EIdx0 _ e -> EIdx0 ext (simplify' e)    EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b)    EIdx _ e es -> EIdx ext (simplify' e) (fmap simplify' es)    EOp _ op e -> EOp ext op (simplify' e) @@ -105,6 +106,7 @@ hasAdds = \case    EBuild _ es e -> getAny (foldMap (Any . hasAdds) es) || hasAdds e    EFold1 _ a b -> hasAdds a || hasAdds b    EConst _ _ _ -> False +  EIdx0 _ e -> hasAdds e    EIdx1 _ a b -> hasAdds a || hasAdds b    EIdx _ e es -> hasAdds e || getAny (foldMap (Any . hasAdds) es)    EOp _ _ e -> hasAdds e | 
