mckris
mckris

Reputation: 31

Pytorch MPS: 'mps.scatter_nd' op invalid input tensor shape

I have set up my model for hyper parameter tuning using either nevergrad or optuna (both result in the same problem).

The problem: After the first model has trained (i.e. first parameter search), then when the second is about to begin training, I get the following error message:

MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":49:0)): error: 'mps.scatter_nd' op invalid input tensor shape: updates tensor shape and data tensor shape must match along inner dimensions /AppleInternal/Library/BuildRoots/97f6331a-ba75-11ed-a4bc-863efbbaf80d/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1710: failed assertion `Error: MLIR pass manager failed'

Seemingly there cannot be anything wrong with the model, because it successfully ran the first time, so i am wondering what causes the issue. I have also found that everything works when I set the device to CPU. The problem occurs when I use mps as device.

Here's a code snippet of my training script:

import os
import time
import torch
import argparse
import numpy as np
from model import SASRec
from utils import *
import wandb
import nevergrad as ng

def str2bool(s):
    if s not in {'false', 'true'}:
        raise ValueError('Not a valid boolean string')
    return s == 'true'

def nevergrad_optimization(args, n_trials):
    # Define the search space for the hyperparameters
    def create_search_space():
        hidden_units_choices = [32, 48, 64, 80, 96, 112, 128]
        num_heads_choices = [i for i in range(1, len(hidden_units_choices)+1) if hidden_units_choices[-1] % i == 0]
        
        search_space = ng.p.Dict(
            dropout_rate=ng.p.Scalar(lower=0.1, upper=0.5).set_mutation(sigma=0.1).set_name("dropout_rate"),
            maxlen=ng.p.Scalar(lower=50, upper=200).set_mutation(sigma=25).set_integer_casting().set_name("maxlen"),
            lr=ng.p.Log(lower=1e-6, upper=1e-2).set_mutation(sigma=1e-3).set_name("lr"),
            hidden_units=ng.p.Choice(hidden_units_choices).set_name("hidden_units"),
            num_heads=ng.p.Choice(num_heads_choices).set_name("num_heads"),
            num_blocks=ng.p.Scalar(lower=1, upper=6).set_mutation(sigma=1).set_integer_casting().set_name("num_blocks"),
        )
        return search_space


    # Define the optimization function
    def objective(hyperparameters):
        args.dropout_rate = hyperparameters["dropout_rate"]
        args.maxlen = int(hyperparameters["maxlen"])
        args.lr = hyperparameters["lr"]
        args.hidden_units = hyperparameters["hidden_units"]

        n_heads_max = args.hidden_units // 2
        var_heads = [i for i in range(1, n_heads_max + 1) if args.hidden_units % i == 0]
        args.num_heads = hyperparameters["num_heads"] % len(var_heads)
        args.num_heads = var_heads[args.num_heads]

        args.num_blocks = int(hyperparameters["num_blocks"])

        hit_10 = main(args)
        return -hit_10

    # create search space
    search_space = create_search_space()
    # Create the optimizer
    optimizer = ng.optimizers.CMA(parametrization=search_space, budget=n_trials)

    # Perform the optimization
    recommendation = optimizer.minimize(objective)

    # Retrieve the best trial
    best_hyperparams = recommendation.value
    best_value = -objective(best_hyperparams)

    best_trial = {
        "dropout_rate": recommendation["dropout_rate"],
        "maxlen": int(recommendation["maxlen"]),
        "lr": recommendation["lr"],
        "hidden_units": recommendation["hidden_units"],
        "num_heads": recommendation["num_heads"],
        "num_blocks": int(recommendation["num_blocks"]),
    }

    return best_trial


