summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chad-fast.cabal4
-rw-r--r--src/AST/Bindings.hs64
-rw-r--r--src/Array.hs3
-rw-r--r--src/CHAD.hs61
-rw-r--r--src/Example.hs32
-rw-r--r--src/Example/Format.hs9
-rw-r--r--src/ForwardAD.hs377
-rw-r--r--src/ForwardAD/DualNumbers.hs206
-rw-r--r--src/Interpreter.hs16
-rw-r--r--src/Interpreter/Rep.hs3
10 files changed, 500 insertions, 275 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 6873efd..ef9f642 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -12,6 +12,7 @@ library
exposed-modules:
Array
AST
+ AST.Bindings
AST.Count
AST.Env
AST.Pretty
@@ -23,7 +24,9 @@ library
-- Compile
Data
Example
+ Example.Format
ForwardAD
+ ForwardAD.DualNumbers
Interpreter
-- Interpreter.AccumOld
Interpreter.Rep
@@ -37,6 +40,7 @@ library
base >= 4.19 && < 4.21,
containers,
-- template-haskell,
+ process,
transformers,
vector,
hs-source-dirs: src
diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs
new file mode 100644
index 0000000..2e63b42
--- /dev/null
+++ b/src/AST/Bindings.hs
@@ -0,0 +1,64 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+
+-- I want to bring various type variables in scope using type annotations in
+-- patterns, but I don't want to have to mention all the other type parameters
+-- of the types in question as well then. Partial type signatures (with '_') are
+-- useful here.
+{-# LANGUAGE PartialTypeSignatures #-}
+{-# OPTIONS -Wno-partial-type-signatures #-}
+module AST.Bindings where
+
+import AST
+import Data
+import Lemmas
+
+
+-- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'.
+data Bindings f env binds where
+ BTop :: Bindings f env '[]
+ BPush :: Bindings f env binds -> (STy t, f (Append binds env) t) -> Bindings f env (t : binds)
+deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env')
+infixl `BPush`
+
+mapBindings :: (forall env' t'. f env' t' -> g env' t')
+ -> Bindings f env binds -> Bindings g env binds
+mapBindings _ BTop = BTop
+mapBindings f (BPush b (t, e)) = BPush (mapBindings f b) (t, f e)
+
+weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t)
+ -> env1 :> env2
+ -> Bindings f env1 binds
+ -> (Bindings f env2 binds, Append binds env1 :> Append binds env2)
+weakenBindings _ w BTop = (BTop, w)
+weakenBindings wf w (BPush b (t, x)) =
+ let (b', w') = weakenBindings wf w b
+ in (BPush b' (t, wf w' x), WCopy w')
+
+weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env'
+weakenOver SNil w = w
+weakenOver (SCons _ ts) w = WCopy (weakenOver ts w)
+
+sinkWithBindings :: Bindings f env binds -> env' :> Append binds env'
+sinkWithBindings BTop = WId
+sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b
+
+bconcat :: forall f env binds1 binds2. Bindings f env binds1 -> Bindings f (Append binds1 env) binds2 -> Bindings f env (Append binds2 binds1)
+bconcat b1 BTop = b1
+bconcat b1 (BPush (b2 :: Bindings _ (Append binds1 env) binds2C) (t, x))
+ | Refl <- lemAppendAssoc @binds2C @binds1 @env
+ = BPush (bconcat b1 b2) (t, x)
+
+bindingsBinds :: Bindings f env binds -> SList STy binds
+bindingsBinds BTop = SNil
+bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds)
+
+letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t
+letBinds BTop = id
+letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs
diff --git a/src/Array.hs b/src/Array.hs
index c939419..8507544 100644
--- a/src/Array.hs
+++ b/src/Array.hs
@@ -45,6 +45,9 @@ emptyShape :: SNat n -> Shape n
emptyShape SZ = ShNil
emptyShape (SS m) = emptyShape m `ShCons` 0
+enumShape :: Shape n -> [Index n]
+enumShape sh = map (fromLinearIndex sh) [0 .. shapeSize sh - 1]
+
-- | TODO: this Vector is a boxed vector, which is horrendously inefficient.
data Array (n :: Nat) t = Array (Shape n) (Vector t)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 12d28e2..bcc1485 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -34,6 +34,7 @@ import GHC.Stack (HasCallStack)
import GHC.TypeLits (Symbol)
import AST
+import AST.Bindings
import AST.Count
import AST.Env
import AST.Weaken.Auto
@@ -42,66 +43,10 @@ import Data
import Lemmas
--- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'.
-data Bindings f env binds where
- BTop :: Bindings f env '[]
- BPush :: Bindings f env binds -> (STy t, f (Append binds env) t) -> Bindings f env (t : binds)
-deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env')
-infixl `BPush`
-
-mapBindings :: (forall env' t'. f env' t' -> g env' t')
- -> Bindings f env binds -> Bindings g env binds
-mapBindings _ BTop = BTop
-mapBindings f (BPush b (t, e)) = BPush (mapBindings f b) (t, f e)
-
-weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t)
- -> env1 :> env2
- -> Bindings f env1 binds
- -> (Bindings f env2 binds, Append binds env1 :> Append binds env2)
-weakenBindings _ w BTop = (BTop, w)
-weakenBindings wf w (BPush b (t, x)) =
- let (b', w') = weakenBindings wf w b
- in (BPush b' (t, wf w' x), WCopy w')
-
-weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env'
-weakenOver SNil w = w
-weakenOver (SCons _ ts) w = WCopy (weakenOver ts w)
-
-sinkWithBindings :: Bindings f env binds -> env' :> Append binds env'
-sinkWithBindings BTop = WId
-sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b
-
-bconcat :: forall f env binds1 binds2. Bindings f env binds1 -> Bindings f (Append binds1 env) binds2 -> Bindings f env (Append binds2 binds1)
-bconcat b1 BTop = b1
-bconcat b1 (BPush (b2 :: Bindings _ (Append binds1 env) binds2C) (t, x))
- | Refl <- lemAppendAssoc @binds2C @binds1 @env
- = BPush (bconcat b1 b2) (t, x)
-
--- bconcat' :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t)
--- -> Bindings f env env1 -> Bindings f env env2
--- -> (forall env12. Bindings f env env12 -> r) -> r
--- bconcat' wf b1 b2 k = weakenBindings wf (sinkWithBindings b1) b2 $ \b2' _ -> k (bconcat b1 b2')
-
--- bsnoc :: (forall env1 env2 t'. env1 :> env2 -> f env1 t' -> f env2 t')
--- -> STy t -> f env t -> Bindings f env binds -> (Bindings f env (Snoc binds t), Append binds env :> Append (Snoc binds t) env)
--- bsnoc _ t x BTop = (BPush BTop (t, x), WSink)
--- bsnoc wf t x (BPush b (t', y)) =
--- let (b', w) = bsnoc wf t x b
--- in (BPush b' (t', wf w y), WCopy w)
-
type family Tape binds where
Tape '[] = TNil
Tape (t : ts) = TPair t (Tape ts)
--- data TupBinds f env binds =
--- TupBinds (SList STy binds)
--- (forall env2. Append binds env :> env2 -> Ex env2 (Tape binds))
--- (forall env2. Idx env2 (Tape binds) -> Bindings f env2 binds)
-
-bindingsBinds :: Bindings f env binds -> SList STy binds
-bindingsBinds BTop = SNil
-bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds)
-
tapeTy :: SList STy binds -> STy (Tape binds)
tapeTy SNil = STNil
tapeTy (SCons t ts) = STPair t (tapeTy ts)
@@ -240,10 +185,6 @@ reconstructBindings binds tape =
(bconcat (mapBindings fromUnfExpr unf) build)
,sreverse (stapeUnfoldings binds))
-letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t
-letBinds BTop = id
-letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs
-
type family Vectorise n list where
Vectorise _ '[] = '[]
Vectorise n (t : ts) = TArr n t : Vectorise n ts
diff --git a/src/Example.hs b/src/Example.hs
index e2f1be9..6701e38 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -11,10 +11,14 @@ import AST
import AST.Pretty
import CHAD
import Data
+import ForwardAD
import Interpreter
import Language
import Simplify
+import Debug.Trace
+import Example.Format
+
-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0)
@@ -175,18 +179,26 @@ neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #
let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
#x3 ! nil
-neuralGo :: (Float
- ,(((((), Either () (Array N2 Float, Array N1 Float))
- ,Either () (Array N2 Float, Array N1 Float))
- ,Array N1 Float)
- ,Array N1 Float))
+type NeuralGrad = ((Array N2 Float, Array N1 Float)
+ ,(Array N2 Float, Array N1 Float)
+ ,Array N1 Float
+ ,Array N1 Float)
+
+neuralGo :: (Float -- primal
+ ,NeuralGrad -- gradient using CHAD
+ ,NeuralGrad) -- gradient using dual-numbers forward AD
neuralGo =
let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0])
lay3 = arrayFromList (ShNil `ShCons` 2) [1,1]
input = arrayFromList (ShNil `ShCons` 2) [1,1]
- in interpretOpen (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) $
- simplifyN 20 $
- freezeRet mergeDescr
- (drev mergeDescr neural)
- (EConst ext STF32 1.0)
+ argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil)
+ revderiv =
+ simplifyN 20 $
+ freezeRet mergeDescr
+ (drev mergeDescr neural)
+ (EConst ext STF32 1.0)
+ (primal, (((((), Right dlay1_1), Right dlay2_1), dlay3_1), dinput_1)) = interpretOpen argument revderiv
+ (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwd knownEnv neural argument 1.0
+ in trace (formatter (ppExpr knownEnv revderiv)) $
+ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2))
diff --git a/src/Example/Format.hs b/src/Example/Format.hs
new file mode 100644
index 0000000..994f431
--- /dev/null
+++ b/src/Example/Format.hs
@@ -0,0 +1,9 @@
+module Example.Format where
+
+import System.IO.Unsafe
+import System.Process
+
+
+{-# NOINLINE formatter #-}
+formatter :: String -> String
+formatter str = unsafePerformIO $ readProcess "hindent" ["--line-length", "100"] str
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index 0a9e12c..63244a8 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -1,202 +1,189 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+module ForwardAD where
--- I want to bring various type variables in scope using type annotations in
--- patterns, but I don't want to have to mention all the other type parameters
--- of the types in question as well then. Partial type signatures (with '_') are
--- useful here.
-{-# LANGUAGE PartialTypeSignatures #-}
-{-# OPTIONS -Wno-partial-type-signatures #-}
-module ForwardAD (
- dfwd,
- FD, FDS, FDE, fd,
-) where
+import Data.Bifunctor (bimap)
+-- import Data.Foldable (toList)
+import Array
import AST
+-- import AST.Bindings
import Data
-
-
--- | Dual-numbers transformation
-type family FD t where
- FD TNil = TNil
- FD (TPair a b) = TPair (FD a) (FD b)
- FD (TEither a b) = TEither (FD a) (FD b)
- FD (TMaybe t) = TMaybe (FD t)
- FD (TArr n t) = TArr n (FD t)
- FD (TScal t) = FDS t
-
-type family FDS t where
- FDS TF32 = TPair (TScal TF32) (TScal TF32)
- FDS TF64 = TPair (TScal TF64) (TScal TF64)
- FDS TI32 = TScal TI32
- FDS TI64 = TScal TI64
- FDS TBool = TScal TBool
-
-type family FDE env where
- FDE '[] = '[]
- FDE (t : ts) = FD t : FDE ts
-
-fd :: STy t -> STy (FD t)
-fd STNil = STNil
-fd (STPair a b) = STPair (fd a) (fd b)
-fd (STEither a b) = STEither (fd a) (fd b)
-fd (STMaybe t) = STMaybe (fd t)
-fd (STArr n t) = STArr n (fd t)
-fd (STScal t) = case t of
- STF32 -> STPair (STScal STF32) (STScal STF32)
- STF64 -> STPair (STScal STF64) (STScal STF64)
- STI32 -> STScal STI32
- STI64 -> STScal STI64
- STBool -> STScal STBool
-fd STAccum{} = error "Accum in source program"
-
-fdPreservesTupIx :: SNat n -> FD (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx)
-fdPreservesTupIx SZ = Refl
-fdPreservesTupIx (SS n) | Refl <- fdPreservesTupIx n = Refl
-
-convIdx :: Idx env t -> Idx (FDE env) (FD t)
-convIdx IZ = IZ
-convIdx (IS i) = IS (convIdx i)
-
-scalTyCase :: SScalTy t
- -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), FD (TScal t) ~ TPair (TScal t) (TScal t)) => r)
- -> (FD (TScal t) ~ TScal t => r)
- -> r
-scalTyCase STF32 k1 _ = k1
-scalTyCase STF64 k1 _ = k1
-scalTyCase STI32 _ k2 = k2
-scalTyCase STI64 _ k2 = k2
-scalTyCase STBool _ k2 = k2
-
--- | Argument does not need to be duplicable.
-dop :: forall a b env. SOp a b -> Ex env (FD a) -> Ex env (FD b)
-dop = \case
- OAdd t -> scalTyCase t
- (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy)))
- (EOp ext (OAdd t))
- OMul t -> scalTyCase t
- (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x))))
- (EOp ext (OMul t))
- ONeg t -> scalTyCase t
- (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx)))
- (EOp ext (ONeg t))
- OLt t -> scalTyCase t
- (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y)))
- (EOp ext (OLt t))
- OLe t -> scalTyCase t
- (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y)))
- (EOp ext (OLe t))
- OEq t -> scalTyCase t
- (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y)))
- (EOp ext (OEq t))
- ONot -> EOp ext ONot
- OIf -> EOp ext OIf
- where
- add :: ScalIsNumeric t ~ True
- => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
- add t a b = EOp ext (OAdd t) (EPair ext a b)
-
- mul :: ScalIsNumeric t ~ True
- => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
- mul t a b = EOp ext (OMul t) (EPair ext a b)
-
- neg :: ScalIsNumeric t ~ True
- => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
- neg t = EOp ext (ONeg t)
-
- unFloat :: FD a ~ TPair a a
- => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (FD b))
- -> Ex env (FD a) -> Ex env (FD b)
- unFloat f e =
- ELet ext e $
- let var = EVar ext (typeOf e) IZ
- in f (EFst ext var, ESnd ext var)
-
- binFloat :: (a ~ TPair s s, FD s ~ TPair s s)
- => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (FD b))
- -> Ex env (FD a) -> Ex env (FD b)
- binFloat f e =
- ELet ext e $
- let var = EVar ext (typeOf e) IZ
- in f (EFst ext (EFst ext var), ESnd ext (EFst ext var))
- (EFst ext (ESnd ext var), ESnd ext (ESnd ext var))
-
-dfwd :: Ex env t -> Ex (FDE env) (FD t)
-dfwd = \case
- EVar _ t i -> EVar ext (fd t) (convIdx i)
- ELet _ a b -> ELet ext (dfwd a) (dfwd b)
- EPair _ a b -> EPair ext (dfwd a) (dfwd b)
- EFst _ e -> EFst ext (dfwd e)
- ESnd _ e -> ESnd ext (dfwd e)
- ENil _ -> ENil ext
- EInl _ t e -> EInl ext (fd t) (dfwd e)
- EInr _ t e -> EInr ext (fd t) (dfwd e)
- ECase _ e a b -> ECase ext (dfwd e) (dfwd a) (dfwd b)
- ENothing _ t -> ENothing ext (fd t)
- EJust _ e -> EJust ext (dfwd e)
- EMaybe _ e a b -> EMaybe ext (dfwd e) (dfwd a) (dfwd b)
- EConstArr _ n t x -> scalTyCase t
- (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0))
- (EConstArr ext n t x))
- (EConstArr ext n t x)
- EBuild1 _ a b -> EBuild1 ext (dfwd a) (dfwd b)
- EBuild _ n a b
- | Refl <- fdPreservesTupIx n -> EBuild ext n (dfwd a) (dfwd b)
- EFold1Inner _ a b -> EFold1Inner ext (dfwd a) (dfwd b)
- ESum1Inner _ e ->
- let STArr n (STScal t) = typeOf e
- pairty = (STPair (STScal t) (STScal t))
- in scalTyCase t
- (ELet ext (dfwd e) $
- ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ))
- (EVar ext (STArr n pairty) IZ)))
- (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ))
- (EVar ext (STArr n pairty) IZ))))
- (ESum1Inner ext (dfwd e))
- EUnit _ e -> EUnit ext (dfwd e)
- EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwd a) (dfwd b)
- EConst _ t x -> scalTyCase t
- (EPair ext (EConst ext t x) (EConst ext t 0.0))
- (EConst ext t x)
- EIdx0 _ e -> EIdx0 ext (dfwd e)
- EIdx1 _ a b -> EIdx1 ext (dfwd a) (dfwd b)
- EIdx _ n a b
- | Refl <- fdPreservesTupIx n -> EIdx ext n (dfwd a) (dfwd b)
- EShape _ e
- | Refl <- fdPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwd e)
- EOp _ op e -> dop op (dfwd e)
- EError t s -> EError (fd t) s
-
- EWith{} -> err_accum
- EAccum{} -> err_accum
- EZero{} -> err_monoid
- EPlus{} -> err_monoid
- where
- err_accum = error "Accumulator operations unsupported in the source program"
- err_monoid = error "Monoid operations unsupported in the source program"
-
-emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b)
-emap f arr =
- let STArr n t = typeOf arr
- in ELet ext arr $
- EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $
- ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ)) $
- weakenExpr (WCopy (WSink .> WSink)) f
-
-ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
-ezip a b =
- let STArr n t1 = typeOf a
- STArr _ t2 = typeOf b
- in ELet ext a $
- ELet ext (weakenExpr WSink b) $
- EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $
- EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
- (EIdx ext n (EVar ext (STArr n t2) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
+import ForwardAD.DualNumbers
+import Interpreter
+import Interpreter.Rep
+
+
+-- | Tangent along a type (coincides with cotangent for these types)
+type family Tan t where
+ Tan TNil = TNil
+ Tan (TPair a b) = TPair (Tan a) (Tan b)
+ Tan (TEither a b) = TEither (Tan a) (Tan b)
+ Tan (TMaybe t) = TMaybe (Tan t)
+ Tan (TArr n t) = TArr n (Tan t)
+ Tan (TScal t) = TanS t
+
+type family TanS t where
+ TanS TI32 = TNil
+ TanS TI64 = TNil
+ TanS TF32 = TScal TF32
+ TanS TF64 = TScal TF64
+ TanS TBool = TNil
+
+type family TanE env where
+ TanE '[] = '[]
+ TanE (t : env) = Tan t : TanE env
+
+tanty :: STy t -> STy (Tan t)
+tanty STNil = STNil
+tanty (STPair a b) = STPair (tanty a) (tanty b)
+tanty (STEither a b) = STEither (tanty a) (tanty b)
+tanty (STMaybe t) = STMaybe (tanty t)
+tanty (STArr n t) = STArr n (tanty t)
+tanty (STScal t) = case t of
+ STI32 -> STNil
+ STI64 -> STNil
+ STF32 -> STScal STF32
+ STF64 -> STScal STF64
+ STBool -> STNil
+tanty STAccum{} = error "Accumulators not allowed in input program"
+
+unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t))
+unzipDN STNil _ = ((), ())
+unzipDN (STPair a b) (d1, d2) =
+ let (x, dx) = unzipDN a d1
+ (y, dy) = unzipDN b d2
+ in ((x, y), (dx, dy))
+unzipDN (STEither a b) d = case d of
+ Left d1 -> bimap Left Left (unzipDN a d1)
+ Right d2 -> bimap Right Right (unzipDN b d2)
+unzipDN (STMaybe t) d = case d of
+ Nothing -> (Nothing, Nothing)
+ Just d' -> bimap Just Just (unzipDN t d')
+unzipDN (STArr _ t) d =
+ let pairs = arrayMap (unzipDN t) d
+ in (arrayMap fst pairs, arrayMap snd pairs)
+unzipDN (STScal ty) d = case ty of
+ STI32 -> (d, ())
+ STI64 -> (d, ())
+ STF32 -> d
+ STF64 -> d
+ STBool -> (d, ())
+unzipDN STAccum{} _ = error "Accumulators not allowed in input program"
+
+dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double
+dotprodTan STNil _ _ = 0.0
+dotprodTan (STPair a b) (x, y) (x', y') =
+ dotprodTan a x x' + dotprodTan b y y'
+dotprodTan (STEither a b) x y = case (x, y) of
+ (Left x', Left y') -> dotprodTan a x' y'
+ (Right x', Right y') -> dotprodTan b x' y'
+ _ -> error "dotprodTan: incompatible Either alternatives"
+dotprodTan (STMaybe t) x y = case (x, y) of
+ (Nothing, Nothing) -> 0.0
+ (Just x', Just y') -> dotprodTan t x' y'
+ _ -> error "dotprodTan: incompatible Maybe alternatives"
+dotprodTan (STArr _ t) x y =
+ let sh1 = arrayShape x
+ sh2 = arrayShape y
+ in if | shapeSize sh1 == 0 || shapeSize sh2 == 0 -> 0.0
+ | sh1 == sh2 -> sum [dotprodTan t (arrayIndex x i) (arrayIndex y i) | i <- enumShape sh1]
+ | otherwise -> error "dotprodTan: incompatible array shapes"
+dotprodTan (STScal ty) x y = case ty of
+ STI32 -> 0.0
+ STI64 -> 0.0
+ STF32 -> realToFrac @Float @Double (x * y)
+ STF64 -> x * y
+ STBool -> 0.0
+dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program"
+
+-- -- Primal expression must be duplicable
+-- dnConstE :: STy t -> Ex env t -> Ex env (DN t)
+-- dnConstE STNil _ = ENil ext
+-- dnConstE (STPair t1 t2) e =
+-- -- This creates fst/snd stacks of unbounded size, but let's not care here
+-- EPair ext (dnConstE t1 (EFst ext e)) (dnConstE t2 (ESnd ext e))
+-- dnConstE (STEither t1 t2) e =
+-- ECase ext e
+-- (EInl ext (dn t2) (dnConstE t1 (EVar ext t1 IZ)))
+-- (EInr ext (dn t1) (dnConstE t2 (EVar ext t2 IZ)))
+-- dnConstE (STMaybe t) e =
+-- EMaybe ext (ENothing ext (dn t)) (EJust ext (dnConstE t (EVar ext t IZ))) e
+-- dnConstE (STArr n t) e =
+-- EBuild ext n (EShape ext e)
+-- (dnConstE t (EIdx ext n (weakenExpr WSink e) (EVar ext (tTup (sreplicate n tIx)) IZ)))
+-- dnConstE (STScal t) e = case t of
+-- STI32 -> e
+-- STI64 -> e
+-- STF32 -> EPair ext e (EConst ext STF32 0.0)
+-- STF64 -> EPair ext e (EConst ext STF64 0.0)
+-- STBool -> e
+-- dnConstE STAccum{} _ = error "Accumulators not allowed in input program"
+
+dnConst :: STy t -> Rep t -> Rep (DN t)
+dnConst STNil = const ()
+dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2)
+dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2)
+dnConst (STMaybe t) = fmap (dnConst t)
+dnConst (STArr _ t) = arrayMap (dnConst t)
+dnConst (STScal t) = case t of
+ STI32 -> id
+ STI64 -> id
+ STF32 -> (,0.0)
+ STF64 -> (,0.0)
+ STBool -> id
+dnConst STAccum{} = error "Accumulators not allowed in input program"
+
+-- | Given a function that computes the forward derivative for a particular
+-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this
+-- @t@ input.
+type RevByFwd t = (Rep (DN t) -> Double) -> Rep (Tan t)
+
+dnOnehots :: STy t -> Rep t -> RevByFwd t
+dnOnehots STNil _ = \_ -> ()
+dnOnehots (STPair t1 t2) (x, y) =
+ \f -> (dnOnehots t1 x (f . (,dnConst t2 y)), dnOnehots t2 y (f . (dnConst t1 x,)))
+dnOnehots (STEither t1 t2) e =
+ case e of
+ Left x -> \f -> Left (dnOnehots t1 x (f . Left))
+ Right y -> \f -> Right (dnOnehots t2 y (f . Right))
+dnOnehots (STMaybe t) m =
+ case m of
+ Nothing -> \_ -> Nothing
+ Just x -> \f -> Just (dnOnehots t x (f . Just))
+dnOnehots (STArr _ t) a =
+ \f ->
+ arrayGenerate (arrayShape a) $ \idx ->
+ dnOnehots t (arrayIndex a idx) (f . (\oh -> arrayGenerate (arrayShape a) $ \i ->
+ if i == idx then oh else dnConst t (arrayIndex a i)))
+dnOnehots (STScal t) x = case t of
+ STI32 -> \_ -> ()
+ STI64 -> \_ -> ()
+ STF32 -> \f -> realToFrac @Double @Float $ f (x, 1.0)
+ STF64 -> \f -> f (x, 1.0)
+ STBool -> \_ -> ()
+dnOnehots STAccum{} _ = error "Accumulators not allowed in input program"
+
+dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)
+dnConstEnv SNil SNil = SNil
+dnConstEnv (t `SCons` env) (Value x `SCons` val) = Value (dnConst t x) `SCons` dnConstEnv env val
+
+type RevByFwdEnv env = (SList Value (DNE env) -> Double) -> SList Value (TanE env)
+
+dnOnehotEnvs :: SList STy env -> SList Value env -> RevByFwdEnv env
+dnOnehotEnvs SNil SNil = \_ -> SNil
+dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) =
+ \f ->
+ Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val)))
+ `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh))
+
+drevByFwd :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
+drevByFwd env expr input dres =
+ let outty = typeOf expr
+ in dnOnehotEnvs env input $ \dnInput ->
+ let (_, outtan) = unzipDN outty (interpretOpen dnInput (dfwdDN expr))
+ in dotprodTan outty outtan dres
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
new file mode 100644
index 0000000..f9239e9
--- /dev/null
+++ b/src/ForwardAD/DualNumbers.hs
@@ -0,0 +1,206 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+
+-- I want to bring various type variables in scope using type annotations in
+-- patterns, but I don't want to have to mention all the other type parameters
+-- of the types in question as well then. Partial type signatures (with '_') are
+-- useful here.
+{-# LANGUAGE PartialTypeSignatures #-}
+{-# OPTIONS -Wno-partial-type-signatures #-}
+module ForwardAD.DualNumbers (
+ dfwdDN,
+ DN, DNS, DNE, dn, dne,
+) where
+
+import AST
+import Data
+
+
+-- | Dual-numbers transformation
+type family DN t where
+ DN TNil = TNil
+ DN (TPair a b) = TPair (DN a) (DN b)
+ DN (TEither a b) = TEither (DN a) (DN b)
+ DN (TMaybe t) = TMaybe (DN t)
+ DN (TArr n t) = TArr n (DN t)
+ DN (TScal t) = DNS t
+
+type family DNS t where
+ DNS TF32 = TPair (TScal TF32) (TScal TF32)
+ DNS TF64 = TPair (TScal TF64) (TScal TF64)
+ DNS TI32 = TScal TI32
+ DNS TI64 = TScal TI64
+ DNS TBool = TScal TBool
+
+type family DNE env where
+ DNE '[] = '[]
+ DNE (t : ts) = DN t : DNE ts
+
+dn :: STy t -> STy (DN t)
+dn STNil = STNil
+dn (STPair a b) = STPair (dn a) (dn b)
+dn (STEither a b) = STEither (dn a) (dn b)
+dn (STMaybe t) = STMaybe (dn t)
+dn (STArr n t) = STArr n (dn t)
+dn (STScal t) = case t of
+ STF32 -> STPair (STScal STF32) (STScal STF32)
+ STF64 -> STPair (STScal STF64) (STScal STF64)
+ STI32 -> STScal STI32
+ STI64 -> STScal STI64
+ STBool -> STScal STBool
+dn STAccum{} = error "Accum in source program"
+
+dne :: SList STy env -> SList STy (DNE env)
+dne SNil = SNil
+dne (t `SCons` env) = dn t `SCons` dne env
+
+dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx)
+dnPreservesTupIx SZ = Refl
+dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl
+
+convIdx :: Idx env t -> Idx (DNE env) (DN t)
+convIdx IZ = IZ
+convIdx (IS i) = IS (convIdx i)
+
+scalTyCase :: SScalTy t
+ -> ((ScalIsNumeric t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r)
+ -> (DN (TScal t) ~ TScal t => r)
+ -> r
+scalTyCase STF32 k1 _ = k1
+scalTyCase STF64 k1 _ = k1
+scalTyCase STI32 _ k2 = k2
+scalTyCase STI64 _ k2 = k2
+scalTyCase STBool _ k2 = k2
+
+-- | Argument does not need to be duplicable.
+dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b)
+dop = \case
+ OAdd t -> scalTyCase t
+ (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy)))
+ (EOp ext (OAdd t))
+ OMul t -> scalTyCase t
+ (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x))))
+ (EOp ext (OMul t))
+ ONeg t -> scalTyCase t
+ (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx)))
+ (EOp ext (ONeg t))
+ OLt t -> scalTyCase t
+ (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y)))
+ (EOp ext (OLt t))
+ OLe t -> scalTyCase t
+ (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y)))
+ (EOp ext (OLe t))
+ OEq t -> scalTyCase t
+ (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y)))
+ (EOp ext (OEq t))
+ ONot -> EOp ext ONot
+ OIf -> EOp ext OIf
+ where
+ add :: ScalIsNumeric t ~ True
+ => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
+ add t a b = EOp ext (OAdd t) (EPair ext a b)
+
+ mul :: ScalIsNumeric t ~ True
+ => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
+ mul t a b = EOp ext (OMul t) (EPair ext a b)
+
+ neg :: ScalIsNumeric t ~ True
+ => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t)
+ neg t = EOp ext (ONeg t)
+
+ unFloat :: DN a ~ TPair a a
+ => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b))
+ -> Ex env (DN a) -> Ex env (DN b)
+ unFloat f e =
+ ELet ext e $
+ let var = EVar ext (typeOf e) IZ
+ in f (EFst ext var, ESnd ext var)
+
+ binFloat :: (a ~ TPair s s, DN s ~ TPair s s)
+ => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (DN b))
+ -> Ex env (DN a) -> Ex env (DN b)
+ binFloat f e =
+ ELet ext e $
+ let var = EVar ext (typeOf e) IZ
+ in f (EFst ext (EFst ext var), ESnd ext (EFst ext var))
+ (EFst ext (ESnd ext var), ESnd ext (ESnd ext var))
+
+dfwdDN :: Ex env t -> Ex (DNE env) (DN t)
+dfwdDN = \case
+ EVar _ t i -> EVar ext (dn t) (convIdx i)
+ ELet _ a b -> ELet ext (dfwdDN a) (dfwdDN b)
+ EPair _ a b -> EPair ext (dfwdDN a) (dfwdDN b)
+ EFst _ e -> EFst ext (dfwdDN e)
+ ESnd _ e -> ESnd ext (dfwdDN e)
+ ENil _ -> ENil ext
+ EInl _ t e -> EInl ext (dn t) (dfwdDN e)
+ EInr _ t e -> EInr ext (dn t) (dfwdDN e)
+ ECase _ e a b -> ECase ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
+ ENothing _ t -> ENothing ext (dn t)
+ EJust _ e -> EJust ext (dfwdDN e)
+ EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b)
+ EConstArr _ n t x -> scalTyCase t
+ (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0))
+ (EConstArr ext n t x))
+ (EConstArr ext n t x)
+ EBuild1 _ a b -> EBuild1 ext (dfwdDN a) (dfwdDN b)
+ EBuild _ n a b
+ | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b)
+ EFold1Inner _ a b -> EFold1Inner ext (dfwdDN a) (dfwdDN b)
+ ESum1Inner _ e ->
+ let STArr n (STScal t) = typeOf e
+ pairty = (STPair (STScal t) (STScal t))
+ in scalTyCase t
+ (ELet ext (dfwdDN e) $
+ ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ))
+ (EVar ext (STArr n pairty) IZ)))
+ (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ))
+ (EVar ext (STArr n pairty) IZ))))
+ (ESum1Inner ext (dfwdDN e))
+ EUnit _ e -> EUnit ext (dfwdDN e)
+ EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b)
+ EConst _ t x -> scalTyCase t
+ (EPair ext (EConst ext t x) (EConst ext t 0.0))
+ (EConst ext t x)
+ EIdx0 _ e -> EIdx0 ext (dfwdDN e)
+ EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b)
+ EIdx _ n a b
+ | Refl <- dnPreservesTupIx n -> EIdx ext n (dfwdDN a) (dfwdDN b)
+ EShape _ e
+ | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e)
+ EOp _ op e -> dop op (dfwdDN e)
+ EError t s -> EError (dn t) s
+
+ EWith{} -> err_accum
+ EAccum{} -> err_accum
+ EZero{} -> err_monoid
+ EPlus{} -> err_monoid
+ where
+ err_accum = error "Accumulator operations unsupported in the source program"
+ err_monoid = error "Monoid operations unsupported in the source program"
+
+emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b)
+emap f arr =
+ let STArr n t = typeOf arr
+ in ELet ext arr $
+ EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $
+ ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ))
+ (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) f
+
+ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
+ezip a b =
+ let STArr n t1 = typeOf a
+ STArr _ t2 = typeOf b
+ in ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $
+ EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
+ (EIdx ext n (EVar ext (STArr n t2) (IS IZ))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 316a423..4d1358f 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -19,17 +19,21 @@ module Interpreter (
) where
import Control.Monad (foldM, join)
+import Data.Char (isSpace)
import Data.Kind (Type)
import Data.Int (Int64)
import Data.IORef
import System.IO.Unsafe (unsafePerformIO)
+import Debug.Trace
+
import Array
import AST
import CHAD.Types
import Data
import Interpreter.Rep
import Data.Bifunctor (bimap)
+import GHC.Stack (HasCallStack)
newtype AcM s a = AcM { unAcM :: IO a }
@@ -41,17 +45,16 @@ runAcM (AcM m) = unsafePerformIO m
interpret :: Ex '[] t -> Rep t
interpret = interpretOpen SNil
-newtype Value t = Value (Rep t)
-
interpretOpen :: SList Value env -> Ex env t -> Rep t
interpretOpen env e = runAcM (interpret' env e)
-interpret' :: forall env t s. SList Value env -> Ex env t -> AcM s (Rep t)
+interpret' :: forall env t s. HasCallStack => SList Value env -> Ex env t -> AcM s (Rep t)
interpret' env = \case
EVar _ _ i -> case slistIdx env i of Value x -> return x
ELet _ a b -> do
x <- interpret' env a
interpret' (Value x `SCons` env) b
+ expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined
EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
EFst _ e -> fst <$> interpret' env e
ESnd _ e -> snd <$> interpret' env e
@@ -232,13 +235,6 @@ instance Shapey Shape where
shapeyCase ShNil k0 _ = k0
shapeyCase (ShCons sh n) _ k1 = k1 sh n
-enumInvShape :: InvShape n -> [InvIndex n]
-enumInvShape IShNil = [IIxNil]
-enumInvShape (n `IShCons` sh) = [i `IIxCons` ix | i <- [0 .. n - 1], ix <- enumInvShape sh]
-
-enumShape :: Shape n -> [Index n]
-enumShape = map uninvert . enumInvShape . invert
-
invert :: forall f n. Shapey f => f n -> Inverted f n
invert | Refl <- lemPlusZero @n = flip go InvNil
where
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index c0c38b2..adb4eba 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -38,3 +38,6 @@ type family RepAcDense t where
-- RepAcDense (TArr n t) = Array n (RepAcSparse t)
-- RepAcDense (TScal sty) = ScalRep sty
-- RepAcDense (TAccum t) = TypeError (Text "RepAcDense: Nested accumulators")
+
+newtype Value t = Value (Rep t)
+