diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-11-14 19:27:57 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-14 19:27:57 +0100 | 
| commit | b8c162ce9cb1faeec621b751fff9aff46e022417 (patch) | |
| tree | 9c31700f34f9a1f1a67e0a73c880938130e87ee6 /src | |
| parent | bb84f6930702a02ba982795e2bb95a64d61f672b (diff) | |
Configuration for CHAD
Diffstat (limited to 'src')
| -rw-r--r-- | src/CHAD.hs | 65 | ||||
| -rw-r--r-- | src/CHAD/Top.hs | 13 | ||||
| -rw-r--r-- | src/Example.hs | 2 | 
3 files changed, 72 insertions, 8 deletions
| diff --git a/src/CHAD.hs b/src/CHAD.hs index 45fcc82..b35836a 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -2,6 +2,7 @@  {-# LANGUAGE EmptyCase #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ImplicitParams #-}  {-# LANGUAGE OverloadedLabels #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE QuantifiedConstraints #-} @@ -23,6 +24,8 @@  module CHAD (    drev,    freezeRet, +  CHADConfig(..), +  defaultConfig,    Storage(..),    Descr(..),    Select, @@ -724,10 +727,27 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =            expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) +--------------------------------- CONFIGURATION -------------------------------- + +data CHADConfig = CHADConfig +  { -- | D[let] will bind variables containing arrays in accumulator mode. +    chcLetArrayAccum :: Bool +  , -- | D[case] will bind variables containing arrays in accumulator mode. +    chcCaseArrayAccum :: Bool +  } + +defaultConfig :: CHADConfig +defaultConfig = CHADConfig +  { chcLetArrayAccum = False +  , chcCaseArrayAccum = False +  } + +  ---------------------------- THE CHAD TRANSFORMATION ---------------------------  drev :: forall env sto t. -        Descr env sto +        (?config :: CHADConfig) +     => Descr env sto       -> Ex env t -> Ret env sto t  drev des = \case    EVar _ t i -> @@ -753,6 +773,38 @@ drev des = \case              (subenvNone (select SMerge des))              (ENil ext) +  ELet _ (rhs :: Ex _ a) body +    | chcLetArrayAccum ?config && hasArrays (typeOf rhs) +    , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 +        <- drev des rhs +    , Ret (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 +        <- drev (des `DPush` (typeOf rhs, SAccum)) body +    , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 +    , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) +    , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> +    subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> +    let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in +    Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') +        (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody) +        (weakenExpr wbody0' body1) +        subBoth +        (ELet ext +           (EWith (EZero (typeOf rhs)) $ +              weakenExpr (autoWeak (#d (auto1 @(D2 t)) +                                    &. #body (subList (bindingsBinds body0) subtapeBody) +                                    &. #ac (auto1 @(TAccum (D2 a))) +                                    &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) +                                    &. #tl (d2ace (select SAccum des))) +                                   (#d :++: #body :++: #ac :++: #tl) +                                   (#ac :++: #d :++: (#body :++: #rhs) :++: #tl)) +                         body2) $ +         ELet ext +           (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ +             weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ +         plus_RHS_Body +           (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) +           (EFst ext (EVar ext bodyResType (IS IZ)))) +    ELet _ rhs body      | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2          <- drev des rhs @@ -848,6 +900,8 @@ drev des = \case                (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")                (weakenExpr (WCopy (wSinks' @[_,_])) e2))) +  ECase{} | chcCaseArrayAccum ?config -> error "chcCaseArrayAccum unsupported" +    ECase _ e (a :: Ex _ t) b      | STEither t1 t2 <- typeOf e      , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des e @@ -1187,3 +1241,12 @@ drev des = \case                           (EIdx ext (EVar ext (d2 at') (IS (IS IZ))) (EFst ext (EVar ext tIxN (IS IZ))))                           (EZero t)) $              weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + +    hasArrays :: STy t' -> Bool +    hasArrays STNil = False +    hasArrays (STPair a b) = hasArrays a || hasArrays b +    hasArrays (STEither a b) = hasArrays a || hasArrays b +    hasArrays (STMaybe t) = hasArrays t +    hasArrays STArr{} = True +    hasArrays STScal{} = False +    hasArrays STAccum{} = error "Accumulators not allowed in source program" diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 9df5412..4b2a844 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -1,5 +1,6 @@  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ImplicitParams #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-} @@ -41,13 +42,13 @@ d1eIdentity :: SList STy env -> D1E env :~: env  d1eIdentity SNil = Refl  d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl -chad :: SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) -chad env term +chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) +chad config env term    | Refl <- mergeEnvNoAccum env    , Refl <- mergeEnvOnlyMerge env -  = freezeRet (mergeDescr env) (drev (mergeDescr env) term) +  = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) term) -chad' :: SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) -chad' env term +chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +chad' config env term    | Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term) -  = chad env term +  = chad config env term diff --git a/src/Example.hs b/src/Example.hs index 390031e..94a6934 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -182,7 +182,7 @@ neuralGo =        revderiv =          simplifyN 20 $            ELet ext (EConst ext STF64 1.0) $ -            chad knownEnv neural +            chad defaultConfig knownEnv neural        (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False argument revderiv of          (primal', (((((), Right dlay1_1'), Right dlay2_1'), dlay3_1'), dinput_1')) -> (primal', dlay1_1', dlay2_1', dlay3_1', dinput_1')          _ -> undefined | 
