From 57779d4303f377004705c8da06a5ac46177950b2 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 4 Nov 2025 23:09:21 +0100 Subject: drevLambda works, TODO D[map] --- src/AST.hs | 8 ++-- src/AST/Count.hs | 26 +++++------ src/AST/Pretty.hs | 14 +++--- src/AST/SplitLets.hs | 6 ++- src/Analysis/Identity.hs | 10 ++--- src/CHAD.hs | 112 ++++++++++++++++++----------------------------- src/Compile.hs | 13 ++++-- src/Interpreter.hs | 8 ++-- src/Language.hs | 32 +++++++++++++- src/Language/AST.hs | 41 ++++++++++++----- 10 files changed, 148 insertions(+), 122 deletions(-) (limited to 'src') diff --git a/src/AST.hs b/src/AST.hs index 873a8a5..ca6cdd1 100644 --- a/src/AST.hs +++ b/src/AST.hs @@ -65,7 +65,7 @@ data Expr x env t where EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t) -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) - EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) @@ -79,7 +79,7 @@ data Expr x env t where -- values in some implementation-defined order. -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative - -> Expr x (t1 : t1 : env) (TPair t1 b) + -> Expr x (TPair t1 t1 : env) (TPair t1 b) -> Expr x env t1 -> Expr x env (TArr (S n) t1) -> Expr x env (TPair (TArr n t1) -- normal primal fold output @@ -403,7 +403,7 @@ subst' f w = \case EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b) - EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) ESum1Inner x e -> ESum1Inner x (subst' f w e) EUnit x e -> EUnit x (subst' f w e) EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) @@ -411,7 +411,7 @@ subst' f w = \case EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) EZip x a b -> EZip x (subst' f w a) (subst' f w b) - EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index bc02417..a53822d 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -560,22 +560,23 @@ occCountX initialS topexpr k = case topexpr of EMap ext (mka (OccPush env' () s1)) (mkb env') EFold1Inner _ commut a b c -> - occCountX SsFull a $ \env1''' mka -> - withSome (scaleMany (Some env1''')) $ \env1'' -> - occEnvPop' env1'' $ \env1' s2 -> - occEnvPop' env1' $ \env1 s1 -> - let s0 = case s of + occCountX SsFull a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1' -> + let s1 = case s1' of + SsNone -> Some SsNone + SsPair' s1'a s1'b -> Some s1'a <> Some s1'b + s0 = case s of SsNone -> Some SsNone SsArr' s' -> Some s' in - withSome (Some s1 <> Some s2 <> s0) $ \sElt -> + withSome (s1 <> s0) $ \sElt -> occCountX sElt b $ \env2 mkb -> - occCountX (SsArr sElt) c $ \env3 mkc -> - withSome (Some env1 <> Some env2 <> Some env3) $ \env -> + occCountX (SsArr sElt) c $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> projectSmallerSubstruc (SsArr sElt) s $ EFold1Inner ext commut (projectSmallerSubstruc SsFull sElt $ - mka (OccPush (OccPush env' () sElt) () sElt)) + mka (OccPush env' () (SsPair sElt sElt))) (mkb env') (mkc env') ESum1Inner _ e -> handleReduction (ESum1Inner ext) e @@ -665,7 +666,7 @@ occCountX initialS topexpr k = case topexpr of elet (mapExt (\_ -> ext) e3) $ EPair ext (EShape ext (evar IZ)) - (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1))) + (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1))) (mapExt (\_ -> ext) (weakenExpr WSink e2)) (evar IZ)) in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> @@ -675,15 +676,14 @@ occCountX initialS topexpr k = case topexpr of -- If at least some of the additional stores are required, we need to keep this a mapAccum SsPair' _ (SsArr' sB) -> -- TODO: propagate usage of primals - occCountX (SsPair SsFull sB) e1 $ \env1_2' mka -> - occEnvPop' env1_2' $ \env1_1' _ -> + occCountX (SsPair SsFull sB) e1 $ \env1_1' mka -> occEnvPop' env1_1' $ \env1' _ -> occCountX SsFull e2 $ \env2 mkb -> occCountX SsFull e3 $ \env3 mkc -> withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ - EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull)) + EFold1InnerD1 ext cm (mka (OccPush env' () SsFull)) (mkb env') (mkc env') EFold1InnerD2 _ cm ef ebog ed -> diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 2c51b85..ecdaa88 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -213,14 +213,13 @@ ppExpr' d val expr = case expr of ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b'] EFold1Inner _ cm a b c -> do - name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a - name2 <- genNameIfUsedIn (typeOf a) IZ a - a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a + name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c let opname = "fold1i" ++ ppCommut cm return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] ESum1Inner _ e -> do e' <- ppExpr' 11 val e @@ -254,14 +253,13 @@ ppExpr' d val expr = case expr of return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2'] EFold1InnerD1 _ cm a b c -> do - name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a - name2 <- genNameIfUsedIn (typeOf b) IZ a - a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a + name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c let opname = "fold1iD1" ++ ppCommut cm return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] EFold1InnerD2 _ cm ef ebog ed -> do let STArr _ tB = typeOf ebog diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index d276e44..267dd87 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -34,10 +34,10 @@ splitLets' = \sub -> \case in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c) EFold1Inner x cm a b c -> let STArr _ t1 = typeOf c - in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) EFold1InnerD1 x cm a b c -> let STArr _ t1 = typeOf c - in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) EFold1InnerD2 x cm a b c -> let STArr _ tB = typeOf b STArr _ t2 = typeOf c @@ -56,12 +56,14 @@ splitLets' = \sub -> \case ELInr x t e -> ELInr x t (splitLets' sub e) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) + EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b) ESum1Inner x e -> ESum1Inner x (splitLets' sub e) EUnit x e -> EUnit x (splitLets' sub e) EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b) + EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (splitLets' sub e) EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 71da793..7b896a3 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -213,9 +213,8 @@ idana env expr = case expr of EFold1Inner _ cm e1 e2 e3 -> do let t1 = typeOf e1 - x1 <- genIds t1 - x2 <- genIds t1 - (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 (_, e2') <- idana env e2 (v3, e3') <- idana env e3 let VIArr _ (_ :< sh) = v3 @@ -268,9 +267,8 @@ idana env expr = case expr of EFold1InnerD1 _ cm e1 e2 e3 -> do let t1 = typeOf e2 - x1 <- genIds t1 - x2 <- genIds t1 - (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 (_, e2') <- idana env e2 (v3, e3') <- idana env e3 let VIArr _ sh'@(_ :< sh) = v3 diff --git a/src/CHAD.hs b/src/CHAD.hs index 72ce36d..9da5395 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1077,37 +1077,29 @@ drev des accumMap sd = \case ESnd ext $ wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $ EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $ - -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> sinkOverEnvPro @> IS IZ)) - (EVar ext shty IZ)) $ -- the cotangent for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> WSink .> sinkOverEnvPro @> IZ)) + ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ + -- the tape for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) (EVar ext shty (IS IZ))) $ - weakenExpr (autoWeak library (#d :++: #tape :++: #pro :++: #d2acEnv) - (#d :++: #tape :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) + weakenExpr (autoWeak library (#tape :++: #d :++: #pro :++: #d2acEnv) + (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) e2) - EMap{} -> undefined + EMap{} -> error "TODO: CHAD EMap" EFold1Inner _ commut origef ex₀ earr | SpArr @_ @sdElt sdElt <- sd , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil) <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil -> - deleteUnused (descrList des) (occEnvPopSome (occEnvPopSome (occCountAll origef))) $ \(usedSub :: Subenv env env') -> - let ef = unsafeWeakenWithSubenv (SEYesR (SEYesR usedSub)) origef in - subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> - accumPromote (d2 eltty) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> - let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in - let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in - let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in - let (mergePrimalBindings', _) = weakenBindingsE (sinkWithBindings bindsx₀a) mergePrimalBindings in - case drev (prodes `DPush` (eltty, Nothing, SMerge) `DPush` (eltty, Nothing, SMerge)) accumMapPro (spDense (d2M eltty)) ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> - let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in - let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) + drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) subEf wrapAccum ef2 -> + let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsx₀a) proPrimalBinds in + let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape) + bogTy = STArr (SS ndim) bogEltTy primalTy = STPair (STArr ndim (d1 eltty)) bogTy - zipPrimalTy = STPair (d1 eltty) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) - library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil) + library = #xy (STPair (d1 eltty) (d1 eltty) `SCons` SNil) &. #parr (auto1 @(TArr (S n) (D1 elt))) &. #px₀ (auto1 @(D1 elt)) &. #px (auto1 @(D1 elt)) @@ -1118,70 +1110,52 @@ drev des accumMap sd = \case &. #x₀abinds (bindingsBinds bindsx₀a) &. #fbinds (bindingsBinds ef0) &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a) - &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) - &. #ftape (auto1 @(Tape e_tape)) - &. #primalzip (zipPrimalTy `SCons` SNil) - &. #efPrerebinds efPrerebinds - &. #propr (d1e envPro) + &. #ftape (auto1 @ef_tape) + &. #bogelt (bogEltTy `SCons` SNil) + &. #propr (d1e provars) &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes) - &. #d2acUsed (d2ace (select SAccum usedDes)) &. #d2acEnv (d2ace (select SAccum des)) - &. #d2acPro (d2ace envPro) + &. #d2acPro (d2ace provars) &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro)))) wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a -> - subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f -> - Ret (bconcat bindsx₀a mergePrimalBindings' + subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) $ \subx₀af _ _ plus_x₀a_f -> + Ret (bconcat bindsx₀a proPrimalBinds' `bpush` weakenExpr wOverPrimalBindings ex₀1 `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ) `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1 `bpush` EFold1InnerD1 ext commut (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in - letBinds (fst (weakenBindingsE (autoWeak library - (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - layout) - ef0)) $ - elet (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#fbinds :++: layout)) - ef1) $ - EPair ext - (evar IZ) + letBinds (fst (weakenBindingsE (autoWeak library (#xy :++: #d1env) layout) ef0)) $ + EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape) + (weakenExpr (autoWeak library (#fbinds :++: #xy :++: #d1env) (#fbinds :++: layout)) ef1) (EPair ext - (evar IZ) - (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#px :++: #fbinds :++: layout))))) + (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: layout) @> IZ)) + (weakenExpr (autoWeak library (#fbinds :++: #xy :++: #d1env) (#fbinds :++: layout)) ef1tape))) (EVar ext (d1 eltty) (IS (IS IZ))) (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) - (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro))))))) + (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e provars))))))) (EFst ext (EVar ext primalTy IZ)) subx₀af (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in elet - (uninvertTup (d2e envPro) (STPair (STArr ndim (d2 eltty)) (STArr (SS ndim) (d2 eltty))) $ - makeAccumulators (autoWeak library #propr layout1) envPro $ - let layout2 = #d2acPro :++: layout1 in - EFold1InnerD2 ext commut - (elet (ESnd ext (ESnd ext (EVar ext zipPrimalTy (IS IZ)))) $ - elet (EFst ext (ESnd ext (EVar ext zipPrimalTy (IS (IS IZ))))) $ - elet (EFst ext (EVar ext zipPrimalTy (IS (IS (IS IZ))))) $ - letBinds (efRebinds (IS (IS IZ))) $ - let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #xy :++: #ftape :++: #d :++: #primalzip :++: layout2 in - elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $ - weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3 - .> wPro (subList (bindingsBinds ef0) subtapeEf)) - ef2) $ - EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ))) - (ezip - (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ)) - (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) - (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) - (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) - (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $ + (wrapAccum (autoWeak library #propr layout1) $ + let layout2 = #d2acPro :++: layout1 in + EFold1InnerD2 ext commut + (elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $ + let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in + expandSparse (STPair eltty eltty) subEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $ + weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2) + (ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) + (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) + (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) + (EFst ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))))) $ plus_x₀a_f (plus_x₀_a (elet (EIdx0 ext (EFold1Inner ext Commut - (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (let t = STPair (d2 eltty) (d2 eltty) + in EPlus ext (d2M eltty) (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ))) (eflatten (EFst ext (EFst ext (evar IZ)))))) $ weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) @@ -1189,7 +1163,6 @@ drev des accumMap sd = \case (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) (ESnd ext (evar IZ))) - } EUnit _ e | SpArr sdElt <- sd @@ -1213,9 +1186,8 @@ drev des accumMap sd = \case (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) sub (ELet ext (EFold1Inner ext Commut - (sparsePlus (d2M eltty) sdElt' - (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ)) - (EVar ext (applySparse sdElt' (d2 eltty)) IZ)) + (let t = STPair (applySparse sdElt' (d2 eltty)) (applySparse sdElt' (d2 eltty)) + in sparsePlus (d2M eltty) sdElt' (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) (inj2 (ENil ext)) (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ weakenExpr (WCopy WSink) e2) @@ -1494,7 +1466,7 @@ drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) D1E provars :> env' -> Ex (Append (D2AcE provars) env') b -> Ex ( env') (TPair b (Tup (D2E provars)))) - -> Ex (dt : tape : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' + -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' -> r) -> r drevLambda des accumMap (argty, argsto) sd origef k = @@ -1535,10 +1507,10 @@ drevLambda des accumMap (argty, argsto) sd origef k = uninvertTup (d2e envPro) (typeOf body) $ makeAccumulators wpro1 envPro $ body) - (letBinds (efRebinds (IS IZ)) $ + (letBinds (efRebinds IZ) $ weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#ftapebinds :++: #efPrerebinds) :++: #d :++: #ftape :++: #d2acPro :++: #d2acEnv) + ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv) .> wPro (subList (bindingsBinds ef0) subtapeEf)) (getSparseArg ef2)) }} diff --git a/src/Compile.hs b/src/Compile.hs index d6ad7ec..8627905 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -840,11 +840,14 @@ compile' env = \case -- kvar <- if vecwid > 1 then genName' "k" else return "" accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun + let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]" - (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit + pairstrname <- emitStruct (STPair t t) emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ pure (SVarDecl False (repSTy t) accvar (CELit x0name)) <> x0incrStmts -- we're copying x0 here @@ -854,6 +857,7 @@ compile' env = \case -- what comes out of the function anyway, so that's -- fine, but we do need to increment the array element. arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) <> funStmts <> pure (SAsg accvar funres)) <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) @@ -997,12 +1001,14 @@ compile' env = \case jvar <- genName' "j" accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]" - (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun funresvar <- genName' "res" ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit + pairstrname <- emitStruct (STPair t t) emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ pure (SVarDecl False (repSTy t) accvar (CELit x0name)) <> x0incrStmts -- we're copying x0 here @@ -1012,8 +1018,9 @@ compile' env = \case -- what comes out of the function anyway, so that's -- fine, but we do need to increment the array element. arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) <> funStmts - <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index d982261..e1c81cd 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -121,11 +121,11 @@ interpret'Rec env = \case arrayMapM (\x -> interpret' (V t x `SCons` env) a) =<< interpret' env b EFold1Inner _ _ a b c -> do let t = typeOf b - let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a x0 <- interpret' env b arr <- interpret' env c let sh `ShCons` n = arrayShape arr - arrayGenerateM sh $ \idx -> foldM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + arrayGenerateM sh $ \idx -> foldM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] ESum1Inner _ e -> do arr <- interpret' env e let STArr _ (STScal t) = typeOf e @@ -162,14 +162,14 @@ interpret'Rec env = \case return $ arrayGenerateLin sh (\i -> (arr1 `arrayIndexLinear` i, arr2 `arrayIndexLinear` i)) EFold1InnerD1 _ _ a b c -> do let t = typeOf b - let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a x0 <- interpret' env b arr <- interpret' env c let sh `ShCons` n = arrayShape arr -- TODO: this is very inefficient, even for an interpreter; with mutable -- arrays this can be a lot better with no lists res <- arrayGenerateM sh $ \idx -> do - (y, stores) <- mapAccumLM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + (y, stores) <- mapAccumLM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] return (y, arrayFromList (ShNil `ShCons` n) stores) return (arrayMap fst res ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> diff --git a/src/Language.hs b/src/Language.hs index 31b4b87..c1a6248 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} @@ -15,6 +16,8 @@ module Language ( Lookup, ) where +import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol) + import Array import AST import AST.Sparse.Types @@ -113,7 +116,19 @@ map_ (v :-> a) b NEDrop (SS SZ) (NEDrop (SS SZ) a) fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) -fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3 +fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t t) + in fold1i' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3 sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) sum1i e = NESum1Inner e @@ -135,7 +150,20 @@ reshape = NEReshape fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b)) -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) -fold1iD1 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD1 v1 v2 e1 e2 e3 +fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t1 t1) + in fold1iD1' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3 fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2)) -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) diff --git a/src/Language/AST.hs b/src/Language/AST.hs index c9d05c9..a3b8130 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -4,7 +4,9 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -17,7 +19,7 @@ module Language.AST where import Data.Kind (Type) import Data.Type.Equality import GHC.OverloadedLabels -import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..)) +import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal) import Array import AST @@ -50,7 +52,7 @@ data NExpr env t where -- array operations NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) - NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) + NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEUnit :: NExpr env t -> NExpr env (TArr Z t) NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) @@ -58,7 +60,7 @@ data NExpr env t where NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) - NEFold1InnerD1 :: Var n1 t1 -> Var n2 t1 -> NExpr ('(n2, t1) : '(n1, t1) : env) (TPair t1 b) + NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b) -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) @@ -96,11 +98,16 @@ data NExpr env t where NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t deriving instance Show (NExpr env t) -type family Lookup name env where - Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'") - Lookup name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") - Lookup name ('(name, t) : env) = t - Lookup name (_ : env) = Lookup name env +type Lookup name env = Lookup1 (name == "_") name env +type family Lookup1 eqblank name env where + Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'") + Lookup1 False name env = Lookup2 name env +type family Lookup2 name env where + Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") + Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env +type family Lookup3 eq t name env where + Lookup3 True t _ _ = t + Lookup3 False _ name env = Lookup2 name env type family DropNth i env where DropNth Z (_ : env) = env @@ -209,7 +216,7 @@ fromNamedExpr val = \case NEConstArr n t x -> EConstArr ext n t x NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) - NEFold1Inner n1 n2 a b c -> EFold1Inner ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c) NESum1Inner e -> ESum1Inner ext (go e) NEUnit e -> EUnit ext (go e) NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) @@ -217,7 +224,7 @@ fromNamedExpr val = \case NEMinimum1Inner e -> EMinimum1Inner ext (go e) NEReshape n a b -> EReshape ext n (go a) (go b) - NEFold1InnerD1 n1 n2 a b c -> EFold1InnerD1 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c) NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) NEConst t x -> EConst ext t x @@ -275,3 +282,17 @@ dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env dropNthW SZ (_ `NPush` _) = WSink dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val) dropNthW _ NTop = error "DropNth: index out of range" + +assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r +assertSymbolNotUnderscore s@SSymbol k = + case symbolVal s of + "_" -> error "assertSymbolNotUnderscore: was underscore" + _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k + +assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r +assertSymbolDistinct s1@SSymbol s2@SSymbol k + | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")" + | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k + +equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r +equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k -- cgit v1.2.3-70-g09d2