Reputation:
I am implementing an RSA Encrypted Socket Connection in Java, for doing that i use two classes the first is the Connection Abstract class which represents the real Socket Connection and the Second is the ConnectionCallback which is a class called when the Connection class receives data. When data is received by the Connection class, the data gets Decrypted using a before shared public key coming from the connected endpoint (There can only be 1 connected endpoint).
ByteArray class:
package connection.data;
public class ByteArray {
private byte[] bytes;
public ByteArray(byte[] bytes){
this.bytes = bytes;
}
public ByteArray(){
}
public void add(byte[] data) {
if(this.bytes == null) this.bytes = new byte[0];
this.bytes = joinArrays(this.bytes, data);
}
private byte[] joinArrays(byte[] array1, byte[] array2) {
byte[] array = new byte[array1.length + array2.length];
System.arraycopy(array1, 0, array, 0, array1.length);
System.arraycopy(array2, 0, array, array1.length, array2.length);
return array;
}
public byte[] getBytes(){
return this.bytes;
}
}
Connection class:
package connection;
import connection.data.ByteArray;
import connection.protocols.ProtectedConnectionProtocol;
import crypto.CryptoUtils;
import crypto.algorithm.asymmetric.rsa.RSAAlgorithm;
import protocol.connection.ConnectionProtocol;
import util.function.Callback;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.PublicKey;
import java.util.Base64;
public abstract class Connection implements Runnable {
private DataInputStream in;
private DataOutputStream out;
ConnectionProtocol protocol;
private Callback callback;
private boolean isConnected = false;
public Connection() throws Exception {
this.protocol = new ProtectedConnectionProtocol(new RSAAlgorithm(1024));
this.callback = new ConnectionCallback(this);
}
public Connection(ConnectionProtocol connectionProtocol, Callback callback) throws Exception {
this.protocol = connectionProtocol;
this.callback = callback;
}
@Override
public void run() {
while(isConnected){
try {
ByteArray data = new ByteArray();
while(this.in.available() > 0){
data.add(this.read());
}
if(data.getBytes() != null){
callback.run(data);
}
} catch (Exception e){
e.printStackTrace();
break;
}
}
}
protected void openConnection(InputStream in, OutputStream out) throws Exception{
this.in = new DataInputStream(in);
this.out = new DataOutputStream(out);
this.isConnected = true;
new Thread(this).start();
this.write(CryptoUtils.encode(((PublicKey) this.protocol.getPublicKey()).getEncoded()));
}
private void write(byte[] data) throws Exception{
System.out.println(new String(data,"UTF-8"));
this.out.write(data);
this.out.flush();
}
private byte[] read() throws Exception{
byte[] bytes = new byte[8192];
int read = this.in.read(bytes);
if (read <= 0) return new byte[0]; // or return null, or something, read might be -1 when there was no data.
byte[] readBytes = new byte[read];
System.arraycopy(bytes, 0, readBytes, 0, read);
return bytes;
}
}
ConnectionCallback class:
package connection;
import connection.data.ByteArray;
import crypto.CryptoUtils;
import util.function.Callback;
import java.security.KeyFactory;
import java.security.PublicKey;
import java.security.spec.X509EncodedKeySpec;
public class ConnectionCallback implements Callback {
private Connection connection;
public ConnectionCallback(Connection connection){
this.connection = connection;
}
@Override
public void run(Object data) throws Exception {
ByteArray bytes = (ByteArray) data;
byte[] dataToBytes = CryptoUtils.decode(bytes.getBytes());
if(this.connection.protocol.getSharedKey() == null){
X509EncodedKeySpec spec = new X509EncodedKeySpec(dataToBytes);
KeyFactory kf = KeyFactory.getInstance("RSA");
PublicKey publicKey = kf.generatePublic(spec);
this.connection.protocol.setSharedKey(publicKey);
} else {
//this.so = StrongboxObject.parse(new String(bytes.getBytes()));
}
}
}
RSAlgorithm class:
package crypto.algorithm.asymmetric.rsa;
import crypto.CryptoUtils;
import crypto.algorithm.asymmetric.AssimetricalAlgorithm;
import javax.crypto.Cipher;
import java.security.*;
import java.util.Base64;
public class RSAAlgorithm extends AssimetricalAlgorithm {
private KeyPairGenerator keyGen;
public RSAAlgorithm(int keyLength) throws Exception {
super();
this.keyGen = KeyPairGenerator.getInstance("RSA");
this.keyGen.initialize(keyLength);
this.generateKeys();
}
@Override
public void generateKeys() {
KeyPair pair = this.keyGen.generateKeyPair();
super.setPublicKey(pair.getPublic());
super.setPrivateKey(pair.getPrivate());
}
@Override
public byte[] encrypt(byte[] message) {
try {
super.cipher.init(Cipher.ENCRYPT_MODE, (PublicKey) super.getSharedKey());
return CryptoUtils.encode(super.cipher.doFinal(message));
} catch (Exception e) {
e.printStackTrace();
}
return new byte[0];
}
@Override
public byte[] decrypt(byte[] message) {
message = CryptoUtils.decode(message);
try {
super.cipher.init(Cipher.DECRYPT_MODE, (PrivateKey) super.getPrivateKey());
return super.cipher.doFinal(message);
} catch (Exception e) {
e.printStackTrace();
}
return new byte[0];
}
}
ProtectedConnectionProtocol class:
package connection.protocols;
import protocol.connection.ConnectionProtocol;
import crypto.algorithm.asymmetric.AssimetricalAlgorithm;
public class ProtectedConnectionProtocol extends ConnectionProtocol {
private AssimetricalAlgorithm algorithm;
public ProtectedConnectionProtocol(AssimetricalAlgorithm algorithm){
this.algorithm = algorithm;
}
@Override
public Object getPublicKey() {
return this.algorithm.getPublicKey();
}
@Override
public Object getPrivateKey() {
return this.algorithm.getPrivateKey();
}
@Override
public Object getSharedKey() {
return this.algorithm.getSharedKey();
}
@Override
public void setSharedKey(Object sharedKey){
this.algorithm.setSharedKey(sharedKey);
}
@Override
public byte[] decrypt(byte[] message) {
return this.algorithm.decrypt(message);
}
@Override
public byte[] encrypt(byte[] message) {
return this.algorithm.encrypt(message);
}
}
CryptoUtils class:
package crypto;
import java.util.Base64;
public class CryptoUtils {
public static byte[] encode(byte[] data){
return Base64.getEncoder().encode(data);
}
public static byte[] decode(byte[] data){
return Base64.getDecoder().decode(data);
}
}
UPDATE of 05/09/2019:
Code update same Exception:
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCcrbJGHqpJdhDbVoZCJ0bucb8YnvcVWx9HIUfJOgmAKIuTmw1VUCk85ztqDq0VP2k6IP2bSD5MegR10FtqGtGEQrv+m0eNgbvE3O7czUzvedb5wKbA8eiSPbcX8JElobOhrolOb8JQRQzWAschBNp4MDljlu+0KZQHtZa6pPYJ0wIDAQAB
java.lang.IllegalArgumentException: Illegal base64 character 0
at java.base/java.util.Base64$Decoder.decode0(Base64.java:743)
at java.base/java.util.Base64$Decoder.decode(Base64.java:535)
at crypto.CryptoUtils.decode(CryptoUtils.java:12)
at connection.ConnectionCallback.run(ConnectionCallback.java:21)
at connection.Connection.run(Connection.java:42)
at java.base/java.lang.Thread.run(Thread.java:834)
Please help me i am exasperated with this and have only 2 more days of Bounty, i prefer to give my Bounty to someone who helped me finding the solution to this problem than to lose it.
Upvotes: 2
Views: 645
Reputation: 3675
This is probably caused by your read method:
private byte[] read() throws Exception{
byte[] bytes = new byte[8192];
this.in.read(bytes);
return bytes;
}
You are always reading into array of 8192 bytes, even if there isn't enough bytes in input stream. this.in.read(bytes)
returns amount of bytes read, you should use that value and only use that amount of bytes from this array, ignoring the rest - as rest of array will be just 0
, so when you try to decode base64 from it you will get java.lang.IllegalArgumentException: Illegal base64 character 0
So when reading your bytes you can just copy them to new array:
private byte[] read() throws Exception{
byte[] bytes = new byte[8192];
int read = this.in.read(bytes);
if (read <= 0) return new byte[0]; // or return null, or something, read might be -1 when there was no data.
byte[] readBytes = new byte[read]
System.arraycopy(bytes, 0, readBytes, 0, read)
return readBytes;
}
Note that reading like that is actually pretty bad idea for performance, as you are allocating a lot of stuff for each read. More advanced libraries like netty have own byte buffers with separate read/write positions and just store everything in single self-resizing array of bytes, but first make it work, and if you will have any issues with performance then remember that this is one of places you might find a solution.
Also in your ByteArray you are coping both arrays into same spot:
for(int i = 0; i < this.bytes.length; i++){
bytes1[i] = this.bytes[i];
}
for(int i = 0; i < data.length; i++){
bytes1[i] = data[i]; // this loop starts from 0 too
}
you need to use i + this.bytes.length
in second one. (and it's better to use System.arrayCopy)
public byte[] joinArrays(byte[] array1, byte[] array2) {
byte[] array = new byte[array1.length + array2.length];
System.arraycopy(array1, 0, array, 0, array1.length);
System.arraycopy(array2, 0, array, array1.length, array2.length);
return array;
}
And then just:
public void add(byte[] data) {
if(this.bytes == null) this.bytes = new byte[0];
this.bytes = joinArrays(this.bytes, data);
}
Also like in that other answer - it might be good idea to change flush method to just set field to null, or even better, just remove that method as I don't see it being used, and you could just create new instance of this object anyways.
Upvotes: 1
Reputation: 2358
I looked into your code and figured out that the problem is with the add()
method in the ByteArray
class. Let me show you, (See the comments)
Original : ByteArray
public void add(byte[] data){
if(this.bytes == null)
this.bytes = new byte[data.length];
byte[] bytes1 = new byte[this.bytes.length + data.length];
for(int i = 0; i < this.bytes.length; i++){
bytes1[i] = this.bytes[i]; // when this.bytes is null you are adding data.length amount of 0, which is not something you want i guess. This prevents the base64 decoder to decode
}
for(int i = 0; i < data.length; i++){
bytes1[i] = data[i];
}
this.bytes = bytes1;
}
Solution: ByteArray
public void add(byte[] data){
if(this.bytes == null) {
this.bytes = data; // just store it because the field is null
} else {
byte[] bytes1 = new byte[this.bytes.length + data.length];
for (int i = 0; i < this.bytes.length; i++) {
bytes1[i] = this.bytes[i];
}
for (int i = 0; i < data.length; i++) {
bytes1[i] = data[i];
}
this.bytes = bytes1;
}
}
public void flush(){
this.bytes = null; // Important
}
EDIT
After observing the codes that reads bytes in Connection class I found that it's reading unnecessary 0 bytes at the end. So I come up with the following workaround,
Refactor: Connection
...
public abstract class Connection implements Runnable {
...
@Override
public void run() {
while(isConnected){
try {
ByteArray data = new ByteArray();
while(this.in.available() > 0){
byte[] read = this.read();
if (read != null) {
data.add(read);
}
}
if(data.getBytes() != null){
callback.run(data);
}
} catch (Exception e){
e.printStackTrace();
break;
}
}
}
...
private byte[] read() throws Exception{
byte[] bytes = new byte[this.in.available()];
int read = this.in.read(bytes);
if (read <= 0) return null; // or return null, or something, read might be -1 when there was no data.
return bytes; // just returning the read bytes is fine. you don't need to copy.
}
}
Upvotes: 1