tda
tda

Reputation: 2133

Interpolate a 3D array in Python

I have a 3D NumPy array that looks like this:

arr = np.empty((4,4,5))
arr[:] = np.nan
arr[0] = 1
arr[3] = 4

arr
>>> [[[ 1.  1.  1.  1.  1.]
      [ 1.  1.  1.  1.  1.]
      [ 1.  1.  1.  1.  1.]
      [ 1.  1.  1.  1.  1.]]

     [[ nan nan nan nan nan]
      [ nan nan nan nan nan]
      [ nan nan nan nan nan]
      [ nan nan nan nan nan]]

     [[ nan nan nan nan nan]
      [ nan nan nan nan nan]
      [ nan nan nan nan nan]
      [ nan nan nan nan nan]]

     [[ 4.  4.  4.  4.  4.]
      [ 4.  4.  4.  4.  4.]
      [ 4.  4.  4.  4.  4.]
      [ 4.  4.  4.  4.  4.]]]

I would like to interpolate along axis=0 so that I get the following:

>>> [[[ 1.  1.  1.  1.  1.]
      [ 1.  1.  1.  1.  1.]
      [ 1.  1.  1.  1.  1.]
      [ 1.  1.  1.  1.  1.]]

     [[ 2.  2.  2.  2.  2.]
      [ 2.  2.  2.  2.  2.]
      [ 2.  2.  2.  2.  2.]
      [ 2.  2.  2.  2.  2.]]

     [[ 3.  3.  3.  3.  3.]
      [ 3.  3.  3.  3.  3.]
      [ 3.  3.  3.  3.  3.]
      [ 3.  3.  3.  3.  3.]]

     [[ 4.  4.  4.  4.  4.]
      [ 4.  4.  4.  4.  4.]
      [ 4.  4.  4.  4.  4.]
      [ 4.  4.  4.  4.  4.]]]

I've been looking at the SciPy module and there seems to be methods to do this on a 1D and 2D array, but not 3D like I need - though I may have missed something.

Upvotes: 1

Views: 4136

Answers (2)

xdze2
xdze2

Reputation: 4151

A solution using apply_along_axis:

import numpy as np

def pad(data):
    good = np.isfinite(data)
    interpolated = np.interp(np.arange(data.shape[0]),
                             np.flatnonzero(good), 
                             data[good])
    return interpolated


arr = np.arange(6, dtype=float).reshape((3,2))
arr[1, 1] = np.nan
print(arr)

new = np.apply_along_axis(pad, 0, arr)
print(arr)
print(new)

output:

[[ 0.  1.]
 [ 2. nan]
 [ 4.  5.]]

[[ 0.  1.]
 [ 2. nan]
 [ 4.  5.]]

[[0. 1.]
 [2. 3.]
 [4. 5.]]

[edit] The first proposed solution:

With some modification of the code from this answer:

import numpy as np
from scipy import interpolate

A = np.empty((4,4,5))
A[:] = np.nan
A[0] = 1
A[3] = 4

indexes = np.arange(A.shape[0])
good = np.isfinite(A).all(axis=(1, 2)) 

f = interpolate.interp1d(indexes[good], A[good],
                         bounds_error=False,
                         axis=0)

B = f(indexes)
print(B)

gives:

[[[1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]]

 [[2. 2. 2. 2. 2.]
  [2. 2. 2. 2. 2.]
  [2. 2. 2. 2. 2.]
  [2. 2. 2. 2. 2.]]

 [[3. 3. 3. 3. 3.]
  [3. 3. 3. 3. 3.]
  [3. 3. 3. 3. 3.]
  [3. 3. 3. 3. 3.]]

 [[4. 4. 4. 4. 4.]
  [4. 4. 4. 4. 4.]
  [4. 4. 4. 4. 4.]
  [4. 4. 4. 4. 4.]]]

It works well only if NaNs are all on same slice. The slice in which there is an isolated NaN will be ignored.

Upvotes: 2

tda
tda

Reputation: 2133

From the comment provided by xdze2 and previous answer here, I came up with this:

import numpy as np

def pad(data):
    bad_indexes = np.isnan(data)
    good_indexes = np.logical_not(bad_indexes)
    good_data = data[good_indexes]
    interpolated = np.interp(bad_indexes.nonzero()[0], good_indexes.nonzero()[0], 
    good_data)
    data[bad_indexes] = interpolated
    return data

arr = np.empty((4,4,5))
arr[:] = np.nan

arr[0] = 25
arr[3] = 32.5

# Apply the pad method to each 0 axis
new = np.apply_along_axis(pad, 0, arr)

The 'pad' method essentially applies the interpolation and the np.apply_along_axis method ensures that this is applied to the 3D array.

Upvotes: 0

Related Questions