{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Main where import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State import Data.Bifunctor import Data.Int (Int64) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import qualified Data.Text as T import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range import Test.Framework import Array import AST hiding ((.>)) import AST.Pretty import AST.UnMonoid import CHAD.Top import CHAD.Types import CHAD.Types.ToTan import Compile import qualified Example import qualified Example.GMM as Example import Example.Types import ForwardAD import ForwardAD.DualNumbers import Interpreter import Interpreter.Rep import Language import Simplify data SimplIters = SimplIters Int | SimplFix deriving (Show) simplifyIters :: SimplIters -> SList STy env -> Ex env t -> Ex env t simplifyIters iters env | Dict <- envKnown env = case iters of SimplIters n -> simplifyN n SimplFix -> simplifyFix -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD :: forall env. SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (D2E env))) gradientByCHAD simplIters env term input = let dterm = simplifyIters simplIters env $ ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env term (out, grad) = interpretOpen False env input dterm in (ppExpr env dterm, (out, unTup vUnpair (d2e env) (Value grad))) -- In addition to the gradient, also returns the pretty-printed differentiated term. gradientByCHAD' :: SimplIters -> SList STy env -> Ex env R -> SList Value env -> (String, (Double, SList Value (TanE env))) gradientByCHAD' simplIters env term input = second (second (toTanE env input)) $ gradientByCHAD simplIters env term input gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env) gradientByForward art input = drevByFwd art input 1.0 extendDN :: STy t -> Rep t -> Gen (Rep (DN t)) extendDN STNil () = pure () extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y extendDN (STEither a _) (Left x) = Left <$> extendDN a x extendDN (STEither _ b) (Right y) = Right <$> extendDN b y extendDN (STMaybe _) Nothing = pure Nothing extendDN (STMaybe t) (Just x) = Just <$> extendDN t x extendDN (STArr _ t) arr = traverse (extendDN t) arr extendDN (STScal sty) x = case sty of STF32 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) STF64 -> Gen.realFloat (Range.linearFracFrom 0 (-1) 1) >>= \d -> pure (x, d) STI32 -> pure x STI64 -> pure x STBool -> pure x extendDN (STAccum _) _ = error "Accumulators not supported in input program" extendDNE :: SList STy env -> SList Value env -> Gen (SList Value (DNE env)) extendDNE SNil SNil = pure SNil extendDNE (t `SCons` env) (Value x `SCons` val) = SCons <$> (Value <$> extendDN t x) <*> extendDNE env val closeIsh' :: Double -> Double -> Double -> Bool closeIsh' h a b = abs (a - b) < h || (let scale = min (abs a) (abs b) in scale > 10*h && abs (a - b) / scale < h) closeIsh :: Double -> Double -> Bool closeIsh = closeIsh' 1e-5 closeIshT' :: Double -> STy t -> Rep t -> Rep t -> Bool closeIshT' _ STNil () () = True closeIshT' h (STPair a b) (x, y) (x', y') = closeIshT' h a x x' && closeIshT' h b y y' closeIshT' h (STEither a _) (Left x) (Left x') = closeIshT' h a x x' closeIshT' h (STEither _ b) (Right x) (Right x') = closeIshT' h b x x' closeIshT' _ STEither{} _ _ = False closeIshT' _ (STMaybe _) Nothing Nothing = True closeIshT' h (STMaybe a) (Just x) (Just x') = closeIshT' h a x x' closeIshT' _ STMaybe{} _ _ = False closeIshT' h (STArr _ a) arr1 arr2 = arrayShape arr1 == arrayShape arr2 && and (zipWith (closeIshT' h a) (arrayToList arr1) (arrayToList arr2)) closeIshT' _ (STScal STI32) i j = i == j closeIshT' _ (STScal STI64) i j = i == j closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y) closeIshT' h (STScal STF64) x y = closeIsh' h x y closeIshT' _ (STScal STBool) x y = x == y closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators" closeIshT :: STy t -> Rep t -> Rep t -> Bool closeIshT = closeIshT' 1e-5 data a :$ b = a :$ b deriving (Show) ; infixl :$ -- | The type index is just a marker that helps typed holes show what (type of) -- argument this template constraint belongs to. data TplConstr a = C String -- ^ name; @""@ means anonymous Int -- ^ minimum value to generate | NC -- ^ no constraints type family DimNames n where DimNames Z = () DimNames (S Z) = TplConstr (S Z) DimNames (S n) = DimNames n :$ TplConstr (S n) type family Tpl t where Tpl (TArr n t) = DimNames n Tpl (TPair a b) = (Tpl a, Tpl b) Tpl (TScal TI32) = TplConstr TI32 Tpl (TScal TI64) = TplConstr TI64 -- If you add equations here, don't forget to update genValue! It currently -- just emptyTpl's things out. Tpl _ = () data a :& b = a :& b deriving (Show) ; infixl :& type family TemplateE env where TemplateE '[] = () TemplateE '[t] = Tpl t TemplateE (t : ts) = TemplateE ts :& Tpl t emptyDimNames :: SNat n -> DimNames n emptyDimNames SZ = () emptyDimNames (SS SZ) = NC emptyDimNames (SS n@SS{}) = emptyDimNames n :$ NC emptyTpl :: STy t -> Tpl t emptyTpl (STArr n _) = emptyDimNames n emptyTpl (STPair a b) = (emptyTpl a, emptyTpl b) emptyTpl (STScal STI32) = NC emptyTpl (STScal STI64) = NC emptyTpl (STScal STF32) = () emptyTpl (STScal STF64) = () emptyTpl (STScal STBool) = () emptyTpl _ = error "too lazy" emptyTemplateE :: SList STy env -> TemplateE env emptyTemplateE SNil = () emptyTemplateE (t `SCons` SNil) = emptyTpl t emptyTemplateE (t `SCons` ts@SCons{}) = emptyTemplateE ts :& emptyTpl t genShape :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n) genShape = \n tpl -> do sh <- genShapeNaive n tpl let sz = shapeSize sh factor = sz `div` 100 + 1 return (shapeDiv sh tpl factor) where genShapeNaive :: SNat n -> DimNames n -> StateT (Map String Int) Gen (Shape n) genShapeNaive SZ () = return ShNil genShapeNaive (SS SZ) name = ShCons ShNil <$> genNamedDim name genShapeNaive (SS n@SS{}) (tpl :$ name) = ShCons <$> genShapeNaive n tpl <*> genNamedDim name genNamedDim :: TplConstr a -> StateT (Map String Int) Gen Int genNamedDim NC = genDim 0 genNamedDim (C "" lo) = genDim lo genNamedDim (C name lo) = gets (Map.lookup name) >>= \case Nothing -> do dim <- genDim lo modify (Map.insert name dim) return dim Just dim -> return dim genDim :: Int -> StateT (Map String Int) Gen Int genDim lo = Gen.integral (Range.linear lo (lo+10)) shapeDiv :: Shape n -> DimNames n -> Int -> Shape n shapeDiv ShNil _ _ = ShNil shapeDiv (ShNil `ShCons` n) ( C _ lo) f = ShNil `ShCons` (max lo (n `div` f)) shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ C _ lo) f = shapeDiv sh tpl f `ShCons` (max lo (n `div` f)) shapeDiv (ShNil `ShCons` n) NC f = ShNil `ShCons` (n `div` f) shapeDiv (sh@ShCons{} `ShCons` n) (tpl :$ NC) f = shapeDiv sh tpl f `ShCons` (n `div` f) genArray :: STy a -> Shape n -> Gen (Value (TArr n a)) genArray t sh = Value <$> arrayGenerateLinM sh (\_ -> unValue <$> evalStateT (genValue t (emptyTpl t)) mempty) genValue :: forall t. STy t -> Tpl t -> StateT (Map String Int) Gen (Value t) genValue topty tpl = case topty of STNil -> return (Value ()) STPair a b -> liftV2 (,) <$> genValue a (fst tpl) <*> genValue b (snd tpl) STEither a b -> Gen.choice [liftV Left <$> genValue a (emptyTpl a) ,liftV Right <$> genValue b (emptyTpl b)] STMaybe t -> Gen.choice [return (Value Nothing) ,liftV Just <$> genValue t (emptyTpl t)] STArr n t -> genShape n tpl >>= lift . genArray t STScal sty -> case sty of STF32 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) STF64 -> Value <$> Gen.realFloat (Range.linearFracFrom 0 (-10) 10) STI32 -> genInt STI64 -> genInt STBool -> Gen.choice [return (Value False), return (Value True)] STAccum{} -> error "Cannot generate inputs for accumulators" where genInt :: (Integral (Rep t), Tpl t ~ TplConstr _q) => StateT (Map String Int) Gen (Value t) genInt = do let gen lo = Gen.integral (Range.linearFrom 0 lo (max 10 (lo + 10))) val <- case tpl of NC -> gen (-10) C name lo -> gets (Map.lookup name) >>= \case Nothing -> do val <- fromIntegral @Int @(Rep t) <$> gen lo modify (Map.insert name (fromIntegral @(Rep t) @Int val)) return val Just val -> return (fromIntegral @Int @(Rep t) val) return (Value val) genEnv :: SList STy env -> TemplateE env -> StateT (Map String Int) Gen (SList Value env) genEnv SNil () = return SNil genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl data TypedValue t = TypedValue (STy t) (Rep t) instance Show (TypedValue t) where showsPrec d (TypedValue t x) = showValue d t x compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree compileTest name (expr :: Ex env t) = compileTestTp name (emptyTemplateE (knownEnv @env)) expr compileTestTp :: KnownEnv env => TestName -> TemplateE env -> Ex env t -> TestTree compileTestTp name tmpl expr = compileTestGen name expr (evalStateT (genEnv knownEnv tmpl) mempty) compileTestGen :: KnownEnv env => TestName -> Ex env t -> Gen (SList Value env) -> TestTree compileTestGen name expr envGenerator = let env = knownEnv t = typeOf expr in withCompiled env expr $ \fun -> testProperty name $ property $ do input <- forAllWith (showEnv env) envGenerator let resI = interpretOpen False env input expr resC <- evalIO $ fun input let cmp (TypedValue _ x) (TypedValue _ y) = closeIshT' 1e-8 t x y diff (TypedValue t resI) cmp (TypedValue t resC) adTest :: forall env. KnownEnv env => TestName -> Ex env R -> TestTree adTest name = adTestCon name (const True) adTestCon :: forall env. KnownEnv env => TestName -> (SList Value env -> Bool) -> Ex env R -> TestTree adTestCon name constr term = let env = knownEnv in adTestGen name term (Gen.filter constr (evalStateT (genEnv env (emptyTemplateE env)) mempty)) adTestTp :: forall env. KnownEnv env => TestName -> TemplateE env -> Ex env R -> TestTree adTestTp name tmpl term = adTestGen name term (evalStateT (genEnv knownEnv tmpl) mempty) adTestGen :: forall env. KnownEnv env => TestName -> Ex env R -> Gen (SList Value env) -> TestTree adTestGen name expr envGenerator = let env = knownEnv @env exprS = simplifyFix expr in withCompiled env expr $ \primalfun -> withCompiled env (simplifyFix expr) $ \primalSfun -> testGroupCollapse name [adTestGenPrimal env envGenerator expr exprS primalfun primalSfun ,adTestGenFwd env envGenerator exprS ,adTestGenChad env envGenerator expr exprS primalSfun] adTestGenPrimal :: SList STy env -> Gen (SList Value env) -> Ex env R -> Ex env R -> (SList Value env -> IO Double) -> (SList Value env -> IO Double) -> TestTree adTestGenPrimal env envGenerator expr exprS primalfun primalSfun = testProperty "compile primal" $ property $ do input <- forAllWith (showEnv env) envGenerator let outPrimalI = interpretOpen False env input expr outPrimalC <- evalIO $ primalfun input diff outPrimalI (closeIsh' 1e-8) outPrimalC let outPrimalSI = interpretOpen False env input exprS outPrimalSC <- evalIO $ primalSfun input diff outPrimalSI (closeIsh' 1e-8) outPrimalSC adTestGenFwd :: SList STy env -> Gen (SList Value env) -> Ex env R -> TestTree adTestGenFwd env envGenerator exprS = withCompiled (dne env) (dfwdDN exprS) $ \dnfun -> testProperty "compile fwdAD" $ property $ do input <- forAllWith (showEnv env) envGenerator dinput <- forAllWith (showEnv (dne env)) $ extendDNE env input let (outDNI1, outDNI2) = interpretOpen False (dne env) dinput (dfwdDN exprS) (outDNC1, outDNC2) <- evalIO $ dnfun dinput diff outDNI1 (closeIsh' 1e-8) outDNC1 diff outDNI2 (closeIsh' 1e-8) outDNC2 adTestGenChad :: forall env. SList STy env -> Gen (SList Value env) -> Ex env R -> Ex env R -> (SList Value env -> IO Double) -> TestTree adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = let dtermChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env expr dtermChadS = simplifyFix dtermChad0 dtermSChad0 = ELet ext (EConst ext STF64 1.0) $ chad' defaultConfig env exprS dtermSChadS = simplifyFix dtermSChad0 in withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> withCompiled env (simplifyFix (unMonoid dtermSChadS)) $ \dcompSChadS -> testProperty "chad" $ property $ do annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) -- pack Text for less GC pressure (these values are retained for some reason) diff (T.pack (ppExpr env dtermChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermChad0))) diff (T.pack (ppExpr env dtermSChadS)) (==) (T.pack (ppExpr env (simplifyN 20 dtermSChad0))) input <- forAllWith (showEnv env) envGenerator outPrimal <- evalIO $ primalSfun input let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env) unpackGrad = unTup vUnpair (d2e env) . Value let scFwd = tanEScalars env $ gradientByForward fwdartifactC input let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS scChad = tanEScalars env $ toTanE env input gradChad0 scChadS = tanEScalars env $ toTanE env input gradChadS scSChad = tanEScalars env $ toTanE env input gradSChad0 scSChadS = tanEScalars env $ toTanE env input gradSChadS (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> evalIO (dcompSChadS input) let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS)) -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) annotate (ppExpr env (simplifyFix (unMonoid dtermSChadS))) diff outChad0 closeIsh outPrimal diff outChadS closeIsh outPrimal diff outSChad0 closeIsh outPrimal diff outSChadS closeIsh outPrimal diff outCompSChadS closeIsh outPrimal -- TODO: use closeIshT let closeIshList x y = and (zipWith closeIsh x y) diff scChad closeIshList scFwd diff scChadS closeIshList scFwd diff scSChad closeIshList scFwd diff scSChadS closeIshList scFwd diff scCompSChadS closeIshList scFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) gen_gmm :: Gen (SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64]) gen_gmm = do -- The input ranges here are completely arbitrary. let tR = STScal STF64 kN <- Gen.integral (Range.linear 1 8) kD <- Gen.integral (Range.linear 1 8) kK <- Gen.integral (Range.linear 1 8) let i2i64 = fromIntegral @Int @Int64 valpha <- genArray tR (ShNil `ShCons` kK) vM <- genArray tR (ShNil `ShCons` kK `ShCons` kD) vQ <- genArray tR (ShNil `ShCons` kK `ShCons` kD) vL <- genArray tR (ShNil `ShCons` kK `ShCons` (kD * (kD - 1) `div` 2)) vX <- genArray tR (ShNil `ShCons` kN `ShCons` kD) vgamma <- Gen.realFloat (Range.linearFracFrom 0 (-10) 10) vm <- Gen.integral (Range.linear 0 5) let k1 = 0.5 * fromIntegral (kN * kD) * log (2 * pi) k2 = 0.5 * vgamma * vgamma k3 = 0.42 -- don't feel like multigammaing today return (Value k3 `SCons` Value k2 `SCons` Value k1 `SCons` Value vm `SCons` vX `SCons` vL `SCons` vQ `SCons` vM `SCons` valpha `SCons` Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` SNil) gen_neural :: Gen (SList Value [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)]) gen_neural = do let tR = STScal STF64 let genLayer nin nout = liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) <*> genArray tR (ShNil `ShCons` nout) nin <- Gen.integral (Range.linear 1 10) n1 <- Gen.integral (Range.linear 1 10) n2 <- Gen.integral (Range.linear 1 10) input <- genArray tR (ShNil `ShCons` nin) lay1 <- genLayer nin n1 lay2 <- genLayer n1 n2 lay3 <- genArray tR (ShNil `ShCons` n2) return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) term_build1_sum :: Ex '[TVec R] R term_build1_sum = fromNamed $ lambda #x $ body $ idx0 $ sum1i $ build (SS SZ) (shape #x) $ #idx :-> #x ! #idx term_pairs :: Ex [R, R] R term_pairs = fromNamed $ lambda #x $ lambda #y $ body $ let_ #p (pair #x #y) $ let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $ fst_ #q * #x + snd_ #q * fst_ #p term_sparse :: Ex '[TVec R] R term_sparse = fromNamed $ lambda #inp $ body $ let_ #n (snd_ (shape #inp)) $ let_ #arr (build1 #n (#i :-> #inp ! pair nil #i)) $ let_ #a (build1 #n (#i :-> #arr ! pair nil 2)) $ let_ #b (build1 #n (#i :-> #arr ! pair nil 3)) $ let_ #c (build1 #n (#i :-> #arr ! pair nil 4)) $ idx0 (sum1i #a) + idx0 (sum1i #b) + idx0 (sum1i #c) term_regression_simpl1 :: Ex '[TVec R] R term_regression_simpl1 = fromNamed $ lambda #q $ body $ idx0 $ sum1i $ build (SS SZ) (shape #q) $ #idx :-> let_ #j (snd_ #idx) $ if_ (#j .== 0) (#q ! pair nil 0) (if_ (#j .== #j) 1.0 2.0) term_mulmatvec :: Ex [TVec R, TMat R] R term_mulmatvec = fromNamed $ lambda #mat $ lambda #vec $ body $ idx0 $ sum1i $ let_ #hei (snd_ (fst_ (shape #mat))) $ let_ #wid (snd_ (shape #mat)) $ build1 #hei $ #i :-> idx0 (sum1i (build1 #wid $ #j :-> #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) term_arr_rebind :: Ex '[I64, TVec R] R term_arr_rebind = fromNamed $ lambda #a $ lambda #k $ body $ let_ #n (if_ (#k .< length_ #a) #k (length_ #a)) $ let_ #b (build1 #n (#i :-> #a ! pair nil #i)) $ let_ #p (if_ (#n `mod_` 2 .== 1) (pair #a #b) (pair (map_ (#x :-> #x + 1) #a) #b)) $ if_ (#n `mod_` 3 .== 1) (idx0 (sum1i (snd_ #p))) (let_ #b' (snd_ #p) $ idx0 (sum1i #b') * idx0 (sum1i (map_ (#x :-> 2 * #x) #b'))) -- This simplifies away to a pointless test, but is helpful for debugging what -- term_arr_rebind is supposed to test in a REPL term_arr_rebind_simple :: Ex '[TVec R] R term_arr_rebind_simple = fromNamed $ lambda #a $ body $ let_ #b (build1 (length_ #a) (#i :-> 5 * (#a ! pair nil #i))) $ let_ #c #b $ let_ #d #c $ idx0 (sum1i #d) tests_Compile :: TestTree tests_Compile = testGroup "Compile" [compileTest "accum f64" $ fromNamed $ lambda #b $ lambda #x $ body $ with @R 0.0 $ #ac :-> if_ #b (accum SAPHere nil #x #ac) nil ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $ with @(TPair R R) nothing $ #ac :-> let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $ nil ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda #x $ body $ let_ #len (snd_ (shape #x)) $ with @(TVec R) nothing $ #ac :-> let_ #_ (if_ #b (accum (SAPArrIdx SAPHere (SS SZ)) (pair (pair (pair nil 2) (pair nil #len)) nil) 6.0 #ac) nil) $ let_ #_ (accum SAPHere nil (just #x) #ac) $ nil ] tests_AD :: TestTree tests_AD = testGroup "AD" [adTest "id" $ fromNamed $ lambda #x $ body $ #x ,adTest "idx0" $ fromNamed $ lambda #x $ body $ idx0 #x ,adTest "operators" $ fromNamed $ lambda #x $ lambda #y $ body $ let_ #i (round_ #x) $ let_ #j (round_ #y) $ let_ #a1 (#x + #y) $ let_ #a2 (#x - #y) $ let_ #a3 (#x * #y) $ let_ #a4 (#x / (#y * #y + 1)) $ let_ #b1 (#i + #j) $ let_ #b2 (#i - #j) $ let_ #b3 (#i * #j) $ let_ #b4 (#i `idiv` (#j * #j + 1)) $ #a1 + #a2 + #a3 + #a4 + toFloat_ (#b1 + #b2 + #b3 + #b4) ,adTest "order-of-operations" $ fromNamed $ body $ toFloat_ (3 * (3 `idiv` 2)) -- Compile had a pretty-printing bug at some point ,adTest "sum-vec" $ fromNamed $ lambda #x $ body $ idx0 (sum1i #x) ,adTest "sum-replicate" $ fromNamed $ lambda #x $ body $ idx0 $ sum1i $ replicate1i 10 #x ,adTest "pairs" term_pairs ,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $ idx0 $ build SZ nil $ #idx :-> const_ 0.0 ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $ idx0 $ build SZ (shape #x) $ #idx :-> #x ! #idx ,adTest "build1-sum" term_build1_sum ,adTest "build2-sum" $ fromNamed $ lambda @(TMat _) #x $ body $ idx0 $ sum1i . sum1i $ build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx ,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TMat R) #x $ body $ idx0 $ sum1i $ maximum1i #x ,adTestCon "minimum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $ fromNamed $ lambda @(TMat R) #x $ body $ idx0 $ sum1i $ minimum1i #x ,adTest "unused" $ fromNamed $ lambda @(TVec R) #x $ body $ let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $ 42 ,adTestTp "sparse" (C "" 5) term_sparse -- Regression test for a simplifier bug (89b78d4) ,adTestTp "regression-simpl1" (C "" 1) term_regression_simpl1 -- Regression test for refcounts when indexing in nested arrays ,adTestTp "regression-idx1" (C "" 1 :$ C "" 1) $ fromNamed $ lambda @(TMat R) #L $ body $ if_ (const_ @TI64 1 .> 0) (idx0 $ sum1i (build1 1 $ #_ :-> idx0 (sum1i (build1 1 $ #_ :-> #L ! pair (pair nil 0) 0 * #L ! pair (pair nil 0) 0)))) 42 ,adTest "arr-rebind-simple" term_arr_rebind_simple ,adTestTp "arr-rebind" (NC :& C "" 0) term_arr_rebind ,adTestGen "neural" Example.neural gen_neural ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) gen_neural ,adTestTp "logsumexp" (C "" 1) $ fromNamed $ lambda @(TVec _) #vec $ body $ let_ #m (maximum1i #vec) $ log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m ,adTestTp "mulmatvec" ((NC :$ C "n" 0) :& C "n" 0) term_mulmatvec ,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm ,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) gen_gmm ,adTestGen "gmm" (Example.gmmObjective False) gen_gmm ,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) gen_gmm ] main :: IO () main = defaultMain $ testGroup "All" [tests_Compile ,tests_AD]