Reputation: 2339
Below is the most expensive part of the algorithm I'm developing.
Definitely a hotspot, by far - checked through VisualVM.
This is already after some optimization tricks like unfolding 3 ds
rows.
Of course as GAUSS_POINT_WEIGHTS
has a fixed number of elements equal to 10 one could try to unfold it as well, but this would produce 100
possible combinations so I'm not really sure if it's worth it (maybe JIT does the job anyway).
for (int i = 0; i < mesh.getDofsY(); i++) {
for (int k = 0; k < GAUSS_POINT_COUNT; k++) {
val x = GAUSS_POINTS[k] * dx + leftSegment;
for (int l = 0; l < GAUSS_POINT_COUNT; l++) {
val wk = GAUSS_POINT_WEIGHTS[k];
val wl = GAUSS_POINT_WEIGHTS[l];
val gl = GAUSS_POINTS[l];
val gk = GAUSS_POINTS[k];
if (i > 1) {
val y = (gl + (i - 2)) * dy;
val v = wk * wl * b1.getValue(gl) * problem.valueAt(x, y);
ds.add(0, i, b3.getValue(gk) * v);
ds.add(1, i, b2.getValue(gk) * v);
ds.add(2, i, b1.getValue(gk) * v);
}
if (i > 0 && (i - 1) < mesh.getElementsY()) {
val y = (gl + (i - 1)) * dy;
val v = wk * wl * b2.getValue(gl) * problem.valueAt(x, y);
ds.add(0, i, b3.getValue(gk) * v);
ds.add(1, i, b2.getValue(gk) * v);
ds.add(2, i, b1.getValue(gk) * v);
}
if (i < mesh.getElementsY()) {
val y = (gl + i) * dy;
val v = wk * wl * b3.getValue(gl) * problem.valueAt(x, y);
ds.add(0, i, b3.getValue(gk) * v);
ds.add(1, i, b2.getValue(gk) * v);
ds.add(2, i, b1.getValue(gk) * v);
}
}
}
}
Is there anything else that could be sped up in this portion of the code (assume external calls are as efficient as they can be, this is only about this heavily nested loops). I'm asking about compiler-specific phenomena.
Already applied the suggestion below and it helped (see pastebin), but it seems there is a hotspot within this hotspot, which takes a substantial amount of computation time o the caller loop (above).
Namely problem.valueAt(x,y)
translates into:
private double internalValueAt(double x, double y) {
val ielemx = (long) (x / mesh.getDx());
val ielemy = (long) (y / mesh.getDy());
val localx = x - mesh.getDx() * ielemx;
val localy = y - mesh.getDy() * ielemy;
val sp1x = b1.getValue(localx);
val sp1y = b1.getValue(localy);
val sp2x = b2.getValue(localx);
val sp2y = b2.getValue(localy);
val sp3x = b3.getValue(localx);
val sp3y = b3.getValue(localy);
return coef.doubleValue(0, ielemy) * sp1x * sp1y +
coef.doubleValue(0, ielemy + 1) * sp1x * sp2y +
coef.doubleValue(0, ielemy + 2) * sp1x * sp3y +
coef.doubleValue(1, ielemy) * sp2x * sp1y +
coef.doubleValue(1, ielemy + 1) * sp2x * sp2y +
coef.doubleValue(1, ielemy + 2) * sp2x * sp3y +
coef.doubleValue(2, ielemy) * sp3x * sp1y +
coef.doubleValue(2, ielemy + 1) * sp3x * sp2y +
coef.doubleValue(2, ielemy + 2) * sp3x * sp3y;
}
b1, b2 and b3
are the same functions as in the main loop (b-splines). Perhaps I could pass them from the outside, but is it worth it?
it seems to me that not much can be done about it, but perhaps you can spot something worth doing? Let me know if there is any precondition for doing something.
Upvotes: 0
Views: 86
Reputation: 46392
Unrolling 100 times is beyond crazy as your code is already pretty long. I'd bet there's nothing to gain there. On the opposite; once I achieved a huge speedup (nearly factor two) just by clicking "extract method".
You may hoist wk
and gk
initialization out of one loop.
You may be able to avoid some multiplications by extracting wk * wl
, but I'd bet it's jitted anyway.
You may avoid the conditions by splitting out the first and last iterations.
You may cache some values (e.g., b3.getValue(gk)
) in local variables.
That's all just guessing. Anyway, I'd start by extracting the body of the outermost loop and measuring if there's any speed difference. I'd hope, it won't get worse and then you can extract the first and last iterations easily.
For the middle iterations, there's
val y = (gl + (i - 2)) * dy;
...
val y = (gl + (i - 1)) * dy;
...
val y = (gl + i) * dy;
which could be computed incrementally. The JIT would (most probably) do it for integers, but not for floating point, as it's not exactly the same. With your val
, I don't know which is the case.
You may then aggregate the three calls to ds.add(0, i, ...)
together as ds.add(0, i, b3.getValue(gk) * (v + v' + v'')
with the three v
s taken from the
three if
s.
Upvotes: 1