kiyo
kiyo

Reputation: 1999

Getting the dtype of a result array in numpy

I want to preallocate memory for the output of an array operation, and I need to know what dtype to make it. Below I have a function that does what I want it to do, but is terribly ugly.

import numpy as np

def array_operation(arr1, arr2):
    out_shape = arr1.shape
    # Get the dtype of the output, these lines are the ones I want to replace.
    index1 = ([0],) * arr1.ndim
    index2 = ([0],) * arr2.ndim
    tmp_arr = arr1[index1] * arr2[index2]
    out_dtype = tmp_arr.dtype
    # All so I can do the following.
    out_arr = np.empty(out_shape, out_dtype)

The above is pretty ugly. Does numpy have a function that does this?

Upvotes: 6

Views: 6279

Answers (2)

unutbu
unutbu

Reputation: 879471

For those using numpy version < 1.6, you could use:

def result_type(arr1, arr2):
    x1 = arr1.flat[0]
    x2 = arr2.flat[0]
    return (x1 * x2).dtype

def array_operation(arr1, arr2):
    return np.empty(arr1.shape, result_type(arr1, arr2))

This isn't very different from the code you posted, though I think arr1.flat[0] is a slight improvement over index1 = ([0],) * arr1.ndim; arr1[index1].

For numpy version >= 1.6, use Mike Graham's answer, np.result_type

Upvotes: 1

Mike Graham
Mike Graham

Reputation: 76683

You are looking for numpy.result_type.

(As an aside, do you realize that you can access all multi-dimensional arrays as 1d arrays? You don't need to access x[0, 0, 0, 0, 0] -- you can access x.flat[0].)

Upvotes: 8

Related Questions