diff options
-rw-r--r-- | chad-fast.cabal | 9 | ||||
-rw-r--r-- | example/Main.hs (renamed from test/example/Main.hs) | 0 | ||||
-rw-r--r-- | src/AST.hs | 7 | ||||
-rw-r--r-- | src/CHAD.hs | 5 | ||||
-rw-r--r-- | src/CHAD/Types.hs | 5 | ||||
-rw-r--r-- | src/ForwardAD.hs | 15 | ||||
-rw-r--r-- | test/Main.hs | 99 |
7 files changed, 135 insertions, 5 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index ef9f642..ad949c6 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -49,7 +49,14 @@ library test-suite example type: exitcode-stdio-1.0 - main-is: test/example/Main.hs + main-is: example/Main.hs + build-depends: base, chad-fast + default-language: Haskell2010 + ghc-options: -Wall -threaded + +test-suite test + type: exitcode-stdio-1.0 + main-is: test/Main.hs build-depends: base, chad-fast default-language: Haskell2010 ghc-options: -Wall -threaded diff --git a/test/example/Main.hs b/example/Main.hs index 6c36857..6c36857 100644 --- a/test/example/Main.hs +++ b/example/Main.hs @@ -129,6 +129,13 @@ tTup = mkTup STNil STPair eTup :: SList (Ex env) list -> Ex env (Tup list) eTup = mkTup (ENil ext) (EPair ext) +unTup :: (forall a b. c (TPair a b) -> (c a, c b)) + -> SList f list -> c (Tup list) -> SList c list +unTup _ SNil _ = SNil +unTup unpack (_ `SCons` list) tup = + let (xs, x) = unpack tup + in x `SCons` unTup unpack list xs + type family InvTup core env where InvTup core '[] = core InvTup core (t : ts) = InvTup (TPair core t) ts diff --git a/src/CHAD.hs b/src/CHAD.hs index bcc1485..55d94b1 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -25,6 +25,7 @@ module CHAD ( freezeRet, Storage(..), Descr(..), + Select, ) where import Data.Bifunctor (first, second) @@ -630,10 +631,6 @@ sD1eEnv :: Descr env sto -> SList STy (D1E env) sD1eEnv DTop = SNil sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d) -d2e :: SList STy env -> SList STy (D2E env) -d2e SNil = SNil -d2e (SCons t ts) = SCons (d2 t) (d2e ts) - d2ace :: SList STy env -> SList STy (D2AcE env) d2ace SNil = SNil d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index 0b32393..0b73a3a 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -4,6 +4,7 @@ module CHAD.Types where import AST.Types +import Data type family D1 t where @@ -63,3 +64,7 @@ d2 (STScal t) = case t of STF64 -> STScal STF64 STBool -> STNil d2 STAccum{} = error "Accumulators not allowed in input program" + +d2e :: SList STy env -> SList STy (D2E env) +d2e SNil = SNil +d2e (SCons t ts) = SCons (d2 t) (d2e ts) diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index 63244a8..6d53b48 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -52,6 +52,21 @@ tanty (STScal t) = case t of STBool -> STNil tanty STAccum{} = error "Accumulators not allowed in input program" +zeroTan :: STy t -> Rep t -> Rep (Tan t) +zeroTan STNil () = () +zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y) +zeroTan (STEither a _) (Left x) = Left (zeroTan a x) +zeroTan (STEither _ b) (Right y) = Right (zeroTan b y) +zeroTan (STMaybe _) Nothing = Nothing +zeroTan (STMaybe t) (Just x) = Just (zeroTan t x) +zeroTan (STArr _ t) x = fmap (zeroTan t) x +zeroTan (STScal STI32) _ = () +zeroTan (STScal STI64) _ = () +zeroTan (STScal STF32) _ = 0.0 +zeroTan (STScal STF64) _ = 0.0 +zeroTan (STScal STBool) _ = () +zeroTan STAccum{} _ = error "Accumulators not allowed in input program" + unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t)) unzipDN STNil _ = ((), ()) unzipDN (STPair a b) (d1, d2) = diff --git a/test/Main.hs b/test/Main.hs new file mode 100644 index 0000000..39415bb --- /dev/null +++ b/test/Main.hs @@ -0,0 +1,99 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE LambdaCase #-} +module Main where + +import Data.Bifunctor + +import Array +import AST +import CHAD +import CHAD.Types +import Data +import ForwardAD +import Interpreter +import Interpreter.Rep + + +type family MapMerge env where + MapMerge '[] = '[] + MapMerge (t : ts) = "merge" : MapMerge ts + +mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[] +mapMergeNoAccum SNil = Refl +mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl + +mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env +mapMergeOnlyMerge SNil = Refl +mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl + +gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env) +gradientByCHAD = \env term input -> + case (mapMergeNoAccum env, mapMergeOnlyMerge env) of + (Refl, Refl) -> + let descr = makeMergeDescr env + dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0) + input1 = toPrimalE env input + (_out, grad) = interpretOpen input1 dterm + in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad) + where + makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env') + makeMergeDescr SNil = DTop + makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge) + + toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env') + toPrimalE SNil SNil = SNil + toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp + + toPrimal :: STy t -> Rep t -> Rep (D1 t) + toPrimal = \case + STNil -> id + STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2) + STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2) + STMaybe t -> fmap (toPrimal t) + STArr _ t -> fmap (toPrimal t) + STScal _ -> id + STAccum{} -> error "Accumulators not allowed in input program" + +gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) +gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input) + where + toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) + toTanE SNil SNil SNil = SNil + toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = + Value (toTan t p x) `SCons` toTanE env primal inp + + toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) + toTan typ primal der = case typ of + STNil -> der + STPair t1 t2 -> case der of + Left () -> bimap (zeroTan t1) (zeroTan t2) primal + Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal + STEither t1 t2 -> case der of + Left () -> bimap (zeroTan t1) (zeroTan t2) primal + Right d -> case (primal, d) of + (Left p, Left d') -> Left (toTan t1 p d') + (Right p, Right d') -> Right (toTan t2 p d') + _ -> error "Primal and cotangent disagree on Either alternative" + STMaybe t -> liftA2 (toTan t) primal der + STArr _ t + | shapeSize (arrayShape der) == 0 -> + arrayMap (zeroTan t) primal + | arrayShape primal == arrayShape der -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) + | otherwise -> + error "Primal and cotangent disagree on array shape" + STScal sty -> case sty of + STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der + STAccum{} -> error "Accumulators not allowed in input program" + +gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) +gradientByForward env term input = drevByFwd env term input 1.0 + +main :: IO () +main = return () |