{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -- I want to bring various type variables in scope using type annotations in -- patterns, but I don't want to have to mention all the other type parameters -- of the types in question as well then. Partial type signatures (with '_') are -- useful here. {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} module ForwardAD.DualNumbers ( dfwdDN, DN, DNS, DNE, dn, dne, ) where import AST import Data import ForwardAD.DualNumbers.Types dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) dnPreservesTupIx SZ = Refl dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl convIdx :: Idx env t -> Idx (DNE env) (DN t) convIdx IZ = IZ convIdx (IS i) = IS (convIdx i) scalTyCase :: SScalTy t -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r) -> (DN (TScal t) ~ TScal t => r) -> r scalTyCase STF32 k1 _ = k1 scalTyCase STF64 k1 _ = k1 scalTyCase STI32 _ k2 = k2 scalTyCase STI64 _ k2 = k2 scalTyCase STBool _ k2 = k2 floatingDual :: ScalIsFloating t ~ True => SScalTy t -> ((Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t), ScalIsNumeric t ~ True) => r) -> r floatingDual STF32 k = k floatingDual STF64 k = k -- | Argument does not need to be duplicable. dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b) dop = \case OAdd t -> scalTyCase t (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy))) (EOp ext (OAdd t)) OMul t -> scalTyCase t (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x)))) (EOp ext (OMul t)) ONeg t -> scalTyCase t (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx))) (EOp ext (ONeg t)) OLt t -> scalTyCase t (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y))) (EOp ext (OLt t)) OLe t -> scalTyCase t (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y))) (EOp ext (OLe t)) OEq t -> scalTyCase t (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y))) (EOp ext (OEq t)) ONot -> EOp ext ONot OAnd -> EOp ext OAnd OOr -> EOp ext OOr OIf -> EOp ext OIf ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg) OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0) ORecip t -> floatingDual t $ unFloat (\(x, dx) -> EPair ext (recip' t x) (mul t (neg t (recip' t (mul t x x))) dx)) OExp t -> floatingDual t $ unFloat (\(x, dx) -> EPair ext (EOp ext (OExp t) x) (EOp ext (OExp t) dx)) OLog t -> floatingDual t $ unFloat (\(x, dx) -> EPair ext (EOp ext (OLog t) x) (mul t (recip' t x) dx)) where add :: ScalIsNumeric t ~ True => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) add t a b = EOp ext (OAdd t) (EPair ext a b) mul :: ScalIsNumeric t ~ True => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) mul t a b = EOp ext (OMul t) (EPair ext a b) neg :: ScalIsNumeric t ~ True => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) neg t = EOp ext (ONeg t) recip' :: ScalIsFloating t ~ True => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) recip' t = EOp ext (ORecip t) unFloat :: DN a ~ TPair a a => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b)) -> Ex env (DN a) -> Ex env (DN b) unFloat f e = ELet ext e $ let var = EVar ext (typeOf e) IZ in f (EFst ext var, ESnd ext var) binFloat :: (a ~ TPair s s, DN s ~ TPair s s) => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (DN b)) -> Ex env (DN a) -> Ex env (DN b) binFloat f e = ELet ext e $ let var = EVar ext (typeOf e) IZ in f (EFst ext (EFst ext var), ESnd ext (EFst ext var)) (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t) zeroScalarConst STI32 = EConst ext STI32 0 zeroScalarConst STI64 = EConst ext STI64 0 zeroScalarConst STF32 = EConst ext STF32 0.0 zeroScalarConst STF64 = EConst ext STF64 0.0 dfwdDN :: Ex env t -> Ex (DNE env) (DN t) dfwdDN = \case EVar _ t i -> EVar ext (dn t) (convIdx i) ELet _ a b -> ELet ext (dfwdDN a) (dfwdDN b) EPair _ a b -> EPair ext (dfwdDN a) (dfwdDN b) EFst _ e -> EFst ext (dfwdDN e) ESnd _ e -> ESnd ext (dfwdDN e) ENil _ -> ENil ext EInl _ t e -> EInl ext (dn t) (dfwdDN e) EInr _ t e -> EInr ext (dn t) (dfwdDN e) ECase _ e a b -> ECase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) ENothing _ t -> ENothing ext (dn t) EJust _ e -> EJust ext (dfwdDN e) EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b) EConstArr _ n t x -> scalTyCase t (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) (EConstArr ext n t x)) (EConstArr ext n t x) EBuild _ n a b | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) EFold1Inner _ a b c -> EFold1Inner ext (dfwdDN a) (dfwdDN b) (dfwdDN c) ESum1Inner _ e -> let STArr n (STScal t) = typeOf e pairty = (STPair (STScal t) (STScal t)) in scalTyCase t (ELet ext (dfwdDN e) $ ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) (EVar ext (STArr n pairty) IZ))) (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ)) (EVar ext (STArr n pairty) IZ)))) (ESum1Inner ext (dfwdDN e)) EUnit _ e -> EUnit ext (dfwdDN e) EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e EConst _ t x -> scalTyCase t (EPair ext (EConst ext t x) (EConst ext t 0.0)) (EConst ext t x) EIdx0 _ e -> EIdx0 ext (dfwdDN e) EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b) EIdx _ a b | STArr n _ <- typeOf a , Refl <- dnPreservesTupIx n -> EIdx ext (dfwdDN a) (dfwdDN b) EShape _ e | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e) EOp _ op e -> dop op (dfwdDN e) ECustom _ _ _ _ pr _ _ e1 e2 -> ELet ext (dfwdDN e1) $ ELet ext (weakenExpr WSink (dfwdDN e2)) $ weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr) EError t s -> EError (dn t) s EWith{} -> err_accum EAccum{} -> err_accum EZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" deriv_extremum :: ScalIsNumeric t ~ True => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) -> Ex env (TArr (S n) (TScal t)) -> Ex (DNE env) (TArr n (DN (TScal t))) deriv_extremum extremum e = let STArr (SS n) (STScal t) = typeOf e t2 = STPair (STScal t) (STScal t) ta2 = STArr (SS n) t2 tIxN = tTup (sreplicate (SS n) tIx) in scalTyCase t (ELet ext (dfwdDN e) $ ELet ext (extremum (emap (EFst ext (EVar ext t2 IZ)) (EVar ext ta2 IZ))) $ ezip (EVar ext (STArr n (STScal t)) IZ) (ESum1Inner ext {- build (shape SZ) (\i. if fst (SZ ! i) == Z ! tail i then snd (SZ ! i) else zero) -} (EBuild ext (SS n) (EShape ext (EVar ext ta2 (IS IZ))) $ ELet ext (EIdx ext (EVar ext ta2 (IS (IS IZ))) (EVar ext tIxN IZ)) $ ECase ext (EOp ext OIf (EOp ext (OEq t) (EPair ext (EFst ext (EVar ext t2 IZ)) (EIdx ext (EVar ext (STArr n (STScal t)) (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ))))))) (ESnd ext (EVar ext t2 (IS IZ))) (zeroScalarConst t)))) (EMaximum1Inner ext (dfwdDN e))