Ozcan
Ozcan

Reputation: 726

Python scipy - Line fit with a constraint of all points are above the line

I have the image below and, I would like to fit a line to the white points on the image. However, I have a constraint of all points should be above the line like in the last image. I have prepared the script below when no constraints are given. Can anybody help me with how do I refactor this code with the constraint of all points should be above the fitted line?

input image

import numpy as np
import matplotlib.pyplot as plt
import cv2
from scipy import optimize

def fit_line(img):

    def func_linear(x, x0, y0, k):
        # y=k(x−xo)+yo
        f = lambda x:k*(x-x0)+y0
        return f(x)

    points = np.where(img>0)
    points = np.array([points[1], points[0]]).T
    
    x = points[:,0]
    y = points[:,1]

    p0 = [1, 1, 1]
    p , e = optimize.curve_fit(func_linear, x, y, p0)

    pt1 = (np.min(x).astype(int), func_linear(np.min(x), *p).astype(int))
    pt2 = (np.max(x).astype(int), func_linear(np.max(x), *p).astype(int))

    cv2.line(img, pt1, pt2, (255,0,0), 3)


img = cv2.imread("toy_2.png")

fit_line(img)

plt.imshow(img)

no constraint

with constraint

Upvotes: 0

Views: 142

Answers (1)

u1234x1234
u1234x1234

Reputation: 2510

You can shift the line along the y axis to ensure that all the points are above the line:

p[1] += np.max(y - func_linear(x, *p))

Result:

enter image description here

Full example:

import numpy as np
import matplotlib.pyplot as plt
import cv2
from scipy import optimize


def fit_line(img):
    def func_linear(x, x0, y0, k):
        # y=k(x−xo)+yo
        f = lambda x:k*(x-x0)+y0
        return f(x)

    points = np.where(img > 0)
    points = np.array([points[1], points[0]]).T

    x = points[:, 0]
    y = points[:, 1]

    p0 = [1, 1, 1]
    p, _ = optimize.curve_fit(func_linear, x, y, p0)

    # shift you line along the y axis
    p[1] += np.max(y - func_linear(x, *p)) + 1e-6  # use eps to ensure strictly greater

    pt1 = (np.min(x).astype(int), func_linear(np.min(x), *p).astype(int))
    pt2 = (np.max(x).astype(int), func_linear(np.max(x), *p).astype(int))

    cv2.line(img, pt1, pt2, (255, 0, 0), 3)


img = cv2.imread("toy_2.png")

fit_line(img)
plt.imshow(img)

You can also consider using adding a penalization term but it is more difficult to ensure that all points will be strictly above the line. Please see similar question: How do I put a constraint on SciPy curve fit?

Upvotes: 1

Related Questions