Ata
Ata

Reputation: 53

How can you get length of a TensorFlow string?

Is there any way to get length of a TensorFlow string within TensorFlow? For example, is there any function that returns the length of a = tf.constant("Hello everyone", tf.string) as 14 without passing the string back to Python.

Upvotes: 5

Views: 3135

Answers (4)

starbeamrainbowlabs
starbeamrainbowlabs

Reputation: 6106

I'm not sure what version of Tensorflow it was added in, but in Tensorflow 2.4 and above at least there is now a new function to get the length of a string: tf.strings.length(string_tensor). Here's an example of it at work:

import tensorflow as tf

str_1 = tf.constant("yaaaaay")
str_2 = tf.constant("")

print(str_1)
print(str_2)

print(tf.strings.length(str_1))
print(tf.strings.length(str_2))

Example output:

tf.Tensor(b'yaaaaay', shape=(), dtype=string)
tf.Tensor(b'', shape=(), dtype=string)
tf.Tensor(7, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)

Upvotes: 0

Forth Temple
Forth Temple

Reputation: 114

This works for me:

x = tf.constant("Hello everyone")

# Launch the default graph.
with tf.Session() as sess:
    print(tf.size(tf.string_split([x],"")).eval())

Upvotes: 5

mdaoust
mdaoust

Reputation: 6367

Another sub-optimal option is to convert your strings to sparse:

strings = ['Why hello','world','!']
chars = tf.string_split(strings,"")

Then calculate the max index on each line +1

line_number = chars.indices[:,0]
line_position = chars.indices[:,1]
lengths = tf.segment_max(data = line_position, 
                         segment_ids = line_number) + 1

with tf.Session() as sess:
    print(lengths.eval())

[9 5 1]

Upvotes: 1

keveman
keveman

Reputation: 8487

No such function exists as of TensorFlow version 0.9. However, you can use tf.py_func to run arbitrary Python functions over TensorFlow tensors. Here is one way to get length of a TensorFlow string :

def string_length(t):
  return tf.py_func(lambda p: [len(x) for x in p], [t], [tf.int64])[0]

a = tf.constant(["Hello everyone"], tf.string)
sess = tf.InteractiveSession()
sess.run(string_length(a))

Upvotes: 1

Related Questions