summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-08 22:17:56 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-08 22:17:56 +0100
commitcbe6472a14cc0887295034bb29546dd1a1f083fd (patch)
tree2fa6a20f584d58ac3b89074673990a16cdc7d5b2
parent4fcdb7118e0084f192753ea6c70394352a27d5ed (diff)
WIP maximum/minimum
-rw-r--r--src/AST.hs6
-rw-r--r--src/AST/Count.hs2
-rw-r--r--src/AST/Pretty.hs8
-rw-r--r--src/CHAD.hs24
-rw-r--r--src/ForwardAD/DualNumbers.hs7
-rw-r--r--src/Simplify.hs4
6 files changed, 51 insertions, 0 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 08a5bba..28c5b37 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -87,6 +87,8 @@ data Expr x env t where
ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)
+ EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
+ EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
-- expression operations
EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
@@ -206,6 +208,8 @@ typeOf = \case
ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EUnit _ e -> STArr SZ (typeOf e)
EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t
+ EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EConst _ t _ -> STScal t
EIdx0 _ e | STArr _ t <- typeOf e -> t
@@ -282,6 +286,8 @@ subst' f w = \case
ESum1Inner x e -> ESum1Inner x (subst' f w e)
EUnit x e -> EUnit x (subst' f w e)
EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b)
+ EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)
+ EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (subst' f w e)
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 364773a..22a4da6 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -119,6 +119,8 @@ occCountGeneral onehot unpush alter many = go WId
ESum1Inner _ e -> re e
EUnit _ e -> re e
EReplicate1Inner _ a b -> re a <> re b
+ EMaximum1Inner _ e -> re e
+ EMinimum1Inner _ e -> re e
EConst{} -> mempty
EIdx0 _ e -> re e
EIdx1 _ a b -> re a <> re b
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index a2232ee..4d9aeec 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -141,6 +141,14 @@ ppExpr' d val = \case
b' <- ppExpr' 11 val b
return $ showParen (d > 10) $ showString "replicate1i " . a' . showString " " . b'
+ EMaximum1Inner _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ showParen (d > 10) $ showString "maximum1i " . e'
+
+ EMinimum1Inner _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ showParen (d > 10) $ showString "minimum1i " . e'
+
EConst _ ty v
| Dict <- scalRepIsShow ty -> return $ showsPrec d v
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 2f05807..a08fe80 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1069,6 +1069,9 @@ drev des = \case
(EVar ext (STArr n (d2 t)) IZ)) $
weakenExpr (WCopy (WSink .> WSink)) e2)
+ EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
+ EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
+
-- These should be the next to be implemented, I think
EFold1Inner{} -> err_unsupported "EFold1Inner"
@@ -1086,3 +1089,24 @@ drev des = \case
err_accum = error "Accumulator operations unsupported in the source program"
err_monoid = error "Monoid operations unsupported in the source program"
err_unsupported s = error $ "CHAD: unsupported " ++ s
+
+ deriv_extremum :: ScalIsNumeric t' ~ True
+ => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t')))
+ -> Ex env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t'))
+ deriv_extremum extremum e
+ | Ret e0 subtape e1 sub e2 <- drev des e
+ , at@(STArr (SS n) t@(STScal st)) <- typeOf e
+ , let at' = STArr n t
+ , let tIxN = tTup (sreplicate (SS n) tIx) =
+ Ret (e0 `BPush` (at, e1)
+ `BPush` (at', extremum (EVar ext at IZ)))
+ (SEYes (SEYes subtape))
+ (EVar ext at' IZ)
+ sub
+ (ELet ext (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $
+ ECase ext (EOp ext OIf (EOp ext (OEq st) (EPair ext
+ (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ))
+ (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))))
+ (EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ))))
+ (EZero t)) $
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index f2ded6e..056fcb3 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -131,6 +131,13 @@ dfwdDN = \case
(ESum1Inner ext (dfwdDN e))
EUnit _ e -> EUnit ext (dfwdDN e)
EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b)
+ EMaximum1Inner _ e ->
+ let STArr n (STScal t) = typeOf e
+ in scalTyCase t
+ -- TODO: do roughly the same as what CHAD does, but forward
+ (_ (dfwdDN e))
+ _
+ EMinimum1Inner _ e -> EMinimum1Inner ext (dfwdDN e)
EConst _ t x -> scalTyCase t
(EPair ext (EConst ext t x) (EConst ext t 0.0))
(EConst ext t x)
diff --git a/src/Simplify.hs b/src/Simplify.hs
index e32ba8c..3e14aaf 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -115,6 +115,8 @@ simplify' = \case
ESum1Inner _ e -> ESum1Inner ext <$> simplify' e
EUnit _ e -> EUnit ext <$> simplify' e
EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b
+ EMaximum1Inner _ e -> EMaximum1Inner ext <$> simplify' e
+ EMinimum1Inner _ e -> EMinimum1Inner ext <$> simplify' e
EConst _ t v -> pure $ EConst ext t v
EIdx0 _ e -> EIdx0 ext <$> simplify' e
EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b
@@ -166,6 +168,8 @@ hasAdds = \case
ESum1Inner _ e -> hasAdds e
EUnit _ e -> hasAdds e
EReplicate1Inner _ a b -> hasAdds a || hasAdds b
+ EMaximum1Inner _ e -> hasAdds e
+ EMinimum1Inner _ e -> hasAdds e
ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
EConst _ _ _ -> False
EIdx0 _ e -> hasAdds e