diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-10-01 23:18:15 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-10-01 23:18:15 +0200 | 
| commit | 948cae3ca7279040627db393e4372a668f8a22f7 (patch) | |
| tree | 89eae02aeba1f0bdc30a938c82dc3dfef06cd4af | |
| parent | 1f13bc80915a26473e0622c4afa65c8276b396ff (diff) | |
Reverse-by-forward, and checking neural (it's wrong)
| -rw-r--r-- | chad-fast.cabal | 4 | ||||
| -rw-r--r-- | src/AST/Bindings.hs | 64 | ||||
| -rw-r--r-- | src/Array.hs | 3 | ||||
| -rw-r--r-- | src/CHAD.hs | 61 | ||||
| -rw-r--r-- | src/Example.hs | 32 | ||||
| -rw-r--r-- | src/Example/Format.hs | 9 | ||||
| -rw-r--r-- | src/ForwardAD.hs | 347 | ||||
| -rw-r--r-- | src/ForwardAD/DualNumbers.hs | 206 | ||||
| -rw-r--r-- | src/Interpreter.hs | 16 | ||||
| -rw-r--r-- | src/Interpreter/Rep.hs | 3 | 
10 files changed, 485 insertions, 260 deletions
| diff --git a/chad-fast.cabal b/chad-fast.cabal index 6873efd..ef9f642 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -12,6 +12,7 @@ library    exposed-modules:      Array      AST +    AST.Bindings      AST.Count      AST.Env      AST.Pretty @@ -23,7 +24,9 @@ library      -- Compile      Data      Example +    Example.Format      ForwardAD +    ForwardAD.DualNumbers      Interpreter      -- Interpreter.AccumOld      Interpreter.Rep @@ -37,6 +40,7 @@ library      base >= 4.19 && < 4.21,      containers,      -- template-haskell, +    process,      transformers,      vector,    hs-source-dirs: src diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs new file mode 100644 index 0000000..2e63b42 --- /dev/null +++ b/src/AST/Bindings.hs @@ -0,0 +1,64 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} + +-- I want to bring various type variables in scope using type annotations in +-- patterns, but I don't want to have to mention all the other type parameters +-- of the types in question as well then. Partial type signatures (with '_') are +-- useful here. +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS -Wno-partial-type-signatures #-} +module AST.Bindings where + +import AST +import Data +import Lemmas + + +-- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'. +data Bindings f env binds where +  BTop :: Bindings f env '[] +  BPush :: Bindings f env binds -> (STy t, f (Append binds env) t) -> Bindings f env (t : binds) +deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') +infixl `BPush` + +mapBindings :: (forall env' t'. f env' t' -> g env' t') +            -> Bindings f env binds -> Bindings g env binds +mapBindings _ BTop = BTop +mapBindings f (BPush b (t, e)) = BPush (mapBindings f b) (t, f e) + +weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) +               -> env1 :> env2 +               -> Bindings f env1 binds +               -> (Bindings f env2 binds, Append binds env1 :> Append binds env2) +weakenBindings _ w BTop = (BTop, w) +weakenBindings wf w (BPush b (t, x)) = +  let (b', w') = weakenBindings wf w b +  in (BPush b' (t, wf w' x), WCopy w') + +weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env' +weakenOver SNil w = w +weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) + +sinkWithBindings :: Bindings f env binds -> env' :> Append binds env' +sinkWithBindings BTop = WId +sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b + +bconcat :: forall f env binds1 binds2. Bindings f env binds1 -> Bindings f (Append binds1 env) binds2 -> Bindings f env (Append binds2 binds1) +bconcat b1 BTop = b1 +bconcat b1 (BPush (b2 :: Bindings _ (Append binds1 env) binds2C) (t, x)) +  | Refl <- lemAppendAssoc @binds2C @binds1 @env +  = BPush (bconcat b1 b2) (t, x) + +bindingsBinds :: Bindings f env binds -> SList STy binds +bindingsBinds BTop = SNil +bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds) + +letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t +letBinds BTop = id +letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs diff --git a/src/Array.hs b/src/Array.hs index c939419..8507544 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -45,6 +45,9 @@ emptyShape :: SNat n -> Shape n  emptyShape SZ = ShNil  emptyShape (SS m) = emptyShape m `ShCons` 0 +enumShape :: Shape n -> [Index n] +enumShape sh = map (fromLinearIndex sh) [0 .. shapeSize sh - 1] +  -- | TODO: this Vector is a boxed vector, which is horrendously inefficient.  data Array (n :: Nat) t = Array (Shape n) (Vector t) diff --git a/src/CHAD.hs b/src/CHAD.hs index 12d28e2..bcc1485 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -34,6 +34,7 @@ import GHC.Stack (HasCallStack)  import GHC.TypeLits (Symbol)  import AST +import AST.Bindings  import AST.Count  import AST.Env  import AST.Weaken.Auto @@ -42,66 +43,10 @@ import Data  import Lemmas --- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'. -data Bindings f env binds where -  BTop :: Bindings f env '[] -  BPush :: Bindings f env binds -> (STy t, f (Append binds env) t) -> Bindings f env (t : binds) -deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') -infixl `BPush` - -mapBindings :: (forall env' t'. f env' t' -> g env' t') -            -> Bindings f env binds -> Bindings g env binds -mapBindings _ BTop = BTop -mapBindings f (BPush b (t, e)) = BPush (mapBindings f b) (t, f e) - -weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) -               -> env1 :> env2 -               -> Bindings f env1 binds -               -> (Bindings f env2 binds, Append binds env1 :> Append binds env2) -weakenBindings _ w BTop = (BTop, w) -weakenBindings wf w (BPush b (t, x)) = -  let (b', w') = weakenBindings wf w b -  in (BPush b' (t, wf w' x), WCopy w') - -weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env' -weakenOver SNil w = w -weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) - -sinkWithBindings :: Bindings f env binds -> env' :> Append binds env' -sinkWithBindings BTop = WId -sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b - -bconcat :: forall f env binds1 binds2. Bindings f env binds1 -> Bindings f (Append binds1 env) binds2 -> Bindings f env (Append binds2 binds1) -bconcat b1 BTop = b1 -bconcat b1 (BPush (b2 :: Bindings _ (Append binds1 env) binds2C) (t, x)) -  | Refl <- lemAppendAssoc @binds2C @binds1 @env -  = BPush (bconcat b1 b2) (t, x) - --- bconcat' :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) ---          -> Bindings f env env1 -> Bindings f env env2 ---          -> (forall env12. Bindings f env env12 -> r) -> r --- bconcat' wf b1 b2 k = weakenBindings wf (sinkWithBindings b1) b2 $ \b2' _ -> k (bconcat b1 b2') - --- bsnoc :: (forall env1 env2 t'. env1 :> env2 -> f env1 t' -> f env2 t') ---       -> STy t -> f env t -> Bindings f env binds -> (Bindings f env (Snoc binds t), Append binds env :> Append (Snoc binds t) env) --- bsnoc _ t x BTop = (BPush BTop (t, x), WSink) --- bsnoc wf t x (BPush b (t', y)) = ---   let (b', w) = bsnoc wf t x b ---   in (BPush b' (t', wf w y), WCopy w) -  type family Tape binds where    Tape '[] = TNil    Tape (t : ts) = TPair t (Tape ts) --- data TupBinds f env binds = ---   TupBinds (SList STy binds) ---            (forall env2. Append binds env :> env2 -> Ex env2 (Tape binds)) ---            (forall env2. Idx env2 (Tape binds) -> Bindings f env2 binds) - -bindingsBinds :: Bindings f env binds -> SList STy binds -bindingsBinds BTop = SNil -bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds) -  tapeTy :: SList STy binds -> STy (Tape binds)  tapeTy SNil = STNil  tapeTy (SCons t ts) = STPair t (tapeTy ts) @@ -240,10 +185,6 @@ reconstructBindings binds tape =               (bconcat (mapBindings fromUnfExpr unf) build)       ,sreverse (stapeUnfoldings binds)) -letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t -letBinds BTop = id -letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs -  type family Vectorise n list where    Vectorise _ '[] = '[]    Vectorise n (t : ts) = TArr n t : Vectorise n ts diff --git a/src/Example.hs b/src/Example.hs index e2f1be9..6701e38 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -11,10 +11,14 @@ import AST  import AST.Pretty  import CHAD  import Data +import ForwardAD  import Interpreter  import Language  import Simplify +import Debug.Trace +import Example.Format +  -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) @@ -175,18 +179,26 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #       let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $       #x3 ! nil -neuralGo :: (Float -            ,(((((), Either () (Array N2 Float, Array N1 Float)) -               ,Either () (Array N2 Float, Array N1 Float)) -              ,Array N1 Float) -             ,Array N1 Float)) +type NeuralGrad = ((Array N2 Float, Array N1 Float) +                  ,(Array N2 Float, Array N1 Float) +                  ,Array N1 Float +                  ,Array N1 Float) + +neuralGo :: (Float  -- primal +            ,NeuralGrad  -- gradient using CHAD +            ,NeuralGrad)  -- gradient using dual-numbers forward AD  neuralGo =    let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])        lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])        lay3 = arrayFromList (ShNil `ShCons` 2) [1,1]        input = arrayFromList (ShNil `ShCons` 2) [1,1] -  in interpretOpen (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) $ -     simplifyN 20 $ -     freezeRet mergeDescr -       (drev mergeDescr neural) -       (EConst ext STF32 1.0) +      argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) +      revderiv = +        simplifyN 20 $ +        freezeRet mergeDescr +          (drev mergeDescr neural) +          (EConst ext STF32 1.0) +      (primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen argument revderiv +      (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0 +  in trace (formatter (ppExpr knownEnv revderiv)) $ +     (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) diff --git a/src/Example/Format.hs b/src/Example/Format.hs new file mode 100644 index 0000000..994f431 --- /dev/null +++ b/src/Example/Format.hs @@ -0,0 +1,9 @@ +module Example.Format where + +import System.IO.Unsafe +import System.Process + + +{-# NOINLINE formatter #-} +formatter :: String -> String +formatter str = unsafePerformIO $ readProcess "hindent" ["--line-length", "100"] str diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 0a9e12c..63244a8 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -1,202 +1,189 @@  {-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-} +module ForwardAD where --- I want to bring various type variables in scope using type annotations in --- patterns, but I don't want to have to mention all the other type parameters --- of the types in question as well then. Partial type signatures (with '_') are --- useful here. -{-# LANGUAGE PartialTypeSignatures #-} -{-# OPTIONS -Wno-partial-type-signatures #-} -module ForwardAD ( -  dfwd, -  FD, FDS, FDE, fd, -) where +import Data.Bifunctor (bimap) +-- import Data.Foldable (toList) +import Array  import AST +-- import AST.Bindings  import Data +import ForwardAD.DualNumbers +import Interpreter +import Interpreter.Rep --- | Dual-numbers transformation -type family FD t where -  FD TNil = TNil -  FD (TPair a b) = TPair (FD a) (FD b) -  FD (TEither a b) = TEither (FD a) (FD b) -  FD (TMaybe t) = TMaybe (FD t) -  FD (TArr n t) = TArr n (FD t) -  FD (TScal t) = FDS t +-- | Tangent along a type (coincides with cotangent for these types) +type family Tan t where +  Tan TNil = TNil +  Tan (TPair a b) = TPair (Tan a) (Tan b) +  Tan (TEither a b) = TEither (Tan a) (Tan b) +  Tan (TMaybe t) = TMaybe (Tan t) +  Tan (TArr n t) = TArr n (Tan t) +  Tan (TScal t) = TanS t -type family FDS t where -  FDS TF32 = TPair (TScal TF32) (TScal TF32) -  FDS TF64 = TPair (TScal TF64) (TScal TF64) -  FDS TI32 = TScal TI32 -  FDS TI64 = TScal TI64 -  FDS TBool = TScal TBool +type family TanS t where +  TanS TI32 = TNil +  TanS TI64 = TNil +  TanS TF32 = TScal TF32 +  TanS TF64 = TScal TF64 +  TanS TBool = TNil -type family FDE env where -  FDE '[] = '[] -  FDE (t : ts) = FD t : FDE ts +type family TanE env where +  TanE '[] = '[] +  TanE (t : env) = Tan t : TanE env -fd :: STy t -> STy (FD t) -fd STNil = STNil -fd (STPair a b) = STPair (fd a) (fd b) -fd (STEither a b) = STEither (fd a) (fd b) -fd (STMaybe t) = STMaybe (fd t) -fd (STArr n t) = STArr n (fd t) -fd (STScal t) = case t of -  STF32 -> STPair (STScal STF32) (STScal STF32) -  STF64 -> STPair (STScal STF64) (STScal STF64) -  STI32 -> STScal STI32 -  STI64 -> STScal STI64 -  STBool -> STScal STBool -fd STAccum{} = error "Accum in source program" +tanty :: STy t -> STy (Tan t) +tanty STNil = STNil +tanty (STPair a b) = STPair (tanty a) (tanty b) +tanty (STEither a b) = STEither (tanty a) (tanty b) +tanty (STMaybe t) = STMaybe (tanty t) +tanty (STArr n t) = STArr n (tanty t) +tanty (STScal t) = case t of +  STI32 -> STNil +  STI64 -> STNil +  STF32 -> STScal STF32 +  STF64 -> STScal STF64 +  STBool -> STNil +tanty STAccum{} = error "Accumulators not allowed in input program" -fdPreservesTupIx :: SNat n -> FD (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) -fdPreservesTupIx SZ = Refl -fdPreservesTupIx (SS n) | Refl <- fdPreservesTupIx n = Refl +unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t)) +unzipDN STNil _ = ((), ()) +unzipDN (STPair a b) (d1, d2) = +  let (x, dx) = unzipDN a d1 +      (y, dy) = unzipDN b d2 +  in ((x, y), (dx, dy)) +unzipDN (STEither a b) d = case d of +  Left d1 -> bimap Left Left (unzipDN a d1) +  Right d2 -> bimap Right Right (unzipDN b d2) +unzipDN (STMaybe t) d = case d of +  Nothing -> (Nothing, Nothing) +  Just d' -> bimap Just Just (unzipDN t d') +unzipDN (STArr _ t) d =  +  let pairs = arrayMap (unzipDN t) d +  in (arrayMap fst pairs, arrayMap snd pairs) +unzipDN (STScal ty) d = case ty of +  STI32 -> (d, ()) +  STI64 -> (d, ()) +  STF32 -> d +  STF64 -> d +  STBool -> (d, ()) +unzipDN STAccum{} _ = error "Accumulators not allowed in input program" -convIdx :: Idx env t -> Idx (FDE env) (FD t) -convIdx IZ = IZ -convIdx (IS i) = IS (convIdx i) +dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double +dotprodTan STNil _ _ = 0.0 +dotprodTan (STPair a b) (x, y) (x', y') = +  dotprodTan a x x' + dotprodTan b y y' +dotprodTan (STEither a b) x y = case (x, y) of +  (Left x', Left y') -> dotprodTan a x' y' +  (Right x', Right y') -> dotprodTan b x' y' +  _ -> error "dotprodTan: incompatible Either alternatives" +dotprodTan (STMaybe t) x y = case (x, y) of +  (Nothing, Nothing) -> 0.0 +  (Just x', Just y') -> dotprodTan t x' y' +  _ -> error "dotprodTan: incompatible Maybe alternatives" +dotprodTan (STArr _ t) x y =  +  let sh1 = arrayShape x +      sh2 = arrayShape y +  in if | shapeSize sh1 == 0 || shapeSize sh2 == 0 -> 0.0 +        | sh1 == sh2 -> sum [dotprodTan t (arrayIndex x i) (arrayIndex y i) | i <- enumShape sh1] +        | otherwise -> error "dotprodTan: incompatible array shapes" +dotprodTan (STScal ty) x y = case ty of +  STI32 -> 0.0 +  STI64 -> 0.0 +  STF32 -> realToFrac @Float @Double (x * y) +  STF64 -> x * y +  STBool -> 0.0 +dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" -scalTyCase :: SScalTy t -           -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), FD (TScal t) ~ TPair (TScal t) (TScal t)) => r) -           -> (FD (TScal t) ~ TScal t => r) -           -> r -scalTyCase STF32  k1 _ = k1 -scalTyCase STF64  k1 _ = k1 -scalTyCase STI32  _ k2 = k2 -scalTyCase STI64  _ k2 = k2 -scalTyCase STBool _ k2 = k2 +-- -- Primal expression must be duplicable +-- dnConstE :: STy t -> Ex env t -> Ex env (DN t) +-- dnConstE STNil _ = ENil ext +-- dnConstE (STPair t1 t2) e = +--   -- This creates fst/snd stacks of unbounded size, but let's not care here +--   EPair ext (dnConstE t1 (EFst ext e)) (dnConstE t2 (ESnd ext e)) +-- dnConstE (STEither t1 t2) e = +--   ECase ext e +--     (EInl ext (dn t2) (dnConstE t1 (EVar ext t1 IZ))) +--     (EInr ext (dn t1) (dnConstE t2 (EVar ext t2 IZ))) +-- dnConstE (STMaybe t) e = +--   EMaybe ext (ENothing ext (dn t)) (EJust ext (dnConstE t (EVar ext t IZ))) e +-- dnConstE (STArr n t) e = +--   EBuild ext n (EShape ext e) +--     (dnConstE t (EIdx ext n (weakenExpr WSink e) (EVar ext (tTup (sreplicate n tIx)) IZ))) +-- dnConstE (STScal t) e = case t of +--   STI32 -> e +--   STI64 -> e +--   STF32 -> EPair ext e (EConst ext STF32 0.0) +--   STF64 -> EPair ext e (EConst ext STF64 0.0) +--   STBool -> e +-- dnConstE STAccum{} _ = error "Accumulators not allowed in input program" --- | Argument does not need to be duplicable. -dop :: forall a b env. SOp a b -> Ex env (FD a) -> Ex env (FD b) -dop = \case -  OAdd t -> scalTyCase t -    (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy))) -    (EOp ext (OAdd t)) -  OMul t -> scalTyCase t -    (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x)))) -    (EOp ext (OMul t)) -  ONeg t -> scalTyCase t -    (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx))) -    (EOp ext (ONeg t)) -  OLt t -> scalTyCase t -    (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y))) -    (EOp ext (OLt t)) -  OLe t -> scalTyCase t -    (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y))) -    (EOp ext (OLe t)) -  OEq t -> scalTyCase t -    (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y))) -    (EOp ext (OEq t)) -  ONot -> EOp ext ONot -  OIf -> EOp ext OIf -  where -    add :: ScalIsNumeric t ~ True -        => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) -    add t a b = EOp ext (OAdd t) (EPair ext a b) +dnConst :: STy t -> Rep t -> Rep (DN t) +dnConst STNil = const () +dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STMaybe t) = fmap (dnConst t) +dnConst (STArr _ t) = arrayMap (dnConst t) +dnConst (STScal t) = case t of +  STI32 -> id +  STI64 -> id +  STF32 -> (,0.0) +  STF64 -> (,0.0) +  STBool -> id +dnConst STAccum{} = error "Accumulators not allowed in input program" -    mul :: ScalIsNumeric t ~ True -        => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) -    mul t a b = EOp ext (OMul t) (EPair ext a b) +-- | Given a function that computes the forward derivative for a particular +-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this +-- @t@ input. +type RevByFwd t = (Rep (DN t) -> Double) -> Rep (Tan t) -    neg :: ScalIsNumeric t ~ True -        => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -    neg t = EOp ext (ONeg t) +dnOnehots :: STy t -> Rep t -> RevByFwd t +dnOnehots STNil _ = \_ -> () +dnOnehots (STPair t1 t2) (x, y) = +  \f -> (dnOnehots t1 x (f . (,dnConst t2 y)), dnOnehots t2 y (f . (dnConst t1 x,))) +dnOnehots (STEither t1 t2) e = +  case e of +    Left x -> \f -> Left (dnOnehots t1 x (f . Left)) +    Right y -> \f -> Right (dnOnehots t2 y (f . Right)) +dnOnehots (STMaybe t) m = +  case m of +    Nothing -> \_ -> Nothing +    Just x -> \f -> Just (dnOnehots t x (f . Just)) +dnOnehots (STArr _ t) a = +  \f -> +    arrayGenerate (arrayShape a) $ \idx -> +      dnOnehots t (arrayIndex a idx) (f . (\oh -> arrayGenerate (arrayShape a) $ \i -> +                                                    if i == idx then oh else dnConst t (arrayIndex a i))) +dnOnehots (STScal t) x = case t of +  STI32 -> \_ -> () +  STI64 -> \_ -> () +  STF32 -> \f -> realToFrac @Double @Float $ f (x, 1.0) +  STF64 -> \f -> f (x, 1.0) +  STBool -> \_ -> () +dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" -    unFloat :: FD a ~ TPair a a -            => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (FD b)) -            -> Ex env (FD a) -> Ex env (FD b) -    unFloat f e = -      ELet ext e $ -        let var = EVar ext (typeOf e) IZ -        in f (EFst ext var, ESnd ext var) +dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) +dnConstEnv SNil SNil = SNil +dnConstEnv (t `SCons` env) (Value x `SCons` val) = Value (dnConst t x) `SCons` dnConstEnv env val -    binFloat :: (a ~ TPair s s, FD s ~ TPair s s) -             => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (FD b)) -             -> Ex env (FD a) -> Ex env (FD b) -    binFloat f e = -      ELet ext e $ -        let var = EVar ext (typeOf e) IZ -        in f (EFst ext (EFst ext var), ESnd ext (EFst ext var)) -             (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) +type RevByFwdEnv env = (SList Value (DNE env) -> Double) -> SList Value (TanE env) -dfwd :: Ex env t -> Ex (FDE env) (FD t) -dfwd = \case -  EVar _ t i -> EVar ext (fd t) (convIdx i) -  ELet _ a b -> ELet ext (dfwd a) (dfwd b) -  EPair _ a b -> EPair ext (dfwd a) (dfwd b) -  EFst _ e -> EFst ext (dfwd e) -  ESnd _ e -> ESnd ext (dfwd e) -  ENil _ -> ENil ext -  EInl _ t e -> EInl ext (fd t) (dfwd e) -  EInr _ t e -> EInr ext (fd t) (dfwd e) -  ECase _ e a b -> ECase ext (dfwd e) (dfwd a) (dfwd b) -  ENothing _ t -> ENothing ext (fd t) -  EJust _ e -> EJust ext (dfwd e) -  EMaybe _ e a b -> EMaybe ext (dfwd e) (dfwd a) (dfwd b) -  EConstArr _ n t x -> scalTyCase t -    (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) -          (EConstArr ext n t x)) -    (EConstArr ext n t x) -  EBuild1 _ a b -> EBuild1 ext (dfwd a) (dfwd b) -  EBuild _ n a b -    | Refl <- fdPreservesTupIx n -> EBuild ext n (dfwd a) (dfwd b) -  EFold1Inner _ a b -> EFold1Inner ext (dfwd a) (dfwd b) -  ESum1Inner _ e -> -    let STArr n (STScal t) = typeOf e -        pairty = (STPair (STScal t) (STScal t)) -    in scalTyCase t -         (ELet ext (dfwd e) $ -            ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) -                                       (EVar ext (STArr n pairty) IZ))) -                 (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ)) -                                       (EVar ext (STArr n pairty) IZ)))) -         (ESum1Inner ext (dfwd e)) -  EUnit _ e -> EUnit ext (dfwd e) -  EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwd a) (dfwd b) -  EConst _ t x -> scalTyCase t -    (EPair ext (EConst ext t x) (EConst ext t 0.0)) -    (EConst ext t x) -  EIdx0 _ e -> EIdx0 ext (dfwd e) -  EIdx1 _ a b -> EIdx1 ext (dfwd a) (dfwd b) -  EIdx _ n a b -    | Refl <- fdPreservesTupIx n -> EIdx ext n (dfwd a) (dfwd b) -  EShape _ e -    | Refl <- fdPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwd e) -  EOp _ op e -> dop op (dfwd e) -  EError t s -> EError (fd t) s +dnOnehotEnvs :: SList STy env -> SList Value env -> RevByFwdEnv env +dnOnehotEnvs SNil SNil = \_ -> SNil +dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) = +  \f -> +    Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val))) +    `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh)) -  EWith{} -> err_accum -  EAccum{} -> err_accum -  EZero{} -> err_monoid -  EPlus{} -> err_monoid -  where -    err_accum = error "Accumulator operations unsupported in the source program" -    err_monoid = error "Monoid operations unsupported in the source program" - -emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) -emap f arr = -  let STArr n t = typeOf arr -  in ELet ext arr $ -       EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ -         ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ)) -                              (EVar ext (tTup (sreplicate n tIx)) IZ)) $ -           weakenExpr (WCopy (WSink .> WSink)) f - -ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip a b = -  let STArr n t1 = typeOf a -      STArr _ t2 = typeOf b -  in ELet ext a $ -     ELet ext (weakenExpr WSink b) $ -       EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ -         EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ))) -                               (EVar ext (tTup (sreplicate n tIx)) IZ)) -                   (EIdx ext n (EVar ext (STArr n t2) (IS IZ)) -                               (EVar ext (tTup (sreplicate n tIx)) IZ)) +drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) +drevByFwd env expr input dres = +  let outty = typeOf expr +  in dnOnehotEnvs env input $ \dnInput -> +       let (_, outtan) = unzipDN outty (interpretOpen dnInput (dfwdDN expr)) +       in dotprodTan outty outtan dres diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs new file mode 100644 index 0000000..f9239e9 --- /dev/null +++ b/src/ForwardAD/DualNumbers.hs @@ -0,0 +1,206 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +-- I want to bring various type variables in scope using type annotations in +-- patterns, but I don't want to have to mention all the other type parameters +-- of the types in question as well then. Partial type signatures (with '_') are +-- useful here. +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS -Wno-partial-type-signatures #-} +module ForwardAD.DualNumbers ( +  dfwdDN, +  DN, DNS, DNE, dn, dne, +) where + +import AST +import Data + + +-- | Dual-numbers transformation +type family DN t where +  DN TNil = TNil +  DN (TPair a b) = TPair (DN a) (DN b) +  DN (TEither a b) = TEither (DN a) (DN b) +  DN (TMaybe t) = TMaybe (DN t) +  DN (TArr n t) = TArr n (DN t) +  DN (TScal t) = DNS t + +type family DNS t where +  DNS TF32 = TPair (TScal TF32) (TScal TF32) +  DNS TF64 = TPair (TScal TF64) (TScal TF64) +  DNS TI32 = TScal TI32 +  DNS TI64 = TScal TI64 +  DNS TBool = TScal TBool + +type family DNE env where +  DNE '[] = '[] +  DNE (t : ts) = DN t : DNE ts + +dn :: STy t -> STy (DN t) +dn STNil = STNil +dn (STPair a b) = STPair (dn a) (dn b) +dn (STEither a b) = STEither (dn a) (dn b) +dn (STMaybe t) = STMaybe (dn t) +dn (STArr n t) = STArr n (dn t) +dn (STScal t) = case t of +  STF32 -> STPair (STScal STF32) (STScal STF32) +  STF64 -> STPair (STScal STF64) (STScal STF64) +  STI32 -> STScal STI32 +  STI64 -> STScal STI64 +  STBool -> STScal STBool +dn STAccum{} = error "Accum in source program" + +dne :: SList STy env -> SList STy (DNE env) +dne SNil = SNil +dne (t `SCons` env) = dn t `SCons` dne env + +dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) +dnPreservesTupIx SZ = Refl +dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl + +convIdx :: Idx env t -> Idx (DNE env) (DN t) +convIdx IZ = IZ +convIdx (IS i) = IS (convIdx i) + +scalTyCase :: SScalTy t +           -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r) +           -> (DN (TScal t) ~ TScal t => r) +           -> r +scalTyCase STF32  k1 _ = k1 +scalTyCase STF64  k1 _ = k1 +scalTyCase STI32  _ k2 = k2 +scalTyCase STI64  _ k2 = k2 +scalTyCase STBool _ k2 = k2 + +-- | Argument does not need to be duplicable. +dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b) +dop = \case +  OAdd t -> scalTyCase t +    (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy))) +    (EOp ext (OAdd t)) +  OMul t -> scalTyCase t +    (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x)))) +    (EOp ext (OMul t)) +  ONeg t -> scalTyCase t +    (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx))) +    (EOp ext (ONeg t)) +  OLt t -> scalTyCase t +    (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y))) +    (EOp ext (OLt t)) +  OLe t -> scalTyCase t +    (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y))) +    (EOp ext (OLe t)) +  OEq t -> scalTyCase t +    (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y))) +    (EOp ext (OEq t)) +  ONot -> EOp ext ONot +  OIf -> EOp ext OIf +  where +    add :: ScalIsNumeric t ~ True +        => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) +    add t a b = EOp ext (OAdd t) (EPair ext a b) + +    mul :: ScalIsNumeric t ~ True +        => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) +    mul t a b = EOp ext (OMul t) (EPair ext a b) + +    neg :: ScalIsNumeric t ~ True +        => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) +    neg t = EOp ext (ONeg t) + +    unFloat :: DN a ~ TPair a a +            => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b)) +            -> Ex env (DN a) -> Ex env (DN b) +    unFloat f e = +      ELet ext e $ +        let var = EVar ext (typeOf e) IZ +        in f (EFst ext var, ESnd ext var) + +    binFloat :: (a ~ TPair s s, DN s ~ TPair s s) +             => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (DN b)) +             -> Ex env (DN a) -> Ex env (DN b) +    binFloat f e = +      ELet ext e $ +        let var = EVar ext (typeOf e) IZ +        in f (EFst ext (EFst ext var), ESnd ext (EFst ext var)) +             (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) + +dfwdDN :: Ex env t -> Ex (DNE env) (DN t) +dfwdDN = \case +  EVar _ t i -> EVar ext (dn t) (convIdx i) +  ELet _ a b -> ELet ext (dfwdDN a) (dfwdDN b) +  EPair _ a b -> EPair ext (dfwdDN a) (dfwdDN b) +  EFst _ e -> EFst ext (dfwdDN e) +  ESnd _ e -> ESnd ext (dfwdDN e) +  ENil _ -> ENil ext +  EInl _ t e -> EInl ext (dn t) (dfwdDN e) +  EInr _ t e -> EInr ext (dn t) (dfwdDN e) +  ECase _ e a b -> ECase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) +  ENothing _ t -> ENothing ext (dn t) +  EJust _ e -> EJust ext (dfwdDN e) +  EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b) +  EConstArr _ n t x -> scalTyCase t +    (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) +          (EConstArr ext n t x)) +    (EConstArr ext n t x) +  EBuild1 _ a b -> EBuild1 ext (dfwdDN a) (dfwdDN b) +  EBuild _ n a b +    | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) +  EFold1Inner _ a b -> EFold1Inner ext (dfwdDN a) (dfwdDN b) +  ESum1Inner _ e -> +    let STArr n (STScal t) = typeOf e +        pairty = (STPair (STScal t) (STScal t)) +    in scalTyCase t +         (ELet ext (dfwdDN e) $ +            ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) +                                       (EVar ext (STArr n pairty) IZ))) +                 (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ)) +                                       (EVar ext (STArr n pairty) IZ)))) +         (ESum1Inner ext (dfwdDN e)) +  EUnit _ e -> EUnit ext (dfwdDN e) +  EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) +  EConst _ t x -> scalTyCase t +    (EPair ext (EConst ext t x) (EConst ext t 0.0)) +    (EConst ext t x) +  EIdx0 _ e -> EIdx0 ext (dfwdDN e) +  EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b) +  EIdx _ n a b +    | Refl <- dnPreservesTupIx n -> EIdx ext n (dfwdDN a) (dfwdDN b) +  EShape _ e +    | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e) +  EOp _ op e -> dop op (dfwdDN e) +  EError t s -> EError (dn t) s + +  EWith{} -> err_accum +  EAccum{} -> err_accum +  EZero{} -> err_monoid +  EPlus{} -> err_monoid +  where +    err_accum = error "Accumulator operations unsupported in the source program" +    err_monoid = error "Monoid operations unsupported in the source program" + +emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) +emap f arr = +  let STArr n t = typeOf arr +  in ELet ext arr $ +       EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ +         ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ)) +                              (EVar ext (tTup (sreplicate n tIx)) IZ)) $ +           weakenExpr (WCopy (WSink .> WSink)) f + +ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) +ezip a b = +  let STArr n t1 = typeOf a +      STArr _ t2 = typeOf b +  in ELet ext a $ +     ELet ext (weakenExpr WSink b) $ +       EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ +         EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ))) +                               (EVar ext (tTup (sreplicate n tIx)) IZ)) +                   (EIdx ext n (EVar ext (STArr n t2) (IS IZ)) +                               (EVar ext (tTup (sreplicate n tIx)) IZ)) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 316a423..4d1358f 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -19,17 +19,21 @@ module Interpreter (  ) where  import Control.Monad (foldM, join) +import Data.Char (isSpace)  import Data.Kind (Type)  import Data.Int (Int64)  import Data.IORef  import System.IO.Unsafe (unsafePerformIO) +import Debug.Trace +  import Array  import AST  import CHAD.Types  import Data  import Interpreter.Rep  import Data.Bifunctor (bimap) +import GHC.Stack (HasCallStack)  newtype AcM s a = AcM { unAcM :: IO a } @@ -41,17 +45,16 @@ runAcM (AcM m) = unsafePerformIO m  interpret :: Ex '[] t -> Rep t  interpret = interpretOpen SNil -newtype Value t = Value (Rep t) -  interpretOpen :: SList Value env -> Ex env t -> Rep t  interpretOpen env e = runAcM (interpret' env e) -interpret' :: forall env t s. SList Value env -> Ex env t -> AcM s (Rep t) +interpret' :: forall env t s. HasCallStack => SList Value env -> Ex env t -> AcM s (Rep t)  interpret' env = \case    EVar _ _ i -> case slistIdx env i of Value x -> return x    ELet _ a b -> do      x <- interpret' env a      interpret' (Value x `SCons` env) b +  expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined    EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b    EFst _ e -> fst <$> interpret' env e    ESnd _ e -> snd <$> interpret' env e @@ -232,13 +235,6 @@ instance Shapey Shape where    shapeyCase ShNil k0 _ = k0    shapeyCase (ShCons sh n) _ k1 = k1 sh n -enumInvShape :: InvShape n -> [InvIndex n] -enumInvShape IShNil = [IIxNil] -enumInvShape (n `IShCons` sh) = [i `IIxCons` ix | i <- [0 .. n - 1], ix <- enumInvShape sh] - -enumShape :: Shape n -> [Index n] -enumShape = map uninvert . enumInvShape . invert -  invert :: forall f n. Shapey f => f n -> Inverted f n  invert | Refl <- lemPlusZero @n = flip go InvNil    where diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index c0c38b2..adb4eba 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -38,3 +38,6 @@ type family RepAcDense t where    -- RepAcDense (TArr n t) = Array n (RepAcSparse t)    -- RepAcDense (TScal sty) = ScalRep sty    -- RepAcDense (TAccum t) = TypeError (Text "RepAcDense: Nested accumulators") + +newtype Value t = Value (Rep t) + | 
