Polik Jikop
Polik Jikop

Reputation: 11

How to computes the Jacobian of BertForMaskedLM using jacrev

I tried this plan blow to compute the Jacobian of BertForMaskedLM using jacrev:

import numpy as np
from transformers import BertTokenizer,BertForMaskedLM
import torch
import torch.nn as nn
from functorch import make_functional, make_functional_with_buffers, vmap, vjp, jvp, jacrev
device = 'cuda:2'
torch.cuda.empty_cache()


model_name = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertForMaskedLM.from_pretrained(model_name)

net = bert_model.to(device)
fnet, params, buffers = make_functional_with_buffers(net)

def fnet_single(params,x,y):
    result = fnet(params, buffers, x.unsqueeze(0).unsqueeze(0),y.unsqueeze(0).unsqueeze(0))['logits']
    return result.squeeze(0).squeeze(0)

text = u'大肠杆菌是人和许多动物肠道中最主要的一种细菌'
inputs = tokenizer.encode_plus(text)

segment_ids = inputs['token_type_ids']
token_ids = inputs['input_ids']
length = len(token_ids) - 2


batch_token_ids = torch.tensor([token_ids] * (2 * length - 1),requires_grad=True).to(device)
batch_segment_ids = torch.zeros_like(batch_token_ids).to(device)

for i in range(length):
    if i > 0:
        batch_token_ids[2 * i - 1, i] = 103
        batch_token_ids[2 * i - 1, i + 1] = 103
    batch_token_ids[2 * i, i + 1] = 103
threshold = 100
word_token_ids = [[token_ids[1]]]
for i in range(1, length):
    x,y = batch_token_ids[2 * i],batch_segment_ids[2*i]
    jacobian1 = jacrev(fnet_single,argnums=1)(params,x,y)
    x,y = batch_token_ids[2 * i - 1],batch_segment_ids[2*i-1]
    jacobian2 = jacrev(fnet_single,argnums=1)(params,x,y)

Howerer,an error appeared: 'Traceback (most recent call last): File "study_jacrev.py", line 49, in batch_token_ids = torch.tensor([token_ids] * (2 * length - 1),requires_grad=True).to(device) RuntimeError: Only Tensors of floating point and complex dtype can require gradients' Is there anyone to help me?

Upvotes: 1

Views: 165

Answers (1)

Valentin Goldité
Valentin Goldité

Reputation: 1209

It is because you are trying to get the jacobian with respect to data for whom the gradient scope is not set.

  • If you want to get the jacobian wrt parameters: jacrev(fnet_single, argnums=0)(params, x, y)

  • If you want to get the jacobian wrt data: x = x.to(torch.float32).requires_grad_(True) (note that converting x dtype to float is mandatory to set the scope on it)

Upvotes: 0

Related Questions