Gregorick
Gregorick

Reputation: 61

Tridiagonal Matrix Algorithm (TDMA) aka Thomas Algorithm, using Python with NumPy arrays

I found an implementation of the thomas algorithm or TDMA in MATLAB.

function x = TDMAsolver(a,b,c,d)
    %a, b, c are the column vectors for the compressed tridiagonal matrix, d is the right vector
    n = length(b); % n is the number of rows

    % Modify the first-row coefficients
    c(1) = c(1) / b(1);    % Division by zero risk.
    d(1) = d(1) / b(1);    % Division by zero would imply a singular matrix.

    for i = 2:n-1
        temp = b(i) - a(i) * c(i-1);
        c(i) = c(i) / temp;
        d(i) = (d(i) - a(i) * d(i-1))/temp;
    end

    d(n) = (d(n) - a(n) * d(n-1))/( b(n) - a(n) * c(n-1));

    % Now back substitute.
    x(n) = d(n);
    for i = n-1:-1:1
        x(i) = d(i) - c(i) * x(i + 1);
    end
end

I need it in python using numpy arrays, here my first attempt at the algorithm in python.

import numpy

aa = (0.,8.,9.,3.,4.)
bb = (4.,5.,9.,4.,7.)
cc = (9.,4.,5.,7.,0.)
dd = (8.,4.,5.,9.,6.)

ary = numpy.array

a = ary(aa)
b = ary(bb)
c = ary(cc)
d = ary(dd)

n = len(b)## n is the number of rows

## Modify the first-row coefficients
c[0] = c[0]/ b[0]    ## risk of Division by zero.
d[0] = d[0]/ b[0]

for i in range(1,n,1):
    temp = b[i] - a[i] * c[i-1]
    c[i] = c[i]/temp
    d[i] = (d[i] - a[i] * d[i-1])/temp

d[-1] = (d[-1] - a[-1] * d[-2])/( b[-1] - a[-1] * c[-2])

## Now back substitute.
x = numpy.zeros(5)
x[-1] = d[-1]
for i in range(-2, -n-1, -1):
    x[i] = d[i] - c[i] * x[i + 1]

They give different results, so what am I doing wrong?

Upvotes: 4

Views: 32202

Answers (4)

Justin Clark
Justin Clark

Reputation: 64

I made this since none of the online implementations for python actually work. I've tested it against built-in matrix inversion and the results match.

Here a = Lower Diag, b = Main Diag, c = Upper Diag, d = solution vector

import numpy as np

def TDMA(a,b,c,d):
    n = len(d)
    w= np.zeros(n-1,float)
    g= np.zeros(n, float)
    p = np.zeros(n,float)
    
    w[0] = c[0]/b[0]
    g[0] = d[0]/b[0]

    for i in range(1,n-1):
        w[i] = c[i]/(b[i] - a[i-1]*w[i-1])
    for i in range(1,n):
        g[i] = (d[i] - a[i-1]*g[i-1])/(b[i] - a[i-1]*w[i-1])
    p[n-1] = g[n-1]
    for i in range(n-1,0,-1):
        p[i-1] = g[i-1] - w[i-1]*p[i]
    return p

For an easy performance boost for large matrices, use numba! This code outperforms np.linalg.inv() in my tests:

import numpy as np
from numba import jit    

@jit
def TDMA(a,b,c,d):
    n = len(d)
    w= np.zeros(n-1,float)
    g= np.zeros(n, float)
    p = np.zeros(n,float)
    
    w[0] = c[0]/b[0]
    g[0] = d[0]/b[0]

    for i in range(1,n-1):
        w[i] = c[i]/(b[i] - a[i-1]*w[i-1])
    for i in range(1,n):
        g[i] = (d[i] - a[i-1]*g[i-1])/(b[i] - a[i-1]*w[i-1])
    p[n-1] = g[n-1]
    for i in range(n-1,0,-1):
        p[i-1] = g[i-1] - w[i-1]*p[i]
    return p

Upvotes: 4

Fred Foo
Fred Foo

Reputation: 363567

There's at least one difference between the two:

for i in range(1,n,1):

in Python iterates from index 1 to the last index n-1, while

for i = 2:n-1

iterates from index 1 (zero-based) to the last-1 index, since Matlab has one-based indexing.

Upvotes: 2

Mike
Mike

Reputation: 56

Writing somthing like this in python is going to be really slow. You would be much better off using LAPACK to do the numerical heavy lifting and use python for everything around it. LAPACK is compiled so it will run much faster than python it is also much more higly optimised than it is feasible for most of us to match.

SciPY provides low level wrappers for LAPACK so that you can call it from python very simply, the one you are looking for can be found here:

https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.lapack.dgtsv.html#scipy.linalg.lapack.dgtsv

Upvotes: 1

Ray
Ray

Reputation: 4671

In your loop, the Matlab version iterates over the second through second-to last elements. To do the same in Python, you want:

for i in range(1,n-1):

(As noted in voithos's comment, this is because the range function excludes the last index, so you need to correct for this in addition to the change to 0 indexing).

Upvotes: 0

Related Questions