Taha Mahjoubi
Taha Mahjoubi

Reputation: 402

Training different branches of model network with tf.switch_case

I want to create a neural network in which different branches of the network are trained depending on the t_input. So the t_input can be either 0 or 1 and depending on that only the correct branch will be trained :

import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Dense

x = np.random.uniform(size=(10, 10))
t = np.random.binomial(100, 0.5)

t_input = Input(batch_shape=(1,), dtype='int32', name="t_input")
x_input = Input(shape=(x.shape[0]), name='x_input')

x = Dense(32)(x_input)

x1 = Dense(16)(x)
x1 = Dense(8)(x1)
x1 = Dense(1)(x1)

x2 = Dense(16)(x)
x2 = Dense(8)(x2)
x2 = Dense(1)(x2)

x1 = lambda: x1
x2 = lambda: x2

r = Lambda(lambda x: tf.switch_case(x, branch_fns={0: x1, 1: x2}))(t_input)

# r = tf.case([(tf.equal(t_input, 1), x1), (tf.equal(t_input, 0), x2)], default=x2, exclusive=True)

model = tf.keras.models.Model(inputs=t_input, outputs=r)

print(model.predict([1]))

However, I cannot make this work as it is not flexible enough to use KerasTensors :

Traceback (most recent call last):
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\IPython\core\interactiveshell.py", line 3437, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-59-92db0d55c181>", line 23, in <module>
    r = Lambda(lambda x: tf.switch_case(x, branch_fns={0: x1, 1: x2}))(t_input)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 952, in __call__
    input_list)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1091, in _functional_construction_call
    inputs, input_masks, args, kwargs)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 822, in _keras_tensor_symbolic_call
    return self._infer_output_signature(inputs, args, kwargs, input_masks)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 863, in _infer_output_signature
    outputs = call_fn(inputs, *args, **kwargs)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\keras\layers\core.py", line 917, in call
    result = self.function(inputs, **kwargs)
  File "<ipython-input-59-92db0d55c181>", line 23, in <lambda>
    r = Lambda(lambda x: tf.switch_case(x, branch_fns={0: x1, 1: x2}))(t_input)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3616, in switch_case
    return _indexed_case_helper(branch_fns, default, branch_index, name)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3326, in _indexed_case_helper
    lower_using_switch_merge=lower_using_switch_merge)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\ops\cond_v2.py", line 1040, in indexed_case
    op_return_value=branch_index))
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\framework\func_graph.py", line 995, in func_graph_from_py_func
    expand_composites=True)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\util\nest.py", line 659, in map_structure
    structure[0], [func(*x) for x in entries],
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\util\nest.py", line 659, in <listcomp>
    structure[0], [func(*x) for x in entries],
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\framework\func_graph.py", line 952, in convert
    (str(python_func), type(x)))
TypeError: To be compatible with tf.eager.defun, Python functions must return zero or more Tensors; in compilation of <function <lambda> at 0x000001ED0876EAF8>, found return value of type <class 'function'>, which is not a Tensor.

Upvotes: 1

Views: 583

Answers (2)

Taha Mahjoubi
Taha Mahjoubi

Reputation: 402

I found a way to support more than 2 branches this way :

import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Dense
import tensorflow.keras as K
import numpy as np

def h_chooser(inpus):
  t_inp, h_s = inpus
  h_out = tf.squeeze(tf.concat(h_s, axis=1))
  t_inds = tf.stack([tf.range(tf.size(t_inp)), tf.squeeze(t_inp)], axis=1)
  h_res = tf.gather_nd(h_out, t_inds)
  return h_res

x_test = np.random.uniform(size=(10, 10))
t_test = np.array([np.random.binomial(100, 0.5)])

t_input = Input(shape=(1,), dtype=tf.int32, name="t_input")
x_input = Input(shape=(x_test.shape[1],), name='x_input')

x = Dense(32)(x_input)

x1 = Dense(16)(x)
x1 = Dense(8)(x1)
x1 = Dense(1)(x1)

x2 = Dense(16)(x)
x2 = Dense(8)(x2)
x2 = Dense(1)(x2)

x3 = Dense(16)(x)
x3 = Dense(8)(x3)
x3 = Dense(1)(x3)

h_switch_case = Lambda(lambda x: h_chooser(x))

r = h_switch_case([t_input, [x1, x2, x3])

model = tf.keras.models.Model(inputs=[t_input,x_input], outputs=r)

print(model.predict([np.tile(t_test,10),x_test]))

Upvotes: 1

jhso
jhso

Reputation: 3283

I got your code working by changing your tf.switch_case to a keras switch, and by inputting the two separate models in (you only input one of them in your code) Note that I had to tile your t_test input because it expects the two inputs to have the same batch dimension. I am also not sure that you want np.random.binomial because this samples from the binomial distribution and will almost never return 0. You should probably look at np.random.randint and limit it to values of 0 or 1.

import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Dense
import tensorflow.keras as K
import numpy as np

x_test = np.random.uniform(size=(10, 10))
t_test = np.array([np.random.binomial(100, 0.5)])

t_input = Input(shape=(1,), dtype=tf.int32, name="t_input")
x_input = Input(shape=(x_test.shape[1],), name='x_input')

x = Dense(32)(x_input)

x1 = Dense(16)(x)
x1 = Dense(8)(x1)
x1 = Dense(1)(x1)

x2 = Dense(16)(x)
x2 = Dense(8)(x2)
x2 = Dense(1)(x2)

r = K.backend.switch(t_input,x1,x2)

model = tf.keras.models.Model(inputs=[t_input,x_input], outputs=r)

print(model.predict([np.tile(t_test,10),x_test]))

Upvotes: 2

Related Questions