cyberface
cyberface

Reputation: 101

How I create a new class that inherits from tf.Tensor?

I would like to create a new class that is basically an array but with some extra attributes.

Specifically I would like to write a class, based on Tensorflow objects, that will describe time series data. As such it will have an associated time spacing (delta_t) and time vector which I make with the sample_times property.

In python/numpy I do the following

import numpy as np
class TimeSeries(object):
    def __init__(self, initial_array, delta_t):
        self.initial_array = initial_array
        self.delta_t = delta_t

    @property
    def sample_times(self):
        return np.arange(self.initial_array.shape[0]) * self.delta_t

Is it possible to do something similar by by inheriting from tf.Tensor? The reason for this is that I believe it would make life simpler as when performing an analysis on these TimeSeries objects I could take advantage of various tensorflow things like the tf.function decorator.

My basic attempt at a minimum working example is as follows and just attempts to be a sub-class of tf.Tensor.

import tensorflow as tf

class TFTimeSeries(tf.Tensor):
    def __init__(self):
        super().__init__()

tf_ts = TFTimeSeries()

I get the following error upon instantiating TFTimeSeries

TypeError: __init__() missing 3 required positional arguments: 'op', 'value_index', and 'dtype'

dtype is easy enough but the other two: 'op' and 'value_index' I am not sure what to do about.

I should say that my knowledge of tensorflow is not very advanced at all and I would appreciate any help with this thanks!

EDIT:

Hi @Filippo Grazioli and thanks for you answer! I think this is the best way forward. After thinking about the design of my code a bit more I don't think it's very Tensorflow the way I was thinking. So just making a class where attributes are Tensors makes more sense.

I will mark this as answered now.

Upvotes: 4

Views: 657

Answers (1)

Filippo Grazioli
Filippo Grazioli

Reputation: 385

You have not passed op, value_index and dtype as arguments of TFTimeSeries when you instantiate your tf_ts object.

The same error would be thrown if you try to instantiate tf.Tensor() in the same fashion.

Regarding how to implement your TFTimeSeries class, tf.Variable and tf.costant might be interesting classes.

Here is their documentation: tf.Variable, tf.costant

Here is a question in which their differences are explained: TensorFlow Variables and Constants

I am not sure if I correctly understood what you need to do, but this might be a starting point:

import tensorflow as tf  
import numpy as np

class TimeSeries(object):
    def __init__(self, initial_array, dt):
        self.tensor = tf.Variable(initial_array, dtype=tf.float32)
        self.dt = dt
        self.initial_array = initial_array 

    def sample_times(self):
        self.tensor = tf.Variable(np.arange(self.initial_array.shape[0]) * self.dt, dtype=tf.float32)

Upvotes: 3

Related Questions