Reputation: 190
I am trying to perform datetime-related calculations element-wise on timestamps contained in a Tensor, using tf.map_fn
. This requires conversion to a datetime and back to a Tensorflow-compatible type.
For example, let's say we want to get the number of the month from a Tensor of timestamps:
from datetime import datetime
dates = [datetime(2016, 1, 1).timestamp(),
datetime(2016, 2, 1).timestamp()]
def timestamp_to_month(timestamp):
return datetime.fromtimestamp(timestamp).month
def month(x, name=None):
with tf.op_scope([x], name, "Month") as scope:
return tf.map_fn(timestamp_to_month, x, back_prop=False)
month(dates)
This does not work as the timestamp
parameter in timestamp_to_month
is passed as a Tensor with shape []
and not a float, and has to be evaluated.
One solution would be to perform a timestamp.eval()
before using the actual value, but then I would have to get the current session, probably with an additional session
parameter, which would be inconvenient.
Additionally, this month
op actually fails during the graph-building phase, not event during its execution, meaning that the mapped timestamp_to_month
function is invoked when building the graph. Including a timestamp.eval()
call would therefore trigger the execution of the graph when I actually just want to build it.
How can I include such base Python (or Numpy) steps inside an op while still deferring the execution of the graph?
Upvotes: 3
Views: 364
Reputation: 2365
You can't insert arbitrary Python code to the TF graph like this. While some datetime functions (mostly timestamp diffs) can be represented with TF operations, the datetime fields will have to be computed in the preprocessing for the time being.
That being said, some datetime computations are easier than others. While months are really tricky, here's an example of extracting day of week and hour of day in Tensorflow. Minutes and seconds should be also doable.
MILIS_PER_HOUR = 1000 * 60 * 60
MILIS_PER_DAY = MILIS_PER_HOUR * 24
HOUR_CYCLE = 24
DAY_CYCLE = 7
MONDAY_MIDNIGHT = 1558310400000
def extract_timestamps(timestamps_ms):
"""Converts batched ms timestamps into one-hot encoded hours+weekdays."""
monday_diff = tf.subtract(timestamps_ms, MONDAY_MIDNIGHT)
hours = _int_mod(monday_diff, MILIS_PER_HOUR, HOUR_CYCLE)
weekdays = _int_mod(monday_diff, MILIS_PER_DAY, DAY_CYCLE)
return tf.concat([
_compact_one_hot(hours, HOUR_CYCLE),
_compact_one_hot(weekdays, DAY_CYCLE)
],
axis=1)
def _compact_one_hot(indices, depth):
return tf.reshape(tf.one_hot(indices, depth), [-1, depth])
def _int_mod(x, x_scale, y):
return tf.cast(tf.mod(tf.divide(x, x_scale), y), dtype=tf.uint8)
Upvotes: 1