Reputation: 31
I have set up my model for hyper parameter tuning using either nevergrad or optuna (both result in the same problem).
The problem: After the first model has trained (i.e. first parameter search), then when the second is about to begin training, I get the following error message:
MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":49:0)): error: 'mps.scatter_nd' op invalid input tensor shape: updates tensor shape and data tensor shape must match along inner dimensions /AppleInternal/Library/BuildRoots/97f6331a-ba75-11ed-a4bc-863efbbaf80d/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1710: failed assertion `Error: MLIR pass manager failed'
Seemingly there cannot be anything wrong with the model, because it successfully ran the first time, so i am wondering what causes the issue. I have also found that everything works when I set the device to CPU. The problem occurs when I use mps as device.
Here's a code snippet of my training script:
import os
import time
import torch
import argparse
import numpy as np
from model import SASRec
from utils import *
import wandb
import nevergrad as ng
def str2bool(s):
if s not in {'false', 'true'}:
raise ValueError('Not a valid boolean string')
return s == 'true'
def nevergrad_optimization(args, n_trials):
# Define the search space for the hyperparameters
def create_search_space():
hidden_units_choices = [32, 48, 64, 80, 96, 112, 128]
num_heads_choices = [i for i in range(1, len(hidden_units_choices)+1) if hidden_units_choices[-1] % i == 0]
search_space = ng.p.Dict(
dropout_rate=ng.p.Scalar(lower=0.1, upper=0.5).set_mutation(sigma=0.1).set_name("dropout_rate"),
maxlen=ng.p.Scalar(lower=50, upper=200).set_mutation(sigma=25).set_integer_casting().set_name("maxlen"),
lr=ng.p.Log(lower=1e-6, upper=1e-2).set_mutation(sigma=1e-3).set_name("lr"),
hidden_units=ng.p.Choice(hidden_units_choices).set_name("hidden_units"),
num_heads=ng.p.Choice(num_heads_choices).set_name("num_heads"),
num_blocks=ng.p.Scalar(lower=1, upper=6).set_mutation(sigma=1).set_integer_casting().set_name("num_blocks"),
)
return search_space
# Define the optimization function
def objective(hyperparameters):
args.dropout_rate = hyperparameters["dropout_rate"]
args.maxlen = int(hyperparameters["maxlen"])
args.lr = hyperparameters["lr"]
args.hidden_units = hyperparameters["hidden_units"]
n_heads_max = args.hidden_units // 2
var_heads = [i for i in range(1, n_heads_max + 1) if args.hidden_units % i == 0]
args.num_heads = hyperparameters["num_heads"] % len(var_heads)
args.num_heads = var_heads[args.num_heads]
args.num_blocks = int(hyperparameters["num_blocks"])
hit_10 = main(args)
return -hit_10
# create search space
search_space = create_search_space()
# Create the optimizer
optimizer = ng.optimizers.CMA(parametrization=search_space, budget=n_trials)
# Perform the optimization
recommendation = optimizer.minimize(objective)
# Retrieve the best trial
best_hyperparams = recommendation.value
best_value = -objective(best_hyperparams)
best_trial = {
"dropout_rate": recommendation["dropout_rate"],
"maxlen": int(recommendation["maxlen"]),
"lr": recommendation["lr"],
"hidden_units": recommendation["hidden_units"],
"num_heads": recommendation["num_heads"],
"num_blocks": int(recommendation["num_blocks"]),
}
return best_trial
def main(args):
if not os.path.isdir(args.dataset + '_' + args.train_dir):
os.makedirs(args.dataset + '_' + args.train_dir)
with open(os.path.join(args.dataset + '_' + args.train_dir, 'args.txt'), 'w') as f:
f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])]))
f.close()
# init wandb
wandb.init(project='SASRec_RND', config=args)
# global dataset
dataset = data_partition(args.dataset)
[user_train, user_valid, user_test, usernum, itemnum] = dataset
num_batch = len(user_train) // args.batch_size
cc = 0.0
for u in user_train:
cc += len(user_train[u])
print('average sequence length: %.2f' % (cc / len(user_train)))
f = open(os.path.join(args.dataset + '_' + args.train_dir, 'log.txt'), 'w')
sampler = WarpSampler(user_train, usernum, itemnum, batch_size=args.batch_size, maxlen=args.maxlen, n_workers=3)
model = SASRec(usernum, itemnum, args).to(args.device)
# init weights
for _, param in model.named_parameters():
try:
torch.nn.init.xavier_normal_(param.data)
except:
pass
model.train()
epoch_start_idx = 1
if args.state_dict_path is not None:
try:
model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device)))
tail = args.state_dict_path[args.state_dict_path.find('epoch=') + 6:]
epoch_start_idx = int(tail[:tail.find('.')]) + 1
except:
print('failed loading state_dicts, pls check file path: ', end="")
print(args.state_dict_path)
print('pdb enabled for your quick check, pls type exit() if you do not need it')
import pdb; pdb.set_trace()
if args.inference_only:
model.eval()
t_test = evaluate_full(model, dataset, args, ks=[1, 3, 5, 10])
for k in t_test:
print('test (NDCG@%d: %.4f, HR@%d: %.4f)' % (k, t_test[k][0], k, t_test[k][1]))
bce_criterion = torch.nn.BCEWithLogitsLoss()
adam_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98))
# Early stopping parameters
early_stopping_patience = 2 # Stop if there is no improvement for this many epochs
best_hit_10 = 0
epochs_without_improvement = 0
T = 0.0
t0 = time.time()
for epoch in range(epoch_start_idx, args.num_epochs + 1):
if args.inference_only:
break
for step in range(num_batch):
u, seq, pos, neg = sampler.next_batch()
u, seq, pos, neg = np.array(u), np.array(seq), np.array(pos), np.array(neg)
pos_logits, neg_logits = model(u, seq, pos, neg)
pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)
adam_optimizer.zero_grad()
indices = np.where(pos != 0)
loss = bce_criterion(pos_logits[indices], pos_labels[indices])
loss += bce_criterion(neg_logits[indices], neg_labels[indices])
for param in model.item_emb.parameters():
loss += args.l2_emb * torch.norm(param)
loss.backward()
adam_optimizer.step()
print("loss in epoch {} iteration {}: {}".format(epoch, step, loss.item()))
wandb.log({"loss": loss.item()})
if epoch % 10 == 0:
model.eval()
t1 = time.time() - t0
T += t1
print('Evaluating', end='')
t_valid = evaluate_valid_full(model, dataset, args, ks=[10])
hit_10 = t_valid[10][1]
print('epoch:%d, time: %f(s), valid (NDCG@10: %.4f, HR@10: %.4f)' % (epoch, T, t_valid[10][0], hit_10))
wandb.log({'epoch': epoch, 'valid_NDCG@10': t_valid[10][0], 'valid_HR@10': hit_10})
# Check for early stopping
if hit_10 > best_hit_10:
best_hit_10 = hit_10
epochs_without_improvement = 0
else:
epochs_without_improvement += 1
if epochs_without_improvement >= early_stopping_patience:
print("Early stopping: No improvement for {} epochs".format(early_stopping_patience))
break
f.write('valid (NDCG@10: %.4f, HR@10: %.4f)\n' % (t_valid[10][0], hit_10))
f.flush()
t0 = time.time()
model.train()
if epoch == args.num_epochs:
folder = args.dataset + '_' + args.train_dir
fname = 'SASRec.epoch={}.lr={}.layer={}.head={}.hidden={}.maxlen={}.pth'
fname = fname.format(args.num_epochs, args.lr, args.num_blocks, args.num_heads, args.hidden_units, args.maxlen)
torch.save(model.state_dict(), os.path.join(folder, fname))
f.close()
sampler.close()
print("Done")
return hit_10
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--n_trials', default=100, type=int, help='Number of random search trials')
parser.add_argument('--dataset', default='clean_cart', type=str)
parser.add_argument('--train_dir', default='default', type=str)
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--num_epochs', default=50, type=int)
parser.add_argument('--l2_emb', default=0.0, type=float)
parser.add_argument('--inference_only', default=False, action='store_true')
parser.add_argument('--state_dict_path', default=None, type=str)
parser.add_argument('--device', default='mps', type=str)
args = parser.parse_args()
n_trials = args.n_trials
best_trial = nevergrad_optimization(args, n_trials)
print("Best trial:")
print(f" Value: {best_trial['value']}")
print(" Params: ")
for key, value in best_trial.items():
if key != "value":
print(f" {key}: {value}")
Upvotes: 2
Views: 324