Pm740
Pm740

Reputation: 391

Keras custom activation function with additional parameter / argument

How can I define a activation function in keras, which takes additional arguments. My initial custom activation function is a function, which generates points based on polynom of N degrees. The input are the coefficients for the polynom. It looks like this:

def poly_transfer(x):
    a = np.arange(0, 1.05, 0.05)
    b = []
    for i in range(x.shape[1]):
        b.append(a**i)
    b = np.asarray(b)
    b = b.astype(np.float32)
    c = matmul(x,b)
    return c

Now I want to set the lenght of the output from outside of the function. Somewhat like this:

def poly_transfer(x, lenght):
    a = np.arange(0, lenght + 0.05, 0.05)
    b = []
    for i in range(x.shape[1]):
        b.append(a**i)
    b = np.asarray(b)
    b = b.astype(np.float32)
    c = matmul(x,b)
    return c

How can I implement this functionlaity and how can I use it? At the moment:

speed_out = Lambda(poly_transfer)(speed_concat_layer)

As I imagined:

speed_out = Lambda(poly_transfer(lenght=lenght))(speed_concat_layer)

Upvotes: 1

Views: 506

Answers (2)

Marco Cerliani
Marco Cerliani

Reputation: 22031

you can simply do it in this way...

X = np.random.uniform(0,1, (100,10))
y = np.random.uniform(0,1, (100,))

def poly_transfer(x, lenght):

    a = np.arange(0, lenght + 0.05, 0.05)

    b = []
    for i in range(x.shape[1]):
        b.append(a**i)

    b = tf.constant(np.asarray(b), dtype=tf.float32)
    c = tf.matmul(x, b)

    return c

inp = Input((10,))
poly = Lambda(lambda x: poly_transfer(x, lenght=1))(inp)
out = Dense(1)(poly)

model = Model(inp, out)
model.compile('adam', 'mse')
model.fit(X, y, epochs=3)

Upvotes: 2

Lescurel
Lescurel

Reputation: 11631

You can use functools.partial to curry the function :

from functools import partial

poly_transfer_set_length = partial(poly_transfer, lenght=lenght)
speed_out = Lambda(poly_transfer_set_length)(speed_concat_layer)

or use a lambda function:

speed_out = Lambda(lambda x: poly_transfer(x, lenght=lenght))(speed_concat_layer)

Upvotes: 1

Related Questions