diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-09 11:15:06 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-09 11:15:06 +0100 |
commit | 992249ebf159ba3783a9345430013e52294c26aa (patch) | |
tree | 2d1a8324310aebd60062fdb7d9ba785fe0298d0c | |
parent | cbe6472a14cc0887295034bb29546dd1a1f083fd (diff) |
Maximum/minimum
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 39 | ||||
-rw-r--r-- | src/Interpreter.hs | 16 | ||||
-rw-r--r-- | src/Language.hs | 6 | ||||
-rw-r--r-- | src/Language/AST.hs | 4 | ||||
-rw-r--r-- | test/Main.hs | 16 |
5 files changed, 67 insertions, 14 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 056fcb3..7d47e6d 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -98,6 +98,12 @@ dop = \case in f (EFst ext (EFst ext var), ESnd ext (EFst ext var)) (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) +zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t) +zeroScalarConst STF32 = EConst ext STF32 0.0 +zeroScalarConst STF64 = EConst ext STF64 0.0 +zeroScalarConst STI32 = EConst ext STI32 0 +zeroScalarConst STI64 = EConst ext STI64 0 + dfwdDN :: Ex env t -> Ex (DNE env) (DN t) dfwdDN = \case EVar _ t i -> EVar ext (dn t) (convIdx i) @@ -131,13 +137,8 @@ dfwdDN = \case (ESum1Inner ext (dfwdDN e)) EUnit _ e -> EUnit ext (dfwdDN e) EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) - EMaximum1Inner _ e -> - let STArr n (STScal t) = typeOf e - in scalTyCase t - -- TODO: do roughly the same as what CHAD does, but forward - (_ (dfwdDN e)) - _ - EMinimum1Inner _ e -> EMinimum1Inner ext (dfwdDN e) + EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e + EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e EConst _ t x -> scalTyCase t (EPair ext (EConst ext t x) (EConst ext t 0.0)) (EConst ext t x) @@ -165,3 +166,27 @@ dfwdDN = \case where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" + + deriv_extremum :: ScalIsNumeric t ~ True + => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) + -> Ex env (TArr (S n) (TScal t)) -> Ex (DNE env) (TArr n (DN (TScal t))) + deriv_extremum extremum e = + let STArr (SS n) (STScal t) = typeOf e + t2 = STPair (STScal t) (STScal t) + ta2 = STArr (SS n) t2 + tIxN = tTup (sreplicate (SS n) tIx) + in scalTyCase t + (ELet ext (dfwdDN e) $ + ELet ext (extremum (emap (EFst ext (EVar ext t2 IZ)) (EVar ext ta2 IZ))) $ + ezip (EVar ext (STArr n (STScal t)) IZ) + (ESum1Inner ext + {- build (shape SZ) (\i. if fst (SZ ! i) == Z ! tail i then snd (SZ ! i) else zero) -} + (EBuild ext (SS n) (EShape ext (EVar ext ta2 (IS IZ))) $ + ELet ext (EIdx ext (EVar ext ta2 (IS (IS IZ))) (EVar ext tIxN IZ)) $ + ECase ext (EOp ext OIf (EOp ext (OEq t) (EPair ext + (EFst ext (EVar ext t2 IZ)) + (EIdx ext (EVar ext (STArr n (STScal t)) (IS (IS IZ))) + (EFst ext (EVar ext tIxN (IS IZ))))))) + (ESnd ext (EVar ext t2 (IS IZ))) + (zeroScalarConst t)))) + (EMaximum1Inner ext (dfwdDN e)) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 47514ae..3c1aad0 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -103,13 +103,25 @@ interpret'Rec env = \case arr <- interpret' env e let STArr _ (STScal t) = typeOf e sh `ShCons` n = arrayShape arr - numericIsNum t $ arrayGenerateM sh $ \idx -> return $ sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + numericIsNum t $ return $ arrayGenerate sh $ \idx -> sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e) EReplicate1Inner _ a b -> do n <- fromIntegral @Int64 @Int <$> interpret' env a arr <- interpret' env b let sh = arrayShape arr - arrayGenerateM (sh `ShCons` n) (\(idx `IxCons` _) -> return (arrayIndex arr idx)) + return $ arrayGenerate (sh `ShCons` n) (\(idx `IxCons` _) -> arrayIndex arr idx) + EMaximum1Inner _ e -> do + arr <- interpret' env e + let STArr _ (STScal t) = typeOf e + sh `ShCons` n = arrayShape arr + numericIsNum t $ return $ + arrayGenerate sh (\idx -> maximum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) + EMinimum1Inner _ e -> do + arr <- interpret' env e + let STArr _ (STScal t) = typeOf e + sh `ShCons` n = arrayShape arr + numericIsNum t $ return $ + arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) EConst _ _ v -> return v EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) diff --git a/src/Language.hs b/src/Language.hs index e8dc89f..88cb1de 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -94,6 +94,12 @@ unit = NEUnit replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t)) replicate1i n a = NEReplicate1Inner n a +maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +maximum1i e = NEMaximum1Inner e + +minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +minimum1i e = NEMinimum1Inner e + const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) const_ x = let ty = knownScalTy diff --git a/src/Language/AST.hs b/src/Language/AST.hs index f5203e9..4194913 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -49,6 +49,8 @@ data NExpr env t where NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEUnit :: NExpr env t -> NExpr env (TArr Z t) NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) + NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) + NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) -- expression operations NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) @@ -157,6 +159,8 @@ fromNamedExpr val = \case NESum1Inner e -> ESum1Inner ext (go e) NEUnit e -> EUnit ext (go e) NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) + NEMaximum1Inner e -> EMaximum1Inner ext (go e) + NEMinimum1Inner e -> EMinimum1Inner ext (go e) NEConst t x -> EConst ext t x NEIdx0 e -> EIdx0 ext (go e) diff --git a/test/Main.hs b/test/Main.hs index d3e55b3..3a598c0 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -132,7 +132,10 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property -adTest = flip adTestGen (genEnv (knownEnv @env)) +adTest = adTestCon (const True) + +adTestCon :: forall env. KnownEnv env => (SList Value env -> Bool) -> Ex env (TScal TF64) -> Property +adTestCon constr term = adTestGen term (Gen.filter constr (genEnv (knownEnv @env))) adTestGen :: forall env. KnownEnv env => Ex env (TScal TF64) -> Gen (SList Value env) -> Property @@ -198,10 +201,13 @@ tests = checkSequential $ Group "AD" idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx) - -- ,("build-sum", adTest $ fromNamed $ lambda #x $ body $ - -- idx0 $ sum1i . sum1i $ - -- build (SS (SS SZ)) (pair (pair nil 2) 3) $ #idx :-> - -- oper OToFl64 $ snd_ (fst_ #idx) + snd_ #idx) + ,("maximum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ + fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ + idx0 $ sum1i $ maximum1i #x) + + ,("minimum", adTestCon (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ + fromNamed $ lambda @(TArr N2 (TScal TF64)) #x $ body $ + idx0 $ sum1i $ minimum1i #x) ,("neural", adTestGen Example.neural $ do let tR = STScal STF64 |