Reputation: 23
I'm working with a xorcipher app to benchmark both python and rust function but I don't know how to add the rust function to python code can anyone help I need the save the output in buf I'm including my rust code
use std::convert::TryInto;
/*
Apply a simple XOR cipher using they specified `key` of size
`key_size`, to the `msg` char/byte array of size `msg_len`.
Writes he ciphertext to the externally allocated buffer `buf`.
*/
#[no_mangle]
pub unsafe fn cipher(msg: *const i8, key: *const i8, buf: *mut i8, msg_len: usize, key_len: usize)
{
let mut i: isize = 0;
while i < msg_len.try_into().unwrap() {
let key_len_i8: i8 = key_len.try_into().unwrap();
*buf.offset(i) = *msg.offset(i) ^ (*key.offset(i) % key_len_i8);
i = i + 1;
}
}
Upvotes: 1
Views: 1296
Reputation: 16920
Let's say your rust code is named rs_cipher.rs
.
If you compile it with this command
rustc --crate-type dylib rs_cipher.rs
you obtain a native dynamic library for your system.
(of course, this can also be done with cargo
and the crate-type = ["cdylib"]
option)
Here is a python code that loads this library, finds the function and calls it (see the comments for these different stages).
import sys
import os
import platform
import ctypes
import traceback
# choose application specific library name
lib_name='rs_cipher'
# determine system system specific library name
if platform.system().startswith('Windows'):
sys_lib_name='%s.dll'%lib_name
elif platform.system().startswith('Darwin'):
sys_lib_name='lib%s.dylib'%lib_name
else:
sys_lib_name='lib%s.so'%lib_name
# try to load native library from various locations
# (current directory, python file directory, system path...)
ntv_lib=None
for d in [os.path.curdir,
os.path.dirname(sys.modules[__name__].__file__),
'']:
path=os.path.join(d, sys_lib_name)
try:
ntv_lib=ctypes.CDLL(path)
break
except:
# traceback.print_exc()
pass
if ntv_lib is None:
sys.stderr.write('cannot load library %s\n'%lib_name)
sys.exit(1)
# try to find native function
fnct_name='cipher'
ntv_fnct=None
try:
ntv_fnct=ntv_lib[fnct_name]
except:
# traceback.print_exc()
pass
if ntv_fnct is None:
sys.stderr.write('cannot find function %s in library %s\n'%
(fnct_name, lib_name))
sys.exit(1)
# describe native function prototype
ntv_fnct.restype=None # no return value
ntv_fnct.argtypes=[ctypes.c_void_p, # msg
ctypes.c_void_p, # key
ctypes.c_void_p, # buf
ctypes.c_size_t, # msg_len
ctypes.c_size_t] # key_len
# use native function
msg=(ctypes.c_int8*10)()
key=(ctypes.c_int8*4)()
buf=(ctypes.c_int8*len(msg))()
for i in range(len(msg)):
msg[i]=1+10*i
for i in range(len(key)):
key[i]=~(20*(i+2))
sys.stdout.write('~~~~ first initial state ~~~~\n')
sys.stdout.write('msg: %s\n'%[v for v in msg])
sys.stdout.write('key: %s\n'%[v for v in key])
sys.stdout.write('buf: %s\n'%[v for v in buf])
sys.stdout.write('~~~~ first call ~~~~\n')
ntv_fnct(msg, key, buf, len(msg), len(key))
sys.stdout.write('msg: %s\n'%[v for v in msg])
sys.stdout.write('key: %s\n'%[v for v in key])
sys.stdout.write('buf: %s\n'%[v for v in buf])
(msg, buf)=(buf, msg)
for i in range(len(buf)):
buf[i]=0
sys.stdout.write('~~~~ second initial state ~~~~\n')
sys.stdout.write('msg: %s\n'%[v for v in msg])
sys.stdout.write('key: %s\n'%[v for v in key])
sys.stdout.write('buf: %s\n'%[v for v in buf])
sys.stdout.write('~~~~ second call ~~~~\n')
ntv_fnct(msg, key, buf, len(msg), len(key))
sys.stdout.write('msg: %s\n'%[v for v in msg])
sys.stdout.write('key: %s\n'%[v for v in key])
sys.stdout.write('buf: %s\n'%[v for v in buf])
Running this python code shows this result
~~~~ first initial state ~~~~
msg: [1, 11, 21, 31, 41, 51, 61, 71, 81, 91]
key: [-41, -61, -81, -101]
buf: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
~~~~ first call ~~~~
msg: [1, 11, 21, 31, 41, 51, 61, 71, 81, 91]
key: [-41, -61, -81, -101]
buf: [-42, -56, -70, -124, -2, -16, -110, -36, -122, -104]
~~~~ second initial state ~~~~
msg: [-42, -56, -70, -124, -2, -16, -110, -36, -122, -104]
key: [-41, -61, -81, -101]
buf: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
~~~~ second call ~~~~
msg: [-42, -56, -70, -124, -2, -16, -110, -36, -122, -104]
key: [-41, -61, -81, -101]
buf: [1, 11, 21, 31, 41, 51, 61, 71, 81, 91]
Note that I replaced these two lines in your rust code
let key_len_i8: i8 = key_len.try_into().unwrap();
*buf.offset(i) = *msg.offset(i) ^ (*key.offset(i) % key_len_i8);
by these
let key_len_isize: isize = key_len.try_into().unwrap();
*buf.offset(i) = *msg.offset(i) ^ (*key.offset(i % key_len_isize));
because the offset on key
was going out of bounds.
Maybe, instead of dealing with pointer arithmetic and unsigned/signed conversions, should you build usual slices from the raw pointers and use them in the safe way?
#[no_mangle]
pub fn cipher(
msg_ptr: *const i8,
key_ptr: *const i8,
buf_ptr: *mut i8,
msg_len: usize,
key_len: usize,
) {
let msg = unsafe { std::slice::from_raw_parts(msg_ptr, msg_len) };
let key = unsafe { std::slice::from_raw_parts(key_ptr, key_len) };
let buf = unsafe { std::slice::from_raw_parts_mut(buf_ptr, msg_len) };
for i in 0..msg_len {
buf[i] = msg[i] ^ key[i % key_len];
}
}
Upvotes: 3