aboutsummaryrefslogtreecommitdiff
path: root/src/ForwardAD/DualNumbers.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-10 21:49:45 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-10 21:50:25 +0100
commit174af2ba568de66e0d890825b8bda930b8e7bb96 (patch)
tree5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/ForwardAD/DualNumbers.hs
parent92bca235e3aaa287286b6af082d3fce585825a35 (diff)
Move module hierarchy under CHAD.
Diffstat (limited to 'src/ForwardAD/DualNumbers.hs')
-rw-r--r--src/ForwardAD/DualNumbers.hs231
1 files changed, 0 insertions, 231 deletions
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
deleted file mode 100644
index a1e9d0d..0000000
--- a/src/ForwardAD/DualNumbers.hs
+++ /dev/null
@@ -1,231 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-
--- I want to bring various type variables in scope using type annotations in
--- patterns, but I don't want to have to mention all the other type parameters
--- of the types in question as well then. Partial type signatures (with '_') are
--- useful here.
-{-# LANGUAGE PartialTypeSignatures #-}
-{-# OPTIONS -Wno-partial-type-signatures #-}
-module ForwardAD.DualNumbers (
- dfwdDN,
- DN, DNS, DNE, dn, dne,
-) where
-
-import AST
-import Data
-import ForwardAD.DualNumbers.Types
-
-
-dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx)
-dnPreservesTupIx SZ = Refl
-dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl
-
-convIdx :: Idx env t -> Idx (DNE env) (DN t)
-convIdx IZ = IZ
-convIdx (IS i) = IS (convIdx i)
-
-scalTyCase :: SScalTy t
- -> ((ScalIsNumeric t ~ True, ScalIsFloating t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r)
- -> (DN (TScal t) ~ TScal t => r)
- -> r
-scalTyCase STF32 k1 _ = k1
-scalTyCase STF64 k1 _ = k1
-scalTyCase STI32 _ k2 = k2
-scalTyCase STI64 _ k2 = k2
-scalTyCase STBool _ k2 = k2
-
-floatingDual :: ScalIsFloating t ~ True
- => SScalTy t
- -> ((Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t), ScalIsNumeric t ~ True) => r) -> r
-floatingDual STF32 k = k
-floatingDual STF64 k = k
-
--- | Argument does not need to be duplicable.
-dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b)
-dop = \case
- OAdd t -> scalTyCase t
- (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy)))
- (EOp ext (OAdd t))
- OMul t -> scalTyCase t
- (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x))))
- (EOp ext (OMul t))
- ONeg t -> scalTyCase t
- (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx)))
- (EOp ext (ONeg t))
- OLt t -> scalTyCase t
- (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y)))
- (EOp ext (OLt t))
- OLe t -> scalTyCase t
- (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y)))
- (EOp ext (OLe t))
- OEq t -> scalTyCase t
- (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y)))
- (EOp ext (OEq t))
- ONot -> EOp ext ONot
- OAnd -> EOp ext OAnd
- OOr -> EOp ext OOr
- OIf -> EOp ext OIf
- ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg)
- OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0)
- ORecip t -> floatingDual t $ unFloat (\(x, dx) ->
- EPair ext (recip' t x)
- (mul t (neg t (recip' t (mul t x x))) dx))
- OExp t -> floatingDual t $ unFloat (\(x, dx) ->
- EPair ext (EOp ext (OExp t) x) (mul t (EOp ext (OExp t) x) dx))
- OLog t -> floatingDual t $ unFloat (\(x, dx) ->
- EPair ext (EOp ext (OLog t) x)
- (mul t (recip' t x) dx))
- OIDiv t -> scalTyCase t
- (case t of {})
- (EOp ext (OIDiv t))
- OMod t -> scalTyCase t
- (case t of {})
- (EOp ext (OMod t))
- where
- add :: ScalIsNumeric t ~ True
- => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
- add t a b = EOp ext (OAdd t) (EPair ext a b)
-
- mul :: ScalIsNumeric t ~ True
- => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
- mul t a b = EOp ext (OMul t) (EPair ext a b)
-
- neg :: ScalIsNumeric t ~ True
- => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
- neg t = EOp ext (ONeg t)
-
- recip' :: ScalIsFloating t ~ True
- => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
- recip' t = EOp ext (ORecip t)
-
- unFloat :: DN a ~ TPair a a
- => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b))
- -> Ex env (DN a) -> Ex env (DN b)
- unFloat f e =
- ELet ext e $
- let var = EVar ext (typeOf e) IZ
- in f (EFst ext var, ESnd ext var)
-
- binFloat :: (a ~ TPair s s, DN s ~ TPair s s)
- => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (DN b))
- -> Ex env (DN a) -> Ex env (DN b)
- binFloat f e =
- ELet ext e $
- let var = EVar ext (typeOf e) IZ
- in f (EFst ext (EFst ext var), ESnd ext (EFst ext var))
- (EFst ext (ESnd ext var), ESnd ext (ESnd ext var))
-
-zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t)
-zeroScalarConst STI32 = EConst ext STI32 0
-zeroScalarConst STI64 = EConst ext STI64 0
-zeroScalarConst STF32 = EConst ext STF32 0.0
-zeroScalarConst STF64 = EConst ext STF64 0.0
-
-dfwdDN :: Ex env t -> Ex (DNE env) (DN t)
-dfwdDN = \case
- EVar _ t i -> EVar ext (dn t) (convIdx i)
- ELet _ a b -> ELet ext (dfwdDN a) (dfwdDN b)
- EPair _ a b -> EPair ext (dfwdDN a) (dfwdDN b)
- EFst _ e -> EFst ext (dfwdDN e)
- ESnd _ e -> ESnd ext (dfwdDN e)
- ENil _ -> ENil ext
- EInl _ t e -> EInl ext (dn t) (dfwdDN e)
- EInr _ t e -> EInr ext (dn t) (dfwdDN e)
- ECase _ e a b -> ECase ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
- ENothing _ t -> ENothing ext (dn t)
- EJust _ e -> EJust ext (dfwdDN e)
- EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
- ELNil _ t1 t2 -> ELNil ext (dn t1) (dn t2)
- ELInl _ t e -> ELInl ext (dn t) (dfwdDN e)
- ELInr _ t e -> ELInr ext (dn t) (dfwdDN e)
- ELCase _ e a b c -> ELCase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) (dfwdDN c)
- EConstArr _ n t x -> scalTyCase t
- (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0))
- (EConstArr ext n t x))
- (EConstArr ext n t x)
- EBuild _ n a b
- | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b)
- EMap _ a b -> EMap ext (dfwdDN a) (dfwdDN b)
- EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c)
- ESum1Inner _ e ->
- let STArr n (STScal t) = typeOf e
- pairty = (STPair (STScal t) (STScal t))
- in scalTyCase t
- (ELet ext (dfwdDN e) $
- ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ))
- (EVar ext (STArr n pairty) IZ)))
- (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ))
- (EVar ext (STArr n pairty) IZ))))
- (ESum1Inner ext (dfwdDN e))
- EUnit _ e -> EUnit ext (dfwdDN e)
- EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b)
- EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
- EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
- EZip _ a b -> EZip ext (dfwdDN a) (dfwdDN b)
- EReshape _ n esh e
- | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e)
- EConst _ t x -> scalTyCase t
- (EPair ext (EConst ext t x) (EConst ext t 0.0))
- (EConst ext t x)
- EIdx0 _ e -> EIdx0 ext (dfwdDN e)
- EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b)
- EIdx _ a b
- | STArr n _ <- typeOf a
- , Refl <- dnPreservesTupIx n
- -> EIdx ext (dfwdDN a) (dfwdDN b)
- EShape _ e
- | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n)
- -> EShape ext (dfwdDN e)
- EOp _ op e -> dop op (dfwdDN e)
- ECustom _ _ _ _ pr _ _ e1 e2 ->
- ELet ext (dfwdDN e1) $
- ELet ext (weakenExpr WSink (dfwdDN e2)) $
- weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr)
- ERecompute _ e -> dfwdDN e
- EError _ t s -> EError ext (dn t) s
-
- EWith{} -> err_accum
- EAccum{} -> err_accum
- EDeepZero{} -> err_monoid
- EZero{} -> err_monoid
- EPlus{} -> err_monoid
- EOneHot{} -> err_monoid
-
- EFold1InnerD1{} -> err_targetlang "EFold1InnerD1"
- EFold1InnerD2{} -> err_targetlang "EFold1InnerD2"
- where
- err_accum = error "Accumulator operations unsupported in the source program"
- err_monoid = error "Monoid operations unsupported in the source program"
- err_targetlang s = error $ "Target language operation " ++ s ++ " not supported in source program"
-
- deriv_extremum :: ScalIsNumeric t ~ True
- => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t)))
- -> Ex env (TArr (S n) (TScal t)) -> Ex (DNE env) (TArr n (DN (TScal t)))
- deriv_extremum extremum e =
- let STArr (SS n) (STScal t) = typeOf e
- t2 = STPair (STScal t) (STScal t)
- ta2 = STArr (SS n) t2
- tIxN = tTup (sreplicate (SS n) tIx)
- in scalTyCase t
- (ELet ext (dfwdDN e) $
- ELet ext (extremum (emap (EFst ext (EVar ext t2 IZ)) (EVar ext ta2 IZ))) $
- ezip (EVar ext (STArr n (STScal t)) IZ)
- (ESum1Inner ext
- {- build (shape SZ) (\i. if fst (SZ ! i) == Z ! tail i then snd (SZ ! i) else zero) -}
- (EBuild ext (SS n) (EShape ext (EVar ext ta2 (IS IZ))) $
- ELet ext (EIdx ext (EVar ext ta2 (IS (IS IZ))) (EVar ext tIxN IZ)) $
- ECase ext (EOp ext OIf (EOp ext (OEq t) (EPair ext
- (EFst ext (EVar ext t2 IZ))
- (EIdx ext (EVar ext (STArr n (STScal t)) (IS (IS IZ)))
- (EFst ext (EVar ext tIxN (IS IZ)))))))
- (ESnd ext (EVar ext t2 (IS IZ)))
- (zeroScalarConst t))))
- (extremum (dfwdDN e))