Ron Thompson
Ron Thompson

Reputation: 1096

Implementing shifting and Hessenberg into an already functioning (slow) QR algorithm

TLDR: help me make the function at the bottom do what I think it should do already.

So after an interview didn't go smoothly, I decided to write my own library so that I would be better prepared if I'm ever asked about matrix multiplication again.

I know that doing it in Java is going to sacrifice some performance (at least I think I do). Also, I know that there are libraries already made for this purpose. The point is to learn enough to be able to off the cuff handle virtually any matrix algorithm I may need to, and to learn where to find the answers to the ones I don't. And I'm doing it in Java to flesh out my understanding of the language.

I've got most of what I set out to do in the library accomplished, except the final part: eigenvalues.

I've implemented a fairly basic QR decomposition method, and it seems to work (read: compare a bunch of random matrices in my library with the output of trusted calculators).

The problem is that it's orders of magnitude too slow. For a 50x50 matrix it takes almost a minute to get the eigenvalues.

Right now that's because I've set the iterations to 500 to handle a pathological case I came across when testing it. I'd like to reduce the number of iterations in order to something like 10, which for almost every matrix is sufficient to make it converge.

I looked around the web, and found a book talking about advanced QR methods. Basically, from what I understand of the paper, if you first convert the matrix to upper Hessenberg form, it should converge orders of magnitude faster. Further, if you implement shifting, it should converge quadratically.

I read and implemented, I think, those two algorithms. They do make the convergence faster; the problem is that now it fails the previous problem test case, spitting out the wrong eigenvalues silently.

Edit: I've tested the Hessenberg function, it SEEMS to work on arbitrary matrices; as in, it spits out the same numbers that MATLAB and wolfram do. However, when I add it in as the first step, the number of iterations goes UP. I'm asking on Math stack exchange about the underlying algorithm right now, but the shifting is really the part that's killing me.

I don't understand why. The book says, paraphrasing through my limited math understanding, that if a matrix A is put through a Hessenberg transformation H(A), it should still have the same eigenvalues. And the same with shifting. But when I implement either or both algorithms, the eigenvalues change.

My problem is either I've implemented the algorithm incorrectly, or I'm misunderstanding the math behind it.

The paper I was talking about for reference: http://people.inf.ethz.ch/arbenz/ewp/Lnotes/chapter3.pdf

Edit: The repository link (with the rest of the code that this code depends on): https://github.com/rwthompsonii/matrix-java

Edit: relevant function as per the rules:

