diff options
Diffstat (limited to 'src/ForwardAD/DualNumbers.hs')
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 206 |
1 files changed, 206 insertions, 0 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs new file mode 100644 index 0000000..f9239e9 --- /dev/null +++ b/src/ForwardAD/DualNumbers.hs @@ -0,0 +1,206 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +-- I want to bring various type variables in scope using type annotations in +-- patterns, but I don't want to have to mention all the other type parameters +-- of the types in question as well then. Partial type signatures (with '_') are +-- useful here. +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS -Wno-partial-type-signatures #-} +module ForwardAD.DualNumbers ( + dfwdDN, + DN, DNS, DNE, dn, dne, +) where + +import AST +import Data + + +-- | Dual-numbers transformation +type family DN t where + DN TNil = TNil + DN (TPair a b) = TPair (DN a) (DN b) + DN (TEither a b) = TEither (DN a) (DN b) + DN (TMaybe t) = TMaybe (DN t) + DN (TArr n t) = TArr n (DN t) + DN (TScal t) = DNS t + +type family DNS t where + DNS TF32 = TPair (TScal TF32) (TScal TF32) + DNS TF64 = TPair (TScal TF64) (TScal TF64) + DNS TI32 = TScal TI32 + DNS TI64 = TScal TI64 + DNS TBool = TScal TBool + +type family DNE env where + DNE '[] = '[] + DNE (t : ts) = DN t : DNE ts + +dn :: STy t -> STy (DN t) +dn STNil = STNil +dn (STPair a b) = STPair (dn a) (dn b) +dn (STEither a b) = STEither (dn a) (dn b) +dn (STMaybe t) = STMaybe (dn t) +dn (STArr n t) = STArr n (dn t) +dn (STScal t) = case t of + STF32 -> STPair (STScal STF32) (STScal STF32) + STF64 -> STPair (STScal STF64) (STScal STF64) + STI32 -> STScal STI32 + STI64 -> STScal STI64 + STBool -> STScal STBool +dn STAccum{} = error "Accum in source program" + +dne :: SList STy env -> SList STy (DNE env) +dne SNil = SNil +dne (t `SCons` env) = dn t `SCons` dne env + +dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) +dnPreservesTupIx SZ = Refl +dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl + +convIdx :: Idx env t -> Idx (DNE env) (DN t) +convIdx IZ = IZ +convIdx (IS i) = IS (convIdx i) + +scalTyCase :: SScalTy t + -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r) + -> (DN (TScal t) ~ TScal t => r) + -> r +scalTyCase STF32 k1 _ = k1 +scalTyCase STF64 k1 _ = k1 +scalTyCase STI32 _ k2 = k2 +scalTyCase STI64 _ k2 = k2 +scalTyCase STBool _ k2 = k2 + +-- | Argument does not need to be duplicable. +dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b) +dop = \case + OAdd t -> scalTyCase t + (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy))) + (EOp ext (OAdd t)) + OMul t -> scalTyCase t + (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x)))) + (EOp ext (OMul t)) + ONeg t -> scalTyCase t + (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx))) + (EOp ext (ONeg t)) + OLt t -> scalTyCase t + (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y))) + (EOp ext (OLt t)) + OLe t -> scalTyCase t + (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y))) + (EOp ext (OLe t)) + OEq t -> scalTyCase t + (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y))) + (EOp ext (OEq t)) + ONot -> EOp ext ONot + OIf -> EOp ext OIf + where + add :: ScalIsNumeric t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) + add t a b = EOp ext (OAdd t) (EPair ext a b) + + mul :: ScalIsNumeric t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) + mul t a b = EOp ext (OMul t) (EPair ext a b) + + neg :: ScalIsNumeric t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) + neg t = EOp ext (ONeg t) + + unFloat :: DN a ~ TPair a a + => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b)) + -> Ex env (DN a) -> Ex env (DN b) + unFloat f e = + ELet ext e $ + let var = EVar ext (typeOf e) IZ + in f (EFst ext var, ESnd ext var) + + binFloat :: (a ~ TPair s s, DN s ~ TPair s s) + => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (DN b)) + -> Ex env (DN a) -> Ex env (DN b) + binFloat f e = + ELet ext e $ + let var = EVar ext (typeOf e) IZ + in f (EFst ext (EFst ext var), ESnd ext (EFst ext var)) + (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) + +dfwdDN :: Ex env t -> Ex (DNE env) (DN t) +dfwdDN = \case + EVar _ t i -> EVar ext (dn t) (convIdx i) + ELet _ a b -> ELet ext (dfwdDN a) (dfwdDN b) + EPair _ a b -> EPair ext (dfwdDN a) (dfwdDN b) + EFst _ e -> EFst ext (dfwdDN e) + ESnd _ e -> ESnd ext (dfwdDN e) + ENil _ -> ENil ext + EInl _ t e -> EInl ext (dn t) (dfwdDN e) + EInr _ t e -> EInr ext (dn t) (dfwdDN e) + ECase _ e a b -> ECase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) + ENothing _ t -> ENothing ext (dn t) + EJust _ e -> EJust ext (dfwdDN e) + EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b) + EConstArr _ n t x -> scalTyCase t + (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) + (EConstArr ext n t x)) + (EConstArr ext n t x) + EBuild1 _ a b -> EBuild1 ext (dfwdDN a) (dfwdDN b) + EBuild _ n a b + | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) + EFold1Inner _ a b -> EFold1Inner ext (dfwdDN a) (dfwdDN b) + ESum1Inner _ e -> + let STArr n (STScal t) = typeOf e + pairty = (STPair (STScal t) (STScal t)) + in scalTyCase t + (ELet ext (dfwdDN e) $ + ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) + (EVar ext (STArr n pairty) IZ))) + (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ)) + (EVar ext (STArr n pairty) IZ)))) + (ESum1Inner ext (dfwdDN e)) + EUnit _ e -> EUnit ext (dfwdDN e) + EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) + EConst _ t x -> scalTyCase t + (EPair ext (EConst ext t x) (EConst ext t 0.0)) + (EConst ext t x) + EIdx0 _ e -> EIdx0 ext (dfwdDN e) + EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b) + EIdx _ n a b + | Refl <- dnPreservesTupIx n -> EIdx ext n (dfwdDN a) (dfwdDN b) + EShape _ e + | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e) + EOp _ op e -> dop op (dfwdDN e) + EError t s -> EError (dn t) s + + EWith{} -> err_accum + EAccum{} -> err_accum + EZero{} -> err_monoid + EPlus{} -> err_monoid + where + err_accum = error "Accumulator operations unsupported in the source program" + err_monoid = error "Monoid operations unsupported in the source program" + +emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) +emap f arr = + let STArr n t = typeOf arr + in ELet ext arr $ + EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ + ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ)) + (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) f + +ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) +ezip a b = + let STArr n t1 = typeOf a + STArr _ t2 = typeOf b + in ELet ext a $ + ELet ext (weakenExpr WSink b) $ + EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ + EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) + (EIdx ext n (EVar ext (STArr n t2) (IS IZ)) + (EVar ext (tTup (sreplicate n tIx)) IZ)) |