diff options
-rw-r--r-- | chad-fast.cabal | 4 | ||||
-rw-r--r-- | src/AST/Bindings.hs | 64 | ||||
-rw-r--r-- | src/Array.hs | 3 | ||||
-rw-r--r-- | src/CHAD.hs | 61 | ||||
-rw-r--r-- | src/Example.hs | 32 | ||||
-rw-r--r-- | src/Example/Format.hs | 9 | ||||
-rw-r--r-- | src/ForwardAD.hs | 377 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 206 | ||||
-rw-r--r-- | src/Interpreter.hs | 16 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 3 |
10 files changed, 500 insertions, 275 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 6873efd..ef9f642 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -12,6 +12,7 @@ library exposed-modules: Array AST + AST.Bindings AST.Count AST.Env AST.Pretty @@ -23,7 +24,9 @@ library -- Compile Data Example + Example.Format ForwardAD + ForwardAD.DualNumbers Interpreter -- Interpreter.AccumOld Interpreter.Rep @@ -37,6 +40,7 @@ library base >= 4.19 && < 4.21, containers, -- template-haskell, + process, transformers, vector, hs-source-dirs: src diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs new file mode 100644 index 0000000..2e63b42 --- /dev/null +++ b/src/AST/Bindings.hs @@ -0,0 +1,64 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} + +-- I want to bring various type variables in scope using type annotations in +-- patterns, but I don't want to have to mention all the other type parameters +-- of the types in question as well then. Partial type signatures (with '_') are +-- useful here. +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS -Wno-partial-type-signatures #-} +module AST.Bindings where + +import AST +import Data +import Lemmas + + +-- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'. +data Bindings f env binds where + BTop :: Bindings f env '[] + BPush :: Bindings f env binds -> (STy t, f (Append binds env) t) -> Bindings f env (t : binds) +deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') +infixl `BPush` + +mapBindings :: (forall env' t'. f env' t' -> g env' t') + -> Bindings f env binds -> Bindings g env binds +mapBindings _ BTop = BTop +mapBindings f (BPush b (t, e)) = BPush (mapBindings f b) (t, f e) + +weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) + -> env1 :> env2 + -> Bindings f env1 binds + -> (Bindings f env2 binds, Append binds env1 :> Append binds env2) +weakenBindings _ w BTop = (BTop, w) +weakenBindings wf w (BPush b (t, x)) = + let (b', w') = weakenBindings wf w b + in (BPush b' (t, wf w' x), WCopy w') + +weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env' +weakenOver SNil w = w +weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) + +sinkWithBindings :: Bindings f env binds -> env' :> Append binds env' +sinkWithBindings BTop = WId +sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b + +bconcat :: forall f env binds1 binds2. Bindings f env binds1 -> Bindings f (Append binds1 env) binds2 -> Bindings f env (Append binds2 binds1) +bconcat b1 BTop = b1 +bconcat b1 (BPush (b2 :: Bindings _ (Append binds1 env) binds2C) (t, x)) + | Refl <- lemAppendAssoc @binds2C @binds1 @env + = BPush (bconcat b1 b2) (t, x) + +bindingsBinds :: Bindings f env binds -> SList STy binds +bindingsBinds BTop = SNil +bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds) + +letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t +letBinds BTop = id +letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs diff --git a/src/Array.hs b/src/Array.hs index c939419..8507544 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -45,6 +45,9 @@ emptyShape :: SNat n -> Shape n emptyShape SZ = ShNil emptyShape (SS m) = emptyShape m `ShCons` 0 +enumShape :: Shape n -> [Index n] +enumShape sh = map (fromLinearIndex sh) [0 .. shapeSize sh - 1] + -- | TODO: this Vector is a boxed vector, which is horrendously inefficient. data Array (n :: Nat) t = Array (Shape n) (Vector t) diff --git a/src/CHAD.hs b/src/CHAD.hs index 12d28e2..bcc1485 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -34,6 +34,7 @@ import GHC.Stack (HasCallStack) import GHC.TypeLits (Symbol) import AST +import AST.Bindings import AST.Count import AST.Env import AST.Weaken.Auto @@ -42,66 +43,10 @@ import Data import Lemmas --- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'. -data Bindings f env binds where - BTop :: Bindings f env '[] - BPush :: Bindings f env binds -> (STy t, f (Append binds env) t) -> Bindings f env (t : binds) -deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') -infixl `BPush` - -mapBindings :: (forall env' t'. f env' t' -> g env' t') - -> Bindings f env binds -> Bindings g env binds -mapBindings _ BTop = BTop -mapBindings f (BPush b (t, e)) = BPush (mapBindings f b) (t, f e) - -weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) - -> env1 :> env2 - -> Bindings f env1 binds - -> (Bindings f env2 binds, Append binds env1 :> Append binds env2) -weakenBindings _ w BTop = (BTop, w) -weakenBindings wf w (BPush b (t, x)) = - let (b', w') = weakenBindings wf w b - in (BPush b' (t, wf w' x), WCopy w') - -weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env' -weakenOver SNil w = w -weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) - -sinkWithBindings :: Bindings f env binds -> env' :> Append binds env' -sinkWithBindings BTop = WId -sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b - -bconcat :: forall f env binds1 binds2. Bindings f env binds1 -> Bindings f (Append binds1 env) binds2 -> Bindings f env (Append binds2 binds1) -bconcat b1 BTop = b1 -bconcat b1 (BPush (b2 :: Bindings _ (Append binds1 env) binds2C) (t, x)) - | Refl <- lemAppendAssoc @binds2C @binds1 @env - = BPush (bconcat b1 b2) (t, x) - --- bconcat' :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) --- -> Bindings f env env1 -> Bindings f env env2 --- -> (forall env12. Bindings f env env12 -> r) -> r --- bconcat' wf b1 b2 k = weakenBindings wf (sinkWithBindings b1) b2 $ \b2' _ -> k (bconcat b1 b2') - --- bsnoc :: (forall env1 env2 t'. env1 :> env2 -> f env1 t' -> f env2 t') --- -> STy t -> f env t -> Bindings f env binds -> (Bindings f env (Snoc binds t), Append binds env :> Append (Snoc binds t) env) --- bsnoc _ t x BTop = (BPush BTop (t, x), WSink) --- bsnoc wf t x (BPush b (t', y)) = --- let (b', w) = bsnoc wf t x b --- in (BPush b' (t', wf w y), WCopy w) - type family Tape binds where Tape '[] = TNil Tape (t : ts) = TPair t (Tape ts) --- data TupBinds f env binds = --- TupBinds (SList STy binds) --- (forall env2. Append binds env :> env2 -> Ex env2 (Tape binds)) --- (forall env2. Idx env2 (Tape binds) -> Bindings f env2 binds) - -bindingsBinds :: Bindings f env binds -> SList STy binds -bindingsBinds BTop = SNil -bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds) - tapeTy :: SList STy binds -> STy (Tape binds) tapeTy SNil = STNil tapeTy (SCons t ts) = STPair t (tapeTy ts) @@ -240,10 +185,6 @@ reconstructBindings binds tape = (bconcat (mapBindings fromUnfExpr unf) build) ,sreverse (stapeUnfoldings binds)) -letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t -letBinds BTop = id -letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs - type family Vectorise n list where Vectorise _ '[] = '[] Vectorise n (t : ts) = TArr n t : Vectorise n ts diff --git a/src/Example.hs b/src/Example.hs index e2f1be9..6701e38 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -11,10 +11,14 @@ import AST import AST.Pretty import CHAD import Data +import ForwardAD import Interpreter import Language import Simplify +import Debug.Trace +import Example.Format + -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) @@ -175,18 +179,26 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda # let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $ #x3 ! nil -neuralGo :: (Float - ,(((((), Either () (Array N2 Float, Array N1 Float)) - ,Either () (Array N2 Float, Array N1 Float)) - ,Array N1 Float) - ,Array N1 Float)) +type NeuralGrad = ((Array N2 Float, Array N1 Float) + ,(Array N2 Float, Array N1 Float) + ,Array N1 Float + ,Array N1 Float) + +neuralGo :: (Float -- primal + ,NeuralGrad -- gradient using CHAD + ,NeuralGrad) -- gradient using dual-numbers forward AD neuralGo = let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) lay3 = arrayFromList (ShNil `ShCons` 2) [1,1] input = arrayFromList (ShNil `ShCons` 2) [1,1] - in interpretOpen (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) $ - simplifyN 20 $ - freezeRet mergeDescr - (drev mergeDescr neural) - (EConst ext STF32 1.0) + argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) + revderiv = + simplifyN 20 $ + freezeRet mergeDescr + (drev mergeDescr neural) + (EConst ext STF32 1.0) + (primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen argument revderiv + (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0 + in trace (formatter (ppExpr knownEnv revderiv)) $ + (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) diff --git a/src/Example/Format.hs b/src/Example/Format.hs new file mode 100644 index 0000000..994f431 --- /dev/null +++ b/src/Example/Format.hs @@ -0,0 +1,9 @@ +module Example.Format where + +import System.IO.Unsafe +import System.Process + + +{-# NOINLINE formatter #-} +formatter :: String -> String +formatter str = unsafePerformIO $ readProcess "hindent" ["--line-length", "100"] str diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 0a9e12c..63244a8 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -1,202 +1,189 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +module ForwardAD where --- I want to bring various type variables in scope using type annotations in --- patterns, but I don't want to have to mention all the other type parameters --- of the types in question as well then. Partial type signatures (with '_') are --- useful here. -{-# LANGUAGE PartialTypeSignatures #-} -{-# OPTIONS -Wno-partial-type-signatures #-} -module ForwardAD ( - dfwd, - FD, FDS, FDE, fd, -) where +import Data.Bifunctor (bimap) +-- import Data.Foldable (toList) +import Array import AST +-- import AST.Bindings import Data - - --- | Dual-numbers transformation -type family FD t where - FD TNil = TNil - FD (TPair a b) = TPair (FD a) (FD b) - FD (TEither a b) = TEither (FD a) (FD b) - FD (TMaybe t) = TMaybe (FD t) - FD (TArr n t) = TArr n (FD t) - FD (TScal t) = FDS t - -type family FDS t where - FDS TF32 = TPair (TScal TF32) (TScal TF32) - FDS TF64 = TPair (TScal TF64) (TScal TF64) - FDS TI32 = TScal TI32 - FDS TI64 = TScal TI64 - FDS TBool = TScal TBool - -type family FDE env where - FDE '[] = '[] - FDE (t : ts) = FD t : FDE ts - -fd :: STy t -> STy (FD t) -fd STNil = STNil -fd (STPair a b) = STPair (fd a) (fd b) -fd (STEither a b) = STEither (fd a) (fd b) -fd (STMaybe t) = STMaybe (fd t) -fd (STArr n t) = STArr n (fd t) -fd (STScal t) = case t of - STF32 -> STPair (STScal STF32) (STScal STF32) - STF64 -> STPair (STScal STF64) (STScal STF64) - STI32 -> STScal STI32 - STI64 -> STScal STI64 - STBool -> STScal STBool -fd STAccum{} = error "Accum in source program" - -fdPreservesTupIx :: SNat n -> FD (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) -fdPreservesTupIx SZ = Refl -fdPreservesTupIx (SS n) | Refl <- fdPreservesTupIx n = Refl - -convIdx :: Idx env t -> Idx (FDE env) (FD t) -convIdx IZ = IZ -convIdx (IS i) = IS (convIdx i) - -scalTyCase :: SScalTy t - -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), FD (TScal t) ~ TPair (TScal t) (TScal t)) => r) - -> (FD (TScal t) ~ TScal t => r) - -> r -scalTyCase STF32 k1 _ = k1 -scalTyCase STF64 k1 _ = k1 -scalTyCase STI32 _ k2 = k2 -scalTyCase STI64 _ k2 = k2 -scalTyCase STBool _ k2 = k2 - --- | Argument does not need to be duplicable. -dop :: forall a b env. SOp a b -> Ex env (FD a) -> Ex env (FD b) -dop = \case - OAdd t -> scalTyCase t - (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy))) - (EOp ext (OAdd t)) - OMul t -> scalTyCase t - (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x)))) - (EOp ext (OMul t)) - ONeg t -> scalTyCase t - (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx))) - (EOp ext (ONeg t)) - OLt t -> scalTyCase t - (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y))) - (EOp ext (OLt t)) - OLe t -> scalTyCase t - (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y))) - (EOp ext (OLe t)) - OEq t -> scalTyCase t - (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y))) - (EOp ext (OEq t)) - ONot -> EOp ext ONot - OIf -> EOp ext OIf - where - add :: ScalIsNumeric t ~ True - => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) - add t a b = EOp ext (OAdd t) (EPair ext a b) - - mul :: ScalIsNumeric t ~ True - => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) - mul t a b = EOp ext (OMul t) (EPair ext a b) - - neg :: ScalIsNumeric t ~ True - => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) - neg t = EOp ext (ONeg t) - - unFloat :: FD a ~ TPair a a - => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (FD b)) - -> Ex env (FD a) -> Ex env (FD b) - unFloat f e = - ELet ext e $ - let var = EVar ext (typeOf e) IZ - in f (EFst ext var, ESnd ext var) - - binFloat :: (a ~ TPair s s, FD s ~ TPair s s) - => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (FD b)) - -> Ex env (FD a) -> Ex env (FD b) - binFloat f e = - ELet ext e $ - let var = EVar ext (typeOf e) IZ - in f (EFst ext (EFst ext var), ESnd ext (EFst ext var)) - (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) - -dfwd :: Ex env t -> Ex (FDE env) (FD t) -dfwd = \case - EVar _ t i -> EVar ext (fd t) (convIdx i) - ELet _ a b -> ELet ext (dfwd a) (dfwd b) - EPair _ a b -> EPair ext (dfwd a) (dfwd b) - EFst _ e -> EFst ext (dfwd e) - ESnd _ e -> ESnd ext (dfwd e) - ENil _ -> ENil ext - EInl _ t e -> EInl ext (fd t) (dfwd e) - EInr _ t e -> EInr ext (fd t) (dfwd e) - ECase _ e a b -> ECase ext (dfwd e) (dfwd a) (dfwd b) - ENothing _ t -> ENothing ext (fd t) - EJust _ e -> EJust ext (dfwd e) - EMaybe _ e a b -> EMaybe ext (dfwd e) (dfwd a) (dfwd b) - EConstArr _ n t x -> scalTyCase t - (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) - (EConstArr ext n t x)) - (EConstArr ext n t x) - EBuild1 _ a b -> EBuild1 ext (dfwd a) (dfwd b) - EBuild _ n a b - | Refl <- fdPreservesTupIx n -> EBuild ext n (dfwd a) (dfwd b) - EFold1Inner _ a b -> EFold1Inner ext (dfwd a) (dfwd b) - ESum1Inner _ e -> - let STArr n (STScal t) = typeOf e - pairty = (STPair (STScal t) (STScal t)) - in scalTyCase t - (ELet ext (dfwd e) $ - ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) - (EVar ext (STArr n pairty) IZ))) - (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ)) - (EVar ext (STArr n pairty) IZ)))) - (ESum1Inner ext (dfwd e)) - EUnit _ e -> EUnit ext (dfwd e) - EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwd a) (dfwd b) - EConst _ t x -> scalTyCase t - (EPair ext (EConst ext t x) (EConst ext t 0.0)) - (EConst ext t x) - EIdx0 _ e -> EIdx0 ext (dfwd e) - EIdx1 _ a b -> EIdx1 ext (dfwd a) (dfwd b) - EIdx _ n a b - | Refl <- fdPreservesTupIx n -> EIdx ext n (dfwd a) (dfwd b) - EShape _ e - | Refl <- fdPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwd e) - EOp _ op e -> dop op (dfwd e) - EError t s -> EError (fd t) s - - EWith{} -> err_accum - EAccum{} -> err_accum - EZero{} -> err_monoid - EPlus{} -> err_monoid - where - err_accum = error "Accumulator operations unsupported in the source program" - err_monoid = error "Monoid operations unsupported in the source program" - -emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) -emap f arr = - let STArr n t = typeOf arr - in ELet ext arr $ - EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ - ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) f - -ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip a b = - let STArr n t1 = typeOf a - STArr _ t2 = typeOf b - in ELet ext a $ - ELet ext (weakenExpr WSink b) $ - EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ - EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) - (EIdx ext n (EVar ext (STArr n t2) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) +import ForwardAD.DualNumbers +import Interpreter +import Interpreter.Rep + + +-- | Tangent along a type (coincides with cotangent for these types) +type family Tan t where + Tan TNil = TNil + Tan (TPair a b) = TPair (Tan a) (Tan b) + Tan (TEither a b) = TEither (Tan a) (Tan b) + Tan (TMaybe t) = TMaybe (Tan t) + Tan (TArr n t) = TArr n (Tan t) + Tan (TScal t) = TanS t + +type family TanS t where + TanS TI32 = TNil + TanS TI64 = TNil + TanS TF32 = TScal TF32 + TanS TF64 = TScal TF64 + TanS TBool = TNil + +type family TanE env where + TanE '[] = '[] + TanE (t : env) = Tan t : TanE env + +tanty :: STy t -> STy (Tan t) +tanty STNil = STNil +tanty (STPair a b) = STPair (tanty a) (tanty b) +tanty (STEither a b) = STEither (tanty a) (tanty b) +tanty (STMaybe t) = STMaybe (tanty t) +tanty (STArr n t) = STArr n (tanty t) +tanty (STScal t) = case t of + STI32 -> STNil + STI64 -> STNil + STF32 -> STScal STF32 + STF64 -> STScal STF64 + STBool -> STNil +tanty STAccum{} = error "Accumulators not allowed in input program" + +unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t)) +unzipDN STNil _ = ((), ()) +unzipDN (STPair a b) (d1, d2) = + let (x, dx) = unzipDN a d1 + (y, dy) = unzipDN b d2 + in ((x, y), (dx, dy)) +unzipDN (STEither a b) d = case d of + Left d1 -> bimap Left Left (unzipDN a d1) + Right d2 -> bimap Right Right (unzipDN b d2) +unzipDN (STMaybe t) d = case d of + Nothing -> (Nothing, Nothing) + Just d' -> bimap Just Just (unzipDN t d') +unzipDN (STArr _ t) d = + let pairs = arrayMap (unzipDN t) d + in (arrayMap fst pairs, arrayMap snd pairs) +unzipDN (STScal ty) d = case ty of + STI32 -> (d, ()) + STI64 -> (d, ()) + STF32 -> d + STF64 -> d + STBool -> (d, ()) +unzipDN STAccum{} _ = error "Accumulators not allowed in input program" + +dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double +dotprodTan STNil _ _ = 0.0 +dotprodTan (STPair a b) (x, y) (x', y') = + dotprodTan a x x' + dotprodTan b y y' +dotprodTan (STEither a b) x y = case (x, y) of + (Left x', Left y') -> dotprodTan a x' y' + (Right x', Right y') -> dotprodTan b x' y' + _ -> error "dotprodTan: incompatible Either alternatives" +dotprodTan (STMaybe t) x y = case (x, y) of + (Nothing, Nothing) -> 0.0 + (Just x', Just y') -> dotprodTan t x' y' + _ -> error "dotprodTan: incompatible Maybe alternatives" +dotprodTan (STArr _ t) x y = + let sh1 = arrayShape x + sh2 = arrayShape y + in if | shapeSize sh1 == 0 || shapeSize sh2 == 0 -> 0.0 + | sh1 == sh2 -> sum [dotprodTan t (arrayIndex x i) (arrayIndex y i) | i <- enumShape sh1] + | otherwise -> error "dotprodTan: incompatible array shapes" +dotprodTan (STScal ty) x y = case ty of + STI32 -> 0.0 + STI64 -> 0.0 + STF32 -> realToFrac @Float @Double (x * y) + STF64 -> x * y + STBool -> 0.0 +dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" + +-- -- Primal expression must be duplicable +-- dnConstE :: STy t -> Ex env t -> Ex env (DN t) +-- dnConstE STNil _ = ENil ext +-- dnConstE (STPair t1 t2) e = +-- -- This creates fst/snd stacks of unbounded size, but let's not care here +-- EPair ext (dnConstE t1 (EFst ext e)) (dnConstE t2 (ESnd ext e)) +-- dnConstE (STEither t1 t2) e = +-- ECase ext e +-- (EInl ext (dn t2) (dnConstE t1 (EVar ext t1 IZ))) +-- (EInr ext (dn t1) (dnConstE t2 (EVar ext t2 IZ))) +-- dnConstE (STMaybe t) e = +-- EMaybe ext (ENothing ext (dn t)) (EJust ext (dnConstE t (EVar ext t IZ))) e +-- dnConstE (STArr n t) e = +-- EBuild ext n (EShape ext e) +-- (dnConstE t (EIdx ext n (weakenExpr WSink e) (EVar ext (tTup (sreplicate n tIx)) IZ))) +-- dnConstE (STScal t) e = case t of +-- STI32 -> e +-- STI64 -> e +-- STF32 -> EPair ext e (EConst ext STF32 0.0) +-- STF64 -> EPair ext e (EConst ext STF64 0.0) +-- STBool -> e +-- dnConstE STAccum{} _ = error "Accumulators not allowed in input program" + +dnConst :: STy t -> Rep t -> Rep (DN t) +dnConst STNil = const () +dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STMaybe t) = fmap (dnConst t) +dnConst (STArr _ t) = arrayMap (dnConst t) +dnConst (STScal t) = case t of + STI32 -> id + STI64 -> id + STF32 -> (,0.0) + STF64 -> (,0.0) + STBool -> id +dnConst STAccum{} = error "Accumulators not allowed in input program" + +-- | Given a function that computes the forward derivative for a particular +-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this +-- @t@ input. +type RevByFwd t = (Rep (DN t) -> Double) -> Rep (Tan t) + +dnOnehots :: STy t -> Rep t -> RevByFwd t +dnOnehots STNil _ = \_ -> () +dnOnehots (STPair t1 t2) (x, y) = + \f -> (dnOnehots t1 x (f . (,dnConst t2 y)), dnOnehots t2 y (f . (dnConst t1 x,))) +dnOnehots (STEither t1 t2) e = + case e of + Left x -> \f -> Left (dnOnehots t1 x (f . Left)) + Right y -> \f -> Right (dnOnehots t2 y (f . Right)) +dnOnehots (STMaybe t) m = + case m of + Nothing -> \_ -> Nothing + Just x -> \f -> Just (dnOnehots t x (f . Just)) +dnOnehots (STArr _ t) a = + \f -> + arrayGenerate (arrayShape a) $ \idx -> + dnOnehots t (arrayIndex a idx) (f . (\oh -> arrayGenerate (arrayShape a) $ \i -> + if i == idx then oh else dnConst t (arrayIndex a i))) +dnOnehots (STScal t) x = case t of + STI32 -> \_ -> () + STI64 -> \_ -> () + STF32 -> \f -> realToFrac @Double @Float $ f (x, 1.0) + STF64 -> \f -> f (x, 1.0) + STBool -> \_ -> () +dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" + +dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) +dnConstEnv SNil SNil = SNil +dnConstEnv (t `SCons` env) (Value x `SCons` val) = Value (dnConst t x) `SCons` dnConstEnv env val + +type RevByFwdEnv env = (SList Value (DNE env) -> Double) -> SList Value (TanE env) + +dnOnehotEnvs :: SList STy env -> SList Value env -> RevByFwdEnv env +dnOnehotEnvs SNil SNil = \_ -> SNil +dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) = + \f -> + Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val))) + `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh)) + +drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) +drevByFwd env expr input dres = + let outty = typeOf expr + in dnOnehotEnvs env input $ \dnInput -> + let (_, outtan) = unzipDN outty (interpretOpen dnInput (dfwdDN expr)) + in dotprodTan outty outtan dres diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs new file mode 100644 index 0000000..f9239e9 --- /dev/null +++ b/src/ForwardAD/DualNumbers.hs @@ -0,0 +1,206 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +-- I want to bring various type variables in scope using type annotations in +-- patterns, but I don't want to have to mention all the other type parameters +-- of the types in question as well then. Partial type signatures (with '_') are +-- useful here. +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS -Wno-partial-type-signatures #-} +module ForwardAD.DualNumbers ( + dfwdDN, + DN, DNS, DNE, dn, dne, +) where + +import AST +import Data + + +-- | Dual-numbers transformation +type family DN t where + DN TNil = TNil + DN (TPair a b) = TPair (DN a) (DN b) + DN (TEither a b) = TEither (DN a) (DN b) + DN (TMaybe t) = TMaybe (DN t) + DN (TArr n t) = TArr n (DN t) + DN (TScal t) = DNS t + +type family DNS t where + DNS TF32 = TPair (TScal TF32) (TScal TF32) + DNS TF64 = TPair (TScal TF64) (TScal TF64) + DNS TI32 = TScal TI32 + DNS TI64 = TScal TI64 + DNS TBool = TScal TBool + +type family DNE env where + DNE '[] = '[] + DNE (t : ts) = DN t : DNE ts + +dn :: STy t -> STy (DN t) +dn STNil = STNil +dn (STPair a b) = STPair (dn a) (dn b) +dn (STEither a b) = STEither (dn a) (dn b) +dn (STMaybe t) = STMaybe (dn t) +dn (STArr n t) = STArr n (dn t) +dn (STScal t) = case t of + STF32 -> STPair (STScal STF32) (STScal STF32) + STF64 -> STPair (STScal STF64) (STScal STF64) + STI32 -> STScal STI32 + STI64 -> STScal STI64 + STBool -> STScal STBool +dn STAccum{} = error "Accum in source program" + +dne :: SList STy env -> SList STy (DNE env) +dne SNil = SNil +dne (t `SCons` env) = dn t `SCons` dne env + +dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) +dnPreservesTupIx SZ = Refl +dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl + +convIdx :: Idx env t -> Idx (DNE env) (DN t) +convIdx IZ = IZ +convIdx (IS i) = IS (convIdx i) + +scalTyCase :: SScalTy t + -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r) + -> (DN (TScal t) ~ TScal t => r) + -> r +scalTyCase STF32 k1 _ = k1 +scalTyCase STF64 k1 _ = k1 +scalTyCase STI32 _ k2 = k2 +scalTyCase STI64 _ k2 = k2 +scalTyCase STBool _ k2 = k2 + +-- | Argument does not need to be duplicable. +dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b) +dop = \case + OAdd t -> scalTyCase t + (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy))) + (EOp ext (OAdd t)) + OMul t -> scalTyCase t + (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x)))) + (EOp ext (OMul t)) + ONeg t -> scalTyCase t + (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx))) + (EOp ext (ONeg t)) + OLt t -> scalTyCase t + (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y))) + (EOp ext (OLt t)) + OLe t -> scalTyCase t + (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y))) + (EOp ext (OLe t)) + OEq t -> scalTyCase t + (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y))) + (EOp ext (OEq t)) + ONot -> EOp ext ONot + OIf -> EOp ext OIf + where + add :: ScalIsNumeric t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) + add t a b = EOp ext (OAdd t) (EPair ext a b) + + mul :: ScalIsNumeric t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) + mul t a b = EOp ext (OMul t) (EPair ext a b) + + neg :: ScalIsNumeric t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) + neg t = EOp ext (ONeg t) + + unFloat :: DN a ~ TPair a a + => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b)) + -> Ex env (DN a) -> Ex env (DN b) + unFloat f e = + ELet ext e $ + let var = EVar ext (typeOf e) IZ + in f (EFst ext var, ESnd ext var) + + binFloat :: (a ~ TPair s s, DN s ~ TPair s s) + => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (DN b)) + -> Ex env (DN a) -> Ex env (DN b) + binFloat f e = + ELet ext e $ + let var = EVar ext (typeOf e) IZ + in f (EFst ext (EFst ext var), ESnd ext (EFst ext var)) + (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) + +dfwdDN :: Ex env t -> Ex (DNE env) (DN t) +dfwdDN = \case + EVar _ t i -> EVar ext (dn t) (convIdx i) + ELet _ a b -> ELet ext (dfwdDN a) (dfwdDN b) + EPair _ a b -> EPair ext (dfwdDN a) (dfwdDN b) + EFst _ e -> EFst ext (dfwdDN e) + ESnd _ e -> ESnd ext (dfwdDN e) + ENil _ -> ENil ext + EInl _ t e -> EInl ext (dn t) (dfwdDN e) + EInr _ t e -> EInr ext (dn t) (dfwdDN e) + ECase _ e a b -> ECase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) + ENothing _ t -> ENothing ext (dn t) + EJust _ e -> EJust ext (dfwdDN e) + EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b) + EConstArr _ n t x -> scalTyCase t + (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) + (EConstArr ext n t x)) + (EConstArr ext n t x) + EBuild1 _ a b -> EBuild1 ext (dfwdDN a) (dfwdDN b) + EBuild _ n a b + | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) + EFold1Inner _ a b -> EFold1Inner ext (dfwdDN a) (dfwdDN b) + ESum1Inner _ e -> + let STArr n (STScal t) = typeOf e + pairty = (STPair (STScal t) (STScal t)) + in scalTyCase t + (ELet ext (dfwdDN e) $ + ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) + (EVar ext (STArr n pairty) IZ))) + (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ)) + (EVar ext (STArr n pairty) IZ)))) + (ESum1Inner ext (dfwdDN e)) + EUnit _ e -> EUnit ext (dfwdDN e) + EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) + EConst _ t x -> scalTyCase t + (EPair ext (EConst ext t x) (EConst ext t 0.0)) + (EConst ext t x) + EIdx0 _ e -> EIdx0 ext (dfwdDN e) + EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b) + EIdx _ n a b + | Refl <- dnPreservesTupIx n -> EIdx ext n (dfwdDN a) (dfwdDN b) + EShape _ e + | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e) + EOp _ op e -> dop op (dfwdDN e) + EError t s -> EError (dn t) s + + EWith{} -> err_accum + EAccum{} -> err_accum + EZero{} -> err_monoid + EPlus{} -> err_monoid + where + err_accum = error "Accumulator operations unsupported in the source program" + err_monoid = error "Monoid operations unsupported in the source program" + +emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) +emap f arr = + let STArr n t = typeOf arr + in ELet ext arr $ + EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ + ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ)) + (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) f + +ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) +ezip a b = + let STArr n t1 = typeOf a + STArr _ t2 = typeOf b + in ELet ext a $ + ELet ext (weakenExpr WSink b) $ + EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ + EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) + (EIdx ext n (EVar ext (STArr n t2) (IS IZ)) + (EVar ext (tTup (sreplicate n tIx)) IZ)) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 316a423..4d1358f 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -19,17 +19,21 @@ module Interpreter ( ) where import Control.Monad (foldM, join) +import Data.Char (isSpace) import Data.Kind (Type) import Data.Int (Int64) import Data.IORef import System.IO.Unsafe (unsafePerformIO) +import Debug.Trace + import Array import AST import CHAD.Types import Data import Interpreter.Rep import Data.Bifunctor (bimap) +import GHC.Stack (HasCallStack) newtype AcM s a = AcM { unAcM :: IO a } @@ -41,17 +45,16 @@ runAcM (AcM m) = unsafePerformIO m interpret :: Ex '[] t -> Rep t interpret = interpretOpen SNil -newtype Value t = Value (Rep t) - interpretOpen :: SList Value env -> Ex env t -> Rep t interpretOpen env e = runAcM (interpret' env e) -interpret' :: forall env t s. SList Value env -> Ex env t -> AcM s (Rep t) +interpret' :: forall env t s. HasCallStack => SList Value env -> Ex env t -> AcM s (Rep t) interpret' env = \case EVar _ _ i -> case slistIdx env i of Value x -> return x ELet _ a b -> do x <- interpret' env a interpret' (Value x `SCons` env) b + expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b EFst _ e -> fst <$> interpret' env e ESnd _ e -> snd <$> interpret' env e @@ -232,13 +235,6 @@ instance Shapey Shape where shapeyCase ShNil k0 _ = k0 shapeyCase (ShCons sh n) _ k1 = k1 sh n -enumInvShape :: InvShape n -> [InvIndex n] -enumInvShape IShNil = [IIxNil] -enumInvShape (n `IShCons` sh) = [i `IIxCons` ix | i <- [0 .. n - 1], ix <- enumInvShape sh] - -enumShape :: Shape n -> [Index n] -enumShape = map uninvert . enumInvShape . invert - invert :: forall f n. Shapey f => f n -> Inverted f n invert | Refl <- lemPlusZero @n = flip go InvNil where diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index c0c38b2..adb4eba 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -38,3 +38,6 @@ type family RepAcDense t where -- RepAcDense (TArr n t) = Array n (RepAcSparse t) -- RepAcDense (TScal sty) = ScalRep sty -- RepAcDense (TAccum t) = TypeError (Text "RepAcDense: Nested accumulators") + +newtype Value t = Value (Rep t) + |