public static Complex[] eigenvalues(SquareMatrix A) {
    Complex[] e = new Complex[A.getRows()];

    QRDecomposition qr = new QRDecomposition();

    qr.iterations = 0;
    int total_iter = 0;
    int num_eigen_found = 0;
    SquareMatrix QRIterator = new SquareMatrix(A);

    //in general, QR decomposition will converge faster from an upper
    //Hessenberg matrix.  so, first things first, we bring QRIterator to that form
    //QRIterator = new SquareMatrix(qr.hessenberg(QRIterator));

    int max = MAX_ITERATIONS;
    //double lastElement;
    //SquareMatrix ScaledIdentity;
    do {

        System.out.println("Pre-decompose: QRIterator (Iteration#" + (qr.iterations + 1) + "):\n" + QRIterator);
        if (QRIterator.getRows() == 1) {
            //very last 1x1 element in matrix
            e[num_eigen_found++] = new Complex(
                    QRIterator.getMatrix()[0][0]
            );
            break;
        } else {

            /*lastElement = QRIterator.getMatrix()[QRIterator.getRows() - 1][QRIterator.getColumns() - 1];
            ScaledIdentity = new SquareMatrix(Matrix.IdentityMatrix(QRIterator.getRows()).scale(lastElement));
            try {
                QRIterator = new SquareMatrix(QRIterator.subtract(ScaledIdentity));
            } catch (DimensionMismatchException ex) {
                System.out.println("Unexpected execption during QRIterator -= I*alpha, bailing.");
                System.exit(-1);

            }*/
            qr.decompose(QRIterator);
        }
        try {
            QRIterator = new SquareMatrix(qr.R.mult(qr.Q)/*.add(ScaledIdentity)*/);

        } catch (DimensionMismatchException ex) {
            System.out.println("An unexpected exception occurred during QRIterator = R*Q, bailing.");
            System.exit(-1);
        }
        qr.iterations++;

        //testing indicates that MAX_ITERATIONS iterations should be more than sufficient to converge, if its going to at all
        if (qr.iterations == max || Math.abs(QRIterator.getMatrix()[QRIterator.getRows() - 1][QRIterator.getColumns() - 2]) < CONVERGENCE_CHECK) {
            System.out.println("QRIterator (at max iteration or converged) (Iteration#" + (qr.iterations + 1) + "):\n" + QRIterator);
            if (Math.abs(QRIterator.getMatrix()[QRIterator.getRows() - 1][QRIterator.getColumns() - 2]) < CONVERGENCE_CHECK) {
                //then the value at M[n][n] is an eigenvalue and it is real
                e[num_eigen_found++] = new Complex(
                        QRIterator.getMatrix()[QRIterator.getRows() - 1][QRIterator.getColumns() - 1]
                );

                //System.out.println("e[" + (num_eigen_found - 1) + "]:\t" + e[num_eigen_found - 1] + "\nQRIterator before deflation:\n" + QRIterator);
                double[][] deflatedMatrix = deflate(QRIterator.getMatrix(), 1);
                QRIterator = new SquareMatrix(deflatedMatrix);

                //System.out.println("\nQRIterator after deflation:\n" + QRIterator);
                total_iter += qr.iterations;
                qr.iterations = 0;  //reset the iterations counter to find the next eigenvalue
            } else {
                //this is a 2x2 matrix with either real or complex roots.  need to find them.
                //characteristic equation of 2x2 array => E^2 - (w + z)E + (wz - xy) = 0 where E = eigenvalue (possibly pair, possibly singular, possibly real, possibly complex)
                // and the matrix {{w, x}, {y, z}} is the input array, the task is to calculate the root(s) of that equation
                //that is a quadratic equation => (root = (-b +- sqrt(b^2  - 4ac))/2a)
                //determinant b^2 - 4ac will determine behavior of roots => positive means 2 real roots, 0 means 1 repeated real root, negative means conjugate pair of imaginary roots

                //first, get the wxyz from the (possibly bigger) matrix
                int n = QRIterator.getRows();
                double w = QRIterator.getMatrix()[n - 2][n - 2];
                double x = QRIterator.getMatrix()[n - 2][n - 1];
                double y = QRIterator.getMatrix()[n - 1][n - 2];
                double z = QRIterator.getMatrix()[n - 1][n - 1];

                //a not used since it's = 1
                double b = -(w + z);
                double c = (w * z - x * y);

                //calculate determinant of quadratic equation
                double determ = b * b - 4 * c;

                if (determ >= 0) {
                    //one or two real roots 
                    double sqrt_determ_real = Math.sqrt(determ);
                    e[num_eigen_found++] = new Complex((-b + sqrt_determ_real) / 2.0);
                    e[num_eigen_found++] = new Complex((-b - sqrt_determ_real) / 2.0);
                    //in the zero determinant case that's simply going to add the same eigenvalue to the list twice.  I'm ok with that for now.
                } else if (determ < 0) {
                    //conjugate pair of complex roots
                    double sqrt_determ_imag = Math.sqrt(-determ);
                    e[num_eigen_found++] = new Complex(-b / 2.0, sqrt_determ_imag / 2.0);
                    e[num_eigen_found++] = new Complex(-b / 2.0, -sqrt_determ_imag / 2.0);
                }

                if (QRIterator.getRows() > 2) {
                    total_iter += qr.iterations;
                    qr.iterations = 0;  //reset the iterations counter to find the next eigenvalue
                    double[][] deflatedMatrix = deflate(QRIterator.getMatrix(), 2);
                    QRIterator = new SquareMatrix(deflatedMatrix);
                } 
            }
        }
        //QRIterator = new SquareMatrix(qr.hessenberg(QRIterator));

    } while (qr.iterations < max);

    //used for debugging here
    /*System.out.println("Finished iterating.  Iterations:\t" + total_iter
     + "\nFinal value of qr.Q:\n" + qr.Q + "\nFinal value of qr.R:\n" + qr.R
     + "\nFinal value of QRIterator:\n" + QRIterator
     + "\nOriginal SquareMatrix A:\n" + A);
     */
    return e;
}

Edit: in before snarkiness, there's a whole lot of crap in there that I need to clean up, like a bunch of print statements that I'm using for debugging, mainly because I don't want to step through 500 iterations to see the values. I'd like to get it working and then clean it up to meet my own fairly decent standards of readability. I know there's some refactoring that needs to be done, the function is simply too long as it is. But first, it needs to work. Help a guy out?

Upvotes: 0

Views: 1080

Answers (1)

Patrick McLaren
Patrick McLaren

Reputation: 988

I think a good first step would be to scrap the Rayleigh quotient shift method, the Wilkinson shift, and the double shift method, until you can get the straight-forward reduction to Hessenburg form working, i.e. Algorithm 3.3, pg. 61 of your first reference.

In this case, you've got a fixed number of iterations, it's a little simple to begin with. Calculate your Householder reflectors P_k, where k goes from 1 to n-2, following the definition on page 59, involving the vector norm, and rho. Note that in the real case, one commonly sets rho = -sign(x_1), see the text.

Work out a simple 2x2 example on paper (or 3x3 if you're keen), and print your Java calculations, step by step, to ensure they're in correspondence with your hand-written work. Also, check you're not running into floating point issues, etc.

Finally, don't be upset if your code doesn't find the eigenvalues of an arbitrary 50x50 matrix at the drop of a hat. Most implementations of these reductions are heavily optimized, and rarely verbatim translations of the pseudocode that you'll find in the literature. Check out other libraries, such as Sage or NumPy if you are into Python; or perhaps Boost for C++. These might be more fun to work with, for this kind of problem.

Upvotes: 1

Related Questions