summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-08-30 22:45:46 +0200
committerTom Smeding <tom@tomsmeding.com>2024-08-30 22:45:46 +0200
commit1f7ed2ee02222108684cfde8078e7a182f734a61 (patch)
tree976175ede4ec12a6e4a65d5e45e0b1ee8eeff5e6
parent172887fb577526de92b0653b5d3153114f8ce02a (diff)
WIP Build1
-rw-r--r--src/AST.hs12
-rw-r--r--src/AST/Count.hs106
-rw-r--r--src/AST/Pretty.hs4
-rw-r--r--src/AST/Weaken.hs2
-rw-r--r--src/CHAD.hs59
-rw-r--r--src/Simplify.hs2
6 files changed, 157 insertions, 28 deletions
diff --git a/src/AST.hs b/src/AST.hs
index d9acd99..c191651 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -15,6 +15,7 @@
{-# LANGUAGE EmptyCase #-}
module AST (module AST, module AST.Weaken) where
+import Data.Bifunctor (first)
import Data.Functor.Const
import Data.Kind (Type)
import Data.Int
@@ -55,6 +56,9 @@ deriving instance Show (SScalTy t)
type TIx = TScal TI64
+tIx :: STy TIx
+tIx = STScal STI64
+
type family ScalRep t where
ScalRep TI32 = Int32
ScalRep TI64 = Int64
@@ -92,6 +96,7 @@ data Expr x env t where
-- expression operations
EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
+ EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t
EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)
EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t
EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
@@ -150,6 +155,7 @@ typeOf = \case
EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
EConst _ t _ -> STScal t
+ EIdx0 _ e | STArr _ t <- typeOf e -> t
EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t
EIdx _ e _ | STArr _ t <- typeOf e -> t
EOp _ op _ -> opt2 op
@@ -210,6 +216,7 @@ subst' f w = \case
EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e)
EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
EConst x t v -> EConst x t v
+ EIdx0 x e -> EIdx0 x (subst' f w e)
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es)
EOp x op e -> EOp x op (subst' f w e)
@@ -254,6 +261,11 @@ idx2int :: Idx env t -> Int
idx2int IZ = 0
idx2int (IS n) = 1 + idx2int n
+splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t)
+splitIdx SNil i = Right i
+splitIdx (SCons _ _) IZ = Left IZ
+splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i)
+
class KnownScalTy t where knownScalTy :: SScalTy t
instance KnownScalTy TI32 where knownScalTy = STI32
instance KnownScalTy TI64 where knownScalTy = STI64
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index f66b809..7e70a7d 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -1,10 +1,17 @@
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeOperators #-}
module AST.Count where
+import Data.Functor.Const
import GHC.Generics (Generic, Generically(..))
import AST
@@ -35,24 +42,81 @@ scaleMany :: Occ -> Occ
scaleMany (Occ l _) = Occ l Many
occCount :: Idx env a -> Expr x env t -> Occ
-occCount idx = \case
- EVar _ _ i | idx2int i == idx2int idx -> Occ One One
- | otherwise -> mempty
- ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body
- EPair _ a b -> occCount idx a <> occCount idx b
- EFst _ e -> occCount idx e
- ESnd _ e -> occCount idx e
- ENil _ -> mempty
- EInl _ _ e -> occCount idx e
- EInr _ _ e -> occCount idx e
- ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b)
- EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b)
- EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e)
- EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b
- EConst{} -> mempty
- EIdx1 _ a b -> occCount idx a <> occCount idx b
- EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es
- EOp _ _ e -> occCount idx e
- EWith a b -> occCount idx a <> occCount (IS idx) b
- EAccum a b e -> occCount idx a <> occCount idx b <> occCount idx e
- EError{} -> mempty
+occCount idx =
+ getConst . occCountGeneral
+ (\i o -> if idx2int i == idx2int idx then Const o else mempty)
+ (\(Const o) -> Const o)
+ (\_ (Const o) -> Const o)
+ (\(Const o1) (Const o2) -> Const (o1 <||> o2))
+ (\(Const o) -> Const (scaleMany o))
+
+
+data OccEnv env where
+ OccEnd :: OccEnv env -- not necessarily top!
+ OccPush :: OccEnv env -> Occ -> OccEnv (t : env)
+
+instance Semigroup (OccEnv env) where
+ OccEnd <> e = e
+ e <> OccEnd = e
+ OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o')
+
+instance Monoid (OccEnv env) where
+ mempty = OccEnd
+
+onehotOccEnv :: Idx env t -> Occ -> OccEnv env
+onehotOccEnv IZ v = OccPush OccEnd v
+onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty
+
+(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env
+OccEnd <||>! e = e
+e <||>! OccEnd = e
+OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o')
+
+scaleManyOccEnv :: OccEnv env -> OccEnv env
+scaleManyOccEnv OccEnd = OccEnd
+scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o)
+
+occCountAll :: Expr x env t -> OccEnv env
+occCountAll = occCountGeneral onehotOccEnv unpush unpushN (<||>!) scaleManyOccEnv
+ where
+ unpush :: OccEnv (t : env) -> OccEnv env
+ unpush (OccPush o _) = o
+ unpush OccEnd = OccEnd
+
+ unpushN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env
+ unpushN _ OccEnd = OccEnd
+ unpushN SZ e = e
+ unpushN (SS n) (OccPush e _) = unpushN n e
+
+occCountGeneral :: forall r env t x.
+ (forall env'. Monoid (r env'))
+ => (forall env' a. Idx env' a -> Occ -> r env') -- ^ one-hot
+ -> (forall env' a. r (a : env') -> r env') -- ^ unpush
+ -> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env') -- ^ unpushN
+ -> (forall env'. r env' -> r env' -> r env') -- ^ alternation
+ -> (forall env'. r env' -> r env') -- ^ scale-many
+ -> Expr x env t -> r env
+occCountGeneral onehot unpush unpushN alter many = go
+ where
+ go :: Monoid (r env') => Expr x env' t' -> r env'
+ go = \case
+ EVar _ _ i -> onehot i (Occ One One)
+ ELet _ rhs body -> go rhs <> unpush (go body)
+ EPair _ a b -> go a <> go b
+ EFst _ e -> go e
+ ESnd _ e -> go e
+ ENil _ -> mempty
+ EInl _ _ e -> go e
+ EInr _ _ e -> go e
+ ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b))
+ EBuild1 _ a b -> go a <> many (unpush (go b))
+ EBuild _ es e -> foldMap go es <> many (unpushN (vecLength es) (go e))
+ EFold1 _ a b -> many (unpush (unpush (go a))) <> go b
+ EConst{} -> mempty
+ EIdx0 _ e -> go e
+ EIdx1 _ a b -> go a <> go b
+ EIdx _ e es -> go e <> foldMap go es
+ EOp _ _ e -> go e
+ EWith a b -> go a <> unpush (go b)
+ EAccum a b e -> go a <> go b <> go e
+ EError{} -> mempty
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 3473131..ba1b756 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -136,6 +136,10 @@ ppExpr' d val = \case
EConst _ ty v -> return $ showString $ case ty of
STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v
+ EIdx0 _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ showParen (d > 10) $ showString "idx0 " . e'
+
EIdx1 _ a b -> do
a' <- ppExpr' 9 val a
b' <- ppExpr' 9 val b
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index 432b687..78577ee 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -28,7 +28,7 @@ deriving instance Show (Idx env t)
data env :> env' where
WId :: env :> env
- WSink :: env :> (t : env)
+ WSink :: forall t env. env :> (t : env)
WCopy :: env :> env' -> (t : env) :> (t : env')
WPop :: (t : env) :> env' -> env :> env'
WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3
diff --git a/src/CHAD.hs b/src/CHAD.hs
index e209b67..a6dd9ff 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -27,7 +27,6 @@ module CHAD (
) where
import Data.Bifunctor (first, second)
-import Data.Functor.Const
import Data.Kind (Type)
import GHC.TypeLits (Symbol)
@@ -242,6 +241,28 @@ letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t
letBinds BTop = id
letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs
+type family Vectorise n list where
+ Vectorise _ '[] = '[]
+ Vectorise n (t : ts) = TArr n t : Vectorise n ts
+
+vectoriseIdx :: Idx binds t -> Idx (Vectorise n binds) (TArr n t)
+vectoriseIdx IZ = IZ
+vectoriseIdx (IS i) = IS (vectoriseIdx i)
+
+vectorise1Binds :: forall env binds. SList STy env -> Idx env TIx -> Bindings Ex env binds -> Bindings Ex env (Vectorise (S Z) binds)
+vectorise1Binds _ _ BTop = BTop
+vectorise1Binds env n (bs `BPush` (t, e)) =
+ let bs' = vectorise1Binds env n bs
+ e' = EBuild1 ext (EVar ext tIx (sinkWithBindings bs' @> n))
+ (subst (\_ t' i -> case splitIdx @env (bindingsBinds bs) i of
+ Left i1 ->
+ let i1' = IS (wRaiseAbove (bindingsBinds bs') env @> vectoriseIdx i1)
+ in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t') i1')
+ (EVar ext tIx (WSink .> sinkWithBindings bs' @> n)))
+ Right i2 -> EVar ext t' (IS (sinkWithBindings bs' @> i2)))
+ e)
+ in bs' `BPush` (STArr (SS SZ) t, e')
+
type family D1 t where
D1 TNil = TNil
D1 (TPair a b) = TPair (D1 a) (D1 b)
@@ -588,9 +609,9 @@ select s@SMerge (DPush des (_, SAccum)) = select s des
select s@SAccum (DPush des (_, SMerge)) = select s des
select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des)
-sD1eEnv :: Descr env sto -> SList (Const ()) (D1E env)
+sD1eEnv :: Descr env sto -> SList STy (D1E env)
sD1eEnv DTop = SNil
-sD1eEnv (DPush d _) = SCons (Const ()) (sD1eEnv d)
+sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)
d2e :: SList STy env -> SList STy (D2E env)
d2e SNil = SNil
@@ -806,13 +827,39 @@ drev des = \case
(subenvNone (select SMerge des))
(ENil ext)
+ EBuild1 _ ne e
+ -- TODO: use occCountAll to determine which variables from @env are used in
+ -- 'e', and promote those to SAccum storage in 'des'
+ | Ret (ne0 :: Bindings _ _ ne_binds) ne1 nsub ne2 <- drev des ne
+ , Ret e0 e1 sub e2 <- drev (des `DPush` (tIx, SMerge)) e
+ , let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv des) IZ e0 ->
+ Ret (bconcat (ne0 `BPush` (tIx, ne1))
+ (fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0)))
+ (EBuild1 ext
+ (weakenExpr (wStack @(D1E env) (wSinks (bindingsBinds ve0) .> WSink @TIx @ne_binds))
+ ne1)
+ (subst (\_ t i -> case splitIdx @(TIx : D1E env) (bindingsBinds e0) i of
+ Left ibind ->
+ let ibind' = WSink
+ .> wRaiseAbove (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0))
+ (sD1eEnv des)
+ .> wRaiseAbove (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)
+ @> vectoriseIdx ibind
+ in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t) ibind')
+ (EVar ext tIx IZ))
+ Right IZ -> EVar ext tIx IZ -- build lambda index argument
+ Right (IS ienv) -> EVar ext t (IS (wSinks (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) @> ienv)))
+ e1))
+ (subenvNone (select SMerge des))
+ _
+
-- These should be the next to be implemented, I think
- EBuild1{} -> err_unsupported "EBuild1"
- EFold1{} -> err_unsupported "EFold1"
+ EIdx0{} -> err_unsupported "EIdx0"
EIdx1{} -> err_unsupported "EIdx1"
+ EFold1{} -> err_unsupported "EFold1"
- EBuild{} -> err_unsupported "EBuild"
EIdx{} -> err_unsupported "EIdx"
+ EBuild{} -> err_unsupported "EBuild"
EWith{} -> err_accum
EAccum{} -> err_accum
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 39b3afd..af0ca4c 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -74,6 +74,7 @@ simplify' = \case
EBuild _ es e -> EBuild ext (fmap simplify' es) (simplify' e)
EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b)
EConst _ t v -> EConst ext t v
+ EIdx0 _ e -> EIdx0 ext (simplify' e)
EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b)
EIdx _ e es -> EIdx ext (simplify' e) (fmap simplify' es)
EOp _ op e -> EOp ext op (simplify' e)
@@ -105,6 +106,7 @@ hasAdds = \case
EBuild _ es e -> getAny (foldMap (Any . hasAdds) es) || hasAdds e
EFold1 _ a b -> hasAdds a || hasAdds b
EConst _ _ _ -> False
+ EIdx0 _ e -> hasAdds e
EIdx1 _ a b -> hasAdds a || hasAdds b
EIdx _ e es -> hasAdds e || getAny (foldMap (Any . hasAdds) es)
EOp _ _ e -> hasAdds e