diff options
-rw-r--r-- | src/AST.hs | 4 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 2 | ||||
-rw-r--r-- | src/CHAD.hs | 4 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 2 | ||||
-rw-r--r-- | src/Interpreter.hs | 2 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 13 | ||||
-rw-r--r-- | src/Language.hs | 3 | ||||
-rw-r--r-- | test/Main.hs | 122 |
8 files changed, 110 insertions, 42 deletions
@@ -150,6 +150,8 @@ data SOp a t where OEq :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) ONot :: SOp (TScal TBool) (TScal TBool) OIf :: SOp (TScal TBool) (TEither TNil TNil) + ORound64 :: SOp (TScal TF64) (TScal TI64) + OToFl64 :: SOp (TScal TI64) (TScal TF64) deriving instance Show (SOp a t) opt2 :: SOp a t -> STy t @@ -162,6 +164,8 @@ opt2 = \case OEq _ -> STScal STBool ONot -> STScal STBool OIf -> STEither STNil STNil + ORound64 -> STScal STI64 + OToFl64 -> STScal STF64 typeOf :: Expr x env t -> STy t typeOf = \case diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 7f60db1..8f1fe67 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -243,3 +243,5 @@ operator OLe{} = (Infix, "<=") operator OEq{} = (Infix, "==") operator ONot = (Prefix, "not") operator OIf = (Prefix, "ifB") +operator ORound64 = (Prefix, "round") +operator OToFl64 = (Prefix, "toFl64") diff --git a/src/CHAD.hs b/src/CHAD.hs index 55d94b1..d05e77f 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -546,6 +546,8 @@ d1op (OLe t) e = EOp ext (OLe t) e d1op (OEq t) e = EOp ext (OEq t) e d1op ONot e = EOp ext ONot e d1op OIf e = EOp ext OIf e +d1op ORound64 e = EOp ext ORound64 e +d1op OToFl64 e = EOp ext OToFl64 e -- | Both primal and dual must be duplicable expressions data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) @@ -563,6 +565,8 @@ d2op op = case op of OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) ONot -> Linear $ \_ -> ENil ext OIf -> Linear $ \_ -> ENil ext + ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 + OToFl64 -> Linear $ \_ -> ENil ext where d2opUnArrangeInt :: SScalTy a -> (D2s a ~ TScal a => D2Op (TScal a) t) diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index f9239e9..3e45ce7 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -100,6 +100,8 @@ dop = \case (EOp ext (OEq t)) ONot -> EOp ext ONot OIf -> EOp ext OIf + ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg) + OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0) where add :: ScalIsNumeric t ~ True => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 4d1358f..8ce1b0e 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -125,6 +125,8 @@ interpretOp op arg = case op of OEq st -> numericIsNum st $ uncurry (==) arg ONot -> not arg OIf -> if arg then Left () else Right () + ORound64 -> round arg + OToFl64 -> fromIntegral arg zeroD2 :: STy t -> Rep (D2 t) zeroD2 typ = case typ of diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index adb4eba..ed307c0 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -39,5 +39,16 @@ type family RepAcDense t where -- RepAcDense (TScal sty) = ScalRep sty -- RepAcDense (TAccum t) = TypeError (Text "RepAcDense: Nested accumulators") -newtype Value t = Value (Rep t) +newtype Value t = Value { unValue :: Rep t } +liftV :: (Rep a -> Rep b) -> Value a -> Value b +liftV f (Value x) = Value (f x) + +liftV2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c +liftV2 f (Value x) (Value y) = Value (f x y) + +vPair :: Value a -> Value b -> Value (TPair a b) +vPair = liftV2 (,) + +vUnpair :: Value (TPair a b) -> (Value a, Value b) +vUnpair (Value (x, y)) = (Value x, Value y) diff --git a/src/Language.hs b/src/Language.hs index cdc6d6b..80de713 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE TypeOperators #-} module Language ( @@ -22,7 +23,7 @@ infixr 0 :-> body :: NExpr env t -> NFun env env t body = NBody -lambda :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t +lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t lambda = NLam diff --git a/test/Main.hs b/test/Main.hs index 986c8a0..d90d9cd 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,20 +1,21 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} +-- {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} +-- {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Main where import Data.Bifunctor -import qualified Data.Dependent.Map as DMap -import Data.Dependent.Map (DMap) -import Data.List (intercalate) +-- import qualified Data.Dependent.Map as DMap +-- import Data.Dependent.Map (DMap) +import Data.Foldable (toList) +import Data.List (intercalate, intersperse) import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range @@ -52,7 +53,7 @@ gradientByCHAD = \env term input -> dterm = freezeRet descr (drev descr term) (EConst ext STF64 1.0) input1 = toPrimalE env input (_out, grad) = interpretOpen input1 dterm - in unTup (\(Value (x, y)) -> (Value x, Value y)) (d2e env) (Value grad) + in unTup vUnpair (d2e env) (Value grad) where makeMergeDescr :: SList STy env' -> Descr env' (MapMerge env') makeMergeDescr SNil = DTop @@ -127,17 +128,18 @@ genShape = \n -> do shapeDiv ShNil _ = ShNil shapeDiv (sh `ShCons` n) f = shapeDiv sh f `ShCons` (n `div` f) +genArray :: STy a -> Shape n -> Gen (Value (TArr n a)) +genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> genValue t) + genValue :: STy a -> Gen (Value a) genValue = \case STNil -> return (Value ()) - STPair a b -> lv2 (,) <$> genValue a <*> genValue b - STEither a b -> Gen.choice [lv1 Left <$> genValue a - ,lv1 Right <$> genValue b] + STPair a b -> liftV2 (,) <$> genValue a <*> genValue b + STEither a b -> Gen.choice [liftV Left <$> genValue a + ,liftV Right <$> genValue b] STMaybe t -> Gen.choice [return (Value Nothing) - ,lv1 Just <$> genValue t] - STArr n t -> do - sh <- genShape n - Value <$> arrayGenerateLinM sh (\_ -> (\(Value x) -> x) <$> genValue t) + ,liftV Just <$> genValue t] + STArr n t -> genShape n >>= genArray t STScal sty -> case sty of STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) @@ -145,39 +147,33 @@ genValue = \case STI64 -> Value <$> Gen.integral (Range.linearFrom 0 (-10) 10) STBool -> Gen.choice [return (Value False), return (Value True)] STAccum{} -> error "Cannot generate inputs for accumulators" - where - lv1 :: (Rep a -> Rep b) -> Value a -> Value b - lv1 f (Value x) = Value (f x) - - lv2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c - lv2 f (Value x) (Value y) = Value (f x y) genEnv :: SList STy env -> Gen (SList Value env) genEnv SNil = return SNil genEnv (t `SCons` env) = SCons <$> genValue t <*> genEnv env -data TemplateVar n = TemplateVar (SNat n) String - deriving (Show) +-- data TemplateVar n = TemplateVar (SNat n) String +-- deriving (Show) -data Template t where - TpShape :: TemplateVar n -> STy t -> Template (TArr n t) - TpAny :: STy t -> Template t - TpPair :: Template a -> Template b -> Template (TPair a b) -deriving instance Show (Template t) +-- data Template t where +-- TpShape :: TemplateVar n -> STy t -> Template (TArr n t) +-- TpAny :: STy t -> Template t +-- TpPair :: Template a -> Template b -> Template (TPair a b) +-- deriving instance Show (Template t) -data ShapeConstraint n = ShapeAtLeast (Shape n) - deriving (Show) +-- data ShapeConstraint n = ShapeAtLeast (Shape n) +-- deriving (Show) -genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t) -genTemplate = _ +-- genTemplate :: DMap TemplateVar Shape -> Template t -> Gen (Value t) +-- genTemplate = _ -genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env) -genEnvTemplateExact shapes env = _ +-- genEnvTemplateExact :: DMap TemplateVar Shape -> SList Template env -> Gen (SList Value env) +-- genEnvTemplateExact shapes env = _ -genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env) -genEnvTemplate constrs env = do - shapes <- DMap.traverseWithKey _ constrs - genEnvTemplateExact shapes env +-- genEnvTemplate :: DMap TemplateVar ShapeConstraint -> SList Template env -> Gen (SList Value env) +-- genEnvTemplate constrs env = do +-- shapes <- DMap.traverseWithKey _ constrs +-- genEnvTemplateExact shapes env showValue :: Int -> STy t -> Rep t -> ShowS showValue _ STNil () = showString "()" @@ -186,7 +182,11 @@ showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y showValue _ (STMaybe _) Nothing = showString "Nothing" showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x -showValue d (STArr _ t) arr = showsPrec d (fmap (\x -> showValue 0 t x "") arr) -- TODO: improve +showValue d (STArr _ t) arr = showParen (d > 10) $ + showString "arrayFromList " . showsPrec 11 (arrayShape arr) + . showString " [" + . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr)) + . showString "]" showValue _ (STScal sty) x = case sty of STF32 -> shows x STF64 -> shows x @@ -203,9 +203,18 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs adTest :: forall env. KnownEnv env => Ex env (TScal TF64) -> Property -adTest expr = property $ do +adTest = flip adTestGen (genEnv (knownEnv @env)) + +-- adTestTp :: forall env. KnownEnv env +-- => DMap TemplateVar ShapeConstraint -> SList Template env +-- -> Ex env (TScal TF64) -> Property +-- adTestTp envConstrs envTp = adTestGen (genEnvTemplate envConstrs envTp) + +adTestGen :: forall env. KnownEnv env + => Ex env (TScal TF64) -> Gen (SList Value env) -> Property +adTestGen expr envGenerator = property $ do let env = knownEnv @env - input <- forAllWith (showEnv env) $ genEnv env + input <- forAllWith (showEnv env) envGenerator let gradFwd = gradientByForward knownEnv expr input gradCHAD = gradientByCHAD' knownEnv expr input scFwd = envScalars env gradFwd @@ -219,7 +228,40 @@ adTest expr = property $ do tests :: IO Bool tests = checkParallel $ Group "AD" [("id", adTest $ fromNamed $ lambda #x $ body $ #x) - ,("neural", adTest Example.neural)] + + ,("sum-vec", adTest $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x)) + + ,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $ + idx0 $ + build SZ (shape #x) $ #idx :-> #x ! #idx) + + ,("build1-sum", adTest $ fromNamed $ lambda @(TArr N1 _) #x $ body $ + idx0 $ sum1i $ + build (SS SZ) (shape #x) $ #idx :-> #x ! #idx) + + ,("build2-sum", adTest $ fromNamed $ lambda @(TArr N2 _) #x $ body $ + idx0 $ sum1i . sum1i $ + build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx) + + -- ,("build-sum", adTest $ fromNamed $ lambda #x $ body $ + -- idx0 $ sum1i . sum1i $ + -- build (SS (SS SZ)) (pair (pair nil 2) 3) $ #idx :-> + -- oper OToFl64 $ snd_ (fst_ #idx) + snd_ #idx) + + -- ,("neural", adTestGen Example.neural $ do + -- let tR = STScal STF64 + -- let genLayer nin nout = + -- liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) + -- <*> genArray tR (ShNil `ShCons` nout) + -- nin <- Gen.integral (Range.linear 1 10) + -- n1 <- Gen.integral (Range.linear 1 10) + -- n2 <- Gen.integral (Range.linear 1 10) + -- input <- genArray tR (ShNil `ShCons` nin) + -- lay1 <- genLayer nin n1 + -- lay2 <- genLayer n1 n2 + -- lay3 <- genArray tR (ShNil `ShCons` n2) + -- return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) + ] main :: IO () main = defaultMain [tests] |