diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/AST.hs | 8 | ||||
| -rw-r--r-- | src/AST/Count.hs | 48 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 16 | ||||
| -rw-r--r-- | src/AST/SplitLets.hs | 6 | ||||
| -rw-r--r-- | src/Analysis/Identity.hs | 10 | ||||
| -rw-r--r-- | src/CHAD.hs | 183 | ||||
| -rw-r--r-- | src/Compile.hs | 13 | ||||
| -rw-r--r-- | src/Example.hs | 5 | ||||
| -rw-r--r-- | src/Interpreter.hs | 8 | ||||
| -rw-r--r-- | src/Language.hs | 39 | ||||
| -rw-r--r-- | src/Language/AST.hs | 43 | ||||
| -rw-r--r-- | src/Simplify.hs | 15 |
12 files changed, 242 insertions, 152 deletions
@@ -65,7 +65,7 @@ data Expr x env t where EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t) -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) - EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) @@ -79,7 +79,7 @@ data Expr x env t where -- values in some implementation-defined order. -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative - -> Expr x (t1 : t1 : env) (TPair t1 b) + -> Expr x (TPair t1 t1 : env) (TPair t1 b) -> Expr x env t1 -> Expr x env (TArr (S n) t1) -> Expr x env (TPair (TArr n t1) -- normal primal fold output @@ -403,7 +403,7 @@ subst' f w = \case EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b) - EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) ESum1Inner x e -> ESum1Inner x (subst' f w e) EUnit x e -> EUnit x (subst' f w e) EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) @@ -411,7 +411,7 @@ subst' f w = \case EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) EZip x a b -> EZip x (subst' f w a) (subst' f w b) - EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index bc02417..ac8634e 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -321,13 +321,7 @@ projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex - (SsArr s1, SsArr s2) - | STArr n t <- typeOf ex -> - elet ex $ - EBuild ext n (EShape ext (evar IZ)) $ - projectSmallerSubstruc s1 s2 - (EIdx ext (EVar ext (STArr n t) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) + (SsArr s1, SsArr s2) -> emap (projectSmallerSubstruc s1 s2 (evar IZ)) ex (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex @@ -560,22 +554,23 @@ occCountX initialS topexpr k = case topexpr of EMap ext (mka (OccPush env' () s1)) (mkb env') EFold1Inner _ commut a b c -> - occCountX SsFull a $ \env1''' mka -> - withSome (scaleMany (Some env1''')) $ \env1'' -> - occEnvPop' env1'' $ \env1' s2 -> - occEnvPop' env1' $ \env1 s1 -> - let s0 = case s of + occCountX SsFull a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1' -> + let s1 = case s1' of + SsNone -> Some SsNone + SsPair' s1'a s1'b -> Some s1'a <> Some s1'b + s0 = case s of SsNone -> Some SsNone SsArr' s' -> Some s' in - withSome (Some s1 <> Some s2 <> s0) $ \sElt -> + withSome (s1 <> s0) $ \sElt -> occCountX sElt b $ \env2 mkb -> - occCountX (SsArr sElt) c $ \env3 mkc -> - withSome (Some env1 <> Some env2 <> Some env3) $ \env -> + occCountX (SsArr sElt) c $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> projectSmallerSubstruc (SsArr sElt) s $ EFold1Inner ext commut (projectSmallerSubstruc SsFull sElt $ - mka (OccPush (OccPush env' () sElt) () sElt)) + mka (OccPush env' () (SsPair sElt sElt))) (mkb env') (mkc env') ESum1Inner _ e -> handleReduction (ESum1Inner ext) e @@ -638,6 +633,20 @@ occCountX initialS topexpr k = case topexpr of withSome (Some env1 <> Some env2) $ \env -> k env $ \env' -> use (mkb env') $ mka env' + SsArr' (SsPair' SsNone s2) -> + occCountX SsNone a $ \env1 mka -> + occCountX (SsArr s2) b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ + emap (EPair ext (ENil ext) (evar IZ)) (mkb env') + SsArr' (SsPair' s1 SsNone) -> + occCountX (SsArr s1) a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkb env') $ + emap (EPair ext (evar IZ) (ENil ext)) (mka env') SsArr' (SsPair' s1 s2) -> occCountX (SsArr s1) a $ \env1 mka -> occCountX (SsArr s2) b $ \env2 mkb -> @@ -665,7 +674,7 @@ occCountX initialS topexpr k = case topexpr of elet (mapExt (\_ -> ext) e3) $ EPair ext (EShape ext (evar IZ)) - (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1))) + (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1))) (mapExt (\_ -> ext) (weakenExpr WSink e2)) (evar IZ)) in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> @@ -675,15 +684,14 @@ occCountX initialS topexpr k = case topexpr of -- If at least some of the additional stores are required, we need to keep this a mapAccum SsPair' _ (SsArr' sB) -> -- TODO: propagate usage of primals - occCountX (SsPair SsFull sB) e1 $ \env1_2' mka -> - occEnvPop' env1_2' $ \env1_1' _ -> + occCountX (SsPair SsFull sB) e1 $ \env1_1' mka -> occEnvPop' env1_1' $ \env1' _ -> occCountX SsFull e2 $ \env2 mkb -> occCountX SsFull e3 $ \env3 mkc -> withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ - EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull)) + EFold1InnerD1 ext cm (mka (OccPush env' () SsFull)) (mkb env') (mkc env') EFold1InnerD2 _ cm ef ebog ed -> diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 2c51b85..bbcfd9e 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -206,21 +206,20 @@ ppExpr' d val expr = case expr of EMap _ a b -> do let STArr _ t1 = typeOf b - name <- genNameIfUsedIn' "i" t1 IZ a + name <- genNameIfUsedIn t1 IZ a a' <- ppExpr' 0 (Const name `SCons` val) a b' <- ppExpr' 11 val b return $ ppParen (d > 0) $ ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b'] EFold1Inner _ cm a b c -> do - name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a - name2 <- genNameIfUsedIn (typeOf a) IZ a - a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a + name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c let opname = "fold1i" ++ ppCommut cm return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] ESum1Inner _ e -> do e' <- ppExpr' 11 val e @@ -254,14 +253,13 @@ ppExpr' d val expr = case expr of return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2'] EFold1InnerD1 _ cm a b c -> do - name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a - name2 <- genNameIfUsedIn (typeOf b) IZ a - a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a + name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c let opname = "fold1iD1" ++ ppCommut cm return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] EFold1InnerD2 _ cm ef ebog ed -> do let STArr _ tB = typeOf ebog diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index d276e44..267dd87 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -34,10 +34,10 @@ splitLets' = \sub -> \case in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c) EFold1Inner x cm a b c -> let STArr _ t1 = typeOf c - in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) EFold1InnerD1 x cm a b c -> let STArr _ t1 = typeOf c - in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) EFold1InnerD2 x cm a b c -> let STArr _ tB = typeOf b STArr _ t2 = typeOf c @@ -56,12 +56,14 @@ splitLets' = \sub -> \case ELInr x t e -> ELInr x t (splitLets' sub e) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) + EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b) ESum1Inner x e -> ESum1Inner x (splitLets' sub e) EUnit x e -> EUnit x (splitLets' sub e) EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b) + EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (splitLets' sub e) EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 71da793..7b896a3 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -213,9 +213,8 @@ idana env expr = case expr of EFold1Inner _ cm e1 e2 e3 -> do let t1 = typeOf e1 - x1 <- genIds t1 - x2 <- genIds t1 - (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 (_, e2') <- idana env e2 (v3, e3') <- idana env e3 let VIArr _ (_ :< sh) = v3 @@ -268,9 +267,8 @@ idana env expr = case expr of EFold1InnerD1 _ cm e1 e2 e3 -> do let t1 = typeOf e2 - x1 <- genIds t1 - x2 <- genIds t1 - (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 (_, e2') <- idana env e2 (v3, e3') <- idana env e3 let VIArr _ sh'@(_ :< sh) = v3 diff --git a/src/CHAD.hs b/src/CHAD.hs index 72ce36d..298d964 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1043,12 +1043,12 @@ drev des accumMap sd = \case (subenvNone (d2e (select SMerge des))) (ENil ext) - EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty) + EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ eltty) | SpArr @_ @sdElt sdElt <- sd - , let eltty = typeOf orige + , let eltty = typeOf ef , shty :: STy shty <- tTup (sreplicate ndim tIx) , Refl <- indexTupD1Id ndim -> - drevLambda des accumMap (shty, SDiscr) sdElt orige $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 -> + drevLambda des accumMap (shty, SDiscr) sdElt ef $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 -> let library = #ix (shty `SCons` SNil) &. #e0 (bindingsBinds e0) &. #propr (d1e provars) @@ -1060,15 +1060,11 @@ drev des accumMap sd = \case &. #darr (auto1 @(TArr ndim sdElt)) &. #tapearr (auto1 @(TArr ndim e_tape)) in Ret (proPrimalBinds - `bpush` EBuild ext ndim - (weakenExpr (wSinks (d1e provars)) (drevPrimal des she)) - (letBinds (fst (weakenBindingsE (autoWeak library - (#ix :++: #d1env) - (#ix :++: #propr :++: #d1env)) - e0)) $ - weakenExpr (autoWeak library (#e0 :++: #ix :++: #d1env) - (#e0 :++: #ix :++: #propr :++: #d1env)) - (EPair ext e1 e1tape)) + `bpush` weakenExpr (wSinks (d1e provars)) + (EBuild ext ndim + (drevPrimal des she) + (letBinds e0 $ + EPair ext e1 e1tape)) `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ)) (SEYesR (SENo (subenvAll (d1e provars)))) (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ))) @@ -1077,37 +1073,77 @@ drev des accumMap sd = \case ESnd ext $ wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $ EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $ - -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> sinkOverEnvPro @> IS IZ)) - (EVar ext shty IZ)) $ -- the cotangent for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> WSink .> sinkOverEnvPro @> IZ)) + ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ + -- the tape for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) (EVar ext shty (IS IZ))) $ - weakenExpr (autoWeak library (#d :++: #tape :++: #pro :++: #d2acEnv) - (#d :++: #tape :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) + weakenExpr (autoWeak library (#tape :++: #d :++: #pro :++: #d2acEnv) + (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) e2) - EMap{} -> undefined + EMap _ ef (earr :: Expr _ _ (TArr n a)) + | SpArr sdElt <- sd + , let STArr ndim t1 = typeOf earr + t2 = typeOf ef -> + drevLambda des accumMap (t1, SMerge) sdElt ef $ \provars efsub proPrimalBinds ef0 ef1 ef1tape spEf wrapAccum ef2 -> + case drev des accumMap (SpArr spEf) earr of { Ret ea0 easubtape ea1 easub ea2 -> + let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings ea0) proPrimalBinds + ttape = typeOf ef1tape + library = #d1env (desD1E des) + &. #a0 (bindingsBinds ea0) + &. #atapebinds (subList (bindingsBinds ea0) easubtape) + &. #propr (d1e provars) + &. #x (d1 t1 `SCons` SNil) + &. #parr (STArr ndim (d1 t1) `SCons` SNil) + &. #tapearr (STArr ndim ttape `SCons` SNil) + &. #darr (STArr ndim (applySparse sdElt (d2 t2)) `SCons` SNil) + &. #dy (applySparse sdElt (d2 t2) `SCons` SNil) + &. #tape (ttape `SCons` SNil) + &. #dytape (STPair (applySparse sdElt (d2 t2)) ttape `SCons` SNil) + &. #d2acEnv (d2ace (select SAccum des)) + &. #pro (d2ace provars) + in + subenvPlus SF SF (d2eM (select SMerge des)) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) easub $ \subfa _ _ plus_f_a -> + Ret (bconcat ea0 proPrimalBinds' + `bpush` weakenExpr (autoWeak library (#a0 :++: #d1env) ((#propr :++: #a0) :++: #d1env)) ea1 + `bpush` emap (weakenExpr (autoWeak library (#x :++: #d1env) (#x :++: #parr :++: (#propr :++: #a0) :++: #d1env)) + (letBinds ef0 $ + EPair ext ef1 ef1tape)) + (EVar ext (STArr ndim (d1 t1)) IZ) + `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) IZ)) + (SEYesR (SENo (SENo (subenvConcat easubtape (subenvAll (d1e provars)))))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) (IS IZ))) + subfa + (let layout = #darr :++: #tapearr :++: (#propr :++: #atapebinds) :++: #d2acEnv in + elet + (wrapAccum (autoWeak library #propr layout) $ + emap (elet (EFst ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) IZ)) $ + elet (ESnd ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) (IS IZ))) $ + weakenExpr (autoWeak library (#tape :++: #dy :++: #pro :++: #d2acEnv) + (#tape :++: #dy :++: #dytape :++: #pro :++: layout)) + ef2) + (ezip (EVar ext (STArr ndim (applySparse sdElt (d2 t2))) (autoWeak library #darr (#pro :++: layout) @> IZ)) + (EVar ext (STArr ndim ttape) (autoWeak library #tapearr (#pro :++: layout) @> IZ)))) $ + plus_f_a + (ESnd ext (evar IZ)) + (weakenExpr (WCopy (autoWeak library (#atapebinds :++: #d2acEnv) layout)) + (subst0 (EFst ext (EVar ext (STPair (STArr ndim (typeOf ef2)) (tTup (d2e provars))) IZ)) + ea2))) + } EFold1Inner _ commut origef ex₀ earr | SpArr @_ @sdElt sdElt <- sd , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil) <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil -> - deleteUnused (descrList des) (occEnvPopSome (occEnvPopSome (occCountAll origef))) $ \(usedSub :: Subenv env env') -> - let ef = unsafeWeakenWithSubenv (SEYesR (SEYesR usedSub)) origef in - subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> - accumPromote (d2 eltty) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> - let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in - let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in - let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in - let (mergePrimalBindings', _) = weakenBindingsE (sinkWithBindings bindsx₀a) mergePrimalBindings in - case drev (prodes `DPush` (eltty, Nothing, SMerge) `DPush` (eltty, Nothing, SMerge)) accumMapPro (spDense (d2M eltty)) ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> - let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in - let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) + drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) spEf wrapAccum ef2 -> + let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsx₀a) proPrimalBinds in + let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape) + bogTy = STArr (SS ndim) bogEltTy primalTy = STPair (STArr ndim (d1 eltty)) bogTy - zipPrimalTy = STPair (d1 eltty) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) - library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil) + library = #xy (STPair (d1 eltty) (d1 eltty) `SCons` SNil) &. #parr (auto1 @(TArr (S n) (D1 elt))) &. #px₀ (auto1 @(D1 elt)) &. #px (auto1 @(D1 elt)) @@ -1118,70 +1154,53 @@ drev des accumMap sd = \case &. #x₀abinds (bindingsBinds bindsx₀a) &. #fbinds (bindingsBinds ef0) &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a) - &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) - &. #ftape (auto1 @(Tape e_tape)) - &. #primalzip (zipPrimalTy `SCons` SNil) - &. #efPrerebinds efPrerebinds - &. #propr (d1e envPro) + &. #ftape (auto1 @ef_tape) + &. #bogelt (bogEltTy `SCons` SNil) + &. #propr (d1e provars) &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes) - &. #d2acUsed (d2ace (select SAccum usedDes)) &. #d2acEnv (d2ace (select SAccum des)) - &. #d2acPro (d2ace envPro) + &. #d2acPro (d2ace provars) &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro)))) wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a -> - subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f -> - Ret (bconcat bindsx₀a mergePrimalBindings' + subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) $ \subx₀af _ _ plus_x₀a_f -> + Ret (bconcat bindsx₀a proPrimalBinds' `bpush` weakenExpr wOverPrimalBindings ex₀1 `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ) `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1 `bpush` EFold1InnerD1 ext commut (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in - letBinds (fst (weakenBindingsE (autoWeak library - (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - layout) - ef0)) $ - elet (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#fbinds :++: layout)) - ef1) $ - EPair ext - (evar IZ) - (EPair ext - (evar IZ) - (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#px :++: #fbinds :++: layout))))) + weakenExpr (autoWeak library (#xy :++: #d1env) layout) + (letBinds ef0 $ + EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape) + ef1 + (EPair ext + (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: #xy :++: #d1env) @> IZ)) + ef1tape))) (EVar ext (d1 eltty) (IS (IS IZ))) (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) - (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro))))))) + (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e provars))))))) (EFst ext (EVar ext primalTy IZ)) subx₀af (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in elet - (uninvertTup (d2e envPro) (STPair (STArr ndim (d2 eltty)) (STArr (SS ndim) (d2 eltty))) $ - makeAccumulators (autoWeak library #propr layout1) envPro $ - let layout2 = #d2acPro :++: layout1 in - EFold1InnerD2 ext commut - (elet (ESnd ext (ESnd ext (EVar ext zipPrimalTy (IS IZ)))) $ - elet (EFst ext (ESnd ext (EVar ext zipPrimalTy (IS (IS IZ))))) $ - elet (EFst ext (EVar ext zipPrimalTy (IS (IS (IS IZ))))) $ - letBinds (efRebinds (IS (IS IZ))) $ - let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #xy :++: #ftape :++: #d :++: #primalzip :++: layout2 in - elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $ - weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3 - .> wPro (subList (bindingsBinds ef0) subtapeEf)) - ef2) $ - EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ))) - (ezip - (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ)) - (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) - (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) - (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) - (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $ + (wrapAccum (autoWeak library #propr layout1) $ + let layout2 = #d2acPro :++: layout1 in + EFold1InnerD2 ext commut + (elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $ + let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in + expandSparse (STPair eltty eltty) spEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $ + weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2) + (ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) + (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) + (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) + (EFst ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))))) $ plus_x₀a_f (plus_x₀_a (elet (EIdx0 ext (EFold1Inner ext Commut - (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (let t = STPair (d2 eltty) (d2 eltty) + in EPlus ext (d2M eltty) (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ))) (eflatten (EFst ext (EFst ext (evar IZ)))))) $ weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) @@ -1189,7 +1208,6 @@ drev des accumMap sd = \case (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) (ESnd ext (evar IZ))) - } EUnit _ e | SpArr sdElt <- sd @@ -1213,9 +1231,8 @@ drev des accumMap sd = \case (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) sub (ELet ext (EFold1Inner ext Commut - (sparsePlus (d2M eltty) sdElt' - (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ)) - (EVar ext (applySparse sdElt' (d2 eltty)) IZ)) + (let t = STPair (applySparse sdElt' (d2 eltty)) (applySparse sdElt' (d2 eltty)) + in sparsePlus (d2M eltty) sdElt' (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) (inj2 (ENil ext)) (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ weakenExpr (WCopy WSink) e2) @@ -1494,7 +1511,7 @@ drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) D1E provars :> env' -> Ex (Append (D2AcE provars) env') b -> Ex ( env') (TPair b (Tup (D2E provars)))) - -> Ex (dt : tape : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' + -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' -> r) -> r drevLambda des accumMap (argty, argsto) sd origef k = @@ -1535,10 +1552,10 @@ drevLambda des accumMap (argty, argsto) sd origef k = uninvertTup (d2e envPro) (typeOf body) $ makeAccumulators wpro1 envPro $ body) - (letBinds (efRebinds (IS IZ)) $ + (letBinds (efRebinds IZ) $ weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#ftapebinds :++: #efPrerebinds) :++: #d :++: #ftape :++: #d2acPro :++: #d2acEnv) + ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv) .> wPro (subList (bindingsBinds ef0) subtapeEf)) (getSparseArg ef2)) }} diff --git a/src/Compile.hs b/src/Compile.hs index d6ad7ec..8627905 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -840,11 +840,14 @@ compile' env = \case -- kvar <- if vecwid > 1 then genName' "k" else return "" accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun + let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]" - (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit + pairstrname <- emitStruct (STPair t t) emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ pure (SVarDecl False (repSTy t) accvar (CELit x0name)) <> x0incrStmts -- we're copying x0 here @@ -854,6 +857,7 @@ compile' env = \case -- what comes out of the function anyway, so that's -- fine, but we do need to increment the array element. arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) <> funStmts <> pure (SAsg accvar funres)) <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) @@ -997,12 +1001,14 @@ compile' env = \case jvar <- genName' "j" accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]" - (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun funresvar <- genName' "res" ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit + pairstrname <- emitStruct (STPair t t) emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ pure (SVarDecl False (repSTy t) accvar (CELit x0name)) <> x0incrStmts -- we're copying x0 here @@ -1012,8 +1018,9 @@ compile' env = \case -- what comes out of the function anyway, so that's -- fine, but we do need to increment the array element. arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) <> funStmts - <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) diff --git a/src/Example.hs b/src/Example.hs index 2c51291..e996002 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -34,9 +34,8 @@ pipeline config term | Dict <- styKnown (d2 (typeOf term)) = simplifyFix $ pruneExpr knownEnv $ simplifyFix $ unMonoid $ - chad' config knownEnv $ - simplifyFix $ - term + simplifyFix $ chad' config knownEnv $ + simplifyFix $ term -- :seti -XOverloadedLabels -XPartialTypeSignatures -Wno-partial-type-signatures pipeline' :: KnownEnv env => CHADConfig -> Ex env t -> IO () diff --git a/src/Interpreter.hs b/src/Interpreter.hs index d982261..e1c81cd 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -121,11 +121,11 @@ interpret'Rec env = \case arrayMapM (\x -> interpret' (V t x `SCons` env) a) =<< interpret' env b EFold1Inner _ _ a b c -> do let t = typeOf b - let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a x0 <- interpret' env b arr <- interpret' env c let sh `ShCons` n = arrayShape arr - arrayGenerateM sh $ \idx -> foldM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + arrayGenerateM sh $ \idx -> foldM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] ESum1Inner _ e -> do arr <- interpret' env e let STArr _ (STScal t) = typeOf e @@ -162,14 +162,14 @@ interpret'Rec env = \case return $ arrayGenerateLin sh (\i -> (arr1 `arrayIndexLinear` i, arr2 `arrayIndexLinear` i)) EFold1InnerD1 _ _ a b c -> do let t = typeOf b - let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a x0 <- interpret' env b arr <- interpret' env c let sh `ShCons` n = arrayShape arr -- TODO: this is very inefficient, even for an interpreter; with mutable -- arrays this can be a lot better with no lists res <- arrayGenerateM sh $ \idx -> do - (y, stores) <- mapAccumLM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + (y, stores) <- mapAccumLM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] return (y, arrayFromList (ShNil `ShCons` n) stores) return (arrayMap fst res ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> diff --git a/src/Language.hs b/src/Language.hs index 31b4b87..4886ddc 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} @@ -15,6 +16,8 @@ module Language ( Lookup, ) where +import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol) + import Array import AST import AST.Sparse.Types @@ -105,15 +108,22 @@ build n a (v :-> b) = NEBuild n a v b map_ :: forall n a b env name. (KnownNat n, KnownTy a) => (Var name a :-> NExpr ('(name, a) : env) b) -> NExpr env (TArr n a) -> NExpr env (TArr n b) -map_ (v :-> a) b - | Dict <- styKnown (tTup (sreplicate (knownNat @n) tIx)) = - let_ #arg b $ - build knownNat (shape #arg) $ #i :-> - let_ v (#arg ! #i) $ - NEDrop (SS SZ) (NEDrop (SS SZ) a) +map_ (v :-> a) b = NEMap v a b fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) -fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3 +fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t t) + in fold1i' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3 sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) sum1i e = NESum1Inner e @@ -135,7 +145,20 @@ reshape = NEReshape fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b)) -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) -fold1iD1 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD1 v1 v2 e1 e2 e3 +fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t1 t1) + in fold1iD1' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3 fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2)) -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) diff --git a/src/Language/AST.hs b/src/Language/AST.hs index c9d05c9..3d6ede5 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -4,7 +4,9 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -17,7 +19,7 @@ module Language.AST where import Data.Kind (Type) import Data.Type.Equality import GHC.OverloadedLabels -import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..)) +import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal) import Array import AST @@ -50,7 +52,8 @@ data NExpr env t where -- array operations NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) - NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) + NEMap :: Var name a -> NExpr ('(name, a) : env) t -> NExpr env (TArr n a) -> NExpr env (TArr n t) + NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEUnit :: NExpr env t -> NExpr env (TArr Z t) NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) @@ -58,7 +61,7 @@ data NExpr env t where NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) - NEFold1InnerD1 :: Var n1 t1 -> Var n2 t1 -> NExpr ('(n2, t1) : '(n1, t1) : env) (TPair t1 b) + NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b) -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) @@ -96,11 +99,16 @@ data NExpr env t where NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t deriving instance Show (NExpr env t) -type family Lookup name env where - Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'") - Lookup name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") - Lookup name ('(name, t) : env) = t - Lookup name (_ : env) = Lookup name env +type Lookup name env = Lookup1 (name == "_") name env +type family Lookup1 eqblank name env where + Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'") + Lookup1 False name env = Lookup2 name env +type family Lookup2 name env where + Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") + Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env +type family Lookup3 eq t name env where + Lookup3 True t _ _ = t + Lookup3 False _ name env = Lookup2 name env type family DropNth i env where DropNth Z (_ : env) = env @@ -209,7 +217,8 @@ fromNamedExpr val = \case NEConstArr n t x -> EConstArr ext n t x NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) - NEFold1Inner n1 n2 a b c -> EFold1Inner ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + NEMap n a b -> EMap ext (lambda val n a) (go b) + NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c) NESum1Inner e -> ESum1Inner ext (go e) NEUnit e -> EUnit ext (go e) NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) @@ -217,7 +226,7 @@ fromNamedExpr val = \case NEMinimum1Inner e -> EMinimum1Inner ext (go e) NEReshape n a b -> EReshape ext n (go a) (go b) - NEFold1InnerD1 n1 n2 a b c -> EFold1InnerD1 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c) NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) NEConst t x -> EConst ext t x @@ -275,3 +284,17 @@ dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env dropNthW SZ (_ `NPush` _) = WSink dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val) dropNthW _ NTop = error "DropNth: index out of range" + +assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r +assertSymbolNotUnderscore s@SSymbol k = + case symbolVal s of + "_" -> error "assertSymbolNotUnderscore: was underscore" + _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k + +assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r +assertSymbolDistinct s1@SSymbol s2@SSymbol k + | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")" + | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k + +equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r +equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k diff --git a/src/Simplify.hs b/src/Simplify.hs index 1889adc..19d0c17 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -185,6 +185,21 @@ simplify'Rec = \case ELet _ e1 (ENil _) | STNil <- typeOf e1 -> acted $ simplify' e1 + -- map (\_ -> x) e ~> build (shape e) (\_ -> x) + EMap _ e1 e2 + | Occ Zero Zero <- occCount IZ e1 + , STArr n _ <- typeOf e2 -> + acted $ simplify' $ + EBuild ext n (EShape ext e2) $ + subst (\_ t' -> \case IZ -> error "Unused variable was used" + IS i -> EVar ext t' (IS i)) + e1 + + -- vertical fusion + EMap _ e1 (EMap _ e2 e3) -> + acted $ simplify' $ + EMap ext (ELet ext e2 (weakenExpr (WCopy WSink) e1)) e3 + -- projection down-commuting EFst _ (ECase _ e1 e2 e3) -> acted $ simplify' $ |
