serious_luffy
serious_luffy

Reputation: 419

Implementing Knuth-Morris-Pratt (KMP) algorithm for string matching with Python

I am following Cormen Leiserson Rivest Stein (clrs) book and came across "kmp algorithm" for string matching. I implemented it using Python (as-is).

However, it doesn't seem to work for some reason. where is my fault?

The code is given below:

def kmp_matcher(t,p):
    n=len(t)
    m=len(p)
    # pi=[0]*n;
    pi = compute_prefix_function(p)
    q=-1
    for i in range(n):
        while(q>0 and p[q]!=t[i]):
            q=pi[q]
        if(p[q]==t[i]):
            q=q+1
        if(q==m):
            print "pattern occurs with shift "+str(i-m)
            q=pi[q]


def compute_prefix_function(p):
    m=len(p)
    pi =range(m)
    pi[1]=0
    k=0
    for q in range(2,m):
        while(k>0 and p[k]!=p[q]):
            k=pi[k]
        if(p[k]==p[q]):
            k=k+1
        pi[q]=k
    return pi

t = 'brownfoxlazydog'
p = 'lazy'
kmp_matcher(t,p)

Upvotes: 5

Views: 9707

Answers (4)

JB-Franco
JB-Franco

Reputation: 256

KMP stands for Knuth-Morris-Pratt it is a linear time string-matching algorithm.

Note that in python, the string is ZERO BASED, (while in the book the string starts with index 1).

So we can workaround this by inserting an empty space at the beginning of both strings.

This causes four facts:

  1. The len of both text and pattern is augmented by 1, so in the loop range, we do NOT have to insert the +1 to the right interval. (note that in python the last step is excluded);
  2. To avoid accesses out of range, you have to check the values of k+1 and q+1 BEFORE to give them as index to arrays;
  3. Since the length of m is augmented by 1, in kmp_matcher, before to print the response, you have to check this instead: q==m-1;
  4. For the same reason, to calculate the correct shift you have to compute this instead: i-(m-1)

so the correct code, based on your original question, and considering the starting code from Cormen, as you have requested, would be the following:

(note : I have inserted a matching pattern inside, and some debug text that helped me to find logical errors):

def compute_prefix_function(P):
   m     = len(P)
   pi    = [None] * m
   pi[1] = 0
   k     = 0

   for q in range(2, m):
      print ("q=", q, "\n")
      print ("k=", k, "\n")

      if ((k+1) < m):
         while (k > 0 and P[k+1] != P[q]):
            print ("entered while: \n")
            print ("k: ", k, "\tP[k+1]: ", P[k+1], "\tq: ", q, "\tP[q]: ", P[q])
            k = pi[k]

         if P[k+1] == P[q]:
            k = k+1
            print ("Entered if: \n")
            print ("k: ", k, "\tP[k]: ", P[k], "\tq: ", q, "\tP[q]: ", P[q])
      pi[q] = k
      print ("Outside while or if: \n")
      print ("pi[", q, "] = ", k, "\n")
      print ("---next---")
   print ("---end for---")
   return pi

def kmp_matcher(T, P):
   n  = len(T)
   m  = len(P)
   pi = compute_prefix_function(P)
   q  = 0

   for i in range(1, n):
      print ("i=", i, "\n")
      print ("q=", q, "\n")
      print ("m=", m, "\n")

      if ((q+1) < m):
         while (q > 0 and P[q+1] != T[i]):
            q = pi[q]
         if P[q+1] == T[i]:
            q = q+1
         if q == m-1:
            print ("Pattern occurs with shift", i-(m-1))
            q = pi[q]
      print("---next---")
   print("---end for---")


txt = " bacbababaabcbab"
ptn = " ababaab"
kmp_matcher(txt, ptn)

(so this would be the correct accepted answer...)

hope that it helps.

Upvotes: 1

proprius
proprius

Reputation: 522

You might want to try out my code:

def recursive_find_match(i, j, pattern, pattern_track):

    if pattern[i] == pattern[j]:
        pattern_track.append(i+1)
        return {"append":pattern_track, "i": i+1, "j": j+1}
    elif pattern[i] != pattern[j] and i == 0:
        pattern_track.append(i)
        return {"append":pattern_track, "i": i, "j": j+1}

    else:
        i = pattern_track[i-1]
        return recursive_find_match(i, j, pattern, pattern_track)

