OmaymaS
OmaymaS

Reputation: 1731

keras embedding weights lookup with categorical variables

Suppose that I have a list of users and items as in the ratings dataframe. If I create embedding layers item_embedding and user_embeddingthen get the weights of the embedding layers, how do the embedding vectors map to the item/user id? Do they follow the order of the ids?

import pandas as pd

## minimal example
ratings = {'user': [1000, 10001, 1000], 'item': [115, 112, 115], 'rating': [5, 3, 4]}


## keras model----------------------------------------------------

n_latent_factors_user = 8
n_latent_factors_item = 8

n_users = len(train.user_id.unique())
n_items = len(train.book_id.unique())

## items
item_input= keras.layers.Input(shape=[1],name='Item') ## input 
item_embedding = keras.layers.Embedding(n_items + 1, n_latent_factors_item, name='item-Embedding')(item_input)
item_vec = keras.layers.Flatten(name='Flattenitems')(item_embedding)

## users
user_input = keras.layers.Input(shape=[1],name='User') ## input
user_embedding = keras.layers.Embedding(n_users + 1, n_latent_factors_user,name='User-Embedding')(user_input)
user_vec = keras.layers.Flatten(name='FlattenUsers')(user_embedding)

## concat items and users
concat = keras.layers.concatenate([item_vec, user_vec])

## fully connected
dense_1 = keras.layers.Dense(20,name='FullyConnected', activation='relu')(concat)

## output
result = keras.layers.Dense(1, activation='relu',name='Activation')(dense_1)

## model with input and output
model = keras.Model([user_input, item_input], result)

I mean if we get the weights of the items embedding layer as follows, would the first vector correspond to item 112?

## items embedding weights
model.layers[2].get_weights()[0]  ## 2x8

Upvotes: 1

Views: 1738

Answers (1)

Anna Krogager
Anna Krogager

Reputation: 3588

You need to make sure that your users are enumerated as 0, ..., n_users and that items are enumerated as 0, ..., n_items. To get the embeddings of your items you can do

embeddings_items = model.get_layer('item-Embedding').get_weights()[0]

Then embeddings_items[0] gives you the embedding of item number 0.

Upvotes: 1

Related Questions