From 3fd8d35cca2a23c137934a170c67e8ce310edf13 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Tue, 29 Apr 2025 15:54:12 +0200 Subject: Complete monoidal accumulator rewrite --- bench/Main.hs | 21 ----- src/CHAD/Types/ToTan.hs | 5 ++ src/Compile.hs | 205 ++++++++++++++++++++++++++++++------------------ src/Example/GMM.hs | 4 +- src/ForwardAD.hs | 28 +++++++ src/Interpreter.hs | 100 ++++++++++------------- src/Interpreter/Rep.hs | 40 ++++++++-- src/Simplify.hs | 2 +- test/Main.hs | 88 +++++++++++++-------- 9 files changed, 294 insertions(+), 199 deletions(-) diff --git a/bench/Main.hs b/bench/Main.hs index af83ef7..358ba31 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -12,7 +12,6 @@ module Main where import Control.DeepSeq import Criterion.Main -import Data.Coerce import Data.Int (Int64) import Data.Kind (Constraint) import GHC.Exts (withDict) @@ -38,26 +37,6 @@ gradCHAD config term = simplifyFix $ unMonoid $ simplifyFix $ ELet ext (EConst ext STF64 1.0) $ chad' config knownEnv term -instance KnownTy t => NFData (Value t) where - rnf = \(Value x) -> go (knownTy @t) x - where - go :: STy t' -> Rep t' -> () - go STNil () = () - go (STPair a b) (x, y) = go a x `seq` go b y - go (STEither a _) (Left x) = go a x - go (STEither _ b) (Right y) = go b y - go (STMaybe _) Nothing = () - go (STMaybe t) (Just x) = go t x - go (STArr (_ :: SNat n) (t :: STy t2)) arr = - withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr) - go (STScal t) x = case t of - STI32 -> rnf x - STI64 -> rnf x - STF32 -> rnf x - STF64 -> rnf x - STBool -> rnf x - go STAccum{} _ = error "Cannot rnf accumulators" - type AllNFDataRep :: [Ty] -> Constraint type family AllNFDataRep env where AllNFDataRep '[] = () diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs index f843206..87c01cb 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Types/ToTan.hs @@ -40,3 +40,8 @@ toTan typ primal der = case typ of STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" + STLEither t1 t2 -> case (primal, der) of + (_, Nothing) -> Nothing + (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d)) + (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) + _ -> error "Primal and cotangent disagree on LEither alternative" diff --git a/src/Compile.hs b/src/Compile.hs index 503c342..6ba3a39 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1001,16 +1001,6 @@ compile' env = \case return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] EAccum _ t prj eidx eval eacc -> do - nameidx <- compileAssign "acidx" env eidx - nameval <- compileAssign "acval" env eval - - -- Generate the variable manually because this one has to be non-const. - -- TODO: old code: - -- eacc' <- compile' env eacc - -- nameacc <- genName' "acac" - -- emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc' - nameacc <- compileAssign "acac" env eacc - let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a -- full zero array with the given zero info (for the type SMTArr n t1). initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM () @@ -1041,77 +1031,66 @@ compile' env = \case initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi initZeroFromMemset SMTScal{} = Nothing - let -- initZero (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) - initZero :: SMTy a -> String -> String -> CompM () - initZero SMTNil _ _ = return () - initZero (SMTPair t1 t2) v vzi = do - initZero t1 (v++".a") (vzi++".a") - initZero t2 (v++".b") (vzi++".b") - initZero SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0") - initZero SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0") - initZero (SMTArr n t1) v vzi = initZeroArray n t1 v vzi - initZero (SMTScal sty) v _ = case sty of + let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) + initZeroZI :: SMTy a -> String -> String -> CompM () + initZeroZI SMTNil _ _ = return () + initZeroZI (SMTPair t1 t2) v vzi = do + initZeroZI t1 (v++".a") (vzi++".a") + initZeroZI t2 (v++".b") (vzi++".b") + initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0") + initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0") + initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi + initZeroZI (SMTScal sty) v _ = case sty of STI32 -> emit $ SAsg v (CELit "0") STI64 -> emit $ SAsg v (CELit "0l") STF32 -> emit $ SAsg v (CELit "0.0f") STF64 -> emit $ SAsg v (CELit "0.0") - let -- | Dereference an accumulation value. Sparse components encountered - -- along the way are initialised before proceeding downwards. At the - -- point where we have the projected accumulator position available, - -- the handler will be invoked with a variable name pointing to the - -- projected position. - -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (handler) - accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> (String -> CompM ()) -> CompM () - accumRef _ SAPHere v _ k = k v - - accumRef (SMTPair ta _) (SAPFst prj') v i k = accumRef ta prj' (v++".a") (i++".a") k - accumRef (SMTPair _ tb) (SAPSnd prj') v i k = accumRef tb prj' (v++".b") (i++".b") k - - accumRef (SMTLEither ta _) (SAPLeft prj') v i k = do - ((), stmtsInit1) <- scope $ initZero ta (v++".l") i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty - accumRef ta prj' (v++".l") i k - accumRef (SMTLEither _ tb) (SAPRight prj') v i k = do - ((), stmtsInit2) <- scope $ initZero tb (v++".r") i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "2")) <> stmtsInit2) mempty - accumRef tb prj' (v++".r") i k - - accumRef (SMTMaybe tj) (SAPJust prj') v i k = do - ((), stmtsInit1) <- scope $ initZero tj (v++".j") i - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty - accumRef tj prj' (v++".j") i k - - accumRef (SMTArr n t') (SAPArrIdx prj') v i k = do - when emitChecks $ do - let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" - forM_ (zip3 [0::Int ..] - (indexTupleComponents n (i++".a.a")) - (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do - let a .||. b = CEBinop a "||" b - emit $ SIf (CEBinop ixcomp "<" (CELit "0") - .||. - CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) - .||. - CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ - "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ - v ++ ".buf" ++ - concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++ - "); " ++ - "return false;") - mempty - - accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") k + let -- Initialise an uninitialised accumulation value, potentially already + -- with the addend, potentially to zero depending on the nature of the + -- projection. + -- 1. If the projection indexes only through dense monoids before + -- reaching SAPHere, the thing cannot be initialised to zero with + -- only an AcIdx; it would need to model a zero after the addend, + -- which is stupid and redundant. In this case, we return Left: + -- (accumulation value) (AcIdx value) (addend value). + -- The addend is copied, not consumed. (We can't reliably _always_ + -- consume it, so it's not worth trying to do it sometimes.) + -- 2. Otherwise, a sparse monoid is found along the way, and we can + -- initalise the dense prefix of the path to zero by setting the + -- indexed-through sparse value to a sparse zero. Afterwards, the + -- main recursion can proceed further. In this case, we return + -- Right: (accumulation value) (AcIdx value) + -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type) + initZeroChunk :: SMTy a -> SAcPrj p a b + -> Either (String -> String -> String -> CompM ()) -- dense initialisation with addend + (String -> String -> CompM ()) -- zero initialisation of sparse chunk + initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of + -- reached target before the first sparse constructor + (t1 , SAPHere ) -> Left $ \v _ addend -> do + incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend + emit $ SAsg v (CELit addend) + -- sparse types + (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") + (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") + -- dense types + (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do + f (v++".a") (i++".a") + initZeroZI t2 (v++".b") (i++".b") + (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do + initZeroZI t1 (v++".a") (i++".a") + f (v++".b") (i++".b") + (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do + initZeroArray n t1 v (i++".a.b") + linidxvar <- genName' "li" + emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a")) + f (v++".buf->xs["++linidxvar++"]") (i++".b") + where + applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i + applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i let -- Add a value (s) into an existing accumulation value (d). If a sparse - -- component of d is encountered, s is simply written there. + -- component of d is encountered, s is copied there. add :: SMTy a -> String -> String -> CompM () add SMTNil _ _ = return () add (SMTPair t1 t2) d s = do @@ -1165,16 +1144,88 @@ compile' env = \case mempty shsizename <- genName' "acshsz" - emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j")) + emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s) ivar <- genName' "i" ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]") emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) stmts1 add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";" + let -- | Dereference an accumulation value and add a given value to that + -- position. Sparse components encountered along the way are + -- initialised before proceeding downwards. + -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there) + accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM () + accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend + + accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend + accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend + + accumRef (SMTLEither ta tb) prj0 v i addend = do + let chunkres = case prj0 of SAPLeft prj' -> initZeroChunk ta prj' + SAPRight prj' -> initZeroChunk tb prj' + subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r") + tagval = case prj0 of SAPLeft{} -> "1" + SAPRight{} -> "2" + ((), stmtsAdd) <- scope $ case prj0 of SAPLeft prj' -> accumRef ta prj' subv i addend + SAPRight prj' -> accumRef tb prj' subv i addend + case chunkres of + Left densef -> do + ((), stmtsSet) <- scope $ densef subv i addend + emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) + (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet) + stmtsAdd -- TODO: emit check for consistency of tags? + Right sparsef -> do + ((), stmtsInit) <- scope $ sparsef subv i + emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) + (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty + forM_ stmtsAdd emit + + accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do + case initZeroChunk tj prj' of + Left densef -> do + ((), stmtsSet1) <- scope $ densef (v++".j") i addend + ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend + emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) + (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1) + stmtsAdd1 + Right sparsef -> do + ((), stmtsInit1) <- scope $ sparsef (v++".j") i + emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) + (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty + accumRef tj prj' (v++".j") i addend + + accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do + when emitChecks $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + forM_ (zip3 [0::Int ..] + (indexTupleComponents n (i++".a.a")) + (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do + let a .||. b = CEBinop a "||" b + emit $ SIf (CEBinop ixcomp "<" (CELit "0") + .||. + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) + .||. + CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ + "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ + v ++ ".buf" ++ + concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++ + "); " ++ + "return false;") + mempty + + accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend + + nameidx <- compileAssign "acidx" env eidx + nameval <- compileAssign "acval" env eval + nameacc <- compileAssign "acac" env eacc + emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")" - accumRef t prj (nameacc++".buf->ac") nameidx $ \dest -> - add (acPrjTy prj t) dest nameval + accumRef t prj (nameacc++".buf->ac") nameidx nameval emit $ SVerbatim $ "// compile EAccum end" incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs index 12bbd98..206e534 100644 --- a/src/Example/GMM.hs +++ b/src/Example/GMM.hs @@ -31,10 +31,10 @@ import Language -- -- -- The 'wrong' argument, when set to True, changes the objective function to --- one with a bug that makes a certain `build` result unused. This triggers +-- one with a bug that makes a certain `build` result unused. This -- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's -- dense, even though it may be a zero (i.e. empty). The "unused" test in --- test/Main.hs tries to isolate this test, but the wrong version of +-- test/Main.hs tries to isolate this case, but the wrong version of -- gmmObjective is here to check (after that bug is fixed) whether it really -- fixes the original bug. gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index b7036dd..5756f96 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -29,6 +29,7 @@ type family Tan t where Tan (TMaybe t) = TMaybe (Tan t) Tan (TArr n t) = TArr n (Tan t) Tan (TScal t) = TanS t + Tan (TLEither a b) = TLEither (Tan a) (Tan b) type family TanS t where TanS TI32 = TNil @@ -54,6 +55,11 @@ tanty (STScal t) = case t of STF64 -> STScal STF64 STBool -> STNil tanty STAccum{} = error "Accumulators not allowed in input program" +tanty (STLEither a b) = STLEither (tanty a) (tanty b) + +tanenv :: SList STy env -> SList STy (TanE env) +tanenv SNil = SNil +tanenv (t `SCons` env) = tanty t `SCons` tanenv env zeroTan :: STy t -> Rep t -> Rep (Tan t) zeroTan STNil () = () @@ -69,6 +75,9 @@ zeroTan (STScal STF32) _ = 0.0 zeroTan (STScal STF64) _ = 0.0 zeroTan (STScal STBool) _ = () zeroTan STAccum{} _ = error "Accumulators not allowed in input program" +zeroTan (STLEither _ _) Nothing = Nothing +zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) +zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y)) tanScalars :: STy t -> Rep (Tan t) -> [Double] tanScalars STNil () = [] @@ -84,6 +93,9 @@ tanScalars (STScal STF32) x = [realToFrac x] tanScalars (STScal STF64) x = [x] tanScalars (STScal STBool) _ = [] tanScalars STAccum{} _ = error "Accumulators not allowed in input program" +tanScalars (STLEither _ _) Nothing = [] +tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x +tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double] tanEScalars SNil SNil = [] @@ -111,6 +123,10 @@ unzipDN (STScal ty) d = case ty of STF64 -> d STBool -> (d, ()) unzipDN STAccum{} _ = error "Accumulators not allowed in input program" +unzipDN (STLEither a b) d = case d of + Nothing -> (Nothing, Nothing) + Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) + Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y) dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double dotprodTan STNil _ _ = 0.0 @@ -137,6 +153,12 @@ dotprodTan (STScal ty) x y = case ty of STF64 -> x * y STBool -> 0.0 dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" +dotprodTan (STLEither a b) x y = case (x, y) of + (Nothing, _) -> 0.0 -- 0 * y = 0 + (_, Nothing) -> 0.0 -- x * 0 = 0 + (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' + (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' + _ -> error "dotprodTan: incompatible LEither alternatives" -- -- Primal expression must be duplicable -- dnConstE :: STy t -> Ex env t -> Ex env (DN t) @@ -174,6 +196,7 @@ dnConst (STScal t) = case t of STF64 -> (,0.0) STBool -> id dnConst STAccum{} = error "Accumulators not allowed in input program" +dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2)) -- | Given a function that computes the forward derivative for a particular -- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this @@ -204,6 +227,11 @@ dnOnehots (STScal t) x = case t of STF64 -> \f -> f (x, 1.0) STBool -> \_ -> () dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" +dnOnehots (STLEither t1 t2) e = + case e of + Nothing -> \_ -> Nothing + Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) + Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right))) dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) dnConstEnv SNil SNil = SNil diff --git a/src/Interpreter.hs b/src/Interpreter.hs index af11de8..d7916d8 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -35,7 +35,6 @@ import Debug.Trace import Array import AST import AST.Pretty -import CHAD.Types import Data import Interpreter.Rep @@ -253,7 +252,7 @@ withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Re withAccum t _ initval f = AcM $ do accum <- newAcDense t initval out <- unAcM $ f accum - val <- readAcSparse t accum + val <- readAc t accum return (out, val) newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t) @@ -300,81 +299,62 @@ onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) = !linindex = toLinearIndex arrsh arrindex in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i)) -readAcSparse :: SMTy t -> RepAc t -> IO (Rep t) -readAcSparse typ val = case typ of +readAc :: SMTy t -> RepAc t -> IO (Rep t) +readAc typ val = case typ of SMTNil -> return () - SMTPair t1 t2 -> bitraverse (readAcSparse t1) (readAcSparse t2) val - SMTLEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val - SMTMaybe t -> traverse (readAcSparse t) =<< readIORef val - SMTArr _ t -> traverse (readAcSparse t) val + SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val + SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val + SMTMaybe t -> traverse (readAc t) =<< readIORef val + SMTArr _ t -> traverse (readAc t) val SMTScal _ -> readIORef val -accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s () -accumAddSparse typ prj ref idx val = case (typ, prj) of - (STNil, SAPHere) -> return () - - (STPair t1 t2, SAPHere) -> - case val of - Nothing -> return () - Just (val1, val2) -> - realiseMaybeSparse ref ((,) <$> newAcDense t1 val1 - <*> newAcDense t2 val2) - (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1 - accumAddSparse t2 SAPHere ac2 () val2) - (STPair t1 t2, SAPFst prj') -> - realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2) - (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val) - (STPair t1 t2, SAPSnd prj') -> - realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val) - (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val) - - (STEither{}, SAPHere) -> +accumAddDense :: SMTy a -> RepAc a -> Rep a -> AcM s () +accumAddDense typ ref val = case typ of + SMTNil -> return () + SMTPair t1 t2 -> do + accumAddDense t1 (fst ref) (fst val) + accumAddDense t2 (snd ref) (snd val) + SMTLEither{} -> case val of Nothing -> return () Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1 Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2 - (STEither t1 _, SAPLeft prj') -> + SMTMaybe{} -> + case val of + Nothing -> return () + Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' + SMTArr _ t1 -> + forM_ [0 .. arraySize ref - 1] $ \i -> + accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i) + SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) + +accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s () +accumAddSparse typ prj ref idx val = case (typ, prj) of + (_, SAPHere) -> accumAddDense typ ref val + + (SMTPair t1 _, SAPFst prj') -> accumAddSparse t1 prj' (fst ref) (fst idx) val + (SMTPair _ t2, SAPSnd prj') -> accumAddSparse t2 prj' (snd ref) (snd idx) val + + (SMTLEither t1 _, SAPLeft prj') -> realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val Right{} -> error "Mismatched Either in accumAddSparse (r +l)") - (STEither _ t2, SAPRight prj') -> + (SMTLEither _ t2, SAPRight prj') -> realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val Left{} -> error "Mismatched Either in accumAddSparse (l +r)") - (STMaybe{}, SAPHere) -> - case val of - Nothing -> return () - Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' - (STMaybe t1, SAPJust prj') -> - realiseMaybeSparse ref (newAcSparse t1 prj' idx val) - (\ac -> accumAddSparse t1 prj' ac idx val) + (SMTMaybe t1, SAPJust prj') -> + realiseMaybeSparse ref (newAcSparse t1 prj' idx val) + (\ac -> accumAddSparse t1 prj' ac idx val) - (STArr _ t1, SAPHere) -> - case val of - Nothing -> return () - Just val' -> - realiseMaybeSparse ref - (arrayMapM (newAcDense t1) val') - (\ac -> forM_ [0 .. arraySize ac - 1] $ \i -> - accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i)) - (STArr n t1, SAPArrIdx prj') -> - let ((arrindex', arrsh'), idx') = idx + (SMTArr n t1, SAPArrIdx prj') -> + let ((arrindex', ziarr), idx') = idx arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = unTupRepIdx ShNil ShCons n arrsh' + arrsh = arrayShape ziarr linindex = toLinearIndex arrsh arrindex - in realiseMaybeSparse ref - (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx) - (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val) - - (STScal sty, SAPHere) -> AcM $ case sty of - STI32 -> return () - STI64 -> return () - STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STBool -> return () - - (STAccum{}, _) -> error "Accumulators not allowed in source program" + in accumAddSparse t1 prj' (arrayIndexLinear ref linindex) idx' val + realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () realiseMaybeSparse ref makeval modifyval = diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 1226b0c..070ba4c 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -1,11 +1,16 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module Interpreter.Rep where +import Control.DeepSeq +import Data.Coerce (coerce) import Data.List (intersperse, intercalate) import Data.Foldable (toList) import Data.IORef +import GHC.Exts (withDict) import Array import AST @@ -58,12 +63,12 @@ showValue d (STArr _ t) arr = showParen (d > 10) $ . showString " [" . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr)) . showString "]" -showValue _ (STScal sty) x = case sty of - STF32 -> shows x - STF64 -> shows x - STI32 -> shows x - STI64 -> shows x - STBool -> shows x +showValue d (STScal sty) x = case sty of + STF32 -> showsPrec d x + STF64 -> showsPrec d x + STI32 -> showsPrec d x + STI64 -> showsPrec d x + STBool -> showsPrec d x showValue _ (STAccum t) _ = showString $ "" showValue _ (STLEither _ _) Nothing = showString "LNil" showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x @@ -75,3 +80,26 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" showEntries :: SList STy env -> SList Value env -> [String] showEntries SNil SNil = [] showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs + +rnfRep :: STy t -> Rep t -> () +rnfRep STNil () = () +rnfRep (STPair a b) (x, y) = rnfRep a x `seq` rnfRep b y +rnfRep (STEither a _) (Left x) = rnfRep a x +rnfRep (STEither _ b) (Right y) = rnfRep b y +rnfRep (STMaybe _) Nothing = () +rnfRep (STMaybe t) (Just x) = rnfRep t x +rnfRep (STArr (_ :: SNat n) (t :: STy t2)) arr = + withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr) +rnfRep (STScal t) x = case t of + STI32 -> rnf x + STI64 -> rnf x + STF32 -> rnf x + STF64 -> rnf x + STBool -> rnf x +rnfRep STAccum{} _ = error "Cannot rnf accumulators" +rnfRep (STLEither _ _) Nothing = () +rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x +rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y + +instance KnownTy t => NFData (Value t) where + rnf (Value x) = rnfRep (knownTy @t) x diff --git a/src/Simplify.hs b/src/Simplify.hs index 228f265..2a1d3b6 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -10,7 +10,7 @@ {-# LANGUAGE TypeOperators #-} module Simplify ( simplifyN, simplifyFix, - SimplifyConfig(..), simplifyWith, simplifyFixWith, + SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith, ) where import Data.Function (fix) diff --git a/test/Main.hs b/test/Main.hs index afbd79b..f5e4a3c 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -42,6 +42,18 @@ import Language import Simplify +data TypedValue t = TypedValue (STy t) (Rep t) +instance Show (TypedValue t) where + showsPrec d (TypedValue t x) = showValue d t x + +data TypedEnv env = TypedEnv (SList STy env) (SList Value env) +instance Show (TypedEnv env) where + show (TypedEnv env xs) = showEnv env xs + +unTypedEnv :: TypedEnv env -> SList Value env +unTypedEnv (TypedEnv _ xs) = xs + + data SimplIters = SimplIters Int | SimplFix deriving (Show) @@ -67,6 +79,7 @@ gradientByCHAD' simplIters env term input = gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env) gradientByForward art input = drevByFwd art input 1.0 +-- | Generate input tangents for this primal extendDN :: STy t -> Rep t -> Gen (Rep (DN t)) extendDN STNil () = pure () extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y @@ -82,6 +95,9 @@ extendDN (STScal sty) x = case sty of STI64 -> pure x STBool -> pure x extendDN (STAccum _) _ = error "Accumulators not supported in input program" +extendDN (STLEither _ _) Nothing = pure Nothing +extendDN (STLEither a _) (Just (Left x)) = Just . Left <$> extendDN a x +extendDN (STLEither _ b) (Just (Right y)) = Just . Right <$> extendDN b y extendDNE :: SList STy env -> SList Value env -> Gen (SList Value (DNE env)) extendDNE SNil SNil = pure SNil @@ -112,10 +128,19 @@ closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y) closeIshT' h (STScal STF64) x y = closeIsh' h x y closeIshT' _ (STScal STBool) x y = x == y closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators" +closeIshT' _ (STLEither _ _) Nothing Nothing = True +closeIshT' h (STLEither a _) (Just (Left x)) (Just (Left x')) = closeIshT' h a x x' +closeIshT' h (STLEither _ b) (Just (Right y)) (Just (Right y')) = closeIshT' h b y y' +closeIshT' _ STLEither{} _ _ = False closeIshT :: STy t -> Rep t -> Rep t -> Bool closeIshT = closeIshT' 1e-5 +closeIshE :: SList STy t -> SList Value t -> SList Value t -> Bool +closeIshE SNil SNil SNil = True +closeIshE (t `SCons` env) (Value x `SCons` xs) (Value y `SCons` ys) = + closeIshT t x y && closeIshE env xs ys + data a :$ b = a :$ b deriving (Show) ; infixl :$ -- | The type index is just a marker that helps typed holes show what (type of) @@ -218,6 +243,9 @@ genValue topty tpl = case topty of STI64 -> genInt STBool -> Gen.choice [return (Value False), return (Value True)] STAccum{} -> error "Cannot generate inputs for accumulators" + STLEither a b -> Gen.frequency [(1, pure (Value Nothing)) + ,(8, liftV (Just . Left) <$> genValue a (emptyTpl a)) + ,(8, liftV (Just . Right) <$> genValue b (emptyTpl b))] where genInt :: (Integral (Rep t), Tpl t ~ TplConstr _q) => StateT (Map String Int) Gen (Value t) genInt = do @@ -237,10 +265,6 @@ genEnv SNil () = return SNil genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl -data TypedValue t = TypedValue (STy t) (Rep t) -instance Show (TypedValue t) where - showsPrec d (TypedValue t x) = showValue d t x - compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree compileTest name (expr :: Ex env t) = compileTestTp name (emptyTemplateE (knownEnv @env)) expr @@ -337,22 +361,22 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env) unpackGrad = unTup vUnpair (d2e env) . Value - let scFwd = tanEScalars env $ gradientByForward fwdartifactC input + let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0 (outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS (outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0 (outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS - scChad = tanEScalars env $ toTanE env input gradChad0 - scChadS = tanEScalars env $ toTanE env input gradChadS - scSChad = tanEScalars env $ toTanE env input gradSChad0 - scSChadS = tanEScalars env $ toTanE env input gradSChadS + tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0 + tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS + tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0 + tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS (outCompSChadS, gradCompSChadS) <- second unpackGrad <$> evalIO (dcompSChadS input) - let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS + let tansCompSChadS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadS - -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0)) - -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS)) + -- annotate (showEnv (d2e env) gradChad0) + -- annotate (showEnv (d2e env) gradChadS) -- annotate (ppExpr knownEnv expr) -- annotate (ppExpr env dtermChad0) -- annotate (ppExpr env dtermChadS) @@ -362,13 +386,12 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env = diff outSChad0 closeIsh outPrimal diff outSChadS closeIsh outPrimal diff outCompSChadS closeIsh outPrimal - -- TODO: use closeIshT - let closeIshList x y = and (zipWith closeIsh x y) - diff scChad closeIshList scFwd - diff scChadS closeIshList scFwd - diff scSChad closeIshList scFwd - diff scSChadS closeIshList scFwd - diff scCompSChadS closeIshList scFwd + let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2) + diff tansChad closeIshE' tansFwd + diff tansChadS closeIshE' tansFwd + diff tansSChad closeIshE' tansFwd + diff tansSChadS closeIshE' tansFwd + diff tansCompSChadS closeIshE' tansFwd withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) @@ -478,18 +501,25 @@ tests_Compile = testGroup "Compile" nil ,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $ - with @(TPair R R) nothing $ #ac :-> - let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $ + with @(TPair R R) (pair 0.0 0.0) $ #ac :-> + let_ #_ (if_ #b (accum (SAPFst SAPHere) (pair nil nil) 3.0 #ac) nil) $ + let_ #_ (accum SAPHere nil #x #ac) $ + let_ #_ (accum (SAPSnd SAPHere) (pair nil nil) 4.0 #ac) $ + nil + + ,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $ + with @(TMaybe (TPair R R)) nothing $ #ac :-> + let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) (pair nil nil) 3.0 #ac) nil) $ let_ #_ (accum SAPHere nil #x #ac) $ - let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $ + let_ #_ (accum (SAPJust (SAPSnd SAPHere)) (pair nil nil) 4.0 #ac) $ nil - ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda #x $ body $ + ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $ let_ #len (snd_ (shape #x)) $ - with @(TVec R) nothing $ #ac :-> - let_ #_ (if_ #b (accum (SAPArrIdx SAPHere (SS SZ)) (pair (pair (pair nil 2) (pair nil #len)) nil) 6.0 #ac) + with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :-> + let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair (pair nil 2) (build1 #len (#_ :-> nil))) nil) 6.0 #ac) nil) $ - let_ #_ (accum SAPHere nil (just #x) #ac) $ + let_ #_ (accum SAPHere nil #x #ac) $ nil ] @@ -567,8 +597,6 @@ tests_AD = testGroup "AD" ,adTestGen "neural" Example.neural gen_neural - ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) gen_neural - ,adTestTp "logsumexp" (C "" 1) $ fromNamed $ lambda @(TVec _) #vec $ body $ let_ #m (maximum1i #vec) $ @@ -578,11 +606,7 @@ tests_AD = testGroup "AD" ,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm - ,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) gen_gmm - ,adTestGen "gmm" (Example.gmmObjective False) gen_gmm - - ,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) gen_gmm ] main :: IO () -- cgit v1.2.3-70-g09d2