diff options
Diffstat (limited to 'src/ForwardAD/DualNumbers.hs')
| -rw-r--r-- | src/ForwardAD/DualNumbers.hs | 231 |
1 files changed, 0 insertions, 231 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs deleted file mode 100644 index a1e9d0d..0000000 --- a/src/ForwardAD/DualNumbers.hs +++ /dev/null @@ -1,231 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} - --- I want to bring various type variables in scope using type annotations in --- patterns, but I don't want to have to mention all the other type parameters --- of the types in question as well then. Partial type signatures (with '_') are --- useful here. -{-# LANGUAGE PartialTypeSignatures #-} -{-# OPTIONS -Wno-partial-type-signatures #-} -module ForwardAD.DualNumbers ( - dfwdDN, - DN, DNS, DNE, dn, dne, -) where - -import AST -import Data -import ForwardAD.DualNumbers.Types - - -dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) -dnPreservesTupIx SZ = Refl -dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl - -convIdx :: Idx env t -> Idx (DNE env) (DN t) -convIdx IZ = IZ -convIdx (IS i) = IS (convIdx i) - -scalTyCase :: SScalTy t - -> ((ScalIsNumeric t ~ True, ScalIsFloating t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r) - -> (DN (TScal t) ~ TScal t => r) - -> r -scalTyCase STF32 k1 _ = k1 -scalTyCase STF64 k1 _ = k1 -scalTyCase STI32 _ k2 = k2 -scalTyCase STI64 _ k2 = k2 -scalTyCase STBool _ k2 = k2 - -floatingDual :: ScalIsFloating t ~ True - => SScalTy t - -> ((Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t), ScalIsNumeric t ~ True) => r) -> r -floatingDual STF32 k = k -floatingDual STF64 k = k - --- | Argument does not need to be duplicable. -dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b) -dop = \case - OAdd t -> scalTyCase t - (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy))) - (EOp ext (OAdd t)) - OMul t -> scalTyCase t - (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x)))) - (EOp ext (OMul t)) - ONeg t -> scalTyCase t - (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx))) - (EOp ext (ONeg t)) - OLt t -> scalTyCase t - (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y))) - (EOp ext (OLt t)) - OLe t -> scalTyCase t - (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y))) - (EOp ext (OLe t)) - OEq t -> scalTyCase t - (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y))) - (EOp ext (OEq t)) - ONot -> EOp ext ONot - OAnd -> EOp ext OAnd - OOr -> EOp ext OOr - OIf -> EOp ext OIf - ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg) - OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0) - ORecip t -> floatingDual t $ unFloat (\(x, dx) -> - EPair ext (recip' t x) - (mul t (neg t (recip' t (mul t x x))) dx)) - OExp t -> floatingDual t $ unFloat (\(x, dx) -> - EPair ext (EOp ext (OExp t) x) (mul t (EOp ext (OExp t) x) dx)) - OLog t -> floatingDual t $ unFloat (\(x, dx) -> - EPair ext (EOp ext (OLog t) x) - (mul t (recip' t x) dx)) - OIDiv t -> scalTyCase t - (case t of {}) - (EOp ext (OIDiv t)) - OMod t -> scalTyCase t - (case t of {}) - (EOp ext (OMod t)) - where - add :: ScalIsNumeric t ~ True - => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) - add t a b = EOp ext (OAdd t) (EPair ext a b) - - mul :: ScalIsNumeric t ~ True - => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) - mul t a b = EOp ext (OMul t) (EPair ext a b) - - neg :: ScalIsNumeric t ~ True - => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) - neg t = EOp ext (ONeg t) - - recip' :: ScalIsFloating t ~ True - => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) - recip' t = EOp ext (ORecip t) - - unFloat :: DN a ~ TPair a a - => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b)) - -> Ex env (DN a) -> Ex env (DN b) - unFloat f e = - ELet ext e $ - let var = EVar ext (typeOf e) IZ - in f (EFst ext var, ESnd ext var) - - binFloat :: (a ~ TPair s s, DN s ~ TPair s s) - => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (DN b)) - -> Ex env (DN a) -> Ex env (DN b) - binFloat f e = - ELet ext e $ - let var = EVar ext (typeOf e) IZ - in f (EFst ext (EFst ext var), ESnd ext (EFst ext var)) - (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) - -zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t) -zeroScalarConst STI32 = EConst ext STI32 0 -zeroScalarConst STI64 = EConst ext STI64 0 -zeroScalarConst STF32 = EConst ext STF32 0.0 -zeroScalarConst STF64 = EConst ext STF64 0.0 - -dfwdDN :: Ex env t -> Ex (DNE env) (DN t) -dfwdDN = \case - EVar _ t i -> EVar ext (dn t) (convIdx i) - ELet _ a b -> ELet ext (dfwdDN a) (dfwdDN b) - EPair _ a b -> EPair ext (dfwdDN a) (dfwdDN b) - EFst _ e -> EFst ext (dfwdDN e) - ESnd _ e -> ESnd ext (dfwdDN e) - ENil _ -> ENil ext - EInl _ t e -> EInl ext (dn t) (dfwdDN e) - EInr _ t e -> EInr ext (dn t) (dfwdDN e) - ECase _ e a b -> ECase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) - ENothing _ t -> ENothing ext (dn t) - EJust _ e -> EJust ext (dfwdDN e) - EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b) - ELNil _ t1 t2 -> ELNil ext (dn t1) (dn t2) - ELInl _ t e -> ELInl ext (dn t) (dfwdDN e) - ELInr _ t e -> ELInr ext (dn t) (dfwdDN e) - ELCase _ e a b c -> ELCase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) (dfwdDN c) - EConstArr _ n t x -> scalTyCase t - (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) - (EConstArr ext n t x)) - (EConstArr ext n t x) - EBuild _ n a b - | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) - EMap _ a b -> EMap ext (dfwdDN a) (dfwdDN b) - EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c) - ESum1Inner _ e -> - let STArr n (STScal t) = typeOf e - pairty = (STPair (STScal t) (STScal t)) - in scalTyCase t - (ELet ext (dfwdDN e) $ - ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) - (EVar ext (STArr n pairty) IZ))) - (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ)) - (EVar ext (STArr n pairty) IZ)))) - (ESum1Inner ext (dfwdDN e)) - EUnit _ e -> EUnit ext (dfwdDN e) - EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) - EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e - EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e - EZip _ a b -> EZip ext (dfwdDN a) (dfwdDN b) - EReshape _ n esh e - | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e) - EConst _ t x -> scalTyCase t - (EPair ext (EConst ext t x) (EConst ext t 0.0)) - (EConst ext t x) - EIdx0 _ e -> EIdx0 ext (dfwdDN e) - EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b) - EIdx _ a b - | STArr n _ <- typeOf a - , Refl <- dnPreservesTupIx n - -> EIdx ext (dfwdDN a) (dfwdDN b) - EShape _ e - | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) - -> EShape ext (dfwdDN e) - EOp _ op e -> dop op (dfwdDN e) - ECustom _ _ _ _ pr _ _ e1 e2 -> - ELet ext (dfwdDN e1) $ - ELet ext (weakenExpr WSink (dfwdDN e2)) $ - weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr) - ERecompute _ e -> dfwdDN e - EError _ t s -> EError ext (dn t) s - - EWith{} -> err_accum - EAccum{} -> err_accum - EDeepZero{} -> err_monoid - EZero{} -> err_monoid - EPlus{} -> err_monoid - EOneHot{} -> err_monoid - - EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" - EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" - where - err_accum = error "Accumulator operations unsupported in the source program" - err_monoid = error "Monoid operations unsupported in the source program" - err_targetlang s = error $ "Target language operation " ++ s ++ " not supported in source program" - - deriv_extremum :: ScalIsNumeric t ~ True - => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) - -> Ex env (TArr (S n) (TScal t)) -> Ex (DNE env) (TArr n (DN (TScal t))) - deriv_extremum extremum e = - let STArr (SS n) (STScal t) = typeOf e - t2 = STPair (STScal t) (STScal t) - ta2 = STArr (SS n) t2 - tIxN = tTup (sreplicate (SS n) tIx) - in scalTyCase t - (ELet ext (dfwdDN e) $ - ELet ext (extremum (emap (EFst ext (EVar ext t2 IZ)) (EVar ext ta2 IZ))) $ - ezip (EVar ext (STArr n (STScal t)) IZ) - (ESum1Inner ext - {- build (shape SZ) (\i. if fst (SZ ! i) == Z ! tail i then snd (SZ ! i) else zero) -} - (EBuild ext (SS n) (EShape ext (EVar ext ta2 (IS IZ))) $ - ELet ext (EIdx ext (EVar ext ta2 (IS (IS IZ))) (EVar ext tIxN IZ)) $ - ECase ext (EOp ext OIf (EOp ext (OEq t) (EPair ext - (EFst ext (EVar ext t2 IZ)) - (EIdx ext (EVar ext (STArr n (STScal t)) (IS (IS IZ))) - (EFst ext (EVar ext tIxN (IS IZ))))))) - (ESnd ext (EVar ext t2 (IS IZ))) - (zeroScalarConst t)))) - (extremum (dfwdDN e)) |
