summaryrefslogtreecommitdiff
path: root/test/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'test/Main.hs')
-rw-r--r--test/Main.hs99
1 files changed, 99 insertions, 0 deletions
diff --git a/test/Main.hs b/test/Main.hs
new file mode 100644
index 0000000..39415bb
--- /dev/null
+++ b/test/Main.hs
@@ -0,0 +1,99 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE LambdaCase #-}
+module Main where
+
+import Data.Bifunctor
+
+import Array
+import AST
+import CHAD
+import CHAD.Types
+import Data
+import ForwardAD
+import Interpreter
+import Interpreter.Rep
+
+
+type family MapMerge env where
+ MapMerge '[] = '[]
+ MapMerge (t : ts) = "merge" : MapMerge ts
+
+mapMergeNoAccum :: SList f env -> Select env (MapMerge env) "accum" :~: '[]
+mapMergeNoAccum SNil = Refl
+mapMergeNoAccum (_ `SCons` env) | Refl <- mapMergeNoAccum env = Refl
+
+mapMergeOnlyMerge :: SList f env -> Select env (MapMerge env) "merge" :~: env
+mapMergeOnlyMerge SNil = Refl
+mapMergeOnlyMerge (_ `SCons` env) | Refl <- mapMergeOnlyMerge env = Refl
+
+gradientByCHAD :: forall env. SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (D2E env)
+gradientByCHAD = \env term input ->
+ case (mapMergeNoAccum env, mapMergeOnlyMerge env) of
+ (Refl, Refl) ->
+ let descr = makeMergeDescr env
+ dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0)
+ input1 = toPrimalE env input
+ (_out, grad) = interpretOpen input1 dterm
+ in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad)
+ where
+ makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
+ makeMergeDescr SNil = DTop
+ makeMergeDescr (t `SCons` env) = makeMergeDescr env `DPush` (t, SMerge)
+
+ toPrimalE :: SList STy env' -> SList Value env' -> SList Value (D1E env')
+ toPrimalE SNil SNil = SNil
+ toPrimalE (t `SCons` env) (Value x `SCons` inp) = Value (toPrimal t x) `SCons` toPrimalE env inp
+
+ toPrimal :: STy t -> Rep t -> Rep (D1 t)
+ toPrimal = \case
+ STNil -> id
+ STPair t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
+ STEither t1 t2 -> bimap (toPrimal t1) (toPrimal t2)
+ STMaybe t -> fmap (toPrimal t)
+ STArr _ t -> fmap (toPrimal t)
+ STScal _ -> id
+ STAccum{} -> error "Accumulators not allowed in input program"
+
+gradientByCHAD' :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
+gradientByCHAD' = \env term input -> toTanE env input (gradientByCHAD env term input)
+ where
+ toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env)
+ toTanE SNil SNil SNil = SNil
+ toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) =
+ Value (toTan t p x) `SCons` toTanE env primal inp
+
+ toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t)
+ toTan typ primal der = case typ of
+ STNil -> der
+ STPair t1 t2 -> case der of
+ Left () -> bimap (zeroTan t1) (zeroTan t2) primal
+ Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal
+ STEither t1 t2 -> case der of
+ Left () -> bimap (zeroTan t1) (zeroTan t2) primal
+ Right d -> case (primal, d) of
+ (Left p, Left d') -> Left (toTan t1 p d')
+ (Right p, Right d') -> Right (toTan t2 p d')
+ _ -> error "Primal and cotangent disagree on Either alternative"
+ STMaybe t -> liftA2 (toTan t) primal der
+ STArr _ t
+ | shapeSize (arrayShape der) == 0 ->
+ arrayMap (zeroTan t) primal
+ | arrayShape primal == arrayShape der ->
+ arrayGenerateLin (arrayShape primal) $ \i ->
+ toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i)
+ | otherwise ->
+ error "Primal and cotangent disagree on array shape"
+ STScal sty -> case sty of
+ STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
+ STAccum{} -> error "Accumulators not allowed in input program"
+
+gradientByForward :: SList STy env -> Ex env (TScal TF64) -> SList Value env -> SList Value (TanE env)
+gradientByForward env term input = drevByFwd env term input 1.0
+
+main :: IO ()
+main = return ()