diff options
Diffstat (limited to 'src/ForwardAD.hs')
-rw-r--r-- | src/ForwardAD.hs | 202 |
1 files changed, 202 insertions, 0 deletions
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs new file mode 100644 index 0000000..0a9e12c --- /dev/null +++ b/src/ForwardAD.hs @@ -0,0 +1,202 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +-- I want to bring various type variables in scope using type annotations in +-- patterns, but I don't want to have to mention all the other type parameters +-- of the types in question as well then. Partial type signatures (with '_') are +-- useful here. +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS -Wno-partial-type-signatures #-} +module ForwardAD ( + dfwd, + FD, FDS, FDE, fd, +) where + +import AST +import Data + + +-- | Dual-numbers transformation +type family FD t where + FD TNil = TNil + FD (TPair a b) = TPair (FD a) (FD b) + FD (TEither a b) = TEither (FD a) (FD b) + FD (TMaybe t) = TMaybe (FD t) + FD (TArr n t) = TArr n (FD t) + FD (TScal t) = FDS t + +type family FDS t where + FDS TF32 = TPair (TScal TF32) (TScal TF32) + FDS TF64 = TPair (TScal TF64) (TScal TF64) + FDS TI32 = TScal TI32 + FDS TI64 = TScal TI64 + FDS TBool = TScal TBool + +type family FDE env where + FDE '[] = '[] + FDE (t : ts) = FD t : FDE ts + +fd :: STy t -> STy (FD t) +fd STNil = STNil +fd (STPair a b) = STPair (fd a) (fd b) +fd (STEither a b) = STEither (fd a) (fd b) +fd (STMaybe t) = STMaybe (fd t) +fd (STArr n t) = STArr n (fd t) +fd (STScal t) = case t of + STF32 -> STPair (STScal STF32) (STScal STF32) + STF64 -> STPair (STScal STF64) (STScal STF64) + STI32 -> STScal STI32 + STI64 -> STScal STI64 + STBool -> STScal STBool +fd STAccum{} = error "Accum in source program" + +fdPreservesTupIx :: SNat n -> FD (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) +fdPreservesTupIx SZ = Refl +fdPreservesTupIx (SS n) | Refl <- fdPreservesTupIx n = Refl + +convIdx :: Idx env t -> Idx (FDE env) (FD t) +convIdx IZ = IZ +convIdx (IS i) = IS (convIdx i) + +scalTyCase :: SScalTy t + -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), FD (TScal t) ~ TPair (TScal t) (TScal t)) => r) + -> (FD (TScal t) ~ TScal t => r) + -> r +scalTyCase STF32 k1 _ = k1 +scalTyCase STF64 k1 _ = k1 +scalTyCase STI32 _ k2 = k2 +scalTyCase STI64 _ k2 = k2 +scalTyCase STBool _ k2 = k2 + +-- | Argument does not need to be duplicable. +dop :: forall a b env. SOp a b -> Ex env (FD a) -> Ex env (FD b) +dop = \case + OAdd t -> scalTyCase t + (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy))) + (EOp ext (OAdd t)) + OMul t -> scalTyCase t + (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x)))) + (EOp ext (OMul t)) + ONeg t -> scalTyCase t + (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx))) + (EOp ext (ONeg t)) + OLt t -> scalTyCase t + (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y))) + (EOp ext (OLt t)) + OLe t -> scalTyCase t + (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y))) + (EOp ext (OLe t)) + OEq t -> scalTyCase t + (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y))) + (EOp ext (OEq t)) + ONot -> EOp ext ONot + OIf -> EOp ext OIf + where + add :: ScalIsNumeric t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) + add t a b = EOp ext (OAdd t) (EPair ext a b) + + mul :: ScalIsNumeric t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) + mul t a b = EOp ext (OMul t) (EPair ext a b) + + neg :: ScalIsNumeric t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) + neg t = EOp ext (ONeg t) + + unFloat :: FD a ~ TPair a a + => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (FD b)) + -> Ex env (FD a) -> Ex env (FD b) + unFloat f e = + ELet ext e $ + let var = EVar ext (typeOf e) IZ + in f (EFst ext var, ESnd ext var) + + binFloat :: (a ~ TPair s s, FD s ~ TPair s s) + => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (FD b)) + -> Ex env (FD a) -> Ex env (FD b) + binFloat f e = + ELet ext e $ + let var = EVar ext (typeOf e) IZ + in f (EFst ext (EFst ext var), ESnd ext (EFst ext var)) + (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) + +dfwd :: Ex env t -> Ex (FDE env) (FD t) +dfwd = \case + EVar _ t i -> EVar ext (fd t) (convIdx i) + ELet _ a b -> ELet ext (dfwd a) (dfwd b) + EPair _ a b -> EPair ext (dfwd a) (dfwd b) + EFst _ e -> EFst ext (dfwd e) + ESnd _ e -> ESnd ext (dfwd e) + ENil _ -> ENil ext + EInl _ t e -> EInl ext (fd t) (dfwd e) + EInr _ t e -> EInr ext (fd t) (dfwd e) + ECase _ e a b -> ECase ext (dfwd e) (dfwd a) (dfwd b) + ENothing _ t -> ENothing ext (fd t) + EJust _ e -> EJust ext (dfwd e) + EMaybe _ e a b -> EMaybe ext (dfwd e) (dfwd a) (dfwd b) + EConstArr _ n t x -> scalTyCase t + (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) + (EConstArr ext n t x)) + (EConstArr ext n t x) + EBuild1 _ a b -> EBuild1 ext (dfwd a) (dfwd b) + EBuild _ n a b + | Refl <- fdPreservesTupIx n -> EBuild ext n (dfwd a) (dfwd b) + EFold1Inner _ a b -> EFold1Inner ext (dfwd a) (dfwd b) + ESum1Inner _ e -> + let STArr n (STScal t) = typeOf e + pairty = (STPair (STScal t) (STScal t)) + in scalTyCase t + (ELet ext (dfwd e) $ + ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) + (EVar ext (STArr n pairty) IZ))) + (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ)) + (EVar ext (STArr n pairty) IZ)))) + (ESum1Inner ext (dfwd e)) + EUnit _ e -> EUnit ext (dfwd e) + EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwd a) (dfwd b) + EConst _ t x -> scalTyCase t + (EPair ext (EConst ext t x) (EConst ext t 0.0)) + (EConst ext t x) + EIdx0 _ e -> EIdx0 ext (dfwd e) + EIdx1 _ a b -> EIdx1 ext (dfwd a) (dfwd b) + EIdx _ n a b + | Refl <- fdPreservesTupIx n -> EIdx ext n (dfwd a) (dfwd b) + EShape _ e + | Refl <- fdPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwd e) + EOp _ op e -> dop op (dfwd e) + EError t s -> EError (fd t) s + + EWith{} -> err_accum + EAccum{} -> err_accum + EZero{} -> err_monoid + EPlus{} -> err_monoid + where + err_accum = error "Accumulator operations unsupported in the source program" + err_monoid = error "Monoid operations unsupported in the source program" + +emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) +emap f arr = + let STArr n t = typeOf arr + in ELet ext arr $ + EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ + ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ)) + (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) f + +ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) +ezip a b = + let STArr n t1 = typeOf a + STArr _ t2 = typeOf b + in ELet ext a $ + ELet ext (weakenExpr WSink b) $ + EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ + EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) + (EIdx ext n (EVar ext (STArr n t2) (IS IZ)) + (EVar ext (tTup (sreplicate n tIx)) IZ)) |