Alex H
Alex H

Reputation: 76

Dynamic Time Warping implementation in Tensorflow

I've rewritten a Dynamic Time Warping implementation from normal python into Tensorflow. But it's really slow -- much slower than pre-computing distances and loading them into Tensorflow as data. I can't figure out why it's slow or how to improve it.

I have also tried converting other DTW implementations with autograph, with no success. Any suggestions?

def tfDTW(s1, s2):
  r = tf.cast(tf.shape(s1)[0], tf.int32)
  c = tf.cast(tf.shape(s2)[0], tf.int32)
  window = tf.math.reduce_max([r,c])
  max_step = max_dist = 1e7
  penalty = psi = tf.constant(0, dtype=tf.float64)
  length =  tf.math.reduce_min([c + 1, tf.math.abs(r - c) + 2 * (window - 1) + 1 + 1 + 1])
  indices = [0,-1]
  dtw = tf.one_hot(indices, depth = length,
             on_value=0.0, off_value=1e7,
             axis=-1)  # output: [2,length]
  dtw=tf.cast(dtw, tf.float64)
  last_under_max_dist = tf.constant(0)
  skip = tf.constant(0)
  i0 = tf.constant(1)
  i1 = tf.constant(0)
  psi_shortest = 1e7
  def condition1(i, r, dtw, i0, i1, skip, last_under_max_dist):
    return tf.less(i, r)
  def body1(i, r, dtw, i0, i1, skip, last_under_max_dist):
    prev_last_under_max_dist = tf.cond(tf.equal(last_under_max_dist, -1), lambda: tf.cast(tf.constant(1e7), tf.int32), lambda: last_under_max_dist)
    last_under_max_dist = tf.constant(-1)
    skipp = skip
    skip = tf.reduce_max([0, i - tf.reduce_max([0, r - c]) - window + 1])
    i0 = 1 - i0
    i1 = 1 - i1
    dtw = tf.cond(tf.equal(i1, 0), lambda: tf.concat([tf.fill([1, length], tf.constant(1e7, dtype=tf.float64)), [dtw[1]]], 0), lambda: tf.concat([[dtw[0]], tf.fill([1, length], tf.constant(1e7, dtype=tf.float64))], 0) ) #dtw[i1, :] = np.inf
    j_start = tf.reduce_max([0, i - tf.reduce_max([0, r - c]) - window + 1])
    j_end = tf.reduce_min([c, i + tf.reduce_max([0, c - r]) + window])
    skip = tf.constant(0) #tf.cond(tf.equal(dtw.get_shape()[1], c+1), lambda: 0, lambda: skip )
    #if psi != 0 and j_start == 0 and i < psi:            dtw[i1, 0] = 0 #psi always ==0    
    def condition2(j, dtw, j_start, j_end, last_under_max_dist, prev_last_under_max_dist, skip, skipp):
      return tf.math.logical_and(tf.greater(j, j_start-1), tf.less(j,j_end))    
    def body2(j, dtw, j_start, j_end, last_under_max_dist, prev_last_under_max_dist, skip, skipp):
      d = (tf.gather(s1, i) - tf.gather(s2, j))*(tf.gather(s1, i) - tf.gather(s2, j))
      d = tf.cast(d, tf.float64)
      minval = tf.cast(tf.math.reduce_min([dtw[i0, j - skipp],
                                           dtw[i0, j + 1 - skipp] + penalty,
                                           dtw[i1, j - skip] + penalty]), tf.float64)
      indices = tf.cond(tf.equal(i1, 0), lambda: tf.stack([j + 1 - skip, -1] ), lambda: tf.stack([-1, j + 1 - skip]) )
      minusdtw = tf.one_hot(indices, depth = length,
                            on_value=-1*dtw[i1, j + 1 - skip], off_value=tf.constant(0.0, dtype=tf.float64),
                            axis=-1)     # output: [2,length]
      replacement = tf.one_hot(indices, depth = length,
                               on_value=tf.reduce_min([d + minval, 1e7]), off_value=tf.constant(0.0, dtype=tf.float64),
                               axis=-1)  # output: [2,length]
      dtw = dtw + minusdtw + replacement
      last_under_max_dist = j
      return tf.add(j, 1), dtw, j_start, j_end, last_under_max_dist, prev_last_under_max_dist, skip, skipp    
    b = tf.while_loop(condition2, body2, [j_start, dtw, j_start, j_end, last_under_max_dist, prev_last_under_max_dist, skip, skipp ],
                      [j_start.get_shape(), tf.TensorShape((2,None)), j_start.get_shape(), j_end.get_shape(), last_under_max_dist.get_shape(), prev_last_under_max_dist.get_shape(), skip.get_shape(), skipp.get_shape() ])
    return tf.add(i, 1), r, b[1], i0, i1, skip, b[4]
  a = tf.while_loop(condition1, body1, [tf.constant(0), r, dtw, i0, i1, skip, tf.constant(0) ],
                    [tf.constant(0).get_shape(), r.get_shape(), tf.TensorShape((None,None)), i0.get_shape(), i1.get_shape(), skip.get_shape(), tf.constant(0).get_shape() ])
  maindtw = a[2]
  d = tf.math.sqrt(maindtw [a[4]][ tf.reduce_min([c, c + window - 1]) - skip])
  return d

import tensorflow as tf
import numpy as np

graph = tf.Graph()
sess = tf.InteractiveSession()

s1 = tf.constant([10, 0, 1, 2, 1, 0, 1, 0, 0,14,22])
s2 = tf.constant([10, 1, 2, 0, 0, 0, 0])
tfDTW(s1, s2).eval() #26.13426869074396

Upvotes: 5

Views: 2483

Answers (1)


Reputation: 254

If you are doing one DTW it is hard to speed it up.

However, if you are doing many DTW invocations, you can make it amortized O(1).


See also

Upvotes: 2

Related Questions