diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 22:17:56 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 22:17:56 +0100 |
commit | cbe6472a14cc0887295034bb29546dd1a1f083fd (patch) | |
tree | 2fa6a20f584d58ac3b89074673990a16cdc7d5b2 | |
parent | 4fcdb7118e0084f192753ea6c70394352a27d5ed (diff) |
WIP maximum/minimum
-rw-r--r-- | src/AST.hs | 6 | ||||
-rw-r--r-- | src/AST/Count.hs | 2 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 8 | ||||
-rw-r--r-- | src/CHAD.hs | 24 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 7 | ||||
-rw-r--r-- | src/Simplify.hs | 4 |
6 files changed, 51 insertions, 0 deletions
@@ -87,6 +87,8 @@ data Expr x env t where ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) + EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) @@ -206,6 +208,8 @@ typeOf = \case ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EUnit _ e -> STArr SZ (typeOf e) EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t + EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t @@ -282,6 +286,8 @@ subst' f w = \case ESum1Inner x e -> ESum1Inner x (subst' f w e) EUnit x e -> EUnit x (subst' f w e) EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) + EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) + EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 364773a..22a4da6 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -119,6 +119,8 @@ occCountGeneral onehot unpush alter many = go WId ESum1Inner _ e -> re e EUnit _ e -> re e EReplicate1Inner _ a b -> re a <> re b + EMaximum1Inner _ e -> re e + EMinimum1Inner _ e -> re e EConst{} -> mempty EIdx0 _ e -> re e EIdx1 _ a b -> re a <> re b diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index a2232ee..4d9aeec 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -141,6 +141,14 @@ ppExpr' d val = \case b' <- ppExpr' 11 val b return $ showParen (d > 10) $ showString "replicate1i " . a' . showString " " . b' + EMaximum1Inner _ e -> do + e' <- ppExpr' 11 val e + return $ showParen (d > 10) $ showString "maximum1i " . e' + + EMinimum1Inner _ e -> do + e' <- ppExpr' 11 val e + return $ showParen (d > 10) $ showString "minimum1i " . e' + EConst _ ty v | Dict <- scalRepIsShow ty -> return $ showsPrec d v diff --git a/src/CHAD.hs b/src/CHAD.hs index 2f05807..a08fe80 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1069,6 +1069,9 @@ drev des = \case (EVar ext (STArr n (d2 t)) IZ)) $ weakenExpr (WCopy (WSink .> WSink)) e2) + EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e + EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + -- These should be the next to be implemented, I think EFold1Inner{} -> err_unsupported "EFold1Inner" @@ -1086,3 +1089,24 @@ drev des = \case err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s + + deriv_extremum :: ScalIsNumeric t' ~ True + => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) + -> Ex env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t')) + deriv_extremum extremum e + | Ret e0 subtape e1 sub e2 <- drev des e + , at@(STArr (SS n) t@(STScal st)) <- typeOf e + , let at' = STArr n t + , let tIxN = tTup (sreplicate (SS n) tIx) = + Ret (e0 `BPush` (at, e1) + `BPush` (at', extremum (EVar ext at IZ))) + (SEYes (SEYes subtape)) + (EVar ext at' IZ) + sub + (ELet ext (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ + ECase ext (EOp ext OIf (EOp ext (OEq st) (EPair ext + (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) + (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ)))))) + (EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ)))) + (EZero t)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index f2ded6e..056fcb3 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -131,6 +131,13 @@ dfwdDN = \case (ESum1Inner ext (dfwdDN e)) EUnit _ e -> EUnit ext (dfwdDN e) EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) + EMaximum1Inner _ e -> + let STArr n (STScal t) = typeOf e + in scalTyCase t + -- TODO: do roughly the same as what CHAD does, but forward + (_ (dfwdDN e)) + _ + EMinimum1Inner _ e -> EMinimum1Inner ext (dfwdDN e) EConst _ t x -> scalTyCase t (EPair ext (EConst ext t x) (EConst ext t 0.0)) (EConst ext t x) diff --git a/src/Simplify.hs b/src/Simplify.hs index e32ba8c..3e14aaf 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -115,6 +115,8 @@ simplify' = \case ESum1Inner _ e -> ESum1Inner ext <$> simplify' e EUnit _ e -> EUnit ext <$> simplify' e EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b + EMaximum1Inner _ e -> EMaximum1Inner ext <$> simplify' e + EMinimum1Inner _ e -> EMinimum1Inner ext <$> simplify' e EConst _ t v -> pure $ EConst ext t v EIdx0 _ e -> EIdx0 ext <$> simplify' e EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b @@ -166,6 +168,8 @@ hasAdds = \case ESum1Inner _ e -> hasAdds e EUnit _ e -> hasAdds e EReplicate1Inner _ a b -> hasAdds a || hasAdds b + EMaximum1Inner _ e -> hasAdds e + EMinimum1Inner _ e -> hasAdds e ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e EConst _ _ _ -> False EIdx0 _ e -> hasAdds e |