summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/AST.hs4
-rw-r--r--src/AST/Pretty.hs2
-rw-r--r--src/CHAD.hs4
-rw-r--r--src/ForwardAD/DualNumbers.hs2
-rw-r--r--src/Interpreter.hs2
-rw-r--r--src/Interpreter/Rep.hs13
-rw-r--r--src/Language.hs3
-rw-r--r--test/Main.hs122
8 files changed, 110 insertions, 42 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 5dab62f..94c8537 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -150,6 +150,8 @@ data SOp a t where
OEq :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
ONot :: SOp (TScal TBool) (TScal TBool)
OIf :: SOp (TScal TBool) (TEither TNil TNil)
+ ORound64 :: SOp (TScal TF64) (TScal TI64)
+ OToFl64 :: SOp (TScal TI64) (TScal TF64)
deriving instance Show (SOp a t)
opt2 :: SOp a t -> STy t
@@ -162,6 +164,8 @@ opt2 = \case
OEq _ -> STScal STBool
ONot -> STScal STBool
OIf -> STEither STNil STNil
+ ORound64 -> STScal STI64
+ OToFl64 -> STScal STF64
typeOf :: Expr x env t -> STy t
typeOf = \case
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 7f60db1..8f1fe67 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -243,3 +243,5 @@ operator OLe{} = (Infix, "<=")
operator OEq{} = (Infix, "==")
operator ONot = (Prefix, "not")
operator OIf = (Prefix, "ifB")
+operator ORound64 = (Prefix, "round")
+operator OToFl64 = (Prefix, "toFl64")
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 55d94b1..d05e77f 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -546,6 +546,8 @@ d1op (OLe t) e = EOp ext (OLe t) e
d1op (OEq t) e = EOp ext (OEq t) e
d1op ONot e = EOp ext ONot e
d1op OIf e = EOp ext OIf e
+d1op ORound64 e = EOp ext ORound64 e
+d1op OToFl64 e = EOp ext OToFl64 e
-- | Both primal and dual must be duplicable expressions
data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
@@ -563,6 +565,8 @@ d2op op = case op of
OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
ONot -> Linear $ \_ -> ENil ext
OIf -> Linear $ \_ -> ENil ext
+ ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
+ OToFl64 -> Linear $ \_ -> ENil ext
where
d2opUnArrangeInt :: SScalTy a
-> (D2s a ~ TScal a => D2Op (TScal a) t)
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index f9239e9..3e45ce7 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -100,6 +100,8 @@ dop = \case
(EOp ext (OEq t))
ONot -> EOp ext ONot
OIf -> EOp ext OIf
+ ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg)
+ OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0)
where
add :: ScalIsNumeric t ~ True
=> SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t)
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 4d1358f..8ce1b0e 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -125,6 +125,8 @@ interpretOp op arg = case op of
OEq st -> numericIsNum st $ uncurry (==) arg
ONot -> not arg
OIf -> if arg then Left () else Right ()
+ ORound64 -> round arg
+ OToFl64 -> fromIntegral arg
zeroD2 :: STy t -> Rep (D2 t)
zeroD2 typ = case typ of
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index adb4eba..ed307c0 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -39,5 +39,16 @@ type family RepAcDense t where
-- RepAcDense (TScal sty) = ScalRep sty
-- RepAcDense (TAccum t) = TypeError (Text "RepAcDense: Nested accumulators")
-newtype Value t = Value (Rep t)
+newtype Value t = Value { unValue :: Rep t }
+liftV :: (Rep a -> Rep b) -> Value a -> Value b
+liftV f (Value x) = Value (f x)
+
+liftV2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c
+liftV2 f (Value x) (Value y) = Value (f x y)
+
+vPair :: Value a -> Value b -> Value (TPair a b)
+vPair = liftV2 (,)
+
+vUnpair :: Value (TPair a b) -> (Value a, Value b)
+vUnpair (Value (x, y)) = (Value x, Value y)
diff --git a/src/Language.hs b/src/Language.hs
index cdc6d6b..80de713 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE TypeOperators #-}
module Language (
@@ -22,7 +23,7 @@ infixr 0 :->
body :: NExpr env t -> NFun env env t
body = NBody
-lambda :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
+lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t
lambda = NLam
diff --git a/test/Main.hs b/test/Main.hs
index 986c8a0..d90d9cd 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -1,20 +1,21 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE GADTs #-}
+-- {-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
+-- {-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Main where
import Data.Bifunctor
-import qualified Data.Dependent.Map as DMap
-import Data.Dependent.Map (DMap)
-import Data.List (intercalate)
+-- import qualified Data.Dependent.Map as DMap
+-- import Data.Dependent.Map (DMap)
+import Data.Foldable (toList)
+import Data.List (intercalate, intersperse)
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
@@ -52,7 +53,7 @@ gradientByCHAD = \env term input ->
dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0)
input1 = toPrimalE env input
(_out, grad) = interpretOpen input1 dterm
- in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad)
+ in unTup vUnpair (d2e env) (Value grad)
where
makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env')
makeMergeDescr SNil = DTop
@@ -127,17 +128,18 @@ genShape = \n -> do
shapeDiv ShNil _ = ShNil
shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f)
+genArray :: STy a -> Shape n -> Gen (Value (TArr n a))
+genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> genValue t)
+
genValue :: STy a -> Gen (Value a)
genValue = \case
STNil -> return (Value ())
- STPair a b -> lv2 (,) <$> genValue a <*> genValue b
- STEither a b -> Gen.choice [lv1 Left <$> genValue a
- ,lv1 Right <$> genValue b]
+ STPair a b -> liftV2 (,) <$> genValue a <*> genValue b
+ STEither a b -> Gen.choice [liftV Left <$> genValue a
+ ,liftV Right <$> genValue b]
STMaybe t -> Gen.choice [return (Value Nothing)
- ,lv1 Just <$> genValue t]
- STArr n t -> do
- sh <- genShape n
- Value <$> arrayGenerateLinM sh (\_ -> (\(Value x) -> x) <$> genValue t)
+ ,liftV Just <$> genValue t]
+ STArr n t -> genShape n >>= genArray t
STScal sty -> case sty of
STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10)
@@ -145,39 +147,33 @@ genValue = \case
STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10)
STBool -> Gen.choice [return (Value False), return (Value True)]
STAccum{} -> error "Cannot generate inputs for accumulators"
- where
- lv1 :: (Rep a -> Rep b) -> Value a -> Value b
- lv1 f (Value x) = Value (f x)
-
- lv2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c
- lv2 f (Value x) (Value y) = Value (f x y)
genEnv :: SList STy env -> Gen (SList Value env)
genEnv SNil = return SNil
genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env
-data TemplateVar n = TemplateVar (SNat n) String
- deriving (Show)
+-- data TemplateVar n = TemplateVar (SNat n) String
+-- deriving (Show)
-data Template t where
- TpShape :: TemplateVar n -> STy t -> Template (TArr n t)
- TpAny :: STy t -> Template t
- TpPair :: Template a -> Template b -> Template (TPair a b)
-deriving instance Show (Template t)
+-- data Template t where
+-- TpShape :: TemplateVar n -> STy t -> Template (TArr n t)
+-- TpAny :: STy t -> Template t
+-- TpPair :: Template a -> Template b -> Template (TPair a b)
+-- deriving instance Show (Template t)
-data ShapeConstraint n = ShapeAtLeast (Shape n)
- deriving (Show)
+-- data ShapeConstraint n = ShapeAtLeast (Shape n)
+-- deriving (Show)
-genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t)
-genTemplate = _
+-- genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t)
+-- genTemplate = _
-genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env)
-genEnvTemplateExact shapes env = _
+-- genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env)
+-- genEnvTemplateExact shapes env = _
-genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env)
-genEnvTemplate constrs env = do
- shapes <- DMap.traverseWithKey _ constrs
- genEnvTemplateExact shapes env
+-- genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env)
+-- genEnvTemplate constrs env = do
+-- shapes <- DMap.traverseWithKey _ constrs
+-- genEnvTemplateExact shapes env
showValue :: Int -> STy t -> Rep t -> ShowS
showValue _ STNil () = showString "()"
@@ -186,7 +182,11 @@ showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " .
showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y
showValue _ (STMaybe _) Nothing = showString "Nothing"
showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x
-showValue d (STArr _ t) arr = showsPrec d (fmap (\x -> showValue 0 t x "") arr) -- TODO: improve
+showValue d (STArr _ t) arr = showParen (d > 10) $
+ showString "arrayFromList " . showsPrec 11 (arrayShape arr)
+ . showString " ["
+ . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr))
+ . showString "]"
showValue _ (STScal sty) x = case sty of
STF32 -> shows x
STF64 -> shows x
@@ -203,9 +203,18 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs
adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property
-adTest expr = property $ do
+adTest = flip adTestGen (genEnv (knownEnv @env))
+
+-- adTestTp :: forall env. KnownEnv env
+-- => DMap TemplateVar ShapeConstraint -> SList Template env
+-- -> Ex env (TScal TF64) -> Property
+-- adTestTp envConstrs envTp = adTestGen (genEnvTemplate envConstrs envTp)
+
+adTestGen :: forall env. KnownEnv env
+ => Ex env (TScal TF64) -> Gen (SList Value env) -> Property
+adTestGen expr envGenerator = property $ do
let env = knownEnv @env
- input <- forAllWith (showEnv env) $ genEnv env
+ input <- forAllWith (showEnv env) envGenerator
let gradFwd = gradientByForward knownEnv expr input
gradCHAD = gradientByCHAD' knownEnv expr input
scFwd = envScalars env gradFwd
@@ -219,7 +228,40 @@ adTest expr = property $ do
tests :: IO Bool
tests = checkParallel $ Group "AD"
[("id", adTest $ fromNamed $ lambda #x $ body $ #x)
- ,("neural", adTest Example.neural)]
+
+ ,("sum-vec", adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x))
+
+ ,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $
+ idx0 $
+ build SZ (shape #x) $ #idx :-> #x ! #idx)
+
+ ,("build1-sum", adTest $ fromNamed $ lambda @(TArr N1 _) #x $ body $
+ idx0 $ sum1i $
+ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx)
+
+ ,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $
+ idx0 $ sum1i . sum1i $
+ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx)
+
+ -- ,("build-sum", adTest $ fromNamed $ lambda #x $ body $
+ -- idx0 $ sum1i . sum1i $
+ -- build (SS (SS SZ)) (pair (pair nil 2) 3) $ #idx :->
+ -- oper OToFl64 $ snd_ (fst_ #idx) + snd_ #idx)
+
+ -- ,("neural", adTestGen Example.neural $ do
+ -- let tR = STScal STF64
+ -- let genLayer nin nout =
+ -- liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
+ -- <*> genArray tR (ShNil `ShCons` nout)
+ -- nin <- Gen.integral (Range.linear 1 10)
+ -- n1 <- Gen.integral (Range.linear 1 10)
+ -- n2 <- Gen.integral (Range.linear 1 10)
+ -- input <- genArray tR (ShNil `ShCons` nin)
+ -- lay1 <- genLayer nin n1
+ -- lay2 <- genLayer n1 n2
+ -- lay3 <- genArray tR (ShNil `ShCons` n2)
+ -- return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil))
+ ]
main :: IO ()
main = defaultMain [tests]