Fengxiang Hu
Fengxiang Hu

Reputation: 81

a problem and how to deal with batch while creating a Model

enter image description here

from keras_multi_head import MultiHeadAttention
import keras
from keras.layers import Dense,Input,Multiply
from keras import backend as K
from keras.layers.core import Dropout, Layer
from keras.models import Sequential,Model
import numpy as np
import tensorflow as tf
from self_attention_layer import Encoder



## multi source attention
class Multi_source_attention(keras.Model):

    def __init__(self,read_n,embed_dim,num_heads,ff_dim,num_layers):
        super().__init__()
        self.read_n = read_n
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.num_layers = num_layers
        self.get_weights = Dense(49, activation = 'relu',name = "get_weights")
    
        
    def compute_output_shape(self,input_shape):
        #([batch,7,7,256],[1,256])
        return input_shape


    def call(self,inputs):
        ## weights matrix

        #(1,49)
        weights_res = self.get_weights(inputs[1])
        #(1,7,7)
        weights = tf.reshape(weights_res,(1,7,7))
        #(256,7,7)
        weights = tf.tile(weights,[256,1,1])
      
        ## img from mobilenet
        img=tf.reshape(inputs[0],[-1,7,7])


        
        inter_res = tf.multiply(img,weights)
        inter_res = tf.reshape(inter_res, (-1,256,49))
        print(inter_res.shape)
        att = Encoder(self.embed_dim,self.num_heads,self.ff_dim,self.num_layers)(inter_res)

        return att

I try to construct a network to implement the part circled in the image. The output from LSTM **(1,256) and from the previous Mobilenet (batch,7,7,256). Then the output of LSTM is transformed to a weights matrix in form of (7,7).

But the problem is that the input shape of the output from mobilenet has a attribute batch. I have no idea how to deal with "batch" or how to set up a parameter to constraint the batch?

Could someone give me a tip?

And if I remove the function compute_output_shape(), one error unimplementerror occurs. the keras official doc tells me that I don't need to overwrite the function. Could someone explain me about that?

Upvotes: 3

Views: 67

Answers (1)

Fengxiang Hu
Fengxiang Hu

Reputation: 81

Compute_output_shape is crucial to custom the layer. if the function summary() is called, the corresponding Graph is generated where the input and output shapes are showed in every layer. The compute_output_shape is responsible for the output shape.

Upvotes: 2

Related Questions