summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-09 11:15:06 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-09 11:15:06 +0100
commit992249ebf159ba3783a9345430013e52294c26aa (patch)
tree2d1a8324310aebd60062fdb7d9ba785fe0298d0c
parentcbe6472a14cc0887295034bb29546dd1a1f083fd (diff)
Maximum/minimum
-rw-r--r--src/ForwardAD/DualNumbers.hs39
-rw-r--r--src/Interpreter.hs16
-rw-r--r--src/Language.hs6
-rw-r--r--src/Language/AST.hs4
-rw-r--r--test/Main.hs16
5 files changed, 67 insertions, 14 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 056fcb3..7d47e6d 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -98,6 +98,12 @@ dop = \case
in f (EFst ext (EFst ext var), ESnd ext (EFst ext var))
(EFst ext (ESnd ext var), ESnd ext (ESnd ext var))
+zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t)
+zeroScalarConst STF32 = EConst ext STF32 0.0
+zeroScalarConst STF64 = EConst ext STF64 0.0
+zeroScalarConst STI32 = EConst ext STI32 0
+zeroScalarConst STI64 = EConst ext STI64 0
+
dfwdDN :: Ex env t -> Ex (DNE env) (DN t)
dfwdDN = \case
EVar _ t i -> EVar ext (dn t) (convIdx i)
@@ -131,13 +137,8 @@ dfwdDN = \case
(ESum1Inner ext (dfwdDN e))
EUnit _ e -> EUnit ext (dfwdDN e)
EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b)
- EMaximum1Inner _ e ->
- let STArr n (STScal t) = typeOf e
- in scalTyCase t
- -- TODO: do roughly the same as what CHAD does, but forward
- (_ (dfwdDN e))
- _
- EMinimum1Inner _ e -> EMinimum1Inner ext (dfwdDN e)
+ EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
+ EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
EConst _ t x -> scalTyCase t
(EPair ext (EConst ext t x) (EConst ext t 0.0))
(EConst ext t x)
@@ -165,3 +166,27 @@ dfwdDN = \case
where
err_accum = error "Accumulator operations unsupported in the source program"
err_monoid = error "Monoid operations unsupported in the source program"
+
+ deriv_extremum :: ScalIsNumeric t ~ True
+ => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
+ -> Ex env (TArr (S n) (TScal t)) -> Ex (DNE env) (TArr n (DN (TScal t)))
+ deriv_extremum extremum e =
+ let STArr (SS n) (STScal t) = typeOf e
+ t2 = STPair (STScal t) (STScal t)
+ ta2 = STArr (SS n) t2
+ tIxN = tTup (sreplicate (SS n) tIx)
+ in scalTyCase t
+ (ELet ext (dfwdDN e) $
+ ELet ext (extremum (emap (EFst ext (EVar ext t2 IZ)) (EVar ext ta2 IZ))) $
+ ezip (EVar ext (STArr n (STScal t)) IZ)
+ (ESum1Inner ext
+ {- build (shape SZ) (\i. if fst (SZ ! i) == Z ! tail i then snd (SZ ! i) else zero) -}
+ (EBuild ext (SS n) (EShape ext (EVar ext ta2 (IS IZ))) $
+ ELet ext (EIdx ext (EVar ext ta2 (IS (IS IZ))) (EVar ext tIxN IZ)) $
+ ECase ext (EOp ext OIf (EOp ext (OEq t) (EPair ext
+ (EFst ext (EVar ext t2 IZ))
+ (EIdx ext (EVar ext (STArr n (STScal t)) (IS (IS IZ)))
+ (EFst ext (EVar ext tIxN (IS IZ)))))))
+ (ESnd ext (EVar ext t2 (IS IZ)))
+ (zeroScalarConst t))))
+ (EMaximum1Inner ext (dfwdDN e))
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 47514ae..3c1aad0 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -103,13 +103,25 @@ interpret'Rec env = \case
arr <- interpret' env e
let STArr _ (STScal t) = typeOf e
sh `ShCons` n = arrayShape arr
- numericIsNum t $ arrayGenerateM sh $ \idx -> return $ sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ numericIsNum t $ return $ arrayGenerate sh $ \idx -> sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e)
EReplicate1Inner _ a b -> do
n <- fromIntegral @Int64 @Int <$> interpret' env a
arr <- interpret' env b
let sh = arrayShape arr
- arrayGenerateM (sh `ShCons` n) (\(idx `IxCons` _) -> return (arrayIndex arr idx))
+ return $ arrayGenerate (sh `ShCons` n) (\(idx `IxCons` _) -> arrayIndex arr idx)
+ EMaximum1Inner _ e -> do
+ arr <- interpret' env e
+ let STArr _ (STScal t) = typeOf e
+ sh `ShCons` n = arrayShape arr
+ numericIsNum t $ return $
+ arrayGenerate sh (\idx -> maximum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]])
+ EMinimum1Inner _ e -> do
+ arr <- interpret' env e
+ let STArr _ (STScal t) = typeOf e
+ sh `ShCons` n = arrayShape arr
+ numericIsNum t $ return $
+ arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]])
EConst _ _ v -> return v
EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
diff --git a/src/Language.hs b/src/Language.hs
index e8dc89f..88cb1de 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -94,6 +94,12 @@ unit = NEUnit
replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t))
replicate1i n a = NEReplicate1Inner n a
+maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+maximum1i e = NEMaximum1Inner e
+
+minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+minimum1i e = NEMinimum1Inner e
+
const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t)
const_ x =
let ty = knownScalTy
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index f5203e9..4194913 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -49,6 +49,8 @@ data NExpr env t where
NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEUnit :: NExpr env t -> NExpr env (TArr Z t)
NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t)
+ NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+ NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
-- expression operations
NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t)
@@ -157,6 +159,8 @@ fromNamedExpr val = \case
NESum1Inner e -> ESum1Inner ext (go e)
NEUnit e -> EUnit ext (go e)
NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
+ NEMaximum1Inner e -> EMaximum1Inner ext (go e)
+ NEMinimum1Inner e -> EMinimum1Inner ext (go e)
NEConst t x -> EConst ext t x
NEIdx0 e -> EIdx0 ext (go e)
diff --git a/test/Main.hs b/test/Main.hs
index d3e55b3..3a598c0 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -132,7 +132,10 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs
adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property
-adTest = flip adTestGen (genEnv (knownEnv @env))
+adTest = adTestCon (const True)
+
+adTestCon :: forall env. KnownEnv env => (SList Value env -> Bool) -> Ex env (TScal TF64) -> Property
+adTestCon constr term = adTestGen term (Gen.filter constr (genEnv (knownEnv @env)))
adTestGen :: forall env. KnownEnv env
=> Ex env (TScal TF64) -> Gen (SList Value env) -> Property
@@ -198,10 +201,13 @@ tests = checkSequential $ Group "AD"
idx0 $ sum1i . sum1i $
build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx)
- -- ,("build-sum", adTest $ fromNamed $ lambda #x $ body $
- -- idx0 $ sum1i . sum1i $
- -- build (SS (SS SZ)) (pair (pair nil 2) 3) $ #idx :->
- -- oper OToFl64 $ snd_ (fst_ #idx) + snd_ #idx)
+ ,("maximum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
+ fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
+ idx0 $ sum1i $ maximum1i #x)
+
+ ,("minimum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
+ fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $
+ idx0 $ sum1i $ minimum1i #x)
,("neural", adTestGen Example.neural $ do
let tR = STScal STF64