summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-08-30 22:45:46 +0200
committerTom Smeding <tom@tomsmeding.com>2024-08-30 22:45:46 +0200
commit1f7ed2ee02222108684cfde8078e7a182f734a61 (patch)
tree976175ede4ec12a6e4a65d5e45e0b1ee8eeff5e6 /src/CHAD.hs
parent172887fb577526de92b0653b5d3153114f8ce02a (diff)
WIP Build1
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs59
1 files changed, 53 insertions, 6 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index e209b67..a6dd9ff 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -27,7 +27,6 @@ module CHAD (
) where
import Data.Bifunctor (first, second)
-import Data.Functor.Const
import Data.Kind (Type)
import GHC.TypeLits (Symbol)
@@ -242,6 +241,28 @@ letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t
letBinds BTop = id
letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs
+type family Vectorise n list where
+ Vectorise _ '[] = '[]
+ Vectorise n (t : ts) = TArr n t : Vectorise n ts
+
+vectoriseIdx :: Idx binds t -> Idx (Vectorise n binds) (TArr n t)
+vectoriseIdx IZ = IZ
+vectoriseIdx (IS i) = IS (vectoriseIdx i)
+
+vectorise1Binds :: forall env binds. SList STy env -> Idx env TIx -> Bindings Ex env binds -> Bindings Ex env (Vectorise (S Z) binds)
+vectorise1Binds _ _ BTop = BTop
+vectorise1Binds env n (bs `BPush` (t, e)) =
+ let bs' = vectorise1Binds env n bs
+ e' = EBuild1 ext (EVar ext tIx (sinkWithBindings bs' @> n))
+ (subst (\_ t' i -> case splitIdx @env (bindingsBinds bs) i of
+ Left i1 ->
+ let i1' = IS (wRaiseAbove (bindingsBinds bs') env @> vectoriseIdx i1)
+ in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t') i1')
+ (EVar ext tIx (WSink .> sinkWithBindings bs' @> n)))
+ Right i2 -> EVar ext t' (IS (sinkWithBindings bs' @> i2)))
+ e)
+ in bs' `BPush` (STArr (SS SZ) t, e')
+
type family D1 t where
D1 TNil = TNil
D1 (TPair a b) = TPair (D1 a) (D1 b)
@@ -588,9 +609,9 @@ select s@SMerge (DPush des (_, SAccum)) = select s des
select s@SAccum (DPush des (_, SMerge)) = select s des
select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des)
-sD1eEnv :: Descr env sto -> SList (Const ()) (D1E env)
+sD1eEnv :: Descr env sto -> SList STy (D1E env)
sD1eEnv DTop = SNil
-sD1eEnv (DPush d _) = SCons (Const ()) (sD1eEnv d)
+sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)
d2e :: SList STy env -> SList STy (D2E env)
d2e SNil = SNil
@@ -806,13 +827,39 @@ drev des = \case
(subenvNone (select SMerge des))
(ENil ext)
+ EBuild1 _ ne e
+ -- TODO: use occCountAll to determine which variables from @env are used in
+ -- 'e', and promote those to SAccum storage in 'des'
+ | Ret (ne0 :: Bindings _ _ ne_binds) ne1 nsub ne2 <- drev des ne
+ , Ret e0 e1 sub e2 <- drev (des `DPush` (tIx, SMerge)) e
+ , let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv des) IZ e0 ->
+ Ret (bconcat (ne0 `BPush` (tIx, ne1))
+ (fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0)))
+ (EBuild1 ext
+ (weakenExpr (wStack @(D1E env) (wSinks (bindingsBinds ve0) .> WSink @TIx @ne_binds))
+ ne1)
+ (subst (\_ t i -> case splitIdx @(TIx : D1E env) (bindingsBinds e0) i of
+ Left ibind ->
+ let ibind' = WSink
+ .> wRaiseAbove (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0))
+ (sD1eEnv des)
+ .> wRaiseAbove (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)
+ @> vectoriseIdx ibind
+ in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t) ibind')
+ (EVar ext tIx IZ))
+ Right IZ -> EVar ext tIx IZ -- build lambda index argument
+ Right (IS ienv) -> EVar ext t (IS (wSinks (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) @> ienv)))
+ e1))
+ (subenvNone (select SMerge des))
+ _
+
-- These should be the next to be implemented, I think
- EBuild1{} -> err_unsupported "EBuild1"
- EFold1{} -> err_unsupported "EFold1"
+ EIdx0{} -> err_unsupported "EIdx0"
EIdx1{} -> err_unsupported "EIdx1"
+ EFold1{} -> err_unsupported "EFold1"
- EBuild{} -> err_unsupported "EBuild"
EIdx{} -> err_unsupported "EIdx"
+ EBuild{} -> err_unsupported "EBuild"
EWith{} -> err_accum
EAccum{} -> err_accum