summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/AST.hs9
-rw-r--r--src/AST/Count.hs21
-rw-r--r--src/AST/Pretty.hs8
-rw-r--r--src/AST/Weaken/Auto.hs11
-rw-r--r--src/CHAD.hs238
-rw-r--r--src/Data.hs7
-rw-r--r--src/Simplify.hs6
7 files changed, 212 insertions, 88 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 15e6d43..6c90be3 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -92,6 +92,7 @@ data Expr x env t where
EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t)
EBuild :: x (TArr n t) -> Vec n (Expr x env TIx) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t)
EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
+ EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
-- expression operations
EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
@@ -102,7 +103,7 @@ data Expr x env t where
-- accumulation effect
EWith :: Expr x env (TArr n t) -> Expr x (TAccum n t : env) a -> Expr x env (TPair a (TArr n t))
- EAccum :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum n t) -> Expr x env TNil
+ EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil
-- partiality
EError :: STy a -> String -> Expr x env a
@@ -152,6 +153,7 @@ typeOf = \case
EBuild1 _ _ e -> STArr (SS SZ) (typeOf e)
EBuild _ es e -> STArr (vecLength es) (typeOf e)
EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EUnit _ e -> STArr SZ (typeOf e)
EConst _ t _ -> STScal t
EIdx0 _ e | STArr _ t <- typeOf e -> t
@@ -160,7 +162,7 @@ typeOf = \case
EOp _ op _ -> opt2 op
EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)
- EAccum _ _ _ -> STNil
+ EAccum1 _ _ _ -> STNil
EError t _ -> t
@@ -214,13 +216,14 @@ subst' f w = \case
EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e)
EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
+ EUnit x e -> EUnit x (subst' f w e)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (subst' f w e)
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es)
EOp x op e -> EOp x op (subst' f w e)
EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
- EAccum e1 e2 e3 -> EAccum (subst' f w e1) (subst' f w e2) (subst' f w e3)
+ EAccum1 e1 e2 e3 -> EAccum1 (subst' f w e1) (subst' f w e2) (subst' f w e3)
EError t s -> EError t s
where
sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 7e70a7d..289c1fb 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -76,17 +76,17 @@ scaleManyOccEnv :: OccEnv env -> OccEnv env
scaleManyOccEnv OccEnd = OccEnd
scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o)
+occEnvPop :: OccEnv (t : env) -> OccEnv env
+occEnvPop (OccPush o _) = o
+occEnvPop OccEnd = OccEnd
+
occCountAll :: Expr x env t -> OccEnv env
-occCountAll = occCountGeneral onehotOccEnv unpush unpushN (<||>!) scaleManyOccEnv
+occCountAll = occCountGeneral onehotOccEnv occEnvPop occEnvPopN (<||>!) scaleManyOccEnv
where
- unpush :: OccEnv (t : env) -> OccEnv env
- unpush (OccPush o _) = o
- unpush OccEnd = OccEnd
-
- unpushN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env
- unpushN _ OccEnd = OccEnd
- unpushN SZ e = e
- unpushN (SS n) (OccPush e _) = unpushN n e
+ occEnvPopN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env
+ occEnvPopN _ OccEnd = OccEnd
+ occEnvPopN SZ e = e
+ occEnvPopN (SS n) (OccPush e _) = occEnvPopN n e
occCountGeneral :: forall r env t x.
(forall env'. Monoid (r env'))
@@ -112,11 +112,12 @@ occCountGeneral onehot unpush unpushN alter many = go
EBuild1 _ a b -> go a <> many (unpush (go b))
EBuild _ es e -> foldMap go es <> many (unpushN (vecLength es) (go e))
EFold1 _ a b -> many (unpush (unpush (go a))) <> go b
+ EUnit _ e -> go e
EConst{} -> mempty
EIdx0 _ e -> go e
EIdx1 _ a b -> go a <> go b
EIdx _ e es -> go e <> foldMap go es
EOp _ _ e -> go e
EWith a b -> go a <> unpush (go b)
- EAccum a b e -> go a <> go b <> go e
+ EAccum1 a b e -> go a <> go b <> go e
EError{} -> mempty
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index ba1b756..1dc9dd3 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -133,6 +133,10 @@ ppExpr' d val = \case
showString ("fold1 (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a'
. showString ") " . b'
+ EUnit _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ showParen (d > 10) $ showString "unit " . e'
+
EConst _ ty v -> return $ showString $ case ty of
STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v
@@ -176,12 +180,12 @@ ppExpr' d val = \case
showString "with " . e1' . showString (" (\\" ++ name ++ " -> ")
. e2' . showString ")"
- EAccum e1 e2 e3 -> do
+ EAccum1 e1 e2 e3 -> do
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
return $ showParen (d > 10) $
- showString "accum " . e1' . showString " " . e2' . showString " " . e3'
+ showString "accum1 " . e1' . showString " " . e2' . showString " " . e3'
EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs
index 0bf5780..444c540 100644
--- a/src/AST/Weaken/Auto.hs
+++ b/src/AST/Weaken/Auto.hs
@@ -18,7 +18,7 @@
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
module AST.Weaken.Auto (
autoWeak,
- ($..), auto,
+ (&.), auto, auto1,
Layout(..),
) where
@@ -56,9 +56,12 @@ instance (KnownSymbol name, name ~ name', segs ~ '[ '(name', ts)]) => IsLabel na
auto :: KnownListSpine list => SList (Const ()) list
auto = knownListSpine
-infixr $..
-($..) :: SSegments segs1 -> SSegments segs2 -> SSegments (Append segs1 segs2)
-($..) = ssegmentsAppend
+auto1 :: SList (Const ()) '[t]
+auto1 = Const () `SCons` SNil
+
+infixr &.
+(&.) :: SSegments segs1 -> SSegments segs2 -> SSegments (Append segs1 segs2)
+(&.) = ssegmentsAppend
where
ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b)
ssegmentsAppend SSegNil l2 = l2
diff --git a/src/CHAD.hs b/src/CHAD.hs
index aedda5b..9786e1e 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -245,15 +245,45 @@ type family Vectorise n list where
Vectorise _ '[] = '[]
Vectorise n (t : ts) = TArr n t : Vectorise n ts
+vectoriseEnv :: SNat n -> SList STy env -> SList STy (Vectorise n env)
+vectoriseEnv _ SNil = SNil
+vectoriseEnv n (t `SCons` env) = STArr n t `SCons` vectoriseEnv n env
+
vectoriseIdx :: Idx binds t -> Idx (Vectorise n binds) (TArr n t)
vectoriseIdx IZ = IZ
vectoriseIdx (IS i) = IS (vectoriseIdx i)
+vectoriseExpr :: forall prefix binds env t f.
+ SList f prefix
+ -> SList STy binds
+ -> SList f env
+ -> Ex (Append prefix (Append binds env)) t
+ -> Ex (TIx : Append prefix (Append (Vectorise (S Z) binds) env)) t
+vectoriseExpr prefix binds env =
+ let wTarget :: Layout ['("ix", '[TIx]), '("pre", prefix), '("vbinds", Vectorise (S Z) binds), '("env", env)] e
+ -> e :> TIx : Append prefix (Append (Vectorise (S Z) binds) env)
+ wTarget layout =
+ autoWeak (#ix (auto1 @TIx) &. #pre prefix &. #vbinds (vectoriseEnv (SS SZ) binds) &. #env env)
+ layout
+ (#ix :++: #pre :++: #vbinds :++: #env)
+ in
+ subst $ \_ t i ->
+ case splitIdx @(Append binds env) prefix i of
+ Left iPre -> EVar ext t (wTarget #pre @> iPre)
+ Right i' ->
+ case splitIdx @env binds i' of
+ Left iBinds ->
+ EIdx0 ext $
+ EIdx1 ext (EVar ext (STArr (SS SZ) t) (wTarget #vbinds @> vectoriseIdx iBinds))
+ (EVar ext tIx IZ)
+ Right iEnv -> EVar ext t (wTarget #env @> iEnv)
+
vectorise1Binds :: forall env binds. SList STy env -> Idx env TIx -> Bindings Ex env binds -> Bindings Ex env (Vectorise (S Z) binds)
vectorise1Binds _ _ BTop = BTop
vectorise1Binds env n (bs `BPush` (t, e)) =
let bs' = vectorise1Binds env n bs
e' = EBuild1 ext (EVar ext tIx (sinkWithBindings bs' @> n))
+ -- TODO: use vectoriseExpr here
(subst (\_ t' i -> case splitIdx @env (bindingsBinds bs) i of
Left i1 ->
let i1' = IS (wRaiseAbove (bindingsBinds bs') env @> vectoriseIdx i1)
@@ -285,7 +315,7 @@ type family D2s t where
D2s TBool = TNil
type family D2Ac t where
- D2Ac (TArr n t) = TAccum n t
+ D2Ac (TArr n t) = TAccum n (D2 t)
type family D1E env where
D1E '[] = '[]
@@ -327,7 +357,7 @@ d2 (STScal t) = case t of
d2 STAccum{} = error "Accumulators not allowed in input program"
d2ac :: STy t -> STy (D2Ac t)
-d2ac (STArr n t) = STAccum n t
+d2ac (STArr n t) = STAccum n (d2 t)
d2ac _ = error "Only arrays may appear in the accumulator environment"
conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
@@ -346,7 +376,7 @@ zero :: STy t -> Ex env (D2 t)
zero STNil = ENil ext
zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext)
zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext)
-zero STArr{} = error "TODO arrays"
+zero (STArr n t) = EBuild ext (vecGenerate n (\_ -> EConst ext STI64 0)) (zero t)
zero (STScal t) = case t of
STI32 -> ENil ext
STI64 -> ENil ext
@@ -374,7 +404,10 @@ plus (STEither t1 t2) a b =
(ECase ext (EVar ext t (IS IZ))
(EError t "plus r+l")
(EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
-plus STArr{} _ _ = error "TODO arrays"
+plus STArr{} _ _ = error "TODO plus on arrays"
+ -- 'zero' creates an empty array; this should be a new primitive that
+ -- (operationally) intelligently memcpy's the non-overlapping part and does
+ -- a parallel add on the overlapping part.
plus (STScal t) a b = case t of
STI32 -> ENil ext
STI64 -> ENil ext
@@ -426,7 +459,7 @@ accumPromote :: forall dt env sto proxy r.
-> (forall shbinds.
SList STy shbinds
-> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
- :> Append envPro (D2 dt : Append shbinds (D2AcE (Select env sto "accum"))))
+ :> Append (D2AcE envPro) (D2 dt : Append shbinds (D2AcE (Select env sto "accum"))))
-- ^ A weakening that converts a computation in the
-- revised environment to one in the original environment
-- extended with some accumulators.
@@ -443,11 +476,11 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k =
mergesub
envpro
(\shbinds ->
- autoWeak (#pro envpro $.. #d (auto @'[D2 dt]) $.. #shb shbinds $.. #acc (auto @'[D2Ac t]) $.. #tl (d2ace (select SAccum descr)))
+ autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum descr)))
(#acc :++: (#pro :++: #d :++: #shb :++: #tl))
(#pro :++: #d :++: #shb :++: #acc :++: #tl)
.> WCopy (wf shbinds)
- .> autoWeak (#d (auto @'[D2 dt]) $.. #shb shbinds $.. #acc (auto @'[D2Ac t]) $.. #tl (d2ace (select SAccum storepl)))
+ .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum storepl)))
(#d :++: #shb :++: #acc :++: #tl)
(#acc :++: (#d :++: #shb :++: #tl)))
@@ -455,7 +488,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k =
(STArr (arrn :: SNat arrn) (arrt :: STy arrt), SMerge, Occ _ c) | c > Zero ->
k (storepl `DPush` (t, SAccum))
(SENo mergesub)
- (STAccum arrn arrt `SCons` envpro)
+ (STArr arrn arrt `SCons` envpro)
(\(shbinds :: SList _ shbinds) ->
let shbindsC = slistMap (\_ -> Const ()) shbinds
in
@@ -467,30 +500,49 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k =
-- goal: | ARE EQUAL ||
-- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
WCopy (wf shbinds)
- .> WPick @(TAccum arrn arrt) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
+ .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
(WId @(D2AcE (Select env1 stoRepl "accum"))))
- -- Used "merge" values must either _be_ an array (and hence be caught by
- -- the prior case), or contain _no_ arrays at all (TODO: generalise this)
- (_, SMerge, Occ _ c) | c > Zero, containsTArr t ->
- error $ "Closure variable of 'build'-like thing contains a composite type containing an array: " ++ show t
-
- -- What's left are normal "merge" values that don't contain arrays; those
- -- remain as-is
- (_, SMerge, _) ->
- k (storepl `DPush` (t, SMerge))
- (SEYes mergesub)
- envpro
- wf
- where
- containsTArr :: STy t' -> Bool
- containsTArr = \case
- STNil -> False
- STPair a b -> containsTArr a || containsTArr b
- STEither a b -> containsTArr a || containsTArr b
- STArr{} -> True
- STScal{} -> False
- STAccum{} -> error "An accumulator in merge storage?"
+ -- Used "merge" values must be an array, so reject everything else. (TODO: generalise this)
+ (_, SMerge, Occ _ c)
+ | c > Zero ->
+ error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t
+ | otherwise ->
+ k (storepl `DPush` (t, SMerge))
+ (SEYes mergesub)
+ envpro
+ wf
+ -- where
+ -- containsTArr :: STy t' -> Bool
+ -- containsTArr = \case
+ -- STNil -> False
+ -- STPair a b -> containsTArr a || containsTArr b
+ -- STEither a b -> containsTArr a || containsTArr b
+ -- STArr{} -> True
+ -- STScal{} -> False
+ -- STAccum{} -> error "An accumulator in merge storage?"
+
+type family InvTup core env where
+ InvTup core '[] = core
+ InvTup core (t : ts) = InvTup (TPair core t) ts
+
+makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
+makeAccumulators SNil e = e
+makeAccumulators (STArr n t `SCons` envpro) e =
+ makeAccumulators envpro $
+ EWith (zero (STArr n t)) e
+makeAccumulators (t `SCons` _) _ = error $ "makeAccumulators: Not only arrays in envpro: " ++ show t
+
+uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
+uninvertTup SNil _ e = EPair ext e (ENil ext)
+uninvertTup (t `SCons` list) tcore e =
+ ELet ext (uninvertTup list (STPair tcore t) e) $
+ let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding
+ in EPair ext
+ (EFst ext (EFst ext (EVar ext recT IZ)))
+ (EPair ext
+ (ESnd ext (EVar ext recT IZ))
+ (ESnd ext (EFst ext (EVar ext recT IZ))))
-- | @env'@ is a subset of @env@: each element of @env@ is either included in
-- @env'@ ('SEYes') or not included in @env'@ ('SENo').
@@ -526,6 +578,10 @@ subList SNil SETop = SNil
subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub)
subList (SCons _ xs) (SENo sub) = subList xs sub
+subenvAll :: SList f env -> Subenv env env
+subenvAll SNil = SETop
+subenvAll (SCons _ env) = SEYes (subenvAll env)
+
subenvNone :: SList f env -> Subenv env '[]
subenvNone SNil = SETop
subenvNone (SCons _ env) = SENo (subenvNone env)
@@ -582,6 +638,11 @@ expandSubenvZeros (SCons t ts) (SEYes sub) e =
in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)
expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t)
+assertSubenvEmpty :: Subenv env env' -> env' :~: '[]
+assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
+assertSubenvEmpty SETop = Refl
+assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
+
popFromScope
:: Descr env0 sto
-> STy a
@@ -617,7 +678,7 @@ rebaseRetPair :: forall env b1 b2 env0 sto t f.
rebaseRetPair descr b1 b2 (RetPair p sub d)
| Refl <- lemAppendAssoc @b2 @b1 @env =
RetPair p sub (weakenExpr (autoWeak
- (#d (auto @'[D2 t]) $.. #b2 b2 $.. #b1 b1 $.. #tl (d2ace (select SAccum descr)))
+ (#d (auto1 @(D2 t)) &. #b2 b2 &. #b1 b1 &. #tl (d2ace (select SAccum descr)))
(#d :++: (#b2 :++: #tl))
(#d :++: ((#b2 :++: #b1) :++: #tl)))
d)
@@ -759,10 +820,10 @@ drev des = \case
(weakenExpr wbody0' body1)
subBoth
(ELet ext
- (weakenExpr (autoWeak (#d (auto @'[D2 t])
- $.. #body (bindingsBinds body0)
- $.. #rhs (SCons (typeOf rhs1) (bindingsBinds rhs0))
- $.. #tl (d2ace (select SAccum des)))
+ (weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #body (bindingsBinds body0)
+ &. #rhs (SCons (typeOf rhs1) (bindingsBinds rhs0))
+ &. #tl (d2ace (select SAccum des)))
(#d :++: #body :++: #tl)
(#d :++: #body :++: #rhs :++: #tl))
body2') $
@@ -867,12 +928,12 @@ drev des = \case
ELet ext
(EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $
ELet ext
- (weakenExpr (autoWeak (#d (auto @'[D2 t])
- $.. #a0 (bindingsBinds a0)
- $.. #prea0 prerebinds
- $.. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil)
- $.. #binds (tPrimal `SCons` bindingsBinds e0)
- $.. #tl (d2ace (select SAccum des)))
+ (weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #a0 (bindingsBinds a0)
+ &. #prea0 prerebinds
+ &. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil)
+ &. #binds (tPrimal `SCons` bindingsBinds e0)
+ &. #tl (d2ace (select SAccum des)))
(#d :++: #a0 :++: #tl)
(#d :++: (#a0 :++: #prea0) :++: #recon :++: #binds :++: #tl))
a2') $
@@ -886,12 +947,12 @@ drev des = \case
ELet ext
(EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $
ELet ext
- (weakenExpr (autoWeak (#d (auto @'[D2 t])
- $.. #b0 (bindingsBinds b0)
- $.. #preb0 prerebinds
- $.. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil)
- $.. #binds (tPrimal `SCons` bindingsBinds e0)
- $.. #tl (d2ace (select SAccum des)))
+ (weakenExpr (autoWeak (#d (auto1 @(D2 t))
+ &. #b0 (bindingsBinds b0)
+ &. #preb0 prerebinds
+ &. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil)
+ &. #binds (tPrimal `SCons` bindingsBinds e0)
+ &. #tl (d2ace (select SAccum des)))
(#d :++: #b0 :++: #tl)
(#d :++: (#b0 :++: #preb0) :++: #recon :++: #binds :++: #tl))
b2') $
@@ -937,38 +998,81 @@ drev des = \case
(ENil ext)
EBuild1 _ ne e
- -- TODO: use occCountAll to determine which variables from @env are used in
- -- 'e', and promote those to SAccum storage in 'des'
- | Ret (ne0 :: Bindings _ _ ne_binds) ne1 nsub ne2 <- drev des ne
- , Ret e0 e1 sub e2 <- drev (des `DPush` (tIx, SMerge)) e
- , let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv des) IZ e0 ->
+ | Ret (ne0 :: Bindings _ _ ne_binds) ne1 nsub ne2 <- drev des ne ->
+ accumPromote (typeOf e) des (occEnvPop (occCountAll e)) $ \vdes proSub envPro wPro ->
+ case drev (vdes `DPush` (tIx, SMerge)) e of { Ret e0 e1 sub e2 ->
+ case assertSubenvEmpty sub of { Refl ->
+ case assertSubenvEmpty proSub of { Refl ->
+ let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv des) IZ e0 in
Ret (bconcat (ne0 `BPush` (tIx, ne1))
(fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0)))
(EBuild1 ext
(weakenExpr (autoWeak (#ve0 (bindingsBinds ve0)
- $.. #i (auto @'[TIx])
- $.. #ne0 (bindingsBinds ne0)
- $.. #tl (sD1eEnv des))
- (#ne0 :++: #tl)
- ((#ve0 :++: #i :++: #ne0) :++: #tl))
- ne1)
+ &. #binds (tIx `SCons` bindingsBinds ne0)
+ &. #tl (sD1eEnv des))
+ #binds
+ ((#ve0 :++: #binds) :++: #tl))
+ (EVar ext tIx IZ))
(subst (\_ t i -> case splitIdx @(TIx : D1E env) (bindingsBinds e0) i of
Left ibind ->
- let ibind' = WSink
- .> wRaiseAbove (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0))
- (sD1eEnv des)
- .> wRaiseAbove (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)
- @> vectoriseIdx ibind
+ let ibind' =
+ autoWeak (#ix (auto1 @TIx)
+ &. #ve0 (bindingsBinds ve0)
+ &. #binds (tIx `SCons` bindingsBinds ne0)
+ &. #tl (sD1eEnv des))
+ #ve0
+ (#ix :++: (#ve0 :++: #binds) :++: #tl)
+ @> vectoriseIdx ibind
in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t) ibind')
(EVar ext tIx IZ))
Right IZ -> EVar ext tIx IZ -- build lambda index argument
Right (IS ienv) -> EVar ext t (IS (wSinks (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) @> ienv)))
e1))
- (subenvNone (select SMerge des))
- _
+ nsub
+ (ELet ext
+ (makeAccumulators envPro $
+ EBuild1 ext
+ (weakenExpr (autoWeak (#ve0 (bindingsBinds ve0)
+ &. #pro (d2ace envPro)
+ &. #d (auto1 @(D2 t))
+ &. #binds (tIx `SCons` bindingsBinds ne0)
+ &. #tl (d2ace (select SAccum des)))
+ #binds
+ (#pro :++: #d :++: (#ve0 :++: #binds) :++: #tl))
+ (EVar ext tIx IZ))
+ -- TODO: use vectoriseExpr
+ (_ (weakenExpr (wPro (bindingsBinds e0)) e2))) $
+ ELet ext (ENil ext) $
+ weakenExpr (autoWeak (#nil (auto1 @TNil)
+ &. #d (auto1 @(D2 t))
+ &. #nilarr (auto1 @(TArr (S Z) TNil))
+ &. #ve0 (bindingsBinds ve0)
+ &. #n (auto1 @TIx)
+ &. #binds (bindingsBinds ne0)
+ &. #tl (d2ace (select SAccum des)))
+ (#nil :++: #binds :++: #tl)
+ (#nil :++: #nilarr :++: #d :++: (#ve0 :++: #n :++: #binds) :++: #tl))
+ ne2)
+ }}}
+
+ EUnit _ e
+ | Ret e0 e1 sub e2 <- drev des e ->
+ Ret e0
+ (EUnit ext e1)
+ sub
+ (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $
+ weakenExpr (WCopy WSink) e2)
+
+ EIdx0 _ e
+ | Ret e0 e1 sub e2 <- drev des e
+ , STArr _ t <- typeOf e ->
+ Ret e0
+ (EIdx0 ext e1)
+ sub
+ (ELet ext (EUnit ext (EVar ext (d2 t) IZ)) $
+ weakenExpr (WCopy WSink) e2)
-- These should be the next to be implemented, I think
- EIdx0{} -> err_unsupported "EIdx0"
EIdx1{} -> err_unsupported "EIdx1"
EFold1{} -> err_unsupported "EFold1"
@@ -976,7 +1080,7 @@ drev des = \case
EBuild{} -> err_unsupported "EBuild"
EWith{} -> err_accum
- EAccum{} -> err_accum
+ EAccum1{} -> err_accum
where
err_accum = error "Accumulator operations unsupported in the source program"
diff --git a/src/Data.hs b/src/Data.hs
index a3f4c3c..8c39c6c 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -52,3 +52,10 @@ deriving instance Traversable (Vec n)
vecLength :: Vec n t -> SNat n
vecLength VNil = SZ
vecLength (_ :< v) = SS (vecLength v)
+
+vecGenerate :: SNat n -> (forall i. SNat i -> t) -> Vec n t
+vecGenerate = \n f -> go n f SZ
+ where
+ go :: SNat n -> (forall i. SNat i -> t) -> SNat i' -> Vec n t
+ go SZ _ _ = VNil
+ go (SS n) f i = f i :< go n f (SS i)
diff --git a/src/Simplify.hs b/src/Simplify.hs
index af0ca4c..f2fc54a 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -73,13 +73,14 @@ simplify' = \case
EBuild1 _ a b -> EBuild1 ext (simplify' a) (simplify' b)
EBuild _ es e -> EBuild ext (fmap simplify' es) (simplify' e)
EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b)
+ EUnit _ e -> EUnit ext (simplify' e)
EConst _ t v -> EConst ext t v
EIdx0 _ e -> EIdx0 ext (simplify' e)
EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b)
EIdx _ e es -> EIdx ext (simplify' e) (fmap simplify' es)
EOp _ op e -> EOp ext op (simplify' e)
EWith e1 e2 -> EWith (simplify' e1) (let ?accumInScope = True in simplify' e2)
- EAccum e1 e2 e3 -> EAccum (simplify' e1) (simplify' e2) (simplify' e3)
+ EAccum1 e1 e2 e3 -> EAccum1 (simplify' e1) (simplify' e2) (simplify' e3)
EError t s -> EError t s
cheapExpr :: Expr x env t -> Bool
@@ -105,13 +106,14 @@ hasAdds = \case
EBuild1 _ a b -> hasAdds a || hasAdds b
EBuild _ es e -> getAny (foldMap (Any . hasAdds) es) || hasAdds e
EFold1 _ a b -> hasAdds a || hasAdds b
+ EUnit _ e -> hasAdds e
EConst _ _ _ -> False
EIdx0 _ e -> hasAdds e
EIdx1 _ a b -> hasAdds a || hasAdds b
EIdx _ e es -> hasAdds e || getAny (foldMap (Any . hasAdds) es)
EOp _ _ e -> hasAdds e
EWith a b -> hasAdds a || hasAdds b
- EAccum _ _ _ -> True
+ EAccum1 _ _ _ -> True
EError _ _ -> False
checkAccumInScope :: SList STy env -> Bool