isaac6
isaac6

Reputation: 33

Ciphertext stealing in Java

So I have a school project that I am having some trouble with. We are given the encrypt method and then asked to create a decrypt method for it.

public static byte[] encrypt(byte[] plaintext, BlockCipher64 cipher, long IV) {
    if(plaintext.length <= 8)
        throw new IllegalArgumentException("plaintext must be longer than 8 bytes!");

    byte[] ciphertext = new byte[plaintext.length];
    int blocks = plaintext.length / 8;
    if(plaintext.length % 8 != 0) ++blocks;

    long prev = IV;
    for(int block = 0; block < blocks; ++block) {
        prev = cipher.encrypt(prev ^ longAt(plaintext, block * 8));
        storeLongAt(ciphertext, prev, block * 8);
    }

    // copy penultimate to last, then prev to penultimate (ciphertext stealing)
    int lastBlock = (blocks - 1) * 8;
    int secondLastBlock = (blocks - 2) * 8;
    storeLongAt(ciphertext, longAt(ciphertext, secondLastBlock), lastBlock);
    storeLongAt(ciphertext, prev, secondLastBlock);

    return ciphertext;
}

So I've gotten the decrypt method to work when the string length is exactly 16 (2 blocks). Other strings with even blocks (multiples of 8) don't work, 24, 32...etc. Also partial blocks do not work; if I encrypt a string of length 18 and decrypt it again, the last 2 characters are always wrong. Any help is much appreciated! Here is my decrypt method:

public static byte[] decrypt(byte[] ciphertext, BlockCipher64 cipher, long IV) {
    // code here
    // check for an illegal argument
    if(ciphertext.length <= 8)
        throw new IllegalArgumentException("ciphertext must be longer than 8 bytes!");

    // create a byte[] for the plaintext
    byte[] plaintext = new byte[ciphertext.length];

    // calculate how many blocks there are
    int blocks = plaintext.length / 8;
    if(plaintext.length % 8 != 0) ++blocks;

    // handle the last two blocks (which are special because of ciphertext stealing)
    int lastBlock = (blocks - 1) * 8;
    int secondLastBlock = (blocks - 2) * 8;
    long lBlock = longAt(ciphertext, lastBlock);
    int lBlockSize = ciphertext.length % 8;
    // if block sizes are even; swap, if block size is partial; 
    if (lBlockSize != 0) {
        //get blocks to switch
        byte[] NLB = new byte[lBlockSize];
        for (int i=secondLastBlock, j=0; i<secondLastBlock + lBlockSize; i++, j++) {
            NLB[j] = ciphertext[i];
        }
        byte[] NSLB = new byte[lBlockSize];
        for (int i=lastBlock, j=0; i<lastBlock + lBlockSize; i++, j++) {
            NSLB[j] = ciphertext[i];
        }
        //build ciphertext
        for (int i=secondLastBlock, j=0; i<secondLastBlock + lBlockSize; i++, j++) {
            ciphertext[i] = NSLB[j];
        }
        for (int i=lastBlock, j=0; i<lastBlock + lBlockSize; i++, j++) {
            ciphertext[i] = NLB[j];
        }
    } else {
        storeLongAt(ciphertext, longAt(ciphertext, secondLastBlock), lastBlock);
        storeLongAt(ciphertext, lBlock, secondLastBlock);
    }

    // loop over all other blocks, decrypting and xor'ing, and saving the results
    long prev = IV;
    for(int block = 0; block < blocks; ++block) {
        prev = cipher.decrypt(prev ^ longAt(ciphertext, block * 8));
        storeLongAt(plaintext, prev, block * 8);
    }

    return plaintext;
}

If any other info is needed, let me know... I'm real stumped on this one.

EDIT: Additional methods and classes storeLongAt() and longAt():

public static void storeLongAt(byte[] b, long x, int pos) {
    byte[] c = BlockCipher64.longToBytes(x);
    for(int i = 0; i < 8 && (pos + i) < b.length; ++i) {
        b[pos + i] = c[i];
    }
}

public static long longAt(byte[] b, int pos) {
    return BlockCipher64.bytesToLong(Arrays.copyOfRange(b, pos, pos + 8));
}

And the BlockCipher64 class:

public interface BlockCipher64 {
long encrypt(long block);
long decrypt(long block);

public static byte[] longToBytes(long l) {
    byte[] result = new byte[8];
    for (int i = 7; i >= 0; i--) {
        result[i] = (byte) l;
        l >>= 8;
    }
    return result;
}

public static long bytesToLong(byte[] b) {
    long result = 0;
    for (int i = 0; i < 8; i++) {
        result <<= 8;
        result |= ((long) b[i]) & 0xffL; // notice the L
    }
    return result;
}
}

Upvotes: 2

Views: 354

Answers (1)

Maarten Bodewes
Maarten Bodewes

Reputation: 94038

You are swapping the ciphertext bytes instead of the plaintext bytes. Note that the swapping takes place before the encryption of the last block (in time). If you reverse that it must be swapped after decryption of the last block.

Currently you are creating an invalid ciphertext and you try to decrypt that.

Upvotes: 1

Related Questions