msi_gerva
msi_gerva

Reputation: 2078

Fast conversion of Java array to NumPy array (Py4J)

There are some nice examples how to convert NumPy array to Java array, but not vice versa - how to convert data from Java object back to NumPy array. I have a Python script like this:

    from py4j.java_gateway import JavaGateway
    gateway = JavaGateway()            # connect to the JVM
    my_java = gateway.jvm.JavaClass();  # my Java object
    ....
    int_array=my_java.doSomething(int_array); # do something

    my_numpy=np.zeros((size_y,size_x));
    for jj in range(size_y):
        for ii in range(size_x):
            my_numpy[jj,ii]=int_array[jj][ii];

my_numpy is the Numpy array, int_array is the Java array of integers - int[ ][ ] kind of array. Initialized in Python script as:

    int_class=gateway.jvm.int       # make int class
    double_class=gateway.jvm.double # make double class

    int_array = gateway.new_array(int_class,size_y,size_x)
    double_array = gateway.new_array(double_class,size_y,size_x)

Although, it works as it is, it is not the fastest way and works rather slowly - for ~1000x1000 array, the conversion took more than 5 minutes.

Is there any way how to make this with reasonable time?

If I try:

    test=np.array(int_array)

I get:

    ValueError: invalid __array_struct__

Upvotes: 3

Views: 5535

Answers (2)

Erlend Magnus Viggen
Erlend Magnus Viggen

Reputation: 383

I had a similar problem and found a solution that is around 220 times faster for the case I tested on: For transferring a 1628x120 array of short integers from Java to Numpy, the runtime was reduced from 11 seconds to 0.05 seconds. Thanks to this related StackOverflow question, I started looking into py4j byte arrays, and it turns out that py4j efficiently converts Java byte arrays to Python bytes objects and vice versa (passing by value, not by reference). It's a fairly roundabout way of doing things, but not too difficult.

Thus, if you want to transfer an integer array intArray with dimensions iMaxxjMax (and for the sake of the example, I assume that these are all stored as instance variables in your object), you can first write a Java function to convert it to a byte[] like so:

public byte[] getByteArray() {
    // Set up a ByteBuffer called intBuffer
    ByteBuffer intBuffer = ByteBuffer.allocate(4*iMax*jMax); // 4 bytes in an int
    intBuffer.order(ByteOrder.LITTLE_ENDIAN); // Java's default is big-endian

    // Copy ints from intArray into intBuffer as bytes
    for (int i = 0; i < iMax; i++) {
        for (int j = 0; j < jMax; j++){
            intBuffer.putInt(intArray[i][j]);
        }
    }

    // Convert the ByteBuffer to a byte array and return it
    byte[] byteArray = intBuffer.array();
    return byteArray;
}

Then, you can write Python 3 code to receive the byte array and convert it to a numpy array of the correct shape:

byteArray = gateway.entry_point.getByteArray()
intArray = np.frombuffer(byteArray, dtype=np.int32)
intArray = intArray.reshape((iMax, jMax))

Upvotes: 4

user3035873
user3035873

Reputation: 46

I've had a similar issue, just trying to plot spectral vectors (Java arrays) I got from the Java side via py4j. Here, the conversion from the Java Array to a Python list is achieved by the list() function. This might give some clues as how to use it to fill NumPy arrays ...

vectors = space.getVectorsAsArray(); # Java array (MxN)
wvl = space.getAverageWavelengths(); # Java array (N)

wavelengths = list(wvl)

import matplotlib.pyplot as mp
mp.hold
for i, dataset in enumerate(vectors):
    mp.plot(wavelengths, list(dataset))

Whether this is faster than the nested for loops you used I cannot say, but it also does the trick:

import numpy
from numpy  import array
x = array(wavelengths)
v = array(list(vectors))

mp.plot(x, numpy.rot90(v))

Upvotes: 2

Related Questions