Reputation: 11
I tried this plan blow to compute the Jacobian of BertForMaskedLM using jacrev:
import numpy as np
from transformers import BertTokenizer,BertForMaskedLM
import torch
import torch.nn as nn
from functorch import make_functional, make_functional_with_buffers, vmap, vjp, jvp, jacrev
device = 'cuda:2'
torch.cuda.empty_cache()
model_name = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertForMaskedLM.from_pretrained(model_name)
net = bert_model.to(device)
fnet, params, buffers = make_functional_with_buffers(net)
def fnet_single(params,x,y):
result = fnet(params, buffers, x.unsqueeze(0).unsqueeze(0),y.unsqueeze(0).unsqueeze(0))['logits']
return result.squeeze(0).squeeze(0)
text = u'大肠杆菌是人和许多动物肠道中最主要的一种细菌'
inputs = tokenizer.encode_plus(text)
segment_ids = inputs['token_type_ids']
token_ids = inputs['input_ids']
length = len(token_ids) - 2
batch_token_ids = torch.tensor([token_ids] * (2 * length - 1),requires_grad=True).to(device)
batch_segment_ids = torch.zeros_like(batch_token_ids).to(device)
for i in range(length):
if i > 0:
batch_token_ids[2 * i - 1, i] = 103
batch_token_ids[2 * i - 1, i + 1] = 103
batch_token_ids[2 * i, i + 1] = 103
threshold = 100
word_token_ids = [[token_ids[1]]]
for i in range(1, length):
x,y = batch_token_ids[2 * i],batch_segment_ids[2*i]
jacobian1 = jacrev(fnet_single,argnums=1)(params,x,y)
x,y = batch_token_ids[2 * i - 1],batch_segment_ids[2*i-1]
jacobian2 = jacrev(fnet_single,argnums=1)(params,x,y)
Howerer,an error appeared: 'Traceback (most recent call last): File "study_jacrev.py", line 49, in batch_token_ids = torch.tensor([token_ids] * (2 * length - 1),requires_grad=True).to(device) RuntimeError: Only Tensors of floating point and complex dtype can require gradients' Is there anyone to help me?
Upvotes: 1
Views: 165
Reputation: 1209
It is because you are trying to get the jacobian with respect to data for whom the gradient scope is not set.
If you want to get the jacobian wrt parameters: jacrev(fnet_single, argnums=0)(params, x, y)
If you want to get the jacobian wrt data: x = x.to(torch.float32).requires_grad_(True)
(note that converting x dtype to float is mandatory to set the scope on it)
Upvotes: 0