{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -- I want to bring various type variables in scope using type annotations in -- patterns, but I don't want to have to mention all the other type parameters -- of the types in question as well then. Partial type signatures (with '_') are -- useful here. {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} module ForwardAD ( dfwd, FD, FDS, FDE, fd, ) where import AST import Data -- | Dual-numbers transformation type family FD t where FD TNil = TNil FD (TPair a b) = TPair (FD a) (FD b) FD (TEither a b) = TEither (FD a) (FD b) FD (TMaybe t) = TMaybe (FD t) FD (TArr n t) = TArr n (FD t) FD (TScal t) = FDS t type family FDS t where FDS TF32 = TPair (TScal TF32) (TScal TF32) FDS TF64 = TPair (TScal TF64) (TScal TF64) FDS TI32 = TScal TI32 FDS TI64 = TScal TI64 FDS TBool = TScal TBool type family FDE env where FDE '[] = '[] FDE (t : ts) = FD t : FDE ts fd :: STy t -> STy (FD t) fd STNil = STNil fd (STPair a b) = STPair (fd a) (fd b) fd (STEither a b) = STEither (fd a) (fd b) fd (STMaybe t) = STMaybe (fd t) fd (STArr n t) = STArr n (fd t) fd (STScal t) = case t of STF32 -> STPair (STScal STF32) (STScal STF32) STF64 -> STPair (STScal STF64) (STScal STF64) STI32 -> STScal STI32 STI64 -> STScal STI64 STBool -> STScal STBool fd STAccum{} = error "Accum in source program" fdPreservesTupIx :: SNat n -> FD (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) fdPreservesTupIx SZ = Refl fdPreservesTupIx (SS n) | Refl <- fdPreservesTupIx n = Refl convIdx :: Idx env t -> Idx (FDE env) (FD t) convIdx IZ = IZ convIdx (IS i) = IS (convIdx i) scalTyCase :: SScalTy t -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), FD (TScal t) ~ TPair (TScal t) (TScal t)) => r) -> (FD (TScal t) ~ TScal t => r) -> r scalTyCase STF32 k1 _ = k1 scalTyCase STF64 k1 _ = k1 scalTyCase STI32 _ k2 = k2 scalTyCase STI64 _ k2 = k2 scalTyCase STBool _ k2 = k2 -- | Argument does not need to be duplicable. dop :: forall a b env. SOp a b -> Ex env (FD a) -> Ex env (FD b) dop = \case OAdd t -> scalTyCase t (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy))) (EOp ext (OAdd t)) OMul t -> scalTyCase t (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x)))) (EOp ext (OMul t)) ONeg t -> scalTyCase t (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx))) (EOp ext (ONeg t)) OLt t -> scalTyCase t (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y))) (EOp ext (OLt t)) OLe t -> scalTyCase t (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y))) (EOp ext (OLe t)) OEq t -> scalTyCase t (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y))) (EOp ext (OEq t)) ONot -> EOp ext ONot OIf -> EOp ext OIf where add :: ScalIsNumeric t ~ True => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) add t a b = EOp ext (OAdd t) (EPair ext a b) mul :: ScalIsNumeric t ~ True => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) mul t a b = EOp ext (OMul t) (EPair ext a b) neg :: ScalIsNumeric t ~ True => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) neg t = EOp ext (ONeg t) unFloat :: FD a ~ TPair a a => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (FD b)) -> Ex env (FD a) -> Ex env (FD b) unFloat f e = ELet ext e $ let var = EVar ext (typeOf e) IZ in f (EFst ext var, ESnd ext var) binFloat :: (a ~ TPair s s, FD s ~ TPair s s) => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (FD b)) -> Ex env (FD a) -> Ex env (FD b) binFloat f e = ELet ext e $ let var = EVar ext (typeOf e) IZ in f (EFst ext (EFst ext var), ESnd ext (EFst ext var)) (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) dfwd :: Ex env t -> Ex (FDE env) (FD t) dfwd = \case EVar _ t i -> EVar ext (fd t) (convIdx i) ELet _ a b -> ELet ext (dfwd a) (dfwd b) EPair _ a b -> EPair ext (dfwd a) (dfwd b) EFst _ e -> EFst ext (dfwd e) ESnd _ e -> ESnd ext (dfwd e) ENil _ -> ENil ext EInl _ t e -> EInl ext (fd t) (dfwd e) EInr _ t e -> EInr ext (fd t) (dfwd e) ECase _ e a b -> ECase ext (dfwd e) (dfwd a) (dfwd b) ENothing _ t -> ENothing ext (fd t) EJust _ e -> EJust ext (dfwd e) EMaybe _ e a b -> EMaybe ext (dfwd e) (dfwd a) (dfwd b) EConstArr _ n t x -> scalTyCase t (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) (EConstArr ext n t x)) (EConstArr ext n t x) EBuild1 _ a b -> EBuild1 ext (dfwd a) (dfwd b) EBuild _ n a b | Refl <- fdPreservesTupIx n -> EBuild ext n (dfwd a) (dfwd b) EFold1Inner _ a b -> EFold1Inner ext (dfwd a) (dfwd b) ESum1Inner _ e -> let STArr n (STScal t) = typeOf e pairty = (STPair (STScal t) (STScal t)) in scalTyCase t (ELet ext (dfwd e) $ ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) (EVar ext (STArr n pairty) IZ))) (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ)) (EVar ext (STArr n pairty) IZ)))) (ESum1Inner ext (dfwd e)) EUnit _ e -> EUnit ext (dfwd e) EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwd a) (dfwd b) EConst _ t x -> scalTyCase t (EPair ext (EConst ext t x) (EConst ext t 0.0)) (EConst ext t x) EIdx0 _ e -> EIdx0 ext (dfwd e) EIdx1 _ a b -> EIdx1 ext (dfwd a) (dfwd b) EIdx _ n a b | Refl <- fdPreservesTupIx n -> EIdx ext n (dfwd a) (dfwd b) EShape _ e | Refl <- fdPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwd e) EOp _ op e -> dop op (dfwd e) EError t s -> EError (fd t) s EWith{} -> err_accum EAccum{} -> err_accum EZero{} -> err_monoid EPlus{} -> err_monoid where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) emap f arr = let STArr n t = typeOf arr in ELet ext arr $ EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $ weakenExpr (WCopy (WSink .> WSink)) f ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) ezip a b = let STArr n t1 = typeOf a STArr _ t2 = typeOf b in ELet ext a $ ELet ext (weakenExpr WSink b) $ EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) (EIdx ext n (EVar ext (STArr n t2) (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ))