Christos Kalonakis
Christos Kalonakis

Reputation: 1

Java Vector API (Mandelbrot calculation)

I was experimenting with benchmarking a serial version of mandelbrot calulation vs a version using the vector API.

Here's the code:

public class Main {
    public static void main(String[] args) {

        int width = 1920;
        int height = 1080;
        int max_iterations = 5000;

        final int[] default_palette = {-16774636, -16510682, -16246472, -15982517, -15718307, -15454353, -15190398, -14860652, -14596442, -14332487, -14068277, -13804323, -13474576, -13673506, -13806644, -14005318, -14138456, -14337130, -14470267, -14668941, -14802079, -15000753, -15133891, -15332565, -15465702, -14350567, -13169639, -12054248, -10873320, -9758185, -8642793, -7461866, -6280938, -5165803, -3984875, -2869484, -1688556, -2869485, -3984622, -5099759, -6214896, -7330033, -8445170, -9560563, -10675700, -11790837, -12905974, -14021111, -15136247, -14018808, -12901369, -11783674, -10666234, -9548795, -8431356, -7313661, -6196221, -5078526, -3961087, -2843648, -1660416, -2842880, -3959807, -5142270, -6259197, -7441660, -8558332, -9740795, -10857722, -12040185, -13157112, -14339575, -15456246, -15586039, -15715832, -15780089, -15909882, -16039675, -16103931, -16233724, -16363517, -16427774, -16557567, -16687360, -16751616, -16753664, -16755455, -16691966, -16693757, -16630268, -16632060, -16634107, -16570362, -16572409, -16508664, -16510711, -16446966, -15331573, -14216179, -13100785, -11985392, -10869998, -9754604, -8639467, -7523817, -6408423, -5293030, -4177636, -2996706, -3653601, -4310495, -4967389, -5624284, -6281178, -6937816, -7594711, -8251605, -8908499, -9565394, -10287824, -10878926, -10418377, -9892035, -9431485, -8905143, -8444593, -7918507, -7457702, -6931360, -6470810, -5944468, -5483918, -4957576, -5942159, -6926742, -7911324, -8895907, -9880490, -10864816, -11849399, -12833982, -13818564, -14803147, -15787730, -16772056, -16639947, -16442302, -16310192, -16112547, -15980438, -15848073, -15650427, -15452782, -15320672, -15123027, -14990918, -14793016, -14990919, -15123286, -15321189, -15453556, -15651459, -15783826, -15981729, -16114096, -16311999, -16444366, -16642525};

        BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
        BufferedImage image2 = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);

        double size = 8.94069671630859375e-8;
        double xCenter = 0.32199330953881166728080768;
        double yCenter = 0.48903524968773126602172851;

        int image_size = Math.min(width, height);

        double coefx = width == image_size ? 0.5 : (1 + (width - (double)height) / height) * 0.5;
        double coefy = height == image_size ? 0.5 : (1 + (height - (double)width) / width) * 0.5;

        double size_2_x = size * coefx;
        double size_2_y = size * coefy;
        double temp_size_image_size_x = size / image_size;
        double temp_size_image_size_y = size / image_size;

        double temp_xcenter_size = xCenter - size_2_x;
        double temp_ycenter_size = yCenter + size_2_y;

        int[] rgbs = ((DataBufferInt) image.getRaster().getDataBuffer()).getData();
        int[] rgbs2 = ((DataBufferInt) image2.getRaster().getDataBuffer()).getData();

        long time = System.nanoTime();

