Reputation: 1
Recently I am working on a project concerning secure large model inference in client-server settings. In this case, the client and server have to communicate a lot to do collaborative computing.
For example, the server will frequently send model parameters (i.e. some matrix) to the client, which requires reliable transfer of large amount of data under TCP through socket. The direct method to repeatedly invoke socket.socket.sendall()
and socket.socket.recv()
is not suitable since it will be difficult for the receiver to find the boundary between one data and the next one (TCP provides byte streaming services).
I write this:
import socket
import pickle
import torch
class BetterSocket:
def __init__(self, s):
self.socket = s
self.msg_len = 2 ** 12
def sendall(self, obj):
pkl = pickle.dumps(obj)
l = len(pkl)
init_l = len(pkl)
lbs = l.to_bytes(4)
print(f"With LBS = {lbs}")
self.socket.sendall(lbs)
print("Going to send...")
while l > self.msg_len:
self.socket.sendall(pkl[0:self.msg_len])
pkl = pkl[self.msg_len:]
l = l - self.msg_len
self.socket.recv(len(b'0'))
self.socket.sendall(pkl)
print(f"Done. {4 + init_l} bytes have been sent.")
def recv(self):
print("Waiting to receive...")
lbs = self.socket.recv(4)
print(f"LBS {lbs} received.")
l = int.from_bytes(lbs)
print(f"Length {l} received.")
pkl = b''
while l > self.msg_len:
pkl = pkl + self.socket.recv(self.msg_len)
l = l - self.msg_len
self.socket.sendall(b'0')
pkl = pkl + self.socket.recv(l)
obj = pickle.loads(pkl)
return obj
This class helps to transfer objects of any type. 4 bytes is attached to the head of every "data unit" to indicate the whole size of the data. The "inverse confirmation" in the while
loop seems to be necessary to balance the speed of sending and receiving, otherwise _pickle.UnpicklingError: pickle data was truncated
will be throw, which indicates that the receiver split the bytes stream incorrectly. The same error will also occur when self.msg_len
is set to be too large.
However, the above solution is too slow. Transferring a 50000 * 800 matrix consumes about 5 mins since the client and the server make "inverse confirmation" frequently. How to do correctly data transfer without sacrifice the performance so badly? Please help me!
(sorry about my english...)
Upvotes: 0
Views: 80
Reputation: 177461
The pickle protocol already contains information about the size of a single pickle dump. There is a socket method socket.makefile
that can wrap a socket in a file-like object that can be directly used with pickle.dump
and pickle.load
that makes reading and writing pickled objects easy.
Here is an example:
server.py
import socket
import pickle
with socket.socket() as s:
s.bind(('', 5000))
s.listen()
while True:
client, addr = s.accept()
with client, client.makefile('rb') as rfile:
while True:
try:
obj = pickle.load(rfile)
print(f'{addr}: {obj}')
except EOFError: # raised by pickle.load when socket is closed
break
client.py
import socket
import pickle
def send_message(sock, obj):
wfile.write(pickle.dumps(obj))
wfile.flush() # ensures buffered writes are sent to socket.
with socket.socket() as s:
s.connect(('localhost', 5000))
with s.makefile('wb') as wfile:
send_message(wfile, [1, 2, 3, 'abc', 'def'])
send_message(wfile, [complex(1,2), complex(3,4)])
send_message(wfile, dict(zip('abc def ghi'.split(), [123, 456, 789])))
Output of server after one run of client:
('127.0.0.1', 3010): [1, 2, 3, 'abc', 'def']
('127.0.0.1', 3010): [(1+2j), (3+4j)]
('127.0.0.1', 3010): {'abc': 123, 'def': 456, 'ghi': 789}
Upvotes: 0