thaumoctopus
thaumoctopus

Reputation: 213

Getting attention weights in opennmt-py

Specifically in opennmt-py. Now there are many questions on this topic, such as Getting alignment/attention during translation in OpenNMT-py and the following thread on the opennmt forum https://github.com/OpenNMT/OpenNMT-py/issues/575. I use the code suggested by the latter. However none seem to address the problem I have. I try to run the following simple snippet of code.

import onmt
import onmt.inputters
import onmt.translate
import onmt.model_builder
from collections import namedtuple


Opt = namedtuple('Opt', ['models', 'data_type', 'reuse_copy_attn', "gpu"])


opt = Opt("/home/Desktop/hidden-att/model/hidden-2/seed-0/LSTMlang1_step_400.pt", "text",False, 0)
fields, model, model_opt =  onmt.model_builder.load_test_model(opt,{"reuse_copy_attn":False})


I get this error trace.

Traceback (most recent call last):

  File "<ipython-input-63-94c1f45c429f>", line 1, in <module>
    runfile('/home/Desktop/hidden-att/graph_hidden_exp.py', wdir='/home/Desktop/hidden-att')

  File "/home/anaconda3/lib/python3.7/site-packages/spyder_kernels/customize/spydercustomize.py", line 786, in runfile
    execfile(filename, namespace)

  File "/home/anaconda3/lib/python3.7/site-packages/spyder_kernels/customize/spydercustomize.py", line 110, in execfile
    exec(compile(f.read(), filename, 'exec'), namespace)

  File "/home/Desktop/hidden-att/graph_hidden_exp.py", line 33, in <module>
    fields, model, model_opt =  onmt.model_builder.load_test_model(opt,{"reuse_copy_attn":False})

  File "../../Documents/NMT/OpenNMT-py/onmt/model_builder.py", line 85, in load_test_model
    map_location=lambda storage, loc: storage)

  File "/home/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 387, in load
    return _load(f, map_location, pickle_module, **pickle_load_args)

  File "/home/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 549, in _load
    _check_seekable(f)

  File "/home/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 194, in _check_seekable
    raise_err_msg(["seek", "tell"], e)

  File "/home/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 187, in raise_err_msg
    raise type(e)(msg)

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

So has anyone experienced and solved this problem? Or know where to look? I guess it is something about the loaded file, but it is trained with opennmt-py in a fairly standard way.

Upvotes: 1

Views: 553

Answers (2)

Kostas Mouratidis
Kostas Mouratidis

Reputation: 1255

I'm using a different interface and probably a very different version of OpenNMT by now but this might help someone.

I'm catching the attention scores by injecting my code into ONMT like this:

from onmt.utils import misc
from onmt.translate import translator as _translator

data = []

def report_matrix(srcs, preds, attns):
    data.append((srcs, preds, attns))
    return misc.report_matrix(srcs, preds, attns)

_translator.report_matrix = report_matrix

Then I translate with something like this [1]:

from io import StringIO
from onmt.translate.translator import build_translator, Translator
from onmt.utils.misc import split_corpus


filename = "path"
shard_size = 10000
outfile = StringIO()
onmt_config = ...  # custom class with attributes like `models`, `output`, `gpu`

translator = build_translator(onmt_config, report_score=True, out_file=outfile)
for shard in split_corpus(filename, shard_size):
    translator.translate(src=shard, batch_size=4, attn_debug=True)

This works because when you enable the attn_debug property, OpenNMT calls report_matrix here and it gives me all the info I need. The data is a list of tuples where the elements are (in order):

  • a list of strings containing all the tokens, len=N
  • a list of strings containing the predicted tokens, len=M
  • a list of lists of floats containing the attention score for output/input combination, shape=(M,N)

[1] Not exactly my code, I tried to remove a few stuff to make it cleaner.

Upvotes: 0

xwzhang
xwzhang

Reputation: 11

You can add --attn_debug parameter in your translate script to look the attention weights.

translate.py ... \
             -attn_debug \
             ...

Upvotes: 1

Related Questions