Unnikrishnan R
Unnikrishnan R

Reputation: 23

How to run rust function in python program

I'm working with a xorcipher app to benchmark both python and rust function but I don't know how to add the rust function to python code can anyone help I need the save the output in buf I'm including my rust code

use std::convert::TryInto;

/*
    Apply a simple XOR cipher using they specified `key` of size 
    `key_size`, to the `msg` char/byte array of size `msg_len`.

    Writes he ciphertext to the externally allocated buffer `buf`.
*/

#[no_mangle]
pub unsafe fn cipher(msg: *const i8, key: *const i8, buf: *mut i8, msg_len: usize, key_len: usize)
{
    let mut i: isize = 0;
    while i < msg_len.try_into().unwrap() {
        let key_len_i8: i8 = key_len.try_into().unwrap();
        *buf.offset(i) = *msg.offset(i) ^ (*key.offset(i) % key_len_i8);
        i = i + 1;
    }
}

Upvotes: 1

Views: 1296

Answers (1)

prog-fh
prog-fh

Reputation: 16920

Let's say your rust code is named rs_cipher.rs. If you compile it with this command

rustc --crate-type dylib rs_cipher.rs

you obtain a native dynamic library for your system. (of course, this can also be done with cargo and the crate-type = ["cdylib"] option)

Here is a python code that loads this library, finds the function and calls it (see the comments for these different stages).

import sys
import os
import platform
import ctypes
import traceback

# choose application specific library name
lib_name='rs_cipher'

# determine system system specific library name
if platform.system().startswith('Windows'):
  sys_lib_name='%s.dll'%lib_name
elif platform.system().startswith('Darwin'):
  sys_lib_name='lib%s.dylib'%lib_name
else:
  sys_lib_name='lib%s.so'%lib_name

# try to load native library from various locations
# (current directory, python file directory, system path...)
ntv_lib=None
for d in [os.path.curdir,
          os.path.dirname(sys.modules[__name__].__file__),
          '']:
  path=os.path.join(d, sys_lib_name)
  try:
    ntv_lib=ctypes.CDLL(path)
    break
  except:
    # traceback.print_exc()
    pass
if ntv_lib is None:
  sys.stderr.write('cannot load library %s\n'%lib_name)
  sys.exit(1)

# try to find native function
fnct_name='cipher'
ntv_fnct=None
try:
  ntv_fnct=ntv_lib[fnct_name]
except:
  # traceback.print_exc()
  pass
if ntv_fnct is None:
  sys.stderr.write('cannot find function %s in library %s\n'%
                   (fnct_name, lib_name))
  sys.exit(1)

# describe native function prototype
ntv_fnct.restype=None # no return value
ntv_fnct.argtypes=[ctypes.c_void_p, # msg
                   ctypes.c_void_p, # key
                   ctypes.c_void_p, # buf
                   ctypes.c_size_t, # msg_len
                   ctypes.c_size_t] # key_len

# use native function
msg=(ctypes.c_int8*10)()
key=(ctypes.c_int8*4)()
buf=(ctypes.c_int8*len(msg))()
for i in range(len(msg)):
  msg[i]=1+10*i
for i in range(len(key)):
  key[i]=~(20*(i+2))

sys.stdout.write('~~~~ first initial state ~~~~\n')
sys.stdout.write('msg: %s\n'%[v for v in msg])
sys.stdout.write('key: %s\n'%[v for v in key])
sys.stdout.write('buf: %s\n'%[v for v in buf])
sys.stdout.write('~~~~ first call ~~~~\n')
ntv_fnct(msg, key, buf, len(msg), len(key))
sys.stdout.write('msg: %s\n'%[v for v in msg])
sys.stdout.write('key: %s\n'%[v for v in key])
sys.stdout.write('buf: %s\n'%[v for v in buf])

(msg, buf)=(buf, msg)
for i in range(len(buf)):
  buf[i]=0
sys.stdout.write('~~~~ second initial state ~~~~\n')
sys.stdout.write('msg: %s\n'%[v for v in msg])
sys.stdout.write('key: %s\n'%[v for v in key])
sys.stdout.write('buf: %s\n'%[v for v in buf])
sys.stdout.write('~~~~ second call ~~~~\n')
ntv_fnct(msg, key, buf, len(msg), len(key))
sys.stdout.write('msg: %s\n'%[v for v in msg])
sys.stdout.write('key: %s\n'%[v for v in key])
sys.stdout.write('buf: %s\n'%[v for v in buf])

Running this python code shows this result

~~~~ first initial state ~~~~
msg: [1, 11, 21, 31, 41, 51, 61, 71, 81, 91]
key: [-41, -61, -81, -101]
buf: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
~~~~ first call ~~~~
msg: [1, 11, 21, 31, 41, 51, 61, 71, 81, 91]
key: [-41, -61, -81, -101]
buf: [-42, -56, -70, -124, -2, -16, -110, -36, -122, -104]
~~~~ second initial state ~~~~
msg: [-42, -56, -70, -124, -2, -16, -110, -36, -122, -104]
key: [-41, -61, -81, -101]
buf: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
~~~~ second call ~~~~
msg: [-42, -56, -70, -124, -2, -16, -110, -36, -122, -104]
key: [-41, -61, -81, -101]
buf: [1, 11, 21, 31, 41, 51, 61, 71, 81, 91]

Note that I replaced these two lines in your rust code

let key_len_i8: i8 = key_len.try_into().unwrap();
*buf.offset(i) = *msg.offset(i) ^ (*key.offset(i) % key_len_i8);

by these

let key_len_isize: isize = key_len.try_into().unwrap();
*buf.offset(i) = *msg.offset(i) ^ (*key.offset(i % key_len_isize));

because the offset on key was going out of bounds.

Maybe, instead of dealing with pointer arithmetic and unsigned/signed conversions, should you build usual slices from the raw pointers and use them in the safe way?

#[no_mangle]
pub fn cipher(
    msg_ptr: *const i8,
    key_ptr: *const i8,
    buf_ptr: *mut i8,
    msg_len: usize,
    key_len: usize,
) {
    let msg = unsafe { std::slice::from_raw_parts(msg_ptr, msg_len) };
    let key = unsafe { std::slice::from_raw_parts(key_ptr, key_len) };
    let buf = unsafe { std::slice::from_raw_parts_mut(buf_ptr, msg_len) };
    for i in 0..msg_len {
        buf[i] = msg[i] ^ key[i % key_len];
    }
}

Upvotes: 3

Related Questions