user24934971
user24934971

Reputation: 1

How to perform reliable large data transfer under TCP protocol through socket API?

Recently I am working on a project concerning secure large model inference in client-server settings. In this case, the client and server have to communicate a lot to do collaborative computing.

For example, the server will frequently send model parameters (i.e. some matrix) to the client, which requires reliable transfer of large amount of data under TCP through socket. The direct method to repeatedly invoke socket.socket.sendall() and socket.socket.recv() is not suitable since it will be difficult for the receiver to find the boundary between one data and the next one (TCP provides byte streaming services).

I write this:

import socket
import pickle
import torch

class BetterSocket:
    def __init__(self, s):
        self.socket = s
        self.msg_len = 2 ** 12

    def sendall(self, obj):
        pkl = pickle.dumps(obj)
        l = len(pkl)
        init_l = len(pkl)
        lbs = l.to_bytes(4)
        print(f"With LBS = {lbs}")
        self.socket.sendall(lbs)
        print("Going to send...")
        while l > self.msg_len:
            self.socket.sendall(pkl[0:self.msg_len])
            pkl = pkl[self.msg_len:]
            l = l - self.msg_len
            self.socket.recv(len(b'0'))
        self.socket.sendall(pkl)
        print(f"Done. {4 + init_l} bytes have been sent.")
    
    def recv(self):
        print("Waiting to receive...")
        lbs = self.socket.recv(4)
        print(f"LBS {lbs} received.")
        l = int.from_bytes(lbs)
        print(f"Length {l} received.")
        
        pkl = b''
        while l > self.msg_len:
            pkl = pkl + self.socket.recv(self.msg_len)
            l = l - self.msg_len
            self.socket.sendall(b'0')
        pkl = pkl + self.socket.recv(l)
        obj = pickle.loads(pkl)
        return obj

This class helps to transfer objects of any type. 4 bytes is attached to the head of every "data unit" to indicate the whole size of the data. The "inverse confirmation" in the while loop seems to be necessary to balance the speed of sending and receiving, otherwise _pickle.UnpicklingError: pickle data was truncated will be throw, which indicates that the receiver split the bytes stream incorrectly. The same error will also occur when self.msg_len is set to be too large.

However, the above solution is too slow. Transferring a 50000 * 800 matrix consumes about 5 mins since the client and the server make "inverse confirmation" frequently. How to do correctly data transfer without sacrifice the performance so badly? Please help me!

(sorry about my english...)

Upvotes: 0

Views: 80

Answers (1)

Mark Tolonen
Mark Tolonen

Reputation: 177461

The pickle protocol already contains information about the size of a single pickle dump. There is a socket method socket.makefile that can wrap a socket in a file-like object that can be directly used with pickle.dump and pickle.load that makes reading and writing pickled objects easy.

Here is an example:

server.py

import socket
import pickle

with socket.socket() as s:
    s.bind(('', 5000))
    s.listen()
    while True:
        client, addr = s.accept()
        with client, client.makefile('rb') as rfile:
            while True:
                try:
                    obj = pickle.load(rfile)
                    print(f'{addr}: {obj}')
                except EOFError:  # raised by pickle.load when socket is closed
                    break

client.py

import socket
import pickle

def send_message(sock, obj):
    wfile.write(pickle.dumps(obj))
    wfile.flush()  # ensures buffered writes are sent to socket.

with socket.socket() as s:
    s.connect(('localhost', 5000))
    with s.makefile('wb') as wfile:
        send_message(wfile, [1, 2, 3, 'abc', 'def'])
        send_message(wfile, [complex(1,2), complex(3,4)])
        send_message(wfile, dict(zip('abc def ghi'.split(), [123, 456, 789])))

Output of server after one run of client:

('127.0.0.1', 3010): [1, 2, 3, 'abc', 'def']
('127.0.0.1', 3010): [(1+2j), (3+4j)]
('127.0.0.1', 3010): {'abc': 123, 'def': 456, 'ghi': 789}

Upvotes: 0

Related Questions