user7867665
user7867665

Reputation: 882

Weighted sum of an input inside the network

I have a network with multiple inputs and I split out the first 10 inputs and calculate the weighted sum, and then concatenate it with the rest of the input:

first = Lambda(lambda z: z[:, 0:11])(d_inputs)
wsum_first = Lambda(calcWSumF)(first )
d_input = concatenate([d_inputs, wsum_first], axis=-1)

with the function defined as:

w_vec = K.constant(np.array([range(10)]*64).reshape(10, 64)) # batch size is 64
def calcWSumF(x):
    y = K.dot(w_vec, x)
    y = K.expand_dims(y, -1)       
    return y

I want a constant vector to be used to calculate the weighted sum of the first part of the input. The concatenation doesn't work because the shapes don't match. How can I implement this correctly?

Upvotes: 1

Views: 217

Answers (1)

today
today

Reputation: 33440

You can write this much better using K.sum and only a vector containing the coefficients. Further, there is no need to use a fixed batch size (it can be any number):

def calcWSumF(x, idx):
    w_vec = K.constant(np.arange(idx))
    y = K.sum(x[:, 0:idx] * w_vec, axis=-1, keepdims=True)
    return y

d_inputs = Input((15,))
wsum_first = Lambda(calcWSumF, arguments={'idx': 10})(d_inputs)
d_input = concatenate([d_inputs, wsum_first], axis=-1)

model = Model(d_inputs, d_input)
model.predict(np.arange(15).reshape(1, 15))

# output:
array([[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
         11.,  12.,  13.,  14., 285.]], dtype=float32)

# Note: 0*0 + 1*1 + 2*2 + ... + 9*9 = 285

Note that, to make it more general, we have added another argument (idx) to the lambda function which specifies how many of the elements from the beginning we would like to consider.

Upvotes: 1

Related Questions