diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-08-30 22:45:46 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-08-30 22:45:46 +0200 |
commit | 1f7ed2ee02222108684cfde8078e7a182f734a61 (patch) | |
tree | 976175ede4ec12a6e4a65d5e45e0b1ee8eeff5e6 | |
parent | 172887fb577526de92b0653b5d3153114f8ce02a (diff) |
WIP Build1
-rw-r--r-- | src/AST.hs | 12 | ||||
-rw-r--r-- | src/AST/Count.hs | 106 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 4 | ||||
-rw-r--r-- | src/AST/Weaken.hs | 2 | ||||
-rw-r--r-- | src/CHAD.hs | 59 | ||||
-rw-r--r-- | src/Simplify.hs | 2 |
6 files changed, 157 insertions, 28 deletions
@@ -15,6 +15,7 @@ {-# LANGUAGE EmptyCase #-} module AST (module AST, module AST.Weaken) where +import Data.Bifunctor (first) import Data.Functor.Const import Data.Kind (Type) import Data.Int @@ -55,6 +56,9 @@ deriving instance Show (SScalTy t) type TIx = TScal TI64 +tIx :: STy TIx +tIx = STScal STI64 + type family ScalRep t where ScalRep TI32 = Int32 ScalRep TI64 = Int64 @@ -92,6 +96,7 @@ data Expr x env t where -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) + EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t @@ -150,6 +155,7 @@ typeOf = \case EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t EConst _ t _ -> STScal t + EIdx0 _ e | STArr _ t <- typeOf e -> t EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t EIdx _ e _ | STArr _ t <- typeOf e -> t EOp _ op _ -> opt2 op @@ -210,6 +216,7 @@ subst' f w = \case EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e) EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) EConst x t v -> EConst x t v + EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es) EOp x op e -> EOp x op (subst' f w e) @@ -254,6 +261,11 @@ idx2int :: Idx env t -> Int idx2int IZ = 0 idx2int (IS n) = 1 + idx2int n +splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t) +splitIdx SNil i = Right i +splitIdx (SCons _ _) IZ = Left IZ +splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i) + class KnownScalTy t where knownScalTy :: SScalTy t instance KnownScalTy TI32 where knownScalTy = STI32 instance KnownScalTy TI64 where knownScalTy = STI64 diff --git a/src/AST/Count.hs b/src/AST/Count.hs index f66b809..7e70a7d 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -1,10 +1,17 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} module AST.Count where +import Data.Functor.Const import GHC.Generics (Generic, Generically(..)) import AST @@ -35,24 +42,81 @@ scaleMany :: Occ -> Occ scaleMany (Occ l _) = Occ l Many occCount :: Idx env a -> Expr x env t -> Occ -occCount idx = \case - EVar _ _ i | idx2int i == idx2int idx -> Occ One One - | otherwise -> mempty - ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body - EPair _ a b -> occCount idx a <> occCount idx b - EFst _ e -> occCount idx e - ESnd _ e -> occCount idx e - ENil _ -> mempty - EInl _ _ e -> occCount idx e - EInr _ _ e -> occCount idx e - ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b) - EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b) - EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e) - EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b - EConst{} -> mempty - EIdx1 _ a b -> occCount idx a <> occCount idx b - EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es - EOp _ _ e -> occCount idx e - EWith a b -> occCount idx a <> occCount (IS idx) b - EAccum a b e -> occCount idx a <> occCount idx b <> occCount idx e - EError{} -> mempty +occCount idx = + getConst . occCountGeneral + (\i o -> if idx2int i == idx2int idx then Const o else mempty) + (\(Const o) -> Const o) + (\_ (Const o) -> Const o) + (\(Const o1) (Const o2) -> Const (o1 <||> o2)) + (\(Const o) -> Const (scaleMany o)) + + +data OccEnv env where + OccEnd :: OccEnv env -- not necessarily top! + OccPush :: OccEnv env -> Occ -> OccEnv (t : env) + +instance Semigroup (OccEnv env) where + OccEnd <> e = e + e <> OccEnd = e + OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o') + +instance Monoid (OccEnv env) where + mempty = OccEnd + +onehotOccEnv :: Idx env t -> Occ -> OccEnv env +onehotOccEnv IZ v = OccPush OccEnd v +onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty + +(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env +OccEnd <||>! e = e +e <||>! OccEnd = e +OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o') + +scaleManyOccEnv :: OccEnv env -> OccEnv env +scaleManyOccEnv OccEnd = OccEnd +scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o) + +occCountAll :: Expr x env t -> OccEnv env +occCountAll = occCountGeneral onehotOccEnv unpush unpushN (<||>!) scaleManyOccEnv + where + unpush :: OccEnv (t : env) -> OccEnv env + unpush (OccPush o _) = o + unpush OccEnd = OccEnd + + unpushN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env + unpushN _ OccEnd = OccEnd + unpushN SZ e = e + unpushN (SS n) (OccPush e _) = unpushN n e + +occCountGeneral :: forall r env t x. + (forall env'. Monoid (r env')) + => (forall env' a. Idx env' a -> Occ -> r env') -- ^ one-hot + -> (forall env' a. r (a : env') -> r env') -- ^ unpush + -> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env') -- ^ unpushN + -> (forall env'. r env' -> r env' -> r env') -- ^ alternation + -> (forall env'. r env' -> r env') -- ^ scale-many + -> Expr x env t -> r env +occCountGeneral onehot unpush unpushN alter many = go + where + go :: Monoid (r env') => Expr x env' t' -> r env' + go = \case + EVar _ _ i -> onehot i (Occ One One) + ELet _ rhs body -> go rhs <> unpush (go body) + EPair _ a b -> go a <> go b + EFst _ e -> go e + ESnd _ e -> go e + ENil _ -> mempty + EInl _ _ e -> go e + EInr _ _ e -> go e + ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b)) + EBuild1 _ a b -> go a <> many (unpush (go b)) + EBuild _ es e -> foldMap go es <> many (unpushN (vecLength es) (go e)) + EFold1 _ a b -> many (unpush (unpush (go a))) <> go b + EConst{} -> mempty + EIdx0 _ e -> go e + EIdx1 _ a b -> go a <> go b + EIdx _ e es -> go e <> foldMap go es + EOp _ _ e -> go e + EWith a b -> go a <> unpush (go b) + EAccum a b e -> go a <> go b <> go e + EError{} -> mempty diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 3473131..ba1b756 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -136,6 +136,10 @@ ppExpr' d val = \case EConst _ ty v -> return $ showString $ case ty of STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v + EIdx0 _ e -> do + e' <- ppExpr' 11 val e + return $ showParen (d > 10) $ showString "idx0 " . e' + EIdx1 _ a b -> do a' <- ppExpr' 9 val a b' <- ppExpr' 9 val b diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 432b687..78577ee 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -28,7 +28,7 @@ deriving instance Show (Idx env t) data env :> env' where WId :: env :> env - WSink :: env :> (t : env) + WSink :: forall t env. env :> (t : env) WCopy :: env :> env' -> (t : env) :> (t : env') WPop :: (t : env) :> env' -> env :> env' WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3 diff --git a/src/CHAD.hs b/src/CHAD.hs index e209b67..a6dd9ff 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -27,7 +27,6 @@ module CHAD ( ) where import Data.Bifunctor (first, second) -import Data.Functor.Const import Data.Kind (Type) import GHC.TypeLits (Symbol) @@ -242,6 +241,28 @@ letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t letBinds BTop = id letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs +type family Vectorise n list where + Vectorise _ '[] = '[] + Vectorise n (t : ts) = TArr n t : Vectorise n ts + +vectoriseIdx :: Idx binds t -> Idx (Vectorise n binds) (TArr n t) +vectoriseIdx IZ = IZ +vectoriseIdx (IS i) = IS (vectoriseIdx i) + +vectorise1Binds :: forall env binds. SList STy env -> Idx env TIx -> Bindings Ex env binds -> Bindings Ex env (Vectorise (S Z) binds) +vectorise1Binds _ _ BTop = BTop +vectorise1Binds env n (bs `BPush` (t, e)) = + let bs' = vectorise1Binds env n bs + e' = EBuild1 ext (EVar ext tIx (sinkWithBindings bs' @> n)) + (subst (\_ t' i -> case splitIdx @env (bindingsBinds bs) i of + Left i1 -> + let i1' = IS (wRaiseAbove (bindingsBinds bs') env @> vectoriseIdx i1) + in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t') i1') + (EVar ext tIx (WSink .> sinkWithBindings bs' @> n))) + Right i2 -> EVar ext t' (IS (sinkWithBindings bs' @> i2))) + e) + in bs' `BPush` (STArr (SS SZ) t, e') + type family D1 t where D1 TNil = TNil D1 (TPair a b) = TPair (D1 a) (D1 b) @@ -588,9 +609,9 @@ select s@SMerge (DPush des (_, SAccum)) = select s des select s@SAccum (DPush des (_, SMerge)) = select s des select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) -sD1eEnv :: Descr env sto -> SList (Const ()) (D1E env) +sD1eEnv :: Descr env sto -> SList STy (D1E env) sD1eEnv DTop = SNil -sD1eEnv (DPush d _) = SCons (Const ()) (sD1eEnv d) +sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d) d2e :: SList STy env -> SList STy (D2E env) d2e SNil = SNil @@ -806,13 +827,39 @@ drev des = \case (subenvNone (select SMerge des)) (ENil ext) + EBuild1 _ ne e + -- TODO: use occCountAll to determine which variables from @env are used in + -- 'e', and promote those to SAccum storage in 'des' + | Ret (ne0 :: Bindings _ _ ne_binds) ne1 nsub ne2 <- drev des ne + , Ret e0 e1 sub e2 <- drev (des `DPush` (tIx, SMerge)) e + , let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv des) IZ e0 -> + Ret (bconcat (ne0 `BPush` (tIx, ne1)) + (fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0))) + (EBuild1 ext + (weakenExpr (wStack @(D1E env) (wSinks (bindingsBinds ve0) .> WSink @TIx @ne_binds)) + ne1) + (subst (\_ t i -> case splitIdx @(TIx : D1E env) (bindingsBinds e0) i of + Left ibind -> + let ibind' = WSink + .> wRaiseAbove (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) + (sD1eEnv des) + .> wRaiseAbove (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0) + @> vectoriseIdx ibind + in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t) ibind') + (EVar ext tIx IZ)) + Right IZ -> EVar ext tIx IZ -- build lambda index argument + Right (IS ienv) -> EVar ext t (IS (wSinks (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) @> ienv))) + e1)) + (subenvNone (select SMerge des)) + _ + -- These should be the next to be implemented, I think - EBuild1{} -> err_unsupported "EBuild1" - EFold1{} -> err_unsupported "EFold1" + EIdx0{} -> err_unsupported "EIdx0" EIdx1{} -> err_unsupported "EIdx1" + EFold1{} -> err_unsupported "EFold1" - EBuild{} -> err_unsupported "EBuild" EIdx{} -> err_unsupported "EIdx" + EBuild{} -> err_unsupported "EBuild" EWith{} -> err_accum EAccum{} -> err_accum diff --git a/src/Simplify.hs b/src/Simplify.hs index 39b3afd..af0ca4c 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -74,6 +74,7 @@ simplify' = \case EBuild _ es e -> EBuild ext (fmap simplify' es) (simplify' e) EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b) EConst _ t v -> EConst ext t v + EIdx0 _ e -> EIdx0 ext (simplify' e) EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b) EIdx _ e es -> EIdx ext (simplify' e) (fmap simplify' es) EOp _ op e -> EOp ext op (simplify' e) @@ -105,6 +106,7 @@ hasAdds = \case EBuild _ es e -> getAny (foldMap (Any . hasAdds) es) || hasAdds e EFold1 _ a b -> hasAdds a || hasAdds b EConst _ _ _ -> False + EIdx0 _ e -> hasAdds e EIdx1 _ a b -> hasAdds a || hasAdds b EIdx _ e es -> hasAdds e || getAny (foldMap (Any . hasAdds) es) EOp _ _ e -> hasAdds e |