def kmp(str_, pattern):

    len_str = len(str_)
    len_pattern = len(pattern)
    pattern_track = []

    if len_pattern == 0:
        return
    elif len_pattern == 1:
        pattern_track = [0]
    else:   
        pattern_track = [0]
        i = 0
        j = 1

        while j < len_pattern:
            data = recursive_find_match(i, j, pattern, pattern_track)

            i = data["i"]
            j = data["j"]
            pattern_track = data["append"]

    index_str = 0
    index_pattern = 0
    match_from = -1

    while index_str < len_str:
        if index_pattern == len_pattern:
            break
        if str_[index_str] == pattern[index_pattern]:
            if index_pattern == 0:
                match_from = index_str

            index_pattern += 1
            index_str += 1
        else:
            if index_pattern == 0:
                index_str += 1
            else:
                index_pattern = pattern_track[index_pattern-1]
                match_from = index_str - index_pattern

Upvotes: 3

Rohanil
Rohanil

Reputation: 1887

Try this:

def kmp_matcher(t, d):
    n=len(t)
    m=len(d)

    pi = compute_prefix_function(d)
    q = 0
    i = 0
    while i < n:
        if d[q]==t[i]:
            q=q+1
            i = i + 1
        else:
            if q != 0:
                q = pi[q-1]
            else:
                i = i + 1
        if q == m:
            print "pattern occurs with shift "+str(i-q)
            q = pi[q-1]

def compute_prefix_function(p):
    m=len(p)
    pi =range(m)
    k=1
    l = 0
    while k < m:
        if p[k] <= p[l]:
            l = l + 1
            pi[k] = l
            k = k + 1
        else:
            if l != 0:
                l = pi[l-1]
            else:
                pi[k] = 0
                k = k + 1
    return pi

t = 'brownfoxlazydog'
p = 'lazy'
kmp_matcher(t, p)

Upvotes: 2

sstreamer
sstreamer

Reputation: 101

This is a class I wrote based on CLRs KMP algorithm, which contains what you are after. Note that only DNA "characters" are accepted here.

class KmpMatcher(object):
def __init__(self, pattern, string, stringName):
    self.motif = pattern.upper()
    self.seq = string.upper()
    self.header = stringName
    self.prefix = []
    self.validBases = ['A', 'T', 'G', 'C', 'N']

#Matches the motif pattern against itself.
def computePrefix(self):
    #Initialize prefix array
    self.fillPrefixList()
    k = 0

    for pos in range(1, len(self.motif)):
        #Check valid nt
        if(self.motif[pos] not in self.validBases):
            self.invalidMotif()

        #Unique base in motif
        while(k > 0 and self.motif[k] != self.motif[pos]):
            k = self.prefix[k]
        #repeat in motif
        if(self.motif[k] == self.motif[pos]):
            k += 1

        self.prefix[pos] = k

#Initialize the prefix list and set first element to 0
def fillPrefixList(self):
    self.prefix = [None] * len(self.motif)
    self.prefix[0] = 0

#An implementation of the Knuth-Morris-Pratt algorithm for linear time string matching
def kmpSearch(self):
    #Compute prefix array
    self.computePrefix()
    #Number of characters matched
    match = 0
    found = False

    for pos in range(0, len(self.seq)):
        #Check valid nt
        if(self.seq[pos] not in self.validBases):
            self.invalidSequence()

        #Next character is not a match
        while(match > 0 and self.motif[match] != self.seq[pos]):
            match = self.prefix[match-1]
        #A character match has been found
        if(self.motif[match] == self.seq[pos]):
            match += 1
        #Motif found
        if(match == len(self.motif)):
            print(self.header)
            print("Match found at position: " + str(pos-match+2) + ':' + str(pos+1))
            found = True
            match = self.prefix[match-1]

    if(found == False):
        print("Sorry '" + self.motif + "'" + " was not found in " + str(self.header))

#An invalid character in the motif message to the user
def invalidMotif(self):
    print("Error: motif contains invalid DNA nucleotides")
    exit()

#An invalid character in the sequence message to the user
def invalidSequence(self):
    print("Error: " + str(self.header) + "sequence contains invalid DNA nucleotides")
    exit()

Upvotes: 4

Related Questions