diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:49:45 +0100 |
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2025-11-10 21:50:25 +0100 |
| commit | 174af2ba568de66e0d890825b8bda930b8e7bb96 (patch) | |
| tree | 5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/CHAD/ForwardAD | |
| parent | 92bca235e3aaa287286b6af082d3fce585825a35 (diff) | |
Move module hierarchy under CHAD.
Diffstat (limited to 'src/CHAD/ForwardAD')
| -rw-r--r-- | src/CHAD/ForwardAD/DualNumbers.hs | 231 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD/DualNumbers/Types.hs | 48 |
2 files changed, 279 insertions, 0 deletions
diff --git a/src/CHAD/ForwardAD/DualNumbers.hs b/src/CHAD/ForwardAD/DualNumbers.hs new file mode 100644 index 0000000..a71efc8 --- /dev/null +++ b/src/CHAD/ForwardAD/DualNumbers.hs @@ -0,0 +1,231 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +-- I want to bring various type variables in scope using type annotations in +-- patterns, but I don't want to have to mention all the other type parameters +-- of the types in question as well then. Partial type signatures (with '_') are +-- useful here. +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS -Wno-partial-type-signatures #-} +module CHAD.ForwardAD.DualNumbers ( + dfwdDN, + DN, DNS, DNE, dn, dne, +) where + +import CHAD.AST +import CHAD.Data +import CHAD.ForwardAD.DualNumbers.Types + + +dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) +dnPreservesTupIx SZ = Refl +dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl + +convIdx :: Idx env t -> Idx (DNE env) (DN t) +convIdx IZ = IZ +convIdx (IS i) = IS (convIdx i) + +scalTyCase :: SScalTy t + -> ((ScalIsNumeric t ~ True, ScalIsFloating t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r) + -> (DN (TScal t) ~ TScal t => r) + -> r +scalTyCase STF32 k1 _ = k1 +scalTyCase STF64 k1 _ = k1 +scalTyCase STI32 _ k2 = k2 +scalTyCase STI64 _ k2 = k2 +scalTyCase STBool _ k2 = k2 + +floatingDual :: ScalIsFloating t ~ True + => SScalTy t + -> ((Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t), ScalIsNumeric t ~ True) => r) -> r +floatingDual STF32 k = k +floatingDual STF64 k = k + +-- | Argument does not need to be duplicable. +dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b) +dop = \case + OAdd t -> scalTyCase t + (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy))) + (EOp ext (OAdd t)) + OMul t -> scalTyCase t + (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x)))) + (EOp ext (OMul t)) + ONeg t -> scalTyCase t + (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx))) + (EOp ext (ONeg t)) + OLt t -> scalTyCase t + (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y))) + (EOp ext (OLt t)) + OLe t -> scalTyCase t + (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y))) + (EOp ext (OLe t)) + OEq t -> scalTyCase t + (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y))) + (EOp ext (OEq t)) + ONot -> EOp ext ONot + OAnd -> EOp ext OAnd + OOr -> EOp ext OOr + OIf -> EOp ext OIf + ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg) + OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0) + ORecip t -> floatingDual t $ unFloat (\(x, dx) -> + EPair ext (recip' t x) + (mul t (neg t (recip' t (mul t x x))) dx)) + OExp t -> floatingDual t $ unFloat (\(x, dx) -> + EPair ext (EOp ext (OExp t) x) (mul t (EOp ext (OExp t) x) dx)) + OLog t -> floatingDual t $ unFloat (\(x, dx) -> + EPair ext (EOp ext (OLog t) x) + (mul t (recip' t x) dx)) + OIDiv t -> scalTyCase t + (case t of {}) + (EOp ext (OIDiv t)) + OMod t -> scalTyCase t + (case t of {}) + (EOp ext (OMod t)) + where + add :: ScalIsNumeric t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) + add t a b = EOp ext (OAdd t) (EPair ext a b) + + mul :: ScalIsNumeric t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) + mul t a b = EOp ext (OMul t) (EPair ext a b) + + neg :: ScalIsNumeric t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) + neg t = EOp ext (ONeg t) + + recip' :: ScalIsFloating t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) + recip' t = EOp ext (ORecip t) + + unFloat :: DN a ~ TPair a a + => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b)) + -> Ex env (DN a) -> Ex env (DN b) + unFloat f e = + ELet ext e $ + let var = EVar ext (typeOf e) IZ + in f (EFst ext var, ESnd ext var) + + binFloat :: (a ~ TPair s s, DN s ~ TPair s s) + => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (DN b)) + -> Ex env (DN a) -> Ex env (DN b) + binFloat f e = + ELet ext e $ + let var = EVar ext (typeOf e) IZ + in f (EFst ext (EFst ext var), ESnd ext (EFst ext var)) + (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) + +zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t) +zeroScalarConst STI32 = EConst ext STI32 0 +zeroScalarConst STI64 = EConst ext STI64 0 +zeroScalarConst STF32 = EConst ext STF32 0.0 +zeroScalarConst STF64 = EConst ext STF64 0.0 + +dfwdDN :: Ex env t -> Ex (DNE env) (DN t) +dfwdDN = \case + EVar _ t i -> EVar ext (dn t) (convIdx i) + ELet _ a b -> ELet ext (dfwdDN a) (dfwdDN b) + EPair _ a b -> EPair ext (dfwdDN a) (dfwdDN b) + EFst _ e -> EFst ext (dfwdDN e) + ESnd _ e -> ESnd ext (dfwdDN e) + ENil _ -> ENil ext + EInl _ t e -> EInl ext (dn t) (dfwdDN e) + EInr _ t e -> EInr ext (dn t) (dfwdDN e) + ECase _ e a b -> ECase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) + ENothing _ t -> ENothing ext (dn t) + EJust _ e -> EJust ext (dfwdDN e) + EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b) + ELNil _ t1 t2 -> ELNil ext (dn t1) (dn t2) + ELInl _ t e -> ELInl ext (dn t) (dfwdDN e) + ELInr _ t e -> ELInr ext (dn t) (dfwdDN e) + ELCase _ e a b c -> ELCase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) (dfwdDN c) + EConstArr _ n t x -> scalTyCase t + (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) + (EConstArr ext n t x)) + (EConstArr ext n t x) + EBuild _ n a b + | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) + EMap _ a b -> EMap ext (dfwdDN a) (dfwdDN b) + EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c) + ESum1Inner _ e -> + let STArr n (STScal t) = typeOf e + pairty = (STPair (STScal t) (STScal t)) + in scalTyCase t + (ELet ext (dfwdDN e) $ + ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) + (EVar ext (STArr n pairty) IZ))) + (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ)) + (EVar ext (STArr n pairty) IZ)))) + (ESum1Inner ext (dfwdDN e)) + EUnit _ e -> EUnit ext (dfwdDN e) + EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) + EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e + EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + EZip _ a b -> EZip ext (dfwdDN a) (dfwdDN b) + EReshape _ n esh e + | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e) + EConst _ t x -> scalTyCase t + (EPair ext (EConst ext t x) (EConst ext t 0.0)) + (EConst ext t x) + EIdx0 _ e -> EIdx0 ext (dfwdDN e) + EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b) + EIdx _ a b + | STArr n _ <- typeOf a + , Refl <- dnPreservesTupIx n + -> EIdx ext (dfwdDN a) (dfwdDN b) + EShape _ e + | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) + -> EShape ext (dfwdDN e) + EOp _ op e -> dop op (dfwdDN e) + ECustom _ _ _ _ pr _ _ e1 e2 -> + ELet ext (dfwdDN e1) $ + ELet ext (weakenExpr WSink (dfwdDN e2)) $ + weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr) + ERecompute _ e -> dfwdDN e + EError _ t s -> EError ext (dn t) s + + EWith{} -> err_accum + EAccum{} -> err_accum + EDeepZero{} -> err_monoid + EZero{} -> err_monoid + EPlus{} -> err_monoid + EOneHot{} -> err_monoid + + EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" + EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" + where + err_accum = error "Accumulator operations unsupported in the source program" + err_monoid = error "Monoid operations unsupported in the source program" + err_targetlang s = error $ "Target language operation " ++ s ++ " not supported in source program" + + deriv_extremum :: ScalIsNumeric t ~ True + => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) + -> Ex env (TArr (S n) (TScal t)) -> Ex (DNE env) (TArr n (DN (TScal t))) + deriv_extremum extremum e = + let STArr (SS n) (STScal t) = typeOf e + t2 = STPair (STScal t) (STScal t) + ta2 = STArr (SS n) t2 + tIxN = tTup (sreplicate (SS n) tIx) + in scalTyCase t + (ELet ext (dfwdDN e) $ + ELet ext (extremum (emap (EFst ext (EVar ext t2 IZ)) (EVar ext ta2 IZ))) $ + ezip (EVar ext (STArr n (STScal t)) IZ) + (ESum1Inner ext + {- build (shape SZ) (\i. if fst (SZ ! i) == Z ! tail i then snd (SZ ! i) else zero) -} + (EBuild ext (SS n) (EShape ext (EVar ext ta2 (IS IZ))) $ + ELet ext (EIdx ext (EVar ext ta2 (IS (IS IZ))) (EVar ext tIxN IZ)) $ + ECase ext (EOp ext OIf (EOp ext (OEq t) (EPair ext + (EFst ext (EVar ext t2 IZ)) + (EIdx ext (EVar ext (STArr n (STScal t)) (IS (IS IZ))) + (EFst ext (EVar ext tIxN (IS IZ))))))) + (ESnd ext (EVar ext t2 (IS IZ))) + (zeroScalarConst t)))) + (extremum (dfwdDN e)) diff --git a/src/CHAD/ForwardAD/DualNumbers/Types.hs b/src/CHAD/ForwardAD/DualNumbers/Types.hs new file mode 100644 index 0000000..5d5dd9e --- /dev/null +++ b/src/CHAD/ForwardAD/DualNumbers/Types.hs @@ -0,0 +1,48 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.ForwardAD.DualNumbers.Types where + +import CHAD.AST.Types +import CHAD.Data + + +-- | Dual-numbers transformation +type family DN t where + DN TNil = TNil + DN (TPair a b) = TPair (DN a) (DN b) + DN (TEither a b) = TEither (DN a) (DN b) + DN (TLEither a b) = TLEither (DN a) (DN b) + DN (TMaybe t) = TMaybe (DN t) + DN (TArr n t) = TArr n (DN t) + DN (TScal t) = DNS t + +type family DNS t where + DNS TF32 = TPair (TScal TF32) (TScal TF32) + DNS TF64 = TPair (TScal TF64) (TScal TF64) + DNS TI32 = TScal TI32 + DNS TI64 = TScal TI64 + DNS TBool = TScal TBool + +type family DNE env where + DNE '[] = '[] + DNE (t : ts) = DN t : DNE ts + +dn :: STy t -> STy (DN t) +dn STNil = STNil +dn (STPair a b) = STPair (dn a) (dn b) +dn (STEither a b) = STEither (dn a) (dn b) +dn (STLEither a b) = STLEither (dn a) (dn b) +dn (STMaybe t) = STMaybe (dn t) +dn (STArr n t) = STArr n (dn t) +dn (STScal t) = case t of + STF32 -> STPair (STScal STF32) (STScal STF32) + STF64 -> STPair (STScal STF64) (STScal STF64) + STI32 -> STScal STI32 + STI64 -> STScal STI64 + STBool -> STScal STBool +dn STAccum{} = error "Accum in source program" + +dne :: SList STy env -> SList STy (DNE env) +dne SNil = SNil +dne (t `SCons` env) = dn t `SCons` dne env |
