diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/Main.hs | 99 | ||||
-rw-r--r-- | test/example/Main.hs | 7 |
2 files changed, 99 insertions, 7 deletions
diff --git a/test/Main.hs b/test/Main.hs new file mode 100644 index 0000000..39415bb --- /dev/null +++ b/test/Main.hs @@ -0,0 +1,99 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE LambdaCase #-} +module Main where + +import Data.Bifunctor + +import Array +import AST +import CHAD +import CHAD.Types +import Data +import ForwardAD +import Interpreter +import Interpreter.Rep + + +type family MapMerge env where + MapMerge '[] = '[] + MapMerge (t : ts) = "merge" : MapMerge ts + +mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[] +mapMergeNoAccum SNil = Refl +mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl + +mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env +mapMergeOnlyMerge SNil = Refl +mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl + +gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env) +gradientByCHAD = \env term input -> + case (mapMergeNoAccum env, mapMergeOnlyMerge env) of + (Refl, Refl) -> + let descr = makeMergeDescr env + dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0) + input1 = toPrimalE env input + (_out, grad) = interpretOpen input1 dterm + in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad) + where + makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env') + makeMergeDescr SNil = DTop + makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge) + + toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env') + toPrimalE SNil SNil = SNil + toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp + + toPrimal :: STy t -> Rep t -> Rep (D1 t) + toPrimal = \case + STNil -> id + STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2) + STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2) + STMaybe t -> fmap (toPrimal t) + STArr _ t -> fmap (toPrimal t) + STScal _ -> id + STAccum{} -> error "Accumulators not allowed in input program" + +gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) +gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input) + where + toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) + toTanE SNil SNil SNil = SNil + toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = + Value (toTan t p x) `SCons` toTanE env primal inp + + toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) + toTan typ primal der = case typ of + STNil -> der + STPair t1 t2 -> case der of + Left () -> bimap (zeroTan t1) (zeroTan t2) primal + Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal + STEither t1 t2 -> case der of + Left () -> bimap (zeroTan t1) (zeroTan t2) primal + Right d -> case (primal, d) of + (Left p, Left d') -> Left (toTan t1 p d') + (Right p, Right d') -> Right (toTan t2 p d') + _ -> error "Primal and cotangent disagree on Either alternative" + STMaybe t -> liftA2 (toTan t) primal der + STArr _ t + | shapeSize (arrayShape der) == 0 -> + arrayMap (zeroTan t) primal + | arrayShape primal == arrayShape der -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) + | otherwise -> + error "Primal and cotangent disagree on array shape" + STScal sty -> case sty of + STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der + STAccum{} -> error "Accumulators not allowed in input program" + +gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env) +gradientByForward env term input = drevByFwd env term input 1.0 + +main :: IO () +main = return () diff --git a/test/example/Main.hs b/test/example/Main.hs deleted file mode 100644 index 6c36857..0000000 --- a/test/example/Main.hs +++ /dev/null @@ -1,7 +0,0 @@ -module Main where - -import Example - - -main :: IO () -main = print neuralGo |