def main(args):

    if not os.path.isdir(args.dataset + '_' + args.train_dir):
        os.makedirs(args.dataset + '_' + args.train_dir)
    with open(os.path.join(args.dataset + '_' + args.train_dir, 'args.txt'), 'w') as f:
        f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])]))
    f.close()

    # init wandb
    wandb.init(project='SASRec_RND', config=args)

    # global dataset
    dataset = data_partition(args.dataset)

    [user_train, user_valid, user_test, usernum, itemnum] = dataset
    num_batch = len(user_train) // args.batch_size
    cc = 0.0
    for u in user_train:
        cc += len(user_train[u])
    print('average sequence length: %.2f' % (cc / len(user_train)))

    f = open(os.path.join(args.dataset + '_' + args.train_dir, 'log.txt'), 'w')

    sampler = WarpSampler(user_train, usernum, itemnum, batch_size=args.batch_size, maxlen=args.maxlen, n_workers=3)
    model = SASRec(usernum, itemnum, args).to(args.device)

    # init weights
    for _, param in model.named_parameters():
        try:
            torch.nn.init.xavier_normal_(param.data)
        except:
            pass

    model.train()

    epoch_start_idx = 1
    if args.state_dict_path is not None:
        try:
            model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device)))
            tail = args.state_dict_path[args.state_dict_path.find('epoch=') + 6:]
            epoch_start_idx = int(tail[:tail.find('.')]) + 1
        except:
            print('failed loading state_dicts, pls check file path: ', end="")
            print(args.state_dict_path)
            print('pdb enabled for your quick check, pls type exit() if you do not need it')
            import pdb; pdb.set_trace()

    if args.inference_only:
        model.eval()
        t_test = evaluate_full(model, dataset, args, ks=[1, 3, 5, 10])
        for k in t_test:
            print('test (NDCG@%d: %.4f, HR@%d: %.4f)' % (k, t_test[k][0], k, t_test[k][1]))

    bce_criterion = torch.nn.BCEWithLogitsLoss()
    adam_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98))

    # Early stopping parameters
    early_stopping_patience = 2  # Stop if there is no improvement for this many epochs
    best_hit_10 = 0
    epochs_without_improvement = 0

    T = 0.0
    t0 = time.time()

    

    for epoch in range(epoch_start_idx, args.num_epochs + 1):
        if args.inference_only:
            break
        for step in range(num_batch):
            u, seq, pos, neg = sampler.next_batch()
            u, seq, pos, neg = np.array(u), np.array(seq), np.array(pos), np.array(neg)

            pos_logits, neg_logits = model(u, seq, pos, neg)
            pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)
            adam_optimizer.zero_grad()
            indices = np.where(pos != 0)
            loss = bce_criterion(pos_logits[indices], pos_labels[indices])
            loss += bce_criterion(neg_logits[indices], neg_labels[indices])
            for param in model.item_emb.parameters():
                loss += args.l2_emb * torch.norm(param)
            loss.backward()
            adam_optimizer.step()
            print("loss in epoch {} iteration {}: {}".format(epoch, step, loss.item()))
            wandb.log({"loss": loss.item()})
        
        if epoch % 10 == 0:
            model.eval()
            t1 = time.time() - t0
            T += t1
            print('Evaluating', end='')

            t_valid = evaluate_valid_full(model, dataset, args, ks=[10])
            hit_10 = t_valid[10][1]

            print('epoch:%d, time: %f(s), valid (NDCG@10: %.4f, HR@10: %.4f)' % (epoch, T, t_valid[10][0], hit_10))
            wandb.log({'epoch': epoch, 'valid_NDCG@10': t_valid[10][0], 'valid_HR@10': hit_10})
                        # Check for early stopping
            if hit_10 > best_hit_10:
                best_hit_10 = hit_10
                epochs_without_improvement = 0
            else:
                epochs_without_improvement += 1

            if epochs_without_improvement >= early_stopping_patience:
                print("Early stopping: No improvement for {} epochs".format(early_stopping_patience))
                break

            f.write('valid (NDCG@10: %.4f, HR@10: %.4f)\n' % (t_valid[10][0], hit_10))
            f.flush()

            t0 = time.time()
            model.train()

        if epoch == args.num_epochs:
            folder = args.dataset + '_' + args.train_dir
            fname = 'SASRec.epoch={}.lr={}.layer={}.head={}.hidden={}.maxlen={}.pth'
            fname = fname.format(args.num_epochs, args.lr, args.num_blocks, args.num_heads, args.hidden_units, args.maxlen)
            torch.save(model.state_dict(), os.path.join(folder, fname))

    f.close()
    sampler.close()
    print("Done")

    return hit_10

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--n_trials', default=100, type=int, help='Number of random search trials')
    parser.add_argument('--dataset', default='clean_cart', type=str)
    parser.add_argument('--train_dir', default='default', type=str)
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--num_epochs', default=50, type=int)
    parser.add_argument('--l2_emb', default=0.0, type=float)
    parser.add_argument('--inference_only', default=False, action='store_true')
    parser.add_argument('--state_dict_path', default=None, type=str)
    parser.add_argument('--device', default='mps', type=str)

    args = parser.parse_args()

    n_trials = args.n_trials
    best_trial = nevergrad_optimization(args, n_trials)

    print("Best trial:")
    print(f"  Value: {best_trial['value']}")
    print("  Params: ")
    for key, value in best_trial.items():
        if key != "value":
            print(f"    {key}: {value}")

Upvotes: 2

Views: 324

Answers (0)

Related Questions