        for(int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                double zre = 0;
                double zim = 0;
                double cre = temp_xcenter_size + temp_size_image_size_x * x;
                double cim = temp_ycenter_size - temp_size_image_size_y * y;
                double zre_sqr;
                double zim_sqr;
                double norm_sqr;

                int iteration;
                for(iteration = 0; iteration < max_iterations; iteration++) {
                    zre_sqr = zre * zre;
                    zim_sqr = zim * zim;
                    norm_sqr = zre_sqr + zim_sqr;
                    if(norm_sqr >= 4) {
                        break;
                    }
                    double temp = zre_sqr - zim_sqr + cre;
                    zim = zre * zim * 2 + cim;
                    zre = temp;
                }

                if(iteration == max_iterations) {
                    rgbs[y * width + x] = Color.BLACK.getRGB();
                }
                else {
                    rgbs[y * width + x] = default_palette[iteration % default_palette.length];
                }

            }
        }

        System.out.println(System.nanoTime() - time);

        time = System.nanoTime();

        VectorSpecies<Double> DOUBLE_SPECIES = DoubleVector.SPECIES_PREFERRED;

        int laneCount = DOUBLE_SPECIES.length();

        double[] re = new double[laneCount];
        double[] im = new double[laneCount];

        int length = DOUBLE_SPECIES.loopBound(rgbs.length);

        int[] iters = new int[laneCount];

        for(int p = 0; p < length; p += laneCount) {

            for(int i = 0; i < laneCount; i++) {
                int y = (p + i) / width;
                int x = (p + i) % width;
                re[i] = temp_xcenter_size + temp_size_image_size_x * x;
                im[i] = temp_ycenter_size - temp_size_image_size_y * y;
                iters[i] = -1;
            }


            DoubleVector creals = DoubleVector.fromArray(DOUBLE_SPECIES, re, 0);
            DoubleVector cimaginarys = DoubleVector.fromArray(DOUBLE_SPECIES, im, 0);
            DoubleVector zreals = DoubleVector.zero(DOUBLE_SPECIES);
            DoubleVector zimaginarys = DoubleVector.zero(DOUBLE_SPECIES);

            DoubleVector zreals_sqr;
            DoubleVector zimaginarys_sqr;
            DoubleVector norms;
            VectorMask<Double> maskNorm;

            for(int iteration = 0; iteration < max_iterations; iteration++) {

                zreals_sqr = zreals.mul(zreals);
                zimaginarys_sqr = zimaginarys.mul(zimaginarys);
                norms = zreals_sqr.add(zimaginarys_sqr);
                maskNorm = norms.lt(4.0);


                if(!maskNorm.allTrue()) {
                    setIter(maskNorm, laneCount, iteration, iters);
                }

                if(!maskNorm.anyTrue()) {
                    break;
                }


                var temp_reals = zreals_sqr.sub(zimaginarys_sqr).add(creals);
                zimaginarys = zreals.mul(zimaginarys).mul(2).add(cimaginarys);
                zreals = temp_reals;

            }

            for(int i = 0; i < laneCount; i++) {
                if(iters[i] == -1) {
                    iters[i] = max_iterations;
                }

                setPixel(iters[i], p + i, max_iterations, rgbs2, default_palette);
            }
            
        }

        for (int p = length; p < rgbs.length; p++) {
            //Handle the remaining
            //Todo
        }

        System.out.println(System.nanoTime() - time);

        try {
            ImageIO.write(image, "png", new File("out.png"));
        }
        catch (Exception ex) {}

        try {
            ImageIO.write(image2, "png", new File("out2.png"));
        }
        catch (Exception ex) {}

    }

    private static void setIter(VectorMask<Double> maskNorm, int laneCount, int iteration, int[] iters) {
        long mask = maskNorm.toLong();
        for(int i = 0; i < laneCount; i++) {
            if(iters[i] == -1 && (mask & 1) == 0) {
                iters[i] = iteration;
            }
            mask >>>= 1;
        }
    }

    private static void setPixel(int iteration, int index, int max_iterations, int[] rgbs, int[] default_palette) {
        if(iteration == max_iterations) {
            rgbs[index] = Color.BLACK.getRGB();
        }
        else {
            rgbs[index] = default_palette[iteration % default_palette.length];
        }
    }
}

For some reason the serial code outperforms the vector API version. Keep in mind the the vector API version does not handle the final pixels correctly, its still in todo mode.

Am I doing something wrong here?

The times that I get on my laptop in nanoseconds is: 2718696300 for the serial 3202282700 for the vector API

Also Ideally I would like to stop updating the zre, zim values when the corresponding lane has escaped from the bailout criterion.

var temp_reals = zreals_sqr.sub(zimaginarys_sqr, maskNorm).add(creals, maskNorm); zimaginarys = zreals.mul(zimaginarys, maskNorm).mul(2, maskNorm).add(cimaginarys, maskNorm); zreals = temp_reals;

I used something like that, but this was even slower.

Serial mandebrot calulation is faster than the Vector API equivalent.

Upvotes: 0

Views: 58

Answers (0)

Related Questions