diff options
-rw-r--r-- | src/AST.hs | 9 | ||||
-rw-r--r-- | src/AST/Count.hs | 21 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 8 | ||||
-rw-r--r-- | src/AST/Weaken/Auto.hs | 11 | ||||
-rw-r--r-- | src/CHAD.hs | 238 | ||||
-rw-r--r-- | src/Data.hs | 7 | ||||
-rw-r--r-- | src/Simplify.hs | 6 |
7 files changed, 212 insertions, 88 deletions
@@ -92,6 +92,7 @@ data Expr x env t where EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t) EBuild :: x (TArr n t) -> Vec n (Expr x env TIx) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t) EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) @@ -102,7 +103,7 @@ data Expr x env t where -- accumulation effect EWith :: Expr x env (TArr n t) -> Expr x (TAccum n t : env) a -> Expr x env (TPair a (TArr n t)) - EAccum :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum n t) -> Expr x env TNil + EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil -- partiality EError :: STy a -> String -> Expr x env a @@ -152,6 +153,7 @@ typeOf = \case EBuild1 _ _ e -> STArr (SS SZ) (typeOf e) EBuild _ es e -> STArr (vecLength es) (typeOf e) EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t + EUnit _ e -> STArr SZ (typeOf e) EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t @@ -160,7 +162,7 @@ typeOf = \case EOp _ op _ -> opt2 op EWith e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum _ _ _ -> STNil + EAccum1 _ _ _ -> STNil EError t _ -> t @@ -214,13 +216,14 @@ subst' f w = \case EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b) EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e) EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) + EUnit x e -> EUnit x (subst' f w e) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es) EOp x op e -> EOp x op (subst' f w e) EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum e1 e2 e3 -> EAccum (subst' f w e1) (subst' f w e2) (subst' f w e3) + EAccum1 e1 e2 e3 -> EAccum1 (subst' f w e1) (subst' f w e2) (subst' f w e3) EError t s -> EError t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 7e70a7d..289c1fb 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -76,17 +76,17 @@ scaleManyOccEnv :: OccEnv env -> OccEnv env scaleManyOccEnv OccEnd = OccEnd scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o) +occEnvPop :: OccEnv (t : env) -> OccEnv env +occEnvPop (OccPush o _) = o +occEnvPop OccEnd = OccEnd + occCountAll :: Expr x env t -> OccEnv env -occCountAll = occCountGeneral onehotOccEnv unpush unpushN (<||>!) scaleManyOccEnv +occCountAll = occCountGeneral onehotOccEnv occEnvPop occEnvPopN (<||>!) scaleManyOccEnv where - unpush :: OccEnv (t : env) -> OccEnv env - unpush (OccPush o _) = o - unpush OccEnd = OccEnd - - unpushN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env - unpushN _ OccEnd = OccEnd - unpushN SZ e = e - unpushN (SS n) (OccPush e _) = unpushN n e + occEnvPopN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env + occEnvPopN _ OccEnd = OccEnd + occEnvPopN SZ e = e + occEnvPopN (SS n) (OccPush e _) = occEnvPopN n e occCountGeneral :: forall r env t x. (forall env'. Monoid (r env')) @@ -112,11 +112,12 @@ occCountGeneral onehot unpush unpushN alter many = go EBuild1 _ a b -> go a <> many (unpush (go b)) EBuild _ es e -> foldMap go es <> many (unpushN (vecLength es) (go e)) EFold1 _ a b -> many (unpush (unpush (go a))) <> go b + EUnit _ e -> go e EConst{} -> mempty EIdx0 _ e -> go e EIdx1 _ a b -> go a <> go b EIdx _ e es -> go e <> foldMap go es EOp _ _ e -> go e EWith a b -> go a <> unpush (go b) - EAccum a b e -> go a <> go b <> go e + EAccum1 a b e -> go a <> go b <> go e EError{} -> mempty diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index ba1b756..1dc9dd3 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -133,6 +133,10 @@ ppExpr' d val = \case showString ("fold1 (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a' . showString ") " . b' + EUnit _ e -> do + e' <- ppExpr' 11 val e + return $ showParen (d > 10) $ showString "unit " . e' + EConst _ ty v -> return $ showString $ case ty of STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v @@ -176,12 +180,12 @@ ppExpr' d val = \case showString "with " . e1' . showString (" (\\" ++ name ++ " -> ") . e2' . showString ")" - EAccum e1 e2 e3 -> do + EAccum1 e1 e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ showParen (d > 10) $ - showString "accum " . e1' . showString " " . e2' . showString " " . e3' + showString "accum1 " . e1' . showString " " . e2' . showString " " . e3' EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s) diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs index 0bf5780..444c540 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/AST/Weaken/Auto.hs @@ -18,7 +18,7 @@ {-# OPTIONS_GHC -Wno-partial-type-signatures #-} module AST.Weaken.Auto ( autoWeak, - ($..), auto, + (&.), auto, auto1, Layout(..), ) where @@ -56,9 +56,12 @@ instance (KnownSymbol name, name ~ name', segs ~ '[ '(name', ts)]) => IsLabel na auto :: KnownListSpine list => SList (Const ()) list auto = knownListSpine -infixr $.. -($..) :: SSegments segs1 -> SSegments segs2 -> SSegments (Append segs1 segs2) -($..) = ssegmentsAppend +auto1 :: SList (Const ()) '[t] +auto1 = Const () `SCons` SNil + +infixr &. +(&.) :: SSegments segs1 -> SSegments segs2 -> SSegments (Append segs1 segs2) +(&.) = ssegmentsAppend where ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b) ssegmentsAppend SSegNil l2 = l2 diff --git a/src/CHAD.hs b/src/CHAD.hs index aedda5b..9786e1e 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -245,15 +245,45 @@ type family Vectorise n list where Vectorise _ '[] = '[] Vectorise n (t : ts) = TArr n t : Vectorise n ts +vectoriseEnv :: SNat n -> SList STy env -> SList STy (Vectorise n env) +vectoriseEnv _ SNil = SNil +vectoriseEnv n (t `SCons` env) = STArr n t `SCons` vectoriseEnv n env + vectoriseIdx :: Idx binds t -> Idx (Vectorise n binds) (TArr n t) vectoriseIdx IZ = IZ vectoriseIdx (IS i) = IS (vectoriseIdx i) +vectoriseExpr :: forall prefix binds env t f. + SList f prefix + -> SList STy binds + -> SList f env + -> Ex (Append prefix (Append binds env)) t + -> Ex (TIx : Append prefix (Append (Vectorise (S Z) binds) env)) t +vectoriseExpr prefix binds env = + let wTarget :: Layout ['("ix", '[TIx]), '("pre", prefix), '("vbinds", Vectorise (S Z) binds), '("env", env)] e + -> e :> TIx : Append prefix (Append (Vectorise (S Z) binds) env) + wTarget layout = + autoWeak (#ix (auto1 @TIx) &. #pre prefix &. #vbinds (vectoriseEnv (SS SZ) binds) &. #env env) + layout + (#ix :++: #pre :++: #vbinds :++: #env) + in + subst $ \_ t i -> + case splitIdx @(Append binds env) prefix i of + Left iPre -> EVar ext t (wTarget #pre @> iPre) + Right i' -> + case splitIdx @env binds i' of + Left iBinds -> + EIdx0 ext $ + EIdx1 ext (EVar ext (STArr (SS SZ) t) (wTarget #vbinds @> vectoriseIdx iBinds)) + (EVar ext tIx IZ) + Right iEnv -> EVar ext t (wTarget #env @> iEnv) + vectorise1Binds :: forall env binds. SList STy env -> Idx env TIx -> Bindings Ex env binds -> Bindings Ex env (Vectorise (S Z) binds) vectorise1Binds _ _ BTop = BTop vectorise1Binds env n (bs `BPush` (t, e)) = let bs' = vectorise1Binds env n bs e' = EBuild1 ext (EVar ext tIx (sinkWithBindings bs' @> n)) + -- TODO: use vectoriseExpr here (subst (\_ t' i -> case splitIdx @env (bindingsBinds bs) i of Left i1 -> let i1' = IS (wRaiseAbove (bindingsBinds bs') env @> vectoriseIdx i1) @@ -285,7 +315,7 @@ type family D2s t where D2s TBool = TNil type family D2Ac t where - D2Ac (TArr n t) = TAccum n t + D2Ac (TArr n t) = TAccum n (D2 t) type family D1E env where D1E '[] = '[] @@ -327,7 +357,7 @@ d2 (STScal t) = case t of d2 STAccum{} = error "Accumulators not allowed in input program" d2ac :: STy t -> STy (D2Ac t) -d2ac (STArr n t) = STAccum n t +d2ac (STArr n t) = STAccum n (d2 t) d2ac _ = error "Only arrays may appear in the accumulator environment" conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) @@ -346,7 +376,7 @@ zero :: STy t -> Ex env (D2 t) zero STNil = ENil ext zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext) zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext) -zero STArr{} = error "TODO arrays" +zero (STArr n t) = EBuild ext (vecGenerate n (\_ -> EConst ext STI64 0)) (zero t) zero (STScal t) = case t of STI32 -> ENil ext STI64 -> ENil ext @@ -374,7 +404,10 @@ plus (STEither t1 t2) a b = (ECase ext (EVar ext t (IS IZ)) (EError t "plus r+l") (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) -plus STArr{} _ _ = error "TODO arrays" +plus STArr{} _ _ = error "TODO plus on arrays" + -- 'zero' creates an empty array; this should be a new primitive that + -- (operationally) intelligently memcpy's the non-overlapping part and does + -- a parallel add on the overlapping part. plus (STScal t) a b = case t of STI32 -> ENil ext STI64 -> ENil ext @@ -426,7 +459,7 @@ accumPromote :: forall dt env sto proxy r. -> (forall shbinds. SList STy shbinds -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) - :> Append envPro (D2 dt : Append shbinds (D2AcE (Select env sto "accum")))) + :> Append (D2AcE envPro) (D2 dt : Append shbinds (D2AcE (Select env sto "accum")))) -- ^ A weakening that converts a computation in the -- revised environment to one in the original environment -- extended with some accumulators. @@ -443,11 +476,11 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k = mergesub envpro (\shbinds -> - autoWeak (#pro envpro $.. #d (auto @'[D2 dt]) $.. #shb shbinds $.. #acc (auto @'[D2Ac t]) $.. #tl (d2ace (select SAccum descr))) + autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum descr))) (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) (#pro :++: #d :++: #shb :++: #acc :++: #tl) .> WCopy (wf shbinds) - .> autoWeak (#d (auto @'[D2 dt]) $.. #shb shbinds $.. #acc (auto @'[D2Ac t]) $.. #tl (d2ace (select SAccum storepl))) + .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum storepl))) (#d :++: #shb :++: #acc :++: #tl) (#acc :++: (#d :++: #shb :++: #tl))) @@ -455,7 +488,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k = (STArr (arrn :: SNat arrn) (arrt :: STy arrt), SMerge, Occ _ c) | c > Zero -> k (storepl `DPush` (t, SAccum)) (SENo mergesub) - (STAccum arrn arrt `SCons` envpro) + (STArr arrn arrt `SCons` envpro) (\(shbinds :: SList _ shbinds) -> let shbindsC = slistMap (\_ -> Const ()) shbinds in @@ -467,30 +500,49 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k = -- goal: | ARE EQUAL || -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) WCopy (wf shbinds) - .> WPick @(TAccum arrn arrt) @(D2 dt : shbinds) (Const () `SCons` shbindsC) + .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) (WId @(D2AcE (Select env1 stoRepl "accum")))) - -- Used "merge" values must either _be_ an array (and hence be caught by - -- the prior case), or contain _no_ arrays at all (TODO: generalise this) - (_, SMerge, Occ _ c) | c > Zero, containsTArr t -> - error $ "Closure variable of 'build'-like thing contains a composite type containing an array: " ++ show t - - -- What's left are normal "merge" values that don't contain arrays; those - -- remain as-is - (_, SMerge, _) -> - k (storepl `DPush` (t, SMerge)) - (SEYes mergesub) - envpro - wf - where - containsTArr :: STy t' -> Bool - containsTArr = \case - STNil -> False - STPair a b -> containsTArr a || containsTArr b - STEither a b -> containsTArr a || containsTArr b - STArr{} -> True - STScal{} -> False - STAccum{} -> error "An accumulator in merge storage?" + -- Used "merge" values must be an array, so reject everything else. (TODO: generalise this) + (_, SMerge, Occ _ c) + | c > Zero -> + error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t + | otherwise -> + k (storepl `DPush` (t, SMerge)) + (SEYes mergesub) + envpro + wf + -- where + -- containsTArr :: STy t' -> Bool + -- containsTArr = \case + -- STNil -> False + -- STPair a b -> containsTArr a || containsTArr b + -- STEither a b -> containsTArr a || containsTArr b + -- STArr{} -> True + -- STScal{} -> False + -- STAccum{} -> error "An accumulator in merge storage?" + +type family InvTup core env where + InvTup core '[] = core + InvTup core (t : ts) = InvTup (TPair core t) ts + +makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) +makeAccumulators SNil e = e +makeAccumulators (STArr n t `SCons` envpro) e = + makeAccumulators envpro $ + EWith (zero (STArr n t)) e +makeAccumulators (t `SCons` _) _ = error $ "makeAccumulators: Not only arrays in envpro: " ++ show t + +uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) +uninvertTup SNil _ e = EPair ext e (ENil ext) +uninvertTup (t `SCons` list) tcore e = + ELet ext (uninvertTup list (STPair tcore t) e) $ + let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding + in EPair ext + (EFst ext (EFst ext (EVar ext recT IZ))) + (EPair ext + (ESnd ext (EVar ext recT IZ)) + (ESnd ext (EFst ext (EVar ext recT IZ)))) -- | @env'@ is a subset of @env@: each element of @env@ is either included in -- @env'@ ('SEYes') or not included in @env'@ ('SENo'). @@ -526,6 +578,10 @@ subList SNil SETop = SNil subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) subList (SCons _ xs) (SENo sub) = subList xs sub +subenvAll :: SList f env -> Subenv env env +subenvAll SNil = SETop +subenvAll (SCons _ env) = SEYes (subenvAll env) + subenvNone :: SList f env -> Subenv env '[] subenvNone SNil = SETop subenvNone (SCons _ env) = SENo (subenvNone env) @@ -582,6 +638,11 @@ expandSubenvZeros (SCons t ts) (SEYes sub) e = in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t) +assertSubenvEmpty :: Subenv env env' -> env' :~: '[] +assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl +assertSubenvEmpty SETop = Refl +assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" + popFromScope :: Descr env0 sto -> STy a @@ -617,7 +678,7 @@ rebaseRetPair :: forall env b1 b2 env0 sto t f. rebaseRetPair descr b1 b2 (RetPair p sub d) | Refl <- lemAppendAssoc @b2 @b1 @env = RetPair p sub (weakenExpr (autoWeak - (#d (auto @'[D2 t]) $.. #b2 b2 $.. #b1 b1 $.. #tl (d2ace (select SAccum descr))) + (#d (auto1 @(D2 t)) &. #b2 b2 &. #b1 b1 &. #tl (d2ace (select SAccum descr))) (#d :++: (#b2 :++: #tl)) (#d :++: ((#b2 :++: #b1) :++: #tl))) d) @@ -759,10 +820,10 @@ drev des = \case (weakenExpr wbody0' body1) subBoth (ELet ext - (weakenExpr (autoWeak (#d (auto @'[D2 t]) - $.. #body (bindingsBinds body0) - $.. #rhs (SCons (typeOf rhs1) (bindingsBinds rhs0)) - $.. #tl (d2ace (select SAccum des))) + (weakenExpr (autoWeak (#d (auto1 @(D2 t)) + &. #body (bindingsBinds body0) + &. #rhs (SCons (typeOf rhs1) (bindingsBinds rhs0)) + &. #tl (d2ace (select SAccum des))) (#d :++: #body :++: #tl) (#d :++: #body :++: #rhs :++: #tl)) body2') $ @@ -867,12 +928,12 @@ drev des = \case ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $ ELet ext - (weakenExpr (autoWeak (#d (auto @'[D2 t]) - $.. #a0 (bindingsBinds a0) - $.. #prea0 prerebinds - $.. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil) - $.. #binds (tPrimal `SCons` bindingsBinds e0) - $.. #tl (d2ace (select SAccum des))) + (weakenExpr (autoWeak (#d (auto1 @(D2 t)) + &. #a0 (bindingsBinds a0) + &. #prea0 prerebinds + &. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil) + &. #binds (tPrimal `SCons` bindingsBinds e0) + &. #tl (d2ace (select SAccum des))) (#d :++: #a0 :++: #tl) (#d :++: (#a0 :++: #prea0) :++: #recon :++: #binds :++: #tl)) a2') $ @@ -886,12 +947,12 @@ drev des = \case ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $ ELet ext - (weakenExpr (autoWeak (#d (auto @'[D2 t]) - $.. #b0 (bindingsBinds b0) - $.. #preb0 prerebinds - $.. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil) - $.. #binds (tPrimal `SCons` bindingsBinds e0) - $.. #tl (d2ace (select SAccum des))) + (weakenExpr (autoWeak (#d (auto1 @(D2 t)) + &. #b0 (bindingsBinds b0) + &. #preb0 prerebinds + &. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil) + &. #binds (tPrimal `SCons` bindingsBinds e0) + &. #tl (d2ace (select SAccum des))) (#d :++: #b0 :++: #tl) (#d :++: (#b0 :++: #preb0) :++: #recon :++: #binds :++: #tl)) b2') $ @@ -937,38 +998,81 @@ drev des = \case (ENil ext) EBuild1 _ ne e - -- TODO: use occCountAll to determine which variables from @env are used in - -- 'e', and promote those to SAccum storage in 'des' - | Ret (ne0 :: Bindings _ _ ne_binds) ne1 nsub ne2 <- drev des ne - , Ret e0 e1 sub e2 <- drev (des `DPush` (tIx, SMerge)) e - , let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv des) IZ e0 -> + | Ret (ne0 :: Bindings _ _ ne_binds) ne1 nsub ne2 <- drev des ne -> + accumPromote (typeOf e) des (occEnvPop (occCountAll e)) $ \vdes proSub envPro wPro -> + case drev (vdes `DPush` (tIx, SMerge)) e of { Ret e0 e1 sub e2 -> + case assertSubenvEmpty sub of { Refl -> + case assertSubenvEmpty proSub of { Refl -> + let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv des) IZ e0 in Ret (bconcat (ne0 `BPush` (tIx, ne1)) (fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0))) (EBuild1 ext (weakenExpr (autoWeak (#ve0 (bindingsBinds ve0) - $.. #i (auto @'[TIx]) - $.. #ne0 (bindingsBinds ne0) - $.. #tl (sD1eEnv des)) - (#ne0 :++: #tl) - ((#ve0 :++: #i :++: #ne0) :++: #tl)) - ne1) + &. #binds (tIx `SCons` bindingsBinds ne0) + &. #tl (sD1eEnv des)) + #binds + ((#ve0 :++: #binds) :++: #tl)) + (EVar ext tIx IZ)) (subst (\_ t i -> case splitIdx @(TIx : D1E env) (bindingsBinds e0) i of Left ibind -> - let ibind' = WSink - .> wRaiseAbove (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) - (sD1eEnv des) - .> wRaiseAbove (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0) - @> vectoriseIdx ibind + let ibind' = + autoWeak (#ix (auto1 @TIx) + &. #ve0 (bindingsBinds ve0) + &. #binds (tIx `SCons` bindingsBinds ne0) + &. #tl (sD1eEnv des)) + #ve0 + (#ix :++: (#ve0 :++: #binds) :++: #tl) + @> vectoriseIdx ibind in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t) ibind') (EVar ext tIx IZ)) Right IZ -> EVar ext tIx IZ -- build lambda index argument Right (IS ienv) -> EVar ext t (IS (wSinks (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) @> ienv))) e1)) - (subenvNone (select SMerge des)) - _ + nsub + (ELet ext + (makeAccumulators envPro $ + EBuild1 ext + (weakenExpr (autoWeak (#ve0 (bindingsBinds ve0) + &. #pro (d2ace envPro) + &. #d (auto1 @(D2 t)) + &. #binds (tIx `SCons` bindingsBinds ne0) + &. #tl (d2ace (select SAccum des))) + #binds + (#pro :++: #d :++: (#ve0 :++: #binds) :++: #tl)) + (EVar ext tIx IZ)) + -- TODO: use vectoriseExpr + (_ (weakenExpr (wPro (bindingsBinds e0)) e2))) $ + ELet ext (ENil ext) $ + weakenExpr (autoWeak (#nil (auto1 @TNil) + &. #d (auto1 @(D2 t)) + &. #nilarr (auto1 @(TArr (S Z) TNil)) + &. #ve0 (bindingsBinds ve0) + &. #n (auto1 @TIx) + &. #binds (bindingsBinds ne0) + &. #tl (d2ace (select SAccum des))) + (#nil :++: #binds :++: #tl) + (#nil :++: #nilarr :++: #d :++: (#ve0 :++: #n :++: #binds) :++: #tl)) + ne2) + }}} + + EUnit _ e + | Ret e0 e1 sub e2 <- drev des e -> + Ret e0 + (EUnit ext e1) + sub + (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ + weakenExpr (WCopy WSink) e2) + + EIdx0 _ e + | Ret e0 e1 sub e2 <- drev des e + , STArr _ t <- typeOf e -> + Ret e0 + (EIdx0 ext e1) + sub + (ELet ext (EUnit ext (EVar ext (d2 t) IZ)) $ + weakenExpr (WCopy WSink) e2) -- These should be the next to be implemented, I think - EIdx0{} -> err_unsupported "EIdx0" EIdx1{} -> err_unsupported "EIdx1" EFold1{} -> err_unsupported "EFold1" @@ -976,7 +1080,7 @@ drev des = \case EBuild{} -> err_unsupported "EBuild" EWith{} -> err_accum - EAccum{} -> err_accum + EAccum1{} -> err_accum where err_accum = error "Accumulator operations unsupported in the source program" diff --git a/src/Data.hs b/src/Data.hs index a3f4c3c..8c39c6c 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -52,3 +52,10 @@ deriving instance Traversable (Vec n) vecLength :: Vec n t -> SNat n vecLength VNil = SZ vecLength (_ :< v) = SS (vecLength v) + +vecGenerate :: SNat n -> (forall i. SNat i -> t) -> Vec n t +vecGenerate = \n f -> go n f SZ + where + go :: SNat n -> (forall i. SNat i -> t) -> SNat i' -> Vec n t + go SZ _ _ = VNil + go (SS n) f i = f i :< go n f (SS i) diff --git a/src/Simplify.hs b/src/Simplify.hs index af0ca4c..f2fc54a 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -73,13 +73,14 @@ simplify' = \case EBuild1 _ a b -> EBuild1 ext (simplify' a) (simplify' b) EBuild _ es e -> EBuild ext (fmap simplify' es) (simplify' e) EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b) + EUnit _ e -> EUnit ext (simplify' e) EConst _ t v -> EConst ext t v EIdx0 _ e -> EIdx0 ext (simplify' e) EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b) EIdx _ e es -> EIdx ext (simplify' e) (fmap simplify' es) EOp _ op e -> EOp ext op (simplify' e) EWith e1 e2 -> EWith (simplify' e1) (let ?accumInScope = True in simplify' e2) - EAccum e1 e2 e3 -> EAccum (simplify' e1) (simplify' e2) (simplify' e3) + EAccum1 e1 e2 e3 -> EAccum1 (simplify' e1) (simplify' e2) (simplify' e3) EError t s -> EError t s cheapExpr :: Expr x env t -> Bool @@ -105,13 +106,14 @@ hasAdds = \case EBuild1 _ a b -> hasAdds a || hasAdds b EBuild _ es e -> getAny (foldMap (Any . hasAdds) es) || hasAdds e EFold1 _ a b -> hasAdds a || hasAdds b + EUnit _ e -> hasAdds e EConst _ _ _ -> False EIdx0 _ e -> hasAdds e EIdx1 _ a b -> hasAdds a || hasAdds b EIdx _ e es -> hasAdds e || getAny (foldMap (Any . hasAdds) es) EOp _ _ e -> hasAdds e EWith a b -> hasAdds a || hasAdds b - EAccum _ _ _ -> True + EAccum1 _ _ _ -> True EError _ _ -> False checkAccumInScope :: SList STy env -> Bool |