summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chad-fast.cabal9
-rw-r--r--example/Main.hs (renamed from test/example/Main.hs)0
-rw-r--r--src/AST.hs7
-rw-r--r--src/CHAD.hs5
-rw-r--r--src/CHAD/Types.hs5
-rw-r--r--src/ForwardAD.hs15
-rw-r--r--test/Main.hs99
7 files changed, 135 insertions, 5 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index ef9f642..ad949c6 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -49,7 +49,14 @@ library
test-suite example
type: exitcode-stdio-1.0
- main-is: test/example/Main.hs
+ main-is: example/Main.hs
+ build-depends: base, chad-fast
+ default-language: Haskell2010
+ ghc-options: -Wall -threaded
+
+test-suite test
+ type: exitcode-stdio-1.0
+ main-is: test/Main.hs
build-depends: base, chad-fast
default-language: Haskell2010
ghc-options: -Wall -threaded
diff --git a/test/example/Main.hs b/example/Main.hs
index 6c36857..6c36857 100644
--- a/test/example/Main.hs
+++ b/example/Main.hs
diff --git a/src/AST.hs b/src/AST.hs
index 8dfea68..5dab62f 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -129,6 +129,13 @@ tTup = mkTup STNil STPair
eTup :: SList (Ex env) list -> Ex env (Tup list)
eTup = mkTup (ENil ext) (EPair ext)
+unTup :: (forall a b. c (TPair a b) -> (c a, c b))
+ -> SList f list -> c (Tup list) -> SList c list
+unTup _ SNil _ = SNil
+unTup unpack (_ `SCons` list) tup =
+ let (xs, x) = unpack tup
+ in x `SCons` unTup unpack list xs
+
type family InvTup core env where
InvTup core '[] = core
InvTup core (t : ts) = InvTup (TPair core t) ts
diff --git a/src/CHAD.hs b/src/CHAD.hs
index bcc1485..55d94b1 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -25,6 +25,7 @@ module CHAD (
freezeRet,
Storage(..),
Descr(..),
+ Select,
) where
import Data.Bifunctor (first, second)
@@ -630,10 +631,6 @@ sD1eEnv :: Descr env sto -> SList STy (D1E env)
sD1eEnv DTop = SNil
sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)
-d2e :: SList STy env -> SList STy (D2E env)
-d2e SNil = SNil
-d2e (SCons t ts) = SCons (d2 t) (d2e ts)
-
d2ace :: SList STy env -> SList STy (D2AcE env)
d2ace SNil = SNil
d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 0b32393..0b73a3a 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -4,6 +4,7 @@
module CHAD.Types where
import AST.Types
+import Data
type family D1 t where
@@ -63,3 +64,7 @@ d2 (STScal t) = case t of
STF64 -> STScal STF64
STBool -> STNil
d2 STAccum{} = error "Accumulators not allowed in input program"
+
+d2e :: SList STy env -> SList STy (D2E env)
+d2e SNil = SNil
+d2e (SCons t ts) = SCons (d2 t) (d2e ts)
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index 63244a8..6d53b48 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -52,6 +52,21 @@ tanty (STScal t) = case t of
STBool -> STNil
tanty STAccum{} = error "Accumulators not allowed in input program"
+zeroTan :: STy t -> Rep t -> Rep (Tan t)
+zeroTan STNil () = ()
+zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y)
+zeroTan (STEither a _) (Left x) = Left (zeroTan a x)
+zeroTan (STEither _ b) (Right y) = Right (zeroTan b y)
+zeroTan (STMaybe _) Nothing = Nothing
+zeroTan (STMaybe t) (Just x) = Just (zeroTan t x)
+zeroTan (STArr _ t) x = fmap (zeroTan t) x
+zeroTan (STScal STI32) _ = ()
+zeroTan (STScal STI64) _ = ()
+zeroTan (STScal STF32) _ = 0.0
+zeroTan (STScal STF64) _ = 0.0
+zeroTan (STScal STBool) _ = ()
+zeroTan STAccum{} _ = error "Accumulators not allowed in input program"
+
unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t))
unzipDN STNil _ = ((), ())
unzipDN (STPair a b) (d1, d2) =
diff --git a/test/Main.hs b/test/Main.hs
new file mode 100644
index 0000000..39415bb
--- /dev/null
+++ b/test/Main.hs
@@ -0,0 +1,99 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE LambdaCase #-}
+module Main where
+
+import Data.Bifunctor
+
+import Array
+import AST
+import CHAD
+import CHAD.Types
+import Data
+import ForwardAD
+import Interpreter
+import Interpreter.Rep
+
+
+type family MapMerge env where
+ MapMerge '[] = '[]
+ MapMerge (t : ts) = "merge" : MapMerge ts
+
+mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[]
+mapMergeNoAccum SNil = Refl
+mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl
+
+mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env
+mapMergeOnlyMerge SNil = Refl
+mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl
+
+gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env)
+gradientByCHAD = \env term input ->
+ case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
+ (Refl, Refl) ->
+ let descr = makeMergeDescr env
+ dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0)
+ input1 = toPrimalE env input
+ (_out, grad) = interpretOpen input1 dterm
+ in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad)
+ where
+ makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
+ makeMergeDescr SNil = DTop
+ makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge)
+
+ toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env')
+ toPrimalE SNil SNil = SNil
+ toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp
+
+ toPrimal :: STy t -> Rep t -> Rep (D1 t)
+ toPrimal = \case
+ STNil -> id
+ STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
+ STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
+ STMaybe t -> fmap (toPrimal t)
+ STArr _ t -> fmap (toPrimal t)
+ STScal _ -> id
+ STAccum{} -> error "Accumulators not allowed in input program"
+
+gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
+gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input)
+ where
+ toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
+ toTanE SNil SNil SNil = SNil
+ toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
+ Value (toTan t p x) `SCons` toTanE env primal inp
+
+ toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
+ toTan typ primal der = case typ of
+ STNil -> der
+ STPair t1 t2 -> case der of
+ Left () -> bimap (zeroTan t1) (zeroTan t2) primal
+ Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
+ STEither t1 t2 -> case der of
+ Left () -> bimap (zeroTan t1) (zeroTan t2) primal
+ Right d -> case (primal, d) of
+ (Left p, Left d') -> Left (toTan t1 p d')
+ (Right p, Right d') -> Right (toTan t2 p d')
+ _ -> error "Primal and cotangent disagree on Either alternative"
+ STMaybe t -> liftA2 (toTan t) primal der
+ STArr _ t
+ | shapeSize (arrayShape der) == 0 ->
+ arrayMap (zeroTan t) primal
+ | arrayShape primal == arrayShape der ->
+ arrayGenerateLin (arrayShape primal) $ \i ->
+ toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
+ | otherwise ->
+ error "Primal and cotangent disagree on array shape"
+ STScal sty -> case sty of
+ STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
+ STAccum{} -> error "Accumulators not allowed in input program"
+
+gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
+gradientByForward env term input = drevByFwd env term input 1.0
+
+main :: IO ()
+main = return ()