Patrick Lee
Patrick Lee

Reputation: 165

How to shift a tensor like pandas.shift in tensorflow / keras? (Without shift the last row to first row, like tf.roll)

I want to shift a tensor in a given axis. It's easy to do this in pandas or numpy. Like this:

import numpy as np
import pandas as pd

data = np.arange(0, 6).reshape(-1, 2)
pd.DataFrame(data).shift(1).fillna(0).values

Output is:

array([[0., 0.],
[0., 1.],
[2., 3.]])

But in tensorflow, the closest solution I found is tf.roll. But it shift the last row to the first row. (I don't want that). So I have to use something like

tf.roll + tf.slice(remove the last row) + tf.concat(add tf.zeros to the first row).

It's really ugly.

Is there a better way to handle shift in tensorflow or keras?

Thanks.

Upvotes: 4

Views: 888

Answers (5)

Worthy7
Worthy7

Reputation: 1561

I think you can use pad with negatives. This shifts backwards -1, and puts a 0 at the end.

keras.ops.pad([1, 2, 3], (-1, 1), 'constant', 0.)

tensor([2, 3, 0], device='cuda:0', dtype=torch.int32)

This will shift forward one, and put a 0 at the start

keras.ops.pad([1, 2, 3], (1, 0), 'constant', 0.)

tensor([0, 1, 2, 3], device='cuda:0', dtype=torch.int32)

You get the idea.

https://www.tensorflow.org/api_docs/python/tf/keras/ops/pad

pd.DataFrame([1,2,3]).shift(1).fillna(0)

is

keras.ops.pad([1, 2, 3], (1, -1), 'constant', 0.)

Upvotes: 0

cy789
cy789

Reputation: 55

Generalizing the accepted answer to arbitrary tensor shapes, desired shift, and axis to shift:

import tensorflow as tf

def tf_shift(tensor, shift=1, axis=0):
    dim = len(tensor.shape)

    if axis > dim:
        raise ValueError(
            f'Value of axis ({axis}) must be <= number of tensor axes ({dim})'
        )

    mask_dim = dim - axis
    mask_shape = tensor.shape[-mask_dim:]
    zero_dim = min(shift, mask_shape[0])

    mask = tf.concat(
        [tf.zeros(tf.TensorShape(zero_dim) + mask_shape[1:]),
         tf.ones(tf.TensorShape(mask_shape[0] - zero_dim) + mask_shape[1:])],
        axis=0
    )

    for i in range(dim - mask_dim):
        mask = tf.expand_dims(mask, axis=0)

    return tf.multiply(
        tf.roll(tensor, shift, axis),
        mask
    )

EDIT: This code above doesn't allow for negative shift values, and is pretty slow. Here is a more efficient version utilizing tf.roll and tf.concat without creating a mask and multiplying the tensor of interest by it.

import tensorflow as tf

def tf_shift(values: tf.Tensor, shift: int = 1, axis: int = 0):
    pad = tf.zeros([val if i != axis else abs(shift) for i, val in enumerate(values.shape)],
                   dtype=values.dtype)
    size = [-1 if i != axis else val - abs(shift) for i, val in enumerate(values.shape)]

    if shift > 0:
        shifted = tf.concat(
            [pad, tf.slice(values, [0] * len(values.shape), size)],
            axis=axis
        )
    elif shift < 0:
        shifted = tf.concat(
            [tf.slice(values, [0 if i != axis else abs(shift) for i, _ in enumerate(values.shape)], size), pad],
            axis=axis
        )
    else:
        shifted = values

    return shifted

Upvotes: 2

Viktor Zhou
Viktor Zhou

Reputation: 11

Assuming a 2d tensor, this function should mimic a Dataframe shift:

def shift_tensor(tensor, periods, fill_value):
    num_row = len(tensor)
    num_col = len(tensor[0])
    pad = tf.fill([periods, num_col], fill_value)
    
    if periods > 0:
        shifted_tensor = tf.concat((pad, tensor[:(num_row - periods), :]), axis=0)
    else:
        shifted_tensor = tf.concat((tensor[:(num_row - periods), :], pad), axis=0)
    
    return shifted_tensor

Upvotes: 1

Patrick Lee
Patrick Lee

Reputation: 165

I think I find a better way for this problem.

We could use tf.roll, then apply tf.math.multiply to set the first row to zeros.

Sample code is as follows:

Original tensor:

A = tf.cast(tf.reshape(tf.range(27), (-1, 3, 3)), dtype=tf.float32)
A

Output:

<tf.Tensor: id=117, shape=(3, 3, 3), dtype=float32, numpy=
array([[[ 0.,  1.,  2.],
        [ 3.,  4.,  5.],
        [ 6.,  7.,  8.]],

       [[ 9., 10., 11.],
        [12., 13., 14.],
        [15., 16., 17.]],

       [[18., 19., 20.],
        [21., 22., 23.],
        [24., 25., 26.]]], dtype=float32)>

Shift (like pd.shift):

B = tf.concat((tf.zeros((1, 3)), tf.ones((2, 3))), axis=0)
C = tf.expand_dims(B, axis=0)
tf.math.multiply(tf.roll(A, 1, axis=1), C)

Output:

<tf.Tensor: id=128, shape=(3, 3, 3), dtype=float32, numpy=
array([[[ 0.,  0.,  0.],
        [ 0.,  1.,  2.],
        [ 3.,  4.,  5.]],

       [[ 0.,  0.,  0.],
        [ 9., 10., 11.],
        [12., 13., 14.]],

       [[ 0.,  0.,  0.],
        [18., 19., 20.],
        [21., 22., 23.]]], dtype=float32)>

Upvotes: 3

Andrey
Andrey

Reputation: 6367

Try this:

import tensorflow as tf
input = tf.constant([[0, 1, 3], [4, 5, 6], [7, 8, 9]])
shifted_0dim = input[1:]
shifted_1dim = input[:, 1:]
shifted2 = input[2:]

Upvotes: 2

Related Questions