DongJin Shin
DongJin Shin

Reputation: 29

Separate TcpStream + SslStream into read and write components

I'm trying to make client program that communicates with a server using a TcpStream wrapped by a openssl::ssl::SslStream (from crates.io). It should wait for read, and process data sent from the server if it was received without delay. At the same time, it should be able to send messages to the server regardless of reading.

I tried some methods such as

  1. Passing single stream to both read and write threads. Both read and write methods require a mutable reference, so I couldn't pass a single stream to two threads.
  2. I followed In Rust how do I handle parallel read writes on a TcpStream, but TcpStream doesn't seem to have clone method, and neither does SslStream.
  3. I tried making copy of TcpStream with as_raw_fd and from_raw_fd :
fn irc_read(mut stream: SslStream<TcpStream>) {
    loop {
        let mut buf = vec![0; 2048];
        let resp = stream.ssl_read(&mut buf);
        match resp {
            // Process Message
        }
    }
}

fn irc_write(mut stream: SslStream<TcpStream>) {
    thread::sleep(Duration::new(3, 0));
    let msg = "QUIT\n";
    let res = stream.ssl_write(msg.as_bytes());
    let _ = stream.flush();
    match res {
        // Process
    }
}

fn main() {
    let ctx = SslContext::new(SslMethod::Sslv23).unwrap();
    let read_ssl = Ssl::new(&ctx).unwrap();
    let write_ssl = Ssl::new(&ctx).unwrap();

    let raw_stream = TcpStream::connect((SERVER, PORT)).unwrap();
    let mut fd_stream: TcpStream;
    unsafe {
        fd_stream = TcpStream::from_raw_fd(raw_stream.as_raw_fd());
    }
    let mut read_stream = SslStream::connect(read_ssl, raw_stream).unwrap();
    let mut write_stream = SslStream::connect(write_ssl, fd_stream).unwrap();

    let read_thread = thread::spawn(move || {
        irc_read(read_stream);
    });

    let write_thread = thread::spawn(move || {
        irc_write(write_stream);
    });

    let _ = read_thread.join();
    let _ = write_thread.join();
}

this code compiles, but panics on the second SslStream::connect

thread 'main' panicked at 'called `Result::unwrap()` on an `Err` value: Failure(Ssl(ErrorStack([Error { library: "SSL routines", function: "SSL23_GET_SERVER_HELLO", reason: "unknown protocol" }])))', ../src/libcore/result.rs:788
stack backtrace:
   1:     0x556d719c6069 - std::sys::backtrace::tracing::imp::write::h00e948915d1e4c72
   2:     0x556d719c9d3c - std::panicking::default_hook::_{{closure}}::h7b8a142818383fb8
   3:     0x556d719c8f89 - std::panicking::default_hook::h41cf296f654245d7
   4:     0x556d719c9678 - std::panicking::rust_panic_with_hook::h4cbd7ca63ce1aee9
   5:     0x556d719c94d2 - std::panicking::begin_panic::h93672d0313d5e8e9
   6:     0x556d719c9440 - std::panicking::begin_panic_fmt::hd0daa02942245d81
   7:     0x556d719c93c1 - rust_begin_unwind
   8:     0x556d719ffcbf - core::panicking::panic_fmt::hbfc935564d134c1b
   9:     0x556d71899f02 - core::result::unwrap_failed::h66f79b2edc69ddfd
                        at /buildslave/rust-buildbot/slave/stable-dist-rustc-linux/build/obj/../src/libcore/result.rs:29
  10:     0x556d718952cb - _<core..result..Result<T, E>>::unwrap::h49a140af593bc4fa
                        at /buildslave/rust-buildbot/slave/stable-dist-rustc-linux/build/obj/../src/libcore/result.rs:726
  11:     0x556d718a5e3d - dbrust::main::h24a50e631826915e
                        at /home/lastone817/dbrust/src/main.rs:87
  12:     0x556d719d1826 - __rust_maybe_catch_panic
  13:     0x556d719c8702 - std::rt::lang_start::h53bf99b0829cc03c
  14:     0x556d718a6b83 - main
  15:     0x7f40a0b5082f - __libc_start_main
  16:     0x556d7188d038 - _start
  17:                0x0 - <unknown>
error: Process didn't exit successfully: `target/debug/dbrust` (exit code: 101)

The best solution I've found so far is to use nonblocking. I used Mutex on the stream and passed it to both threads. Then the reading thread acquires a lock and calls read. If there is no message it releases the lock so that the writing thread can use the stream. With this method, the reading thread does busy waiting, resulting in 100% CPU consumption. This is not the best solution, I think.

Is there a safe way to separate the read and write aspects of the stream?

Upvotes: 2

Views: 1743

Answers (1)

Bernd
Bernd

Reputation: 107

I accomplished the split of an SSL stream into a read and a write part by using Rust's std::cell::UnsafeCell.

extern crate native_tls;

use native_tls::TlsConnector;
use std::cell::UnsafeCell;
use std::error::Error;
use std::io::Read;
use std::io::Write;
use std::marker::Sync;
use std::net::TcpStream;
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;

struct UnsafeMutator<T> {
    value: UnsafeCell<T>,
}

impl<T> UnsafeMutator<T> {
    fn new(value: T) -> UnsafeMutator<T> {
        return UnsafeMutator {
            value: UnsafeCell::new(value),
        };
    }

    fn mut_value(&self) -> &mut T {
        return unsafe { &mut *self.value.get() };
    }
}

unsafe impl<T> Sync for UnsafeMutator<T> {}

struct ReadWrapper<R>
where
    R: Read,
{
    inner: Arc<UnsafeMutator<R>>,
}

impl<R: Read> Read for ReadWrapper<R> {
    fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
        return self.inner.mut_value().read(buf);
    }
}
struct WriteWrapper<W>
where
    W: Write,
{
    inner: Arc<UnsafeMutator<W>>,
}

impl<W: Write> Write for WriteWrapper<W> {
    fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
        return self.inner.mut_value().write(buf);
    }

    fn flush(&mut self) -> Result<(), std::io::Error> {
        return self.inner.mut_value().flush();
    }
}

pub struct Socket {
    pub output_stream: Arc<Mutex<Write + Send>>,
    pub input_stream: Arc<Mutex<Read + Send>>,
}

impl Socket {
    pub fn bind(host: &str, port: u16, secure: bool) -> Result<Socket, Box<Error>> {
        let tcp_stream = match TcpStream::connect((host, port)) {
            Ok(x) => x,
            Err(e) => return Err(Box::new(e)),
        };
        if secure {
            let tls_connector = TlsConnector::builder().build().unwrap();
            let tls_stream = match tls_connector.connect(host, tcp_stream) {
                Ok(x) => x,
                Err(e) => return Err(Box::new(e)),
            };
            let mutator = Arc::new(UnsafeMutator::new(tls_stream));
            let input_stream = Arc::new(Mutex::new(ReadWrapper {
                inner: mutator.clone(),
            }));
            let output_stream = Arc::new(Mutex::new(WriteWrapper { inner: mutator }));

            let socket = Socket {
                output_stream,
                input_stream,
            };
            return Ok(socket);
        } else {
            let mutator = Arc::new(UnsafeMutator::new(tcp_stream));
            let input_stream = Arc::new(Mutex::new(ReadWrapper {
                inner: mutator.clone(),
            }));
            let output_stream = Arc::new(Mutex::new(WriteWrapper { inner: mutator }));

            let socket = Socket {
                output_stream,
                input_stream,
            };
            return Ok(socket);
        }
    }
}

fn main() {
    let socket = Arc::new(Socket::bind("google.com", 443, true).unwrap());

    let socket_clone = Arc::clone(&socket);

    let reader_thread = thread::spawn(move || {
        let mut res = vec![];
        let mut input_stream = socket_clone.input_stream.lock().unwrap();
        input_stream.read_to_end(&mut res).unwrap();
        println!("{}", String::from_utf8_lossy(&res));
    });

    let writer_thread = thread::spawn(move || {
        let mut output_stream = socket.output_stream.lock().unwrap();
        output_stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
    });

    writer_thread.join().unwrap();
    reader_thread.join().unwrap();
}

Upvotes: 1

Related Questions