Caden Miller
Caden Miller

Reputation: 21

How to retrieve PyTorch tensor from queue in multiprocessing

I am simply trying to retrieve a tensor that I put into a queue in another process, but I get a 'Connection Refused' error whenever I do. Please point me to any documentation that may help or give me some suggestions please.

import torch
import torch.multiprocessing as mp

def test(q):
    x = torch.normal(mean=0.0, std=1.0, size=(2, 3))
    x.share_memory_()
    q.put(x)

if __name__ == "__main__":
    mp.set_start_method("spawn")

    q = mp.Queue()
    p = mp.Process(target=test, args=(q,))
    p.start()
    p.join()

    print(q.get())

Upvotes: 2

Views: 3503

Answers (1)

Mehmet nuri
Mehmet nuri

Reputation: 928

You should use Manager() to get rid of this error. So working code example should look like below

import torch
import torch.multiprocessing as mp

def test(q):
    x = torch.normal(mean=0.0, std=1.0, size=(2, 3))
    x.share_memory_()
    q.put(x)

if __name__ == "__main__":
    #mp.set_start_method("spawn")
    manager = mp.Manager()
    q = manager.Queue()
    p = mp.Process(target=test, args=(q,))
    p.start()
    p.join()

    print(q.get())

Upvotes: 1

Related Questions