summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-04-29 15:54:12 +0200
committerTom Smeding <tom@tomsmeding.com>2025-04-29 15:54:12 +0200
commit3fd8d35cca2a23c137934a170c67e8ce310edf13 (patch)
tree429fb99f9c1395272f1f9a94bfbc0e003fa39b21
parent919a36f8eed21501357185a90e2b7a4d9eaf7f08 (diff)
Complete monoidal accumulator rewrite
-rw-r--r--bench/Main.hs21
-rw-r--r--src/CHAD/Types/ToTan.hs5
-rw-r--r--src/Compile.hs205
-rw-r--r--src/Example/GMM.hs4
-rw-r--r--src/ForwardAD.hs28
-rw-r--r--src/Interpreter.hs100
-rw-r--r--src/Interpreter/Rep.hs40
-rw-r--r--src/Simplify.hs2
-rw-r--r--test/Main.hs88
9 files changed, 294 insertions, 199 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index af83ef7..358ba31 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -12,7 +12,6 @@ module Main where
import Control.DeepSeq
import Criterion.Main
-import Data.Coerce
import Data.Int (Int64)
import Data.Kind (Constraint)
import GHC.Exts (withDict)
@@ -38,26 +37,6 @@ gradCHAD config term =
simplifyFix $ unMonoid $ simplifyFix $
ELet ext (EConst ext STF64 1.0) $ chad' config knownEnv term
-instance KnownTy t => NFData (Value t) where
- rnf = \(Value x) -> go (knownTy @t) x
- where
- go :: STy t' -> Rep t' -> ()
- go STNil () = ()
- go (STPair a b) (x, y) = go a x `seq` go b y
- go (STEither a _) (Left x) = go a x
- go (STEither _ b) (Right y) = go b y
- go (STMaybe _) Nothing = ()
- go (STMaybe t) (Just x) = go t x
- go (STArr (_ :: SNat n) (t :: STy t2)) arr =
- withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr)
- go (STScal t) x = case t of
- STI32 -> rnf x
- STI64 -> rnf x
- STF32 -> rnf x
- STF64 -> rnf x
- STBool -> rnf x
- go STAccum{} _ = error "Cannot rnf accumulators"
-
type AllNFDataRep :: [Ty] -> Constraint
type family AllNFDataRep env where
AllNFDataRep '[] = ()
diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs
index f843206..87c01cb 100644
--- a/src/CHAD/Types/ToTan.hs
+++ b/src/CHAD/Types/ToTan.hs
@@ -40,3 +40,8 @@ toTan typ primal der = case typ of
STScal sty -> case sty of
STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
STAccum{} -> error "Accumulators not allowed in input program"
+ STLEither t1 t2 -> case (primal, der) of
+ (_, Nothing) -> Nothing
+ (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d))
+ (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d))
+ _ -> error "Primal and cotangent disagree on LEither alternative"
diff --git a/src/Compile.hs b/src/Compile.hs
index 503c342..6ba3a39 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -1001,16 +1001,6 @@ compile' env = \case
return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
EAccum _ t prj eidx eval eacc -> do
- nameidx <- compileAssign "acidx" env eidx
- nameval <- compileAssign "acval" env eval
-
- -- Generate the variable manually because this one has to be non-const.
- -- TODO: old code:
- -- eacc' <- compile' env eacc
- -- nameacc <- genName' "acac"
- -- emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc'
- nameacc <- compileAssign "acac" env eacc
-
let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a
-- full zero array with the given zero info (for the type SMTArr n t1).
initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM ()
@@ -1041,77 +1031,66 @@ compile' env = \case
initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi
initZeroFromMemset SMTScal{} = Nothing
- let -- initZero (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
- initZero :: SMTy a -> String -> String -> CompM ()
- initZero SMTNil _ _ = return ()
- initZero (SMTPair t1 t2) v vzi = do
- initZero t1 (v++".a") (vzi++".a")
- initZero t2 (v++".b") (vzi++".b")
- initZero SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZero SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZero (SMTArr n t1) v vzi = initZeroArray n t1 v vzi
- initZero (SMTScal sty) v _ = case sty of
+ let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
+ initZeroZI :: SMTy a -> String -> String -> CompM ()
+ initZeroZI SMTNil _ _ = return ()
+ initZeroZI (SMTPair t1 t2) v vzi = do
+ initZeroZI t1 (v++".a") (vzi++".a")
+ initZeroZI t2 (v++".b") (vzi++".b")
+ initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0")
+ initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0")
+ initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi
+ initZeroZI (SMTScal sty) v _ = case sty of
STI32 -> emit $ SAsg v (CELit "0")
STI64 -> emit $ SAsg v (CELit "0l")
STF32 -> emit $ SAsg v (CELit "0.0f")
STF64 -> emit $ SAsg v (CELit "0.0")
- let -- | Dereference an accumulation value. Sparse components encountered
- -- along the way are initialised before proceeding downwards. At the
- -- point where we have the projected accumulator position available,
- -- the handler will be invoked with a variable name pointing to the
- -- projected position.
- -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (handler)
- accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> (String -> CompM ()) -> CompM ()
- accumRef _ SAPHere v _ k = k v
-
- accumRef (SMTPair ta _) (SAPFst prj') v i k = accumRef ta prj' (v++".a") (i++".a") k
- accumRef (SMTPair _ tb) (SAPSnd prj') v i k = accumRef tb prj' (v++".b") (i++".b") k
-
- accumRef (SMTLEither ta _) (SAPLeft prj') v i k = do
- ((), stmtsInit1) <- scope $ initZero ta (v++".l") i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
- accumRef ta prj' (v++".l") i k
- accumRef (SMTLEither _ tb) (SAPRight prj') v i k = do
- ((), stmtsInit2) <- scope $ initZero tb (v++".r") i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "2")) <> stmtsInit2) mempty
- accumRef tb prj' (v++".r") i k
-
- accumRef (SMTMaybe tj) (SAPJust prj') v i k = do
- ((), stmtsInit1) <- scope $ initZero tj (v++".j") i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
- accumRef tj prj' (v++".j") i k
-
- accumRef (SMTArr n t') (SAPArrIdx prj') v i k = do
- when emitChecks $ do
- let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
- forM_ (zip3 [0::Int ..]
- (indexTupleComponents n (i++".a.a"))
- (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do
- let a .||. b = CEBinop a "||" b
- emit $ SIf (CEBinop ixcomp "<" (CELit "0")
- .||.
- CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
- .||.
- CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))))
- (pure $ SVerbatim $
- "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
- "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
- v ++ ".buf" ++
- concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++
- "); " ++
- "return false;")
- mempty
-
- accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") k
+ let -- Initialise an uninitialised accumulation value, potentially already
+ -- with the addend, potentially to zero depending on the nature of the
+ -- projection.
+ -- 1. If the projection indexes only through dense monoids before
+ -- reaching SAPHere, the thing cannot be initialised to zero with
+ -- only an AcIdx; it would need to model a zero after the addend,
+ -- which is stupid and redundant. In this case, we return Left:
+ -- (accumulation value) (AcIdx value) (addend value).
+ -- The addend is copied, not consumed. (We can't reliably _always_
+ -- consume it, so it's not worth trying to do it sometimes.)
+ -- 2. Otherwise, a sparse monoid is found along the way, and we can
+ -- initalise the dense prefix of the path to zero by setting the
+ -- indexed-through sparse value to a sparse zero. Afterwards, the
+ -- main recursion can proceed further. In this case, we return
+ -- Right: (accumulation value) (AcIdx value)
+ -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type)
+ initZeroChunk :: SMTy a -> SAcPrj p a b
+ -> Either (String -> String -> String -> CompM ()) -- dense initialisation with addend
+ (String -> String -> CompM ()) -- zero initialisation of sparse chunk
+ initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of
+ -- reached target before the first sparse constructor
+ (t1 , SAPHere ) -> Left $ \v _ addend -> do
+ incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend
+ emit $ SAsg v (CELit addend)
+ -- sparse types
+ (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
+ (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
+ -- dense types
+ (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
+ f (v++".a") (i++".a")
+ initZeroZI t2 (v++".b") (i++".b")
+ (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do
+ initZeroZI t1 (v++".a") (i++".a")
+ f (v++".b") (i++".b")
+ (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
+ initZeroArray n t1 v (i++".a.b")
+ linidxvar <- genName' "li"
+ emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a"))
+ f (v++".buf->xs["++linidxvar++"]") (i++".b")
+ where
+ applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i
+ applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i
let -- Add a value (s) into an existing accumulation value (d). If a sparse
- -- component of d is encountered, s is simply written there.
+ -- component of d is encountered, s is copied there.
add :: SMTy a -> String -> String -> CompM ()
add SMTNil _ _ = return ()
add (SMTPair t1 t2) d s = do
@@ -1165,16 +1144,88 @@ compile' env = \case
mempty
shsizename <- genName' "acshsz"
- emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j"))
+ emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s)
ivar <- genName' "i"
((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")
emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)
stmts1
add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
+ let -- | Dereference an accumulation value and add a given value to that
+ -- position. Sparse components encountered along the way are
+ -- initialised before proceeding downwards.
+ -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there)
+ accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM ()
+ accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend
+
+ accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend
+ accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend
+
+ accumRef (SMTLEither ta tb) prj0 v i addend = do
+ let chunkres = case prj0 of SAPLeft prj' -> initZeroChunk ta prj'
+ SAPRight prj' -> initZeroChunk tb prj'
+ subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r")
+ tagval = case prj0 of SAPLeft{} -> "1"
+ SAPRight{} -> "2"
+ ((), stmtsAdd) <- scope $ case prj0 of SAPLeft prj' -> accumRef ta prj' subv i addend
+ SAPRight prj' -> accumRef tb prj' subv i addend
+ case chunkres of
+ Left densef -> do
+ ((), stmtsSet) <- scope $ densef subv i addend
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet)
+ stmtsAdd -- TODO: emit check for consistency of tags?
+ Right sparsef -> do
+ ((), stmtsInit) <- scope $ sparsef subv i
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty
+ forM_ stmtsAdd emit
+
+ accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do
+ case initZeroChunk tj prj' of
+ Left densef -> do
+ ((), stmtsSet1) <- scope $ densef (v++".j") i addend
+ ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1)
+ stmtsAdd1
+ Right sparsef -> do
+ ((), stmtsInit1) <- scope $ sparsef (v++".j") i
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
+ accumRef tj prj' (v++".j") i addend
+
+ accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do
+ when emitChecks $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ forM_ (zip3 [0::Int ..]
+ (indexTupleComponents n (i++".a.a"))
+ (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do
+ let a .||. b = CEBinop a "||" b
+ emit $ SIf (CEBinop ixcomp "<" (CELit "0")
+ .||.
+ CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
+ .||.
+ CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
+ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
+ v ++ ".buf" ++
+ concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++
+ "); " ++
+ "return false;")
+ mempty
+
+ accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend
+
+ nameidx <- compileAssign "acidx" env eidx
+ nameval <- compileAssign "acval" env eval
+ nameacc <- compileAssign "acac" env eacc
+
emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")"
- accumRef t prj (nameacc++".buf->ac") nameidx $ \dest ->
- add (acPrjTy prj t) dest nameval
+ accumRef t prj (nameacc++".buf->ac") nameidx nameval
emit $ SVerbatim $ "// compile EAccum end"
incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval
diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs
index 12bbd98..206e534 100644
--- a/src/Example/GMM.hs
+++ b/src/Example/GMM.hs
@@ -31,10 +31,10 @@ import Language
-- <https://tomsmeding.com/f/master.pdf>
--
-- The 'wrong' argument, when set to True, changes the objective function to
--- one with a bug that makes a certain `build` result unused. This triggers
+-- one with a bug that makes a certain `build` result unused. This
-- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's
-- dense, even though it may be a zero (i.e. empty). The "unused" test in
--- test/Main.hs tries to isolate this test, but the wrong version of
+-- test/Main.hs tries to isolate this case, but the wrong version of
-- gmmObjective is here to check (after that bug is fixed) whether it really
-- fixes the original bug.
gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index b7036dd..5756f96 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -29,6 +29,7 @@ type family Tan t where
Tan (TMaybe t) = TMaybe (Tan t)
Tan (TArr n t) = TArr n (Tan t)
Tan (TScal t) = TanS t
+ Tan (TLEither a b) = TLEither (Tan a) (Tan b)
type family TanS t where
TanS TI32 = TNil
@@ -54,6 +55,11 @@ tanty (STScal t) = case t of
STF64 -> STScal STF64
STBool -> STNil
tanty STAccum{} = error "Accumulators not allowed in input program"
+tanty (STLEither a b) = STLEither (tanty a) (tanty b)
+
+tanenv :: SList STy env -> SList STy (TanE env)
+tanenv SNil = SNil
+tanenv (t `SCons` env) = tanty t `SCons` tanenv env
zeroTan :: STy t -> Rep t -> Rep (Tan t)
zeroTan STNil () = ()
@@ -69,6 +75,9 @@ zeroTan (STScal STF32) _ = 0.0
zeroTan (STScal STF64) _ = 0.0
zeroTan (STScal STBool) _ = ()
zeroTan STAccum{} _ = error "Accumulators not allowed in input program"
+zeroTan (STLEither _ _) Nothing = Nothing
+zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x))
+zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))
tanScalars :: STy t -> Rep (Tan t) -> [Double]
tanScalars STNil () = []
@@ -84,6 +93,9 @@ tanScalars (STScal STF32) x = [realToFrac x]
tanScalars (STScal STF64) x = [x]
tanScalars (STScal STBool) _ = []
tanScalars STAccum{} _ = error "Accumulators not allowed in input program"
+tanScalars (STLEither _ _) Nothing = []
+tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x
+tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y
tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]
tanEScalars SNil SNil = []
@@ -111,6 +123,10 @@ unzipDN (STScal ty) d = case ty of
STF64 -> d
STBool -> (d, ())
unzipDN STAccum{} _ = error "Accumulators not allowed in input program"
+unzipDN (STLEither a b) d = case d of
+ Nothing -> (Nothing, Nothing)
+ Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x)
+ Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)
dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double
dotprodTan STNil _ _ = 0.0
@@ -137,6 +153,12 @@ dotprodTan (STScal ty) x y = case ty of
STF64 -> x * y
STBool -> 0.0
dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program"
+dotprodTan (STLEither a b) x y = case (x, y) of
+ (Nothing, _) -> 0.0 -- 0 * y = 0
+ (_, Nothing) -> 0.0 -- x * 0 = 0
+ (Just (Left x'), Just (Left y')) -> dotprodTan a x' y'
+ (Just (Right x'), Just (Right y')) -> dotprodTan b x' y'
+ _ -> error "dotprodTan: incompatible LEither alternatives"
-- -- Primal expression must be duplicable
-- dnConstE :: STy t -> Ex env t -> Ex env (DN t)
@@ -174,6 +196,7 @@ dnConst (STScal t) = case t of
STF64 -> (,0.0)
STBool -> id
dnConst STAccum{} = error "Accumulators not allowed in input program"
+dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))
-- | Given a function that computes the forward derivative for a particular
-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this
@@ -204,6 +227,11 @@ dnOnehots (STScal t) x = case t of
STF64 -> \f -> f (x, 1.0)
STBool -> \_ -> ()
dnOnehots STAccum{} _ = error "Accumulators not allowed in input program"
+dnOnehots (STLEither t1 t2) e =
+ case e of
+ Nothing -> \_ -> Nothing
+ Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left)))
+ Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))
dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)
dnConstEnv SNil SNil = SNil
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index af11de8..d7916d8 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -35,7 +35,6 @@ import Debug.Trace
import Array
import AST
import AST.Pretty
-import CHAD.Types
import Data
import Interpreter.Rep
@@ -253,7 +252,7 @@ withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Re
withAccum t _ initval f = AcM $ do
accum <- newAcDense t initval
out <- unAcM $ f accum
- val <- readAcSparse t accum
+ val <- readAc t accum
return (out, val)
newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t)
@@ -300,81 +299,62 @@ onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) =
!linindex = toLinearIndex arrsh arrindex
in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i))
-readAcSparse :: SMTy t -> RepAc t -> IO (Rep t)
-readAcSparse typ val = case typ of
+readAc :: SMTy t -> RepAc t -> IO (Rep t)
+readAc typ val = case typ of
SMTNil -> return ()
- SMTPair t1 t2 -> bitraverse (readAcSparse t1) (readAcSparse t2) val
- SMTLEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val
- SMTMaybe t -> traverse (readAcSparse t) =<< readIORef val
- SMTArr _ t -> traverse (readAcSparse t) val
+ SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val
+ SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val
+ SMTMaybe t -> traverse (readAc t) =<< readIORef val
+ SMTArr _ t -> traverse (readAc t) val
SMTScal _ -> readIORef val
-accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s ()
-accumAddSparse typ prj ref idx val = case (typ, prj) of
- (STNil, SAPHere) -> return ()
-
- (STPair t1 t2, SAPHere) ->
- case val of
- Nothing -> return ()
- Just (val1, val2) ->
- realiseMaybeSparse ref ((,) <$> newAcDense t1 val1
- <*> newAcDense t2 val2)
- (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1
- accumAddSparse t2 SAPHere ac2 () val2)
- (STPair t1 t2, SAPFst prj') ->
- realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2)
- (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val)
- (STPair t1 t2, SAPSnd prj') ->
- realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val)
- (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val)
-
- (STEither{}, SAPHere) ->
+accumAddDense :: SMTy a -> RepAc a -> Rep a -> AcM s ()
+accumAddDense typ ref val = case typ of
+ SMTNil -> return ()
+ SMTPair t1 t2 -> do
+ accumAddDense t1 (fst ref) (fst val)
+ accumAddDense t2 (snd ref) (snd val)
+ SMTLEither{} ->
case val of
Nothing -> return ()
Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1
Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2
- (STEither t1 _, SAPLeft prj') ->
+ SMTMaybe{} ->
+ case val of
+ Nothing -> return ()
+ Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val'
+ SMTArr _ t1 ->
+ forM_ [0 .. arraySize ref - 1] $ \i ->
+ accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i)
+ SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
+
+accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s ()
+accumAddSparse typ prj ref idx val = case (typ, prj) of
+ (_, SAPHere) -> accumAddDense typ ref val
+
+ (SMTPair t1 _, SAPFst prj') -> accumAddSparse t1 prj' (fst ref) (fst idx) val
+ (SMTPair _ t2, SAPSnd prj') -> accumAddSparse t2 prj' (snd ref) (snd idx) val
+
+ (SMTLEither t1 _, SAPLeft prj') ->
realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val)
(\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val
Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
- (STEither _ t2, SAPRight prj') ->
+ (SMTLEither _ t2, SAPRight prj') ->
realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val)
(\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val
Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
- (STMaybe{}, SAPHere) ->
- case val of
- Nothing -> return ()
- Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val'
- (STMaybe t1, SAPJust prj') ->
- realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
- (\ac -> accumAddSparse t1 prj' ac idx val)
+ (SMTMaybe t1, SAPJust prj') ->
+ realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
+ (\ac -> accumAddSparse t1 prj' ac idx val)
- (STArr _ t1, SAPHere) ->
- case val of
- Nothing -> return ()
- Just val' ->
- realiseMaybeSparse ref
- (arrayMapM (newAcDense t1) val')
- (\ac -> forM_ [0 .. arraySize ac - 1] $ \i ->
- accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i))
- (STArr n t1, SAPArrIdx prj') ->
- let ((arrindex', arrsh'), idx') = idx
+ (SMTArr n t1, SAPArrIdx prj') ->
+ let ((arrindex', ziarr), idx') = idx
arrindex = unTupRepIdx IxNil IxCons n arrindex'
- arrsh = unTupRepIdx ShNil ShCons n arrsh'
+ arrsh = arrayShape ziarr
linindex = toLinearIndex arrsh arrindex
- in realiseMaybeSparse ref
- (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx)
- (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val)
-
- (STScal sty, SAPHere) -> AcM $ case sty of
- STI32 -> return ()
- STI64 -> return ()
- STF32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STF64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STBool -> return ()
-
- (STAccum{}, _) -> error "Accumulators not allowed in source program"
+ in accumAddSparse t1 prj' (arrayIndexLinear ref linindex) idx' val
+
realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()
realiseMaybeSparse ref makeval modifyval =
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index 1226b0c..070ba4c 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -1,11 +1,16 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Interpreter.Rep where
+import Control.DeepSeq
+import Data.Coerce (coerce)
import Data.List (intersperse, intercalate)
import Data.Foldable (toList)
import Data.IORef
+import GHC.Exts (withDict)
import Array
import AST
@@ -58,12 +63,12 @@ showValue d (STArr _ t) arr = showParen (d > 10) $
. showString " ["
. foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr))
. showString "]"
-showValue _ (STScal sty) x = case sty of
- STF32 -> shows x
- STF64 -> shows x
- STI32 -> shows x
- STI64 -> shows x
- STBool -> shows x
+showValue d (STScal sty) x = case sty of
+ STF32 -> showsPrec d x
+ STF64 -> showsPrec d x
+ STI32 -> showsPrec d x
+ STI64 -> showsPrec d x
+ STBool -> showsPrec d x
showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSMTy 0 t ++ ">"
showValue _ (STLEither _ _) Nothing = showString "LNil"
showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x
@@ -75,3 +80,26 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
showEntries :: SList STy env -> SList Value env -> [String]
showEntries SNil SNil = []
showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs
+
+rnfRep :: STy t -> Rep t -> ()
+rnfRep STNil () = ()
+rnfRep (STPair a b) (x, y) = rnfRep a x `seq` rnfRep b y
+rnfRep (STEither a _) (Left x) = rnfRep a x
+rnfRep (STEither _ b) (Right y) = rnfRep b y
+rnfRep (STMaybe _) Nothing = ()
+rnfRep (STMaybe t) (Just x) = rnfRep t x
+rnfRep (STArr (_ :: SNat n) (t :: STy t2)) arr =
+ withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr)
+rnfRep (STScal t) x = case t of
+ STI32 -> rnf x
+ STI64 -> rnf x
+ STF32 -> rnf x
+ STF64 -> rnf x
+ STBool -> rnf x
+rnfRep STAccum{} _ = error "Cannot rnf accumulators"
+rnfRep (STLEither _ _) Nothing = ()
+rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x
+rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y
+
+instance KnownTy t => NFData (Value t) where
+ rnf (Value x) = rnfRep (knownTy @t) x
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 228f265..2a1d3b6 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -10,7 +10,7 @@
{-# LANGUAGE TypeOperators #-}
module Simplify (
simplifyN, simplifyFix,
- SimplifyConfig(..), simplifyWith, simplifyFixWith,
+ SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith,
) where
import Data.Function (fix)
diff --git a/test/Main.hs b/test/Main.hs
index afbd79b..f5e4a3c 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -42,6 +42,18 @@ import Language
import Simplify
+data TypedValue t = TypedValue (STy t) (Rep t)
+instance Show (TypedValue t) where
+ showsPrec d (TypedValue t x) = showValue d t x
+
+data TypedEnv env = TypedEnv (SList STy env) (SList Value env)
+instance Show (TypedEnv env) where
+ show (TypedEnv env xs) = showEnv env xs
+
+unTypedEnv :: TypedEnv env -> SList Value env
+unTypedEnv (TypedEnv _ xs) = xs
+
+
data SimplIters = SimplIters Int | SimplFix
deriving (Show)
@@ -67,6 +79,7 @@ gradientByCHAD' simplIters env term input =
gradientByForward :: FwdADArtifact env R -> SList Value env -> SList Value (TanE env)
gradientByForward art input = drevByFwd art input 1.0
+-- | Generate input tangents for this primal
extendDN :: STy t -> Rep t -> Gen (Rep (DN t))
extendDN STNil () = pure ()
extendDN (STPair a b) (x, y) = (,) <$> extendDN a x <*> extendDN b y
@@ -82,6 +95,9 @@ extendDN (STScal sty) x = case sty of
STI64 -> pure x
STBool -> pure x
extendDN (STAccum _) _ = error "Accumulators not supported in input program"
+extendDN (STLEither _ _) Nothing = pure Nothing
+extendDN (STLEither a _) (Just (Left x)) = Just . Left <$> extendDN a x
+extendDN (STLEither _ b) (Just (Right y)) = Just . Right <$> extendDN b y
extendDNE :: SList STy env -> SList Value env -> Gen (SList Value (DNE env))
extendDNE SNil SNil = pure SNil
@@ -112,10 +128,19 @@ closeIshT' h (STScal STF32) x y = closeIsh' h (realToFrac x) (realToFrac y)
closeIshT' h (STScal STF64) x y = closeIsh' h x y
closeIshT' _ (STScal STBool) x y = x == y
closeIshT' _ STAccum{} _ _ = error "closeIshT': Cannot compare accumulators"
+closeIshT' _ (STLEither _ _) Nothing Nothing = True
+closeIshT' h (STLEither a _) (Just (Left x)) (Just (Left x')) = closeIshT' h a x x'
+closeIshT' h (STLEither _ b) (Just (Right y)) (Just (Right y')) = closeIshT' h b y y'
+closeIshT' _ STLEither{} _ _ = False
closeIshT :: STy t -> Rep t -> Rep t -> Bool
closeIshT = closeIshT' 1e-5
+closeIshE :: SList STy t -> SList Value t -> SList Value t -> Bool
+closeIshE SNil SNil SNil = True
+closeIshE (t `SCons` env) (Value x `SCons` xs) (Value y `SCons` ys) =
+ closeIshT t x y && closeIshE env xs ys
+
data a :$ b = a :$ b deriving (Show) ; infixl :$
-- | The type index is just a marker that helps typed holes show what (type of)
@@ -218,6 +243,9 @@ genValue topty tpl = case topty of
STI64 -> genInt
STBool -> Gen.choice [return (Value False), return (Value True)]
STAccum{} -> error "Cannot generate inputs for accumulators"
+ STLEither a b -> Gen.frequency [(1, pure (Value Nothing))
+ ,(8, liftV (Just . Left) <$> genValue a (emptyTpl a))
+ ,(8, liftV (Just . Right) <$> genValue b (emptyTpl b))]
where
genInt :: (Integral (Rep t), Tpl t ~ TplConstr _q) => StateT (Map String Int) Gen (Value t)
genInt = do
@@ -237,10 +265,6 @@ genEnv SNil () = return SNil
genEnv (t `SCons` SNil) tpl = SCons <$> genValue t tpl <*> pure SNil
genEnv (t `SCons` env@SCons{}) (tmpl :& tpl) = SCons <$> genValue t tpl <*> genEnv env tmpl
-data TypedValue t = TypedValue (STy t) (Rep t)
-instance Show (TypedValue t) where
- showsPrec d (TypedValue t x) = showValue d t x
-
compileTest :: KnownEnv env => TestName -> Ex env t -> TestTree
compileTest name (expr :: Ex env t) = compileTestTp name (emptyTemplateE (knownEnv @env)) expr
@@ -337,22 +361,22 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
let unpackGrad :: Rep (Tup (D2E env)) -> SList Value (D2E env)
unpackGrad = unTup vUnpair (d2e env) . Value
- let scFwd = tanEScalars env $ gradientByForward fwdartifactC input
+ let tansFwd = TypedEnv (tanenv env) $ gradientByForward fwdartifactC input
let (outChad0 , gradChad0) = second unpackGrad $ interpretOpen False env input dtermChad0
(outChadS , gradChadS) = second unpackGrad $ interpretOpen False env input dtermChadS
(outSChad0, gradSChad0) = second unpackGrad $ interpretOpen False env input dtermSChad0
(outSChadS, gradSChadS) = second unpackGrad $ interpretOpen False env input dtermSChadS
- scChad = tanEScalars env $ toTanE env input gradChad0
- scChadS = tanEScalars env $ toTanE env input gradChadS
- scSChad = tanEScalars env $ toTanE env input gradSChad0
- scSChadS = tanEScalars env $ toTanE env input gradSChadS
+ tansChad = TypedEnv (tanenv env) $ toTanE env input gradChad0
+ tansChadS = TypedEnv (tanenv env) $ toTanE env input gradChadS
+ tansSChad = TypedEnv (tanenv env) $ toTanE env input gradSChad0
+ tansSChadS = TypedEnv (tanenv env) $ toTanE env input gradSChadS
(outCompSChadS, gradCompSChadS) <- second unpackGrad <$> evalIO (dcompSChadS input)
- let scCompSChadS = tanEScalars env $ toTanE env input gradCompSChadS
+ let tansCompSChadS = TypedEnv (tanenv env) $ toTanE env input gradCompSChadS
- -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChad0))
- -- annotate (showSList (\d (Product.Pair ty (Value x)) -> showValue d ty x "") (slistZip (d2e env) gradChadS))
+ -- annotate (showEnv (d2e env) gradChad0)
+ -- annotate (showEnv (d2e env) gradChadS)
-- annotate (ppExpr knownEnv expr)
-- annotate (ppExpr env dtermChad0)
-- annotate (ppExpr env dtermChadS)
@@ -362,13 +386,12 @@ adTestGenChad env envGenerator expr exprS primalSfun | Dict <- envKnown env =
diff outSChad0 closeIsh outPrimal
diff outSChadS closeIsh outPrimal
diff outCompSChadS closeIsh outPrimal
- -- TODO: use closeIshT
- let closeIshList x y = and (zipWith closeIsh x y)
- diff scChad closeIshList scFwd
- diff scChadS closeIshList scFwd
- diff scSChad closeIshList scFwd
- diff scSChadS closeIshList scFwd
- diff scCompSChadS closeIshList scFwd
+ let closeIshE' e1 e2 = closeIshE (tanenv env) (unTypedEnv e1) (unTypedEnv e2)
+ diff tansChad closeIshE' tansFwd
+ diff tansChadS closeIshE' tansFwd
+ diff tansSChad closeIshE' tansFwd
+ diff tansSChadS closeIshE' tansFwd
+ diff tansCompSChadS closeIshE' tansFwd
withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
withCompiled env expr = withResource (compile env expr) (\_ -> pure ())
@@ -478,18 +501,25 @@ tests_Compile = testGroup "Compile"
nil
,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $
- with @(TPair R R) nothing $ #ac :->
- let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $
+ with @(TPair R R) (pair 0.0 0.0) $ #ac :->
+ let_ #_ (if_ #b (accum (SAPFst SAPHere) (pair nil nil) 3.0 #ac) nil) $
+ let_ #_ (accum SAPHere nil #x #ac) $
+ let_ #_ (accum (SAPSnd SAPHere) (pair nil nil) 4.0 #ac) $
+ nil
+
+ ,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $
+ with @(TMaybe (TPair R R)) nothing $ #ac :->
+ let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) (pair nil nil) 3.0 #ac) nil) $
let_ #_ (accum SAPHere nil #x #ac) $
- let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $
+ let_ #_ (accum (SAPJust (SAPSnd SAPHere)) (pair nil nil) 4.0 #ac) $
nil
- ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda #x $ body $
+ ,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $
let_ #len (snd_ (shape #x)) $
- with @(TVec R) nothing $ #ac :->
- let_ #_ (if_ #b (accum (SAPArrIdx SAPHere (SS SZ)) (pair (pair (pair nil 2) (pair nil #len)) nil) 6.0 #ac)
+ with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :->
+ let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair (pair nil 2) (build1 #len (#_ :-> nil))) nil) 6.0 #ac)
nil) $
- let_ #_ (accum SAPHere nil (just #x) #ac) $
+ let_ #_ (accum SAPHere nil #x #ac) $
nil
]
@@ -567,8 +597,6 @@ tests_AD = testGroup "AD"
,adTestGen "neural" Example.neural gen_neural
- ,adTestGen "neural-unMonoid" (unMonoid (simplifyFix Example.neural)) gen_neural
-
,adTestTp "logsumexp" (C "" 1) $
fromNamed $ lambda @(TVec _) #vec $ body $
let_ #m (maximum1i #vec) $
@@ -578,11 +606,7 @@ tests_AD = testGroup "AD"
,adTestGen "gmm-wrong" (Example.gmmObjective True) gen_gmm
- ,adTestGen "gmm-wrong-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective True))) gen_gmm
-
,adTestGen "gmm" (Example.gmmObjective False) gen_gmm
-
- ,adTestGen "gmm-unMonoid" (unMonoid (simplifyFix (Example.gmmObjective False))) gen_gmm
]
main :: IO ()