diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 12:37:51 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-08 12:37:51 +0100 |
commit | 83692cf41f76272423445c9cbbad65561ee3b50c (patch) | |
tree | 49f56f498a68722a7302b4ce0b41402a9b9da9ef | |
parent | 58d4d0b47f5e609e21132f48b727de37d06b6777 (diff) |
WIP custom derivatives
-rw-r--r-- | chad-fast.cabal | 1 | ||||
-rw-r--r-- | src/AST.hs | 12 | ||||
-rw-r--r-- | src/AST/Count.hs | 1 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 17 | ||||
-rw-r--r-- | src/AST/Weaken.hs | 6 | ||||
-rw-r--r-- | src/AST/Weaken/Auto.hs | 7 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 48 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers/Types.hs | 46 | ||||
-rw-r--r-- | src/Simplify.hs | 6 |
9 files changed, 96 insertions, 48 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 8ff3a21..94d7423 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -28,6 +28,7 @@ library Example.Format ForwardAD ForwardAD.DualNumbers + ForwardAD.DualNumbers.Types Interpreter -- Interpreter.AccumOld Interpreter.Rep @@ -26,6 +26,7 @@ import AST.Types import AST.Weaken import CHAD.Types import Data +import ForwardAD.DualNumbers.Types -- | This index is flipped around from the usual direction: the smallest index @@ -96,6 +97,14 @@ data Expr x env t where EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t + -- custom derivatives + ECustom :: x t -> STy a -> STy b + -> Expr x '[b, a] t -- ^ regular operation + -> Expr x '[DN b, a] (DN t) -- ^ dual-numbers forward derivative + -> Expr x '[D2 t, D1 b, D1 a] (D2 b) -- ^ CHAD reverse derivative + -> Expr x env a -> Expr x env b + -> Expr x env t + -- accumulation effect EWith :: Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil @@ -202,6 +211,8 @@ typeOf = \case EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) EOp _ op _ -> opt2 op + ECustom _ _ _ e _ _ _ _ -> typeOf e + EWith e1 e2 -> STPair (typeOf e2) (typeOf e1) EAccum _ _ _ _ -> STNil @@ -274,6 +285,7 @@ subst' f w = \case EIdx x e es -> EIdx x (subst' f w e) (subst' f w es) EShape x e -> EShape x (subst' f w e) EOp x op e -> EOp x op (subst' f w e) + ECustom x s t a b c e1 e2 -> ECustom x s t a b c (subst' f w e1) (subst' f w e2) EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3) EZero t -> EZero t diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 71b38b1..d365218 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -125,6 +125,7 @@ occCountGeneral onehot unpush alter many = go WId EIdx _ a b -> re a <> re b EShape _ e -> re e EOp _ _ e -> re e + ECustom _ _ _ _ _ _ a b -> re a <> re b EWith a b -> re a <> re1 b EAccum _ a b e -> re a <> re b <> re e EZero _ -> mempty diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 76424fe..bac267d 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -14,7 +14,9 @@ import Data.Functor.Const import AST import AST.Count +import CHAD.Types import Data +import ForwardAD.DualNumbers.Types type SVal = SList (Const String) @@ -175,6 +177,21 @@ ppExpr' d val = \case (Prefix, s) -> s return $ showParen (d > 10) $ showString (ops ++ " ") . e' + ECustom _ t1 t2 a b c e1 e2 -> do + pn1 <- genNameIfUsedIn t1 (IS IZ) a + pn2 <- genNameIfUsedIn t2 IZ a + fn1 <- genNameIfUsedIn t1 (IS IZ) b + fn2 <- genNameIfUsedIn (dn t2) IZ b + rn1 <- genNameIfUsedIn (d1 t1) (IS (IS IZ)) c + rn2 <- genNameIfUsedIn (d1 t2) (IS IZ) c + rn3 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c + a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a + b' <- ppExpr' 11 (Const fn2 `SCons` Const fn1 `SCons` SNil) b + c' <- ppExpr' 11 (Const rn3 `SCons` Const rn2 `SCons` Const rn1 `SCons` SNil) c + e1' <- ppExpr' 11 val e1 + e2' <- ppExpr' 11 val e2 + return $ showParen (d > 10) $ showString "custom " . a' . showString " " . b' . showString " " . c' . showString " " . e1' . showString " " . e2' + EWith e1 e2 -> do e1' <- ppExpr' 11 val e1 name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2 diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index ecd7bc9..dbb37f7 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -42,7 +42,7 @@ data env :> env' where WCopy :: forall t env env'. env :> env' -> (t : env) :> (t : env') WPop :: (t : env) :> env' -> env :> env' WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3 - WClosed :: SList (Const ()) env -> '[] :> env + WClosed :: '[] :> env WIdx :: Idx env t -> (t : env) :> env WPick :: forall t pre env env'. SList (Const ()) pre -> env :> env' -> Append pre (t : env) :> t : Append pre env' @@ -62,7 +62,7 @@ WCopy _ @> IZ = IZ WCopy w @> IS i = IS (w @> i) WPop w @> i = w @> IS i WThen w1 w2 @> i = w2 @> w1 @> i -WClosed _ @> i = case i of {} +WClosed @> i = case i of {} WIdx j @> IZ = j WIdx _ @> IS i = i WPick SNil w @> i = WCopy w @> i @@ -115,5 +115,5 @@ wCopies bs w = in WStack bs' bs' WId w wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env -wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env) +wRaiseAbove SNil _ = WClosed wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs index 8555516..6752c24 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/AST/Weaken/Auto.hs @@ -118,11 +118,6 @@ linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2 = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2) -linLayoutEnv :: SSegments segments -> LinLayout withPre segments env -> SList (Const ()) env -linLayoutEnv _ LinEnd = SNil -linLayoutEnv segs (LinApp name lin) = sappend (segmentLookup segs name) (linLayoutEnv segs lin) -linLayoutEnv segs (LinAppPreW name1 _ _ lin) = sappend (segmentLookup segs name1) (linLayoutEnv segs lin) - lineariseLayout :: Layout withPre segments env -> LinLayout withPre segments env lineariseLayout (LSeg name :: Layout _ _ seg) | Refl <- lemAppendNil @seg @@ -171,7 +166,7 @@ sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail -- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and -- wCopies the name2 on top of that. wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w) -sortLinLayouts segs LinEnd lin2@LinApp{} = WClosed (linLayoutEnv segs lin2) +sortLinLayouts _ LinEnd LinApp{} = WClosed sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target" autoWeak :: forall segments env1 env2. diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 8e84378..3587378 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -3,6 +3,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} @@ -19,46 +20,9 @@ module ForwardAD.DualNumbers ( import AST import Data +import ForwardAD.DualNumbers.Types --- | Dual-numbers transformation -type family DN t where - DN TNil = TNil - DN (TPair a b) = TPair (DN a) (DN b) - DN (TEither a b) = TEither (DN a) (DN b) - DN (TMaybe t) = TMaybe (DN t) - DN (TArr n t) = TArr n (DN t) - DN (TScal t) = DNS t - -type family DNS t where - DNS TF32 = TPair (TScal TF32) (TScal TF32) - DNS TF64 = TPair (TScal TF64) (TScal TF64) - DNS TI32 = TScal TI32 - DNS TI64 = TScal TI64 - DNS TBool = TScal TBool - -type family DNE env where - DNE '[] = '[] - DNE (t : ts) = DN t : DNE ts - -dn :: STy t -> STy (DN t) -dn STNil = STNil -dn (STPair a b) = STPair (dn a) (dn b) -dn (STEither a b) = STEither (dn a) (dn b) -dn (STMaybe t) = STMaybe (dn t) -dn (STArr n t) = STArr n (dn t) -dn (STScal t) = case t of - STF32 -> STPair (STScal STF32) (STScal STF32) - STF64 -> STPair (STScal STF64) (STScal STF64) - STI32 -> STScal STI32 - STI64 -> STScal STI64 - STBool -> STScal STBool -dn STAccum{} = error "Accum in source program" - -dne :: SList STy env -> SList STy (DNE env) -dne SNil = SNil -dne (t `SCons` env) = dn t `SCons` dne env - dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) dnPreservesTupIx SZ = Refl dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl @@ -177,8 +141,14 @@ dfwdDN = \case , Refl <- dnPreservesTupIx n -> EIdx ext (dfwdDN a) (dfwdDN b) EShape _ e - | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e) + | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) + -> EShape ext (dfwdDN e) EOp _ op e -> dop op (dfwdDN e) + ECustom _ s t _ du _ e1 e2 -> + -- TODO: we need a bit of codegen here that projects the primals out from the dual number result of e1. Note that a non-differentiating code transformation does not eliminate the need for this, because then the need just shifts to free variable adaptor code. + ELet ext (_ e1) $ + ELet ext (weakenExpr WSink (dfwdDN e2)) $ + weakenExpr (WCopy (WCopy WClosed)) du EError t s -> EError (dn t) s EWith{} -> err_accum diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs new file mode 100644 index 0000000..fba92d0 --- /dev/null +++ b/src/ForwardAD/DualNumbers/Types.hs @@ -0,0 +1,46 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module ForwardAD.DualNumbers.Types where + +import AST.Types +import Data + + +-- | Dual-numbers transformation +type family DN t where + DN TNil = TNil + DN (TPair a b) = TPair (DN a) (DN b) + DN (TEither a b) = TEither (DN a) (DN b) + DN (TMaybe t) = TMaybe (DN t) + DN (TArr n t) = TArr n (DN t) + DN (TScal t) = DNS t + +type family DNS t where + DNS TF32 = TPair (TScal TF32) (TScal TF32) + DNS TF64 = TPair (TScal TF64) (TScal TF64) + DNS TI32 = TScal TI32 + DNS TI64 = TScal TI64 + DNS TBool = TScal TBool + +type family DNE env where + DNE '[] = '[] + DNE (t : ts) = DN t : DNE ts + +dn :: STy t -> STy (DN t) +dn STNil = STNil +dn (STPair a b) = STPair (dn a) (dn b) +dn (STEither a b) = STEither (dn a) (dn b) +dn (STMaybe t) = STMaybe (dn t) +dn (STArr n t) = STArr n (dn t) +dn (STScal t) = case t of + STF32 -> STPair (STScal STF32) (STScal STF32) + STF64 -> STPair (STScal STF64) (STScal STF64) + STI32 -> STScal STI32 + STI64 -> STScal STI64 + STBool -> STScal STBool +dn STAccum{} = error "Accum in source program" + +dne :: SList STy env -> SList STy (DNE env) +dne SNil = SNil +dne (t `SCons` env) = dn t `SCons` dne env diff --git a/src/Simplify.hs b/src/Simplify.hs index 66a4004..d3ee03c 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -121,6 +121,11 @@ simplify' = \case EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b EShape _ e -> EShape ext <$> simplify' e EOp _ op e -> EOp ext op <$> simplify' e + ECustom _ s t a b c e1 e2 -> + ECustom ext s t <$> (let ?accumInScope = False in simplify' a) + <*> (let ?accumInScope = False in simplify' b) + <*> (let ?accumInScope = False in simplify' c) + <*> simplify' e1 <*> simplify' e2 EWith e1 e2 -> EWith <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) EAccum i e1 e2 e3 -> EAccum i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3 EZero t -> pure $ EZero t @@ -160,6 +165,7 @@ hasAdds = \case ESum1Inner _ e -> hasAdds e EUnit _ e -> hasAdds e EReplicate1Inner _ a b -> hasAdds a || hasAdds b + ECustom _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e EConst _ _ _ -> False EIdx0 _ e -> hasAdds e EIdx1 _ a b -> hasAdds a || hasAdds b |