summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/CHAD/Types/ToTan.hs5
-rw-r--r--src/Compile.hs205
-rw-r--r--src/Example/GMM.hs4
-rw-r--r--src/ForwardAD.hs28
-rw-r--r--src/Interpreter.hs100
-rw-r--r--src/Interpreter/Rep.hs40
-rw-r--r--src/Simplify.hs2
7 files changed, 238 insertions, 146 deletions
diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs
index f843206..87c01cb 100644
--- a/src/CHAD/Types/ToTan.hs
+++ b/src/CHAD/Types/ToTan.hs
@@ -40,3 +40,8 @@ toTan typ primal der = case typ of
STScal sty -> case sty of
STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der
STAccum{} -> error "Accumulators not allowed in input program"
+ STLEither t1 t2 -> case (primal, der) of
+ (_, Nothing) -> Nothing
+ (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d))
+ (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d))
+ _ -> error "Primal and cotangent disagree on LEither alternative"
diff --git a/src/Compile.hs b/src/Compile.hs
index 503c342..6ba3a39 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -1001,16 +1001,6 @@ compile' env = \case
return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
EAccum _ t prj eidx eval eacc -> do
- nameidx <- compileAssign "acidx" env eidx
- nameval <- compileAssign "acval" env eval
-
- -- Generate the variable manually because this one has to be non-const.
- -- TODO: old code:
- -- eacc' <- compile' env eacc
- -- nameacc <- genName' "acac"
- -- emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc'
- nameacc <- compileAssign "acac" env eacc
-
let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a
-- full zero array with the given zero info (for the type SMTArr n t1).
initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM ()
@@ -1041,77 +1031,66 @@ compile' env = \case
initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi
initZeroFromMemset SMTScal{} = Nothing
- let -- initZero (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
- initZero :: SMTy a -> String -> String -> CompM ()
- initZero SMTNil _ _ = return ()
- initZero (SMTPair t1 t2) v vzi = do
- initZero t1 (v++".a") (vzi++".a")
- initZero t2 (v++".b") (vzi++".b")
- initZero SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZero SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZero (SMTArr n t1) v vzi = initZeroArray n t1 v vzi
- initZero (SMTScal sty) v _ = case sty of
+ let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
+ initZeroZI :: SMTy a -> String -> String -> CompM ()
+ initZeroZI SMTNil _ _ = return ()
+ initZeroZI (SMTPair t1 t2) v vzi = do
+ initZeroZI t1 (v++".a") (vzi++".a")
+ initZeroZI t2 (v++".b") (vzi++".b")
+ initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0")
+ initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0")
+ initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi
+ initZeroZI (SMTScal sty) v _ = case sty of
STI32 -> emit $ SAsg v (CELit "0")
STI64 -> emit $ SAsg v (CELit "0l")
STF32 -> emit $ SAsg v (CELit "0.0f")
STF64 -> emit $ SAsg v (CELit "0.0")
- let -- | Dereference an accumulation value. Sparse components encountered
- -- along the way are initialised before proceeding downwards. At the
- -- point where we have the projected accumulator position available,
- -- the handler will be invoked with a variable name pointing to the
- -- projected position.
- -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (handler)
- accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> (String -> CompM ()) -> CompM ()
- accumRef _ SAPHere v _ k = k v
-
- accumRef (SMTPair ta _) (SAPFst prj') v i k = accumRef ta prj' (v++".a") (i++".a") k
- accumRef (SMTPair _ tb) (SAPSnd prj') v i k = accumRef tb prj' (v++".b") (i++".b") k
-
- accumRef (SMTLEither ta _) (SAPLeft prj') v i k = do
- ((), stmtsInit1) <- scope $ initZero ta (v++".l") i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
- accumRef ta prj' (v++".l") i k
- accumRef (SMTLEither _ tb) (SAPRight prj') v i k = do
- ((), stmtsInit2) <- scope $ initZero tb (v++".r") i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "2")) <> stmtsInit2) mempty
- accumRef tb prj' (v++".r") i k
-
- accumRef (SMTMaybe tj) (SAPJust prj') v i k = do
- ((), stmtsInit1) <- scope $ initZero tj (v++".j") i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
- accumRef tj prj' (v++".j") i k
-
- accumRef (SMTArr n t') (SAPArrIdx prj') v i k = do
- when emitChecks $ do
- let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
- forM_ (zip3 [0::Int ..]
- (indexTupleComponents n (i++".a.a"))
- (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do
- let a .||. b = CEBinop a "||" b
- emit $ SIf (CEBinop ixcomp "<" (CELit "0")
- .||.
- CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
- .||.
- CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))))
- (pure $ SVerbatim $
- "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
- "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
- v ++ ".buf" ++
- concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++
- "); " ++
- "return false;")
- mempty
-
- accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") k
+ let -- Initialise an uninitialised accumulation value, potentially already
+ -- with the addend, potentially to zero depending on the nature of the
+ -- projection.
+ -- 1. If the projection indexes only through dense monoids before
+ -- reaching SAPHere, the thing cannot be initialised to zero with
+ -- only an AcIdx; it would need to model a zero after the addend,
+ -- which is stupid and redundant. In this case, we return Left:
+ -- (accumulation value) (AcIdx value) (addend value).
+ -- The addend is copied, not consumed. (We can't reliably _always_
+ -- consume it, so it's not worth trying to do it sometimes.)
+ -- 2. Otherwise, a sparse monoid is found along the way, and we can
+ -- initalise the dense prefix of the path to zero by setting the
+ -- indexed-through sparse value to a sparse zero. Afterwards, the
+ -- main recursion can proceed further. In this case, we return
+ -- Right: (accumulation value) (AcIdx value)
+ -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type)
+ initZeroChunk :: SMTy a -> SAcPrj p a b
+ -> Either (String -> String -> String -> CompM ()) -- dense initialisation with addend
+ (String -> String -> CompM ()) -- zero initialisation of sparse chunk
+ initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of
+ -- reached target before the first sparse constructor
+ (t1 , SAPHere ) -> Left $ \v _ addend -> do
+ incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend
+ emit $ SAsg v (CELit addend)
+ -- sparse types
+ (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
+ (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
+ -- dense types
+ (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
+ f (v++".a") (i++".a")
+ initZeroZI t2 (v++".b") (i++".b")
+ (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do
+ initZeroZI t1 (v++".a") (i++".a")
+ f (v++".b") (i++".b")
+ (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
+ initZeroArray n t1 v (i++".a.b")
+ linidxvar <- genName' "li"
+ emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a"))
+ f (v++".buf->xs["++linidxvar++"]") (i++".b")
+ where
+ applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i
+ applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i
let -- Add a value (s) into an existing accumulation value (d). If a sparse
- -- component of d is encountered, s is simply written there.
+ -- component of d is encountered, s is copied there.
add :: SMTy a -> String -> String -> CompM ()
add SMTNil _ _ = return ()
add (SMTPair t1 t2) d s = do
@@ -1165,16 +1144,88 @@ compile' env = \case
mempty
shsizename <- genName' "acshsz"
- emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j"))
+ emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s)
ivar <- genName' "i"
((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")
emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)
stmts1
add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
+ let -- | Dereference an accumulation value and add a given value to that
+ -- position. Sparse components encountered along the way are
+ -- initialised before proceeding downwards.
+ -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there)
+ accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM ()
+ accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend
+
+ accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend
+ accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend
+
+ accumRef (SMTLEither ta tb) prj0 v i addend = do
+ let chunkres = case prj0 of SAPLeft prj' -> initZeroChunk ta prj'
+ SAPRight prj' -> initZeroChunk tb prj'
+ subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r")
+ tagval = case prj0 of SAPLeft{} -> "1"
+ SAPRight{} -> "2"
+ ((), stmtsAdd) <- scope $ case prj0 of SAPLeft prj' -> accumRef ta prj' subv i addend
+ SAPRight prj' -> accumRef tb prj' subv i addend
+ case chunkres of
+ Left densef -> do
+ ((), stmtsSet) <- scope $ densef subv i addend
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet)
+ stmtsAdd -- TODO: emit check for consistency of tags?
+ Right sparsef -> do
+ ((), stmtsInit) <- scope $ sparsef subv i
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty
+ forM_ stmtsAdd emit
+
+ accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do
+ case initZeroChunk tj prj' of
+ Left densef -> do
+ ((), stmtsSet1) <- scope $ densef (v++".j") i addend
+ ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1)
+ stmtsAdd1
+ Right sparsef -> do
+ ((), stmtsInit1) <- scope $ sparsef (v++".j") i
+ emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
+ (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
+ accumRef tj prj' (v++".j") i addend
+
+ accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do
+ when emitChecks $ do
+ let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
+ forM_ (zip3 [0::Int ..]
+ (indexTupleComponents n (i++".a.a"))
+ (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do
+ let a .||. b = CEBinop a "||" b
+ emit $ SIf (CEBinop ixcomp "<" (CELit "0")
+ .||.
+ CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
+ .||.
+ CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
+ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
+ v ++ ".buf" ++
+ concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++
+ "); " ++
+ "return false;")
+ mempty
+
+ accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend
+
+ nameidx <- compileAssign "acidx" env eidx
+ nameval <- compileAssign "acval" env eval
+ nameacc <- compileAssign "acac" env eacc
+
emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")"
- accumRef t prj (nameacc++".buf->ac") nameidx $ \dest ->
- add (acPrjTy prj t) dest nameval
+ accumRef t prj (nameacc++".buf->ac") nameidx nameval
emit $ SVerbatim $ "// compile EAccum end"
incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval
diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs
index 12bbd98..206e534 100644
--- a/src/Example/GMM.hs
+++ b/src/Example/GMM.hs
@@ -31,10 +31,10 @@ import Language
-- <https://tomsmeding.com/f/master.pdf>
--
-- The 'wrong' argument, when set to True, changes the objective function to
--- one with a bug that makes a certain `build` result unused. This triggers
+-- one with a bug that makes a certain `build` result unused. This
-- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's
-- dense, even though it may be a zero (i.e. empty). The "unused" test in
--- test/Main.hs tries to isolate this test, but the wrong version of
+-- test/Main.hs tries to isolate this case, but the wrong version of
-- gmmObjective is here to check (after that bug is fixed) whether it really
-- fixes the original bug.
gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index b7036dd..5756f96 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -29,6 +29,7 @@ type family Tan t where
Tan (TMaybe t) = TMaybe (Tan t)
Tan (TArr n t) = TArr n (Tan t)
Tan (TScal t) = TanS t
+ Tan (TLEither a b) = TLEither (Tan a) (Tan b)
type family TanS t where
TanS TI32 = TNil
@@ -54,6 +55,11 @@ tanty (STScal t) = case t of
STF64 -> STScal STF64
STBool -> STNil
tanty STAccum{} = error "Accumulators not allowed in input program"
+tanty (STLEither a b) = STLEither (tanty a) (tanty b)
+
+tanenv :: SList STy env -> SList STy (TanE env)
+tanenv SNil = SNil
+tanenv (t `SCons` env) = tanty t `SCons` tanenv env
zeroTan :: STy t -> Rep t -> Rep (Tan t)
zeroTan STNil () = ()
@@ -69,6 +75,9 @@ zeroTan (STScal STF32) _ = 0.0
zeroTan (STScal STF64) _ = 0.0
zeroTan (STScal STBool) _ = ()
zeroTan STAccum{} _ = error "Accumulators not allowed in input program"
+zeroTan (STLEither _ _) Nothing = Nothing
+zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x))
+zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y))
tanScalars :: STy t -> Rep (Tan t) -> [Double]
tanScalars STNil () = []
@@ -84,6 +93,9 @@ tanScalars (STScal STF32) x = [realToFrac x]
tanScalars (STScal STF64) x = [x]
tanScalars (STScal STBool) _ = []
tanScalars STAccum{} _ = error "Accumulators not allowed in input program"
+tanScalars (STLEither _ _) Nothing = []
+tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x
+tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y
tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double]
tanEScalars SNil SNil = []
@@ -111,6 +123,10 @@ unzipDN (STScal ty) d = case ty of
STF64 -> d
STBool -> (d, ())
unzipDN STAccum{} _ = error "Accumulators not allowed in input program"
+unzipDN (STLEither a b) d = case d of
+ Nothing -> (Nothing, Nothing)
+ Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x)
+ Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y)
dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double
dotprodTan STNil _ _ = 0.0
@@ -137,6 +153,12 @@ dotprodTan (STScal ty) x y = case ty of
STF64 -> x * y
STBool -> 0.0
dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program"
+dotprodTan (STLEither a b) x y = case (x, y) of
+ (Nothing, _) -> 0.0 -- 0 * y = 0
+ (_, Nothing) -> 0.0 -- x * 0 = 0
+ (Just (Left x'), Just (Left y')) -> dotprodTan a x' y'
+ (Just (Right x'), Just (Right y')) -> dotprodTan b x' y'
+ _ -> error "dotprodTan: incompatible LEither alternatives"
-- -- Primal expression must be duplicable
-- dnConstE :: STy t -> Ex env t -> Ex env (DN t)
@@ -174,6 +196,7 @@ dnConst (STScal t) = case t of
STF64 -> (,0.0)
STBool -> id
dnConst STAccum{} = error "Accumulators not allowed in input program"
+dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2))
-- | Given a function that computes the forward derivative for a particular
-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this
@@ -204,6 +227,11 @@ dnOnehots (STScal t) x = case t of
STF64 -> \f -> f (x, 1.0)
STBool -> \_ -> ()
dnOnehots STAccum{} _ = error "Accumulators not allowed in input program"
+dnOnehots (STLEither t1 t2) e =
+ case e of
+ Nothing -> \_ -> Nothing
+ Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left)))
+ Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right)))
dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env)
dnConstEnv SNil SNil = SNil
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index af11de8..d7916d8 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -35,7 +35,6 @@ import Debug.Trace
import Array
import AST
import AST.Pretty
-import CHAD.Types
import Data
import Interpreter.Rep
@@ -253,7 +252,7 @@ withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Re
withAccum t _ initval f = AcM $ do
accum <- newAcDense t initval
out <- unAcM $ f accum
- val <- readAcSparse t accum
+ val <- readAc t accum
return (out, val)
newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t)
@@ -300,81 +299,62 @@ onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) =
!linindex = toLinearIndex arrsh arrindex
in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i))
-readAcSparse :: SMTy t -> RepAc t -> IO (Rep t)
-readAcSparse typ val = case typ of
+readAc :: SMTy t -> RepAc t -> IO (Rep t)
+readAc typ val = case typ of
SMTNil -> return ()
- SMTPair t1 t2 -> bitraverse (readAcSparse t1) (readAcSparse t2) val
- SMTLEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val
- SMTMaybe t -> traverse (readAcSparse t) =<< readIORef val
- SMTArr _ t -> traverse (readAcSparse t) val
+ SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val
+ SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val
+ SMTMaybe t -> traverse (readAc t) =<< readIORef val
+ SMTArr _ t -> traverse (readAc t) val
SMTScal _ -> readIORef val
-accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s ()
-accumAddSparse typ prj ref idx val = case (typ, prj) of
- (STNil, SAPHere) -> return ()
-
- (STPair t1 t2, SAPHere) ->
- case val of
- Nothing -> return ()
- Just (val1, val2) ->
- realiseMaybeSparse ref ((,) <$> newAcDense t1 val1
- <*> newAcDense t2 val2)
- (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1
- accumAddSparse t2 SAPHere ac2 () val2)
- (STPair t1 t2, SAPFst prj') ->
- realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2)
- (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val)
- (STPair t1 t2, SAPSnd prj') ->
- realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val)
- (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val)
-
- (STEither{}, SAPHere) ->
+accumAddDense :: SMTy a -> RepAc a -> Rep a -> AcM s ()
+accumAddDense typ ref val = case typ of
+ SMTNil -> return ()
+ SMTPair t1 t2 -> do
+ accumAddDense t1 (fst ref) (fst val)
+ accumAddDense t2 (snd ref) (snd val)
+ SMTLEither{} ->
case val of
Nothing -> return ()
Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1
Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2
- (STEither t1 _, SAPLeft prj') ->
+ SMTMaybe{} ->
+ case val of
+ Nothing -> return ()
+ Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val'
+ SMTArr _ t1 ->
+ forM_ [0 .. arraySize ref - 1] $ \i ->
+ accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i)
+ SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
+
+accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep b -> AcM s ()
+accumAddSparse typ prj ref idx val = case (typ, prj) of
+ (_, SAPHere) -> accumAddDense typ ref val
+
+ (SMTPair t1 _, SAPFst prj') -> accumAddSparse t1 prj' (fst ref) (fst idx) val
+ (SMTPair _ t2, SAPSnd prj') -> accumAddSparse t2 prj' (snd ref) (snd idx) val
+
+ (SMTLEither t1 _, SAPLeft prj') ->
realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val)
(\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val
Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
- (STEither _ t2, SAPRight prj') ->
+ (SMTLEither _ t2, SAPRight prj') ->
realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val)
(\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val
Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
- (STMaybe{}, SAPHere) ->
- case val of
- Nothing -> return ()
- Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val'
- (STMaybe t1, SAPJust prj') ->
- realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
- (\ac -> accumAddSparse t1 prj' ac idx val)
+ (SMTMaybe t1, SAPJust prj') ->
+ realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
+ (\ac -> accumAddSparse t1 prj' ac idx val)
- (STArr _ t1, SAPHere) ->
- case val of
- Nothing -> return ()
- Just val' ->
- realiseMaybeSparse ref
- (arrayMapM (newAcDense t1) val')
- (\ac -> forM_ [0 .. arraySize ac - 1] $ \i ->
- accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i))
- (STArr n t1, SAPArrIdx prj') ->
- let ((arrindex', arrsh'), idx') = idx
+ (SMTArr n t1, SAPArrIdx prj') ->
+ let ((arrindex', ziarr), idx') = idx
arrindex = unTupRepIdx IxNil IxCons n arrindex'
- arrsh = unTupRepIdx ShNil ShCons n arrsh'
+ arrsh = arrayShape ziarr
linindex = toLinearIndex arrsh arrindex
- in realiseMaybeSparse ref
- (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx)
- (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val)
-
- (STScal sty, SAPHere) -> AcM $ case sty of
- STI32 -> return ()
- STI64 -> return ()
- STF32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STF64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
- STBool -> return ()
-
- (STAccum{}, _) -> error "Accumulators not allowed in source program"
+ in accumAddSparse t1 prj' (arrayIndexLinear ref linindex) idx' val
+
realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()
realiseMaybeSparse ref makeval modifyval =
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index 1226b0c..070ba4c 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -1,11 +1,16 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Interpreter.Rep where
+import Control.DeepSeq
+import Data.Coerce (coerce)
import Data.List (intersperse, intercalate)
import Data.Foldable (toList)
import Data.IORef
+import GHC.Exts (withDict)
import Array
import AST
@@ -58,12 +63,12 @@ showValue d (STArr _ t) arr = showParen (d > 10) $
. showString " ["
. foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr))
. showString "]"
-showValue _ (STScal sty) x = case sty of
- STF32 -> shows x
- STF64 -> shows x
- STI32 -> shows x
- STI64 -> shows x
- STBool -> shows x
+showValue d (STScal sty) x = case sty of
+ STF32 -> showsPrec d x
+ STF64 -> showsPrec d x
+ STI32 -> showsPrec d x
+ STI64 -> showsPrec d x
+ STBool -> showsPrec d x
showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSMTy 0 t ++ ">"
showValue _ (STLEither _ _) Nothing = showString "LNil"
showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x
@@ -75,3 +80,26 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]"
showEntries :: SList STy env -> SList Value env -> [String]
showEntries SNil SNil = []
showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs
+
+rnfRep :: STy t -> Rep t -> ()
+rnfRep STNil () = ()
+rnfRep (STPair a b) (x, y) = rnfRep a x `seq` rnfRep b y
+rnfRep (STEither a _) (Left x) = rnfRep a x
+rnfRep (STEither _ b) (Right y) = rnfRep b y
+rnfRep (STMaybe _) Nothing = ()
+rnfRep (STMaybe t) (Just x) = rnfRep t x
+rnfRep (STArr (_ :: SNat n) (t :: STy t2)) arr =
+ withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr)
+rnfRep (STScal t) x = case t of
+ STI32 -> rnf x
+ STI64 -> rnf x
+ STF32 -> rnf x
+ STF64 -> rnf x
+ STBool -> rnf x
+rnfRep STAccum{} _ = error "Cannot rnf accumulators"
+rnfRep (STLEither _ _) Nothing = ()
+rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x
+rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y
+
+instance KnownTy t => NFData (Value t) where
+ rnf (Value x) = rnfRep (knownTy @t) x
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 228f265..2a1d3b6 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -10,7 +10,7 @@
{-# LANGUAGE TypeOperators #-}
module Simplify (
simplifyN, simplifyFix,
- SimplifyConfig(..), simplifyWith, simplifyFixWith,
+ SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith,
) where
import Data.Function (fix)