aboutsummaryrefslogtreecommitdiff
path: root/cbits
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-02-23 21:44:23 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-02-23 21:44:23 +0100
commit5f7a81acc7f75415d62dac86c5b50c848ab15341 (patch)
tree641ed54ce426ed8a1d9a5da12a9cde512b32bedc /cbits
parenta17bd53598ee5266fc3a1c45f8f4bb4798dc495e (diff)
Optimise by backpropagating scalar computation in C
Diffstat (limited to 'cbits')
-rw-r--r--cbits/backprop.c25
1 files changed, 25 insertions, 0 deletions
diff --git a/cbits/backprop.c b/cbits/backprop.c
new file mode 100644
index 0000000..0ca62e3
--- /dev/null
+++ b/cbits/backprop.c
@@ -0,0 +1,25 @@
+// #include <stdio.h>
+#include <stdint.h>
+// #include <inttypes.h>
+
+struct Contrib {
+ int64_t i1;
+ double dx;
+ int64_t i2;
+ double dy;
+};
+
+void ad_dual_hs_backpropagate_double(
+ double *accums,
+ int64_t id_base, int64_t topid, const void *contribs_buf
+) {
+ // fprintf(stderr, "< ci0=%" PRIi64 " topid=%" PRIi64 " >\n", id_base, topid);
+ const struct Contrib *contribs = (const struct Contrib*)contribs_buf;
+
+ for (int64_t i = topid - id_base; i >= 0; i--) {
+ double d = accums[id_base + i];
+ // fprintf(stderr, "ACC i=%" PRIi64 " d=%g C={%" PRIi64 ", %g, %" PRIi64 ", %g}\n", id_base + i, d, contribs[i].i1, contribs[i].dx, contribs[i].i2, contribs[i].dy);
+ if (contribs[i].i1 != -1) accums[contribs[i].i1] += d * contribs[i].dx;
+ if (contribs[i].i2 != -1) accums[contribs[i].i2] += d * contribs[i].dy;
+ }
+}