Macfli
Macfli

Reputation: 190

Tensorflow: base Python calculation inside an op

I am trying to perform datetime-related calculations element-wise on timestamps contained in a Tensor, using tf.map_fn. This requires conversion to a datetime and back to a Tensorflow-compatible type.

For example, let's say we want to get the number of the month from a Tensor of timestamps:

from datetime import datetime

dates = [datetime(2016, 1, 1).timestamp(),
         datetime(2016, 2, 1).timestamp()]

def timestamp_to_month(timestamp):
    return datetime.fromtimestamp(timestamp).month

def month(x, name=None):
    with tf.op_scope([x], name, "Month") as scope:
        return tf.map_fn(timestamp_to_month, x, back_prop=False)

month(dates)

This does not work as the timestamp parameter in timestamp_to_month is passed as a Tensor with shape [] and not a float, and has to be evaluated.

One solution would be to perform a timestamp.eval() before using the actual value, but then I would have to get the current session, probably with an additional session parameter, which would be inconvenient.

Additionally, this month op actually fails during the graph-building phase, not event during its execution, meaning that the mapped timestamp_to_month function is invoked when building the graph. Including a timestamp.eval() call would therefore trigger the execution of the graph when I actually just want to build it.

How can I include such base Python (or Numpy) steps inside an op while still deferring the execution of the graph?

Upvotes: 3

Views: 364

Answers (1)

hlidka
hlidka

Reputation: 2365

You can't insert arbitrary Python code to the TF graph like this. While some datetime functions (mostly timestamp diffs) can be represented with TF operations, the datetime fields will have to be computed in the preprocessing for the time being.

That being said, some datetime computations are easier than others. While months are really tricky, here's an example of extracting day of week and hour of day in Tensorflow. Minutes and seconds should be also doable.

MILIS_PER_HOUR = 1000 * 60 * 60
MILIS_PER_DAY = MILIS_PER_HOUR * 24

HOUR_CYCLE = 24
DAY_CYCLE = 7

MONDAY_MIDNIGHT = 1558310400000


def extract_timestamps(timestamps_ms):
  """Converts batched ms timestamps into one-hot encoded hours+weekdays."""
  monday_diff = tf.subtract(timestamps_ms, MONDAY_MIDNIGHT)
  hours = _int_mod(monday_diff, MILIS_PER_HOUR, HOUR_CYCLE)
  weekdays = _int_mod(monday_diff, MILIS_PER_DAY, DAY_CYCLE)
  return tf.concat([
      _compact_one_hot(hours, HOUR_CYCLE),
      _compact_one_hot(weekdays, DAY_CYCLE)
  ],
                   axis=1)


def _compact_one_hot(indices, depth):
  return tf.reshape(tf.one_hot(indices, depth), [-1, depth])


def _int_mod(x, x_scale, y):
  return tf.cast(tf.mod(tf.divide(x, x_scale), y), dtype=tf.uint8)

Upvotes: 1

Related Questions