user7426532
user7426532

Reputation: 103

Fit straight line on semi-log scale with Matplotlib

I have been struggling with fitting a straight line on a semi-log plot made with Matplotlib and Python 3. I have seen many examples of log-log scale figures, but none of the solutions I tried worked (using numpy). The line always ends up being crooked somewhere.

The following is what I have thus far:

import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

base_path = os.path.dirname(os.path.realpath(__file__))

fig = plt.figure()
ax = fig.add_subplot(111)

# Plot data.
location = os.path.join(base_path, "data.csv")
data = np.genfromtxt(location, delimiter=',', names=['year', 'bw'])
ax.plot(data['year'], data['bw'])

# Fit test.
x = data['year']
y = data['bw']
y_ln = np.log10(y)

n = data.shape[0]
A = np.array(([[x[j], 1] for j in range(n)]))
B = np.array(y_ln[0:n])
B = np.array(y[0:n])

X = np.linalg.lstsq(A, B)[0]
a = X[0]
b = X[1]

fit = a * x + b

p = np.polyfit(x, np.log(y), 1)
ax.semilogy(x, p[0] * x + p[1], 'g--')

ax.set_yscale('log')

The associated data.csv file looks as follows:

2016, 68.41987090116676
2017, 88.9788618486191
2018, 90.94850458504749
2019, 113.20946182004333
2020, 115.71547492850719

The figure I obtain looks as following, where the fitted line is crooked. Semi-log plot with fit which should be a straight line.

Feedback and suggestions are very much appreciated.

Upvotes: 4

Views: 4384

Answers (1)

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339170

If you fit the logarithm of the data to a line, you need to invert this operation when actually plotting the fitted data. I.e. if you fit a line to np.log(y), you need to plot np.exp(fit_result).

# Fit test.
x = data['year']
y = data['bw']

p = np.polyfit(x, np.log(y), 1)
ax.semilogy(x, np.exp(p[0] * x + p[1]), 'g--')

Complete example:

import io
import matplotlib.pyplot as plt
import numpy as np

u = u"""2016, 68.41987090116676
2017, 88.9788618486191
2018, 90.94850458504749
2019, 113.20946182004333
2020, 115.71547492850719"""

data = np.genfromtxt(io.StringIO(u), delimiter=',', names=['year', 'bw'])

fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(data['year'], data['bw'])

# Fit test.
x = data['year']
y = data['bw']

p = np.polyfit(x, np.log(y), 1)
ax.semilogy(x, np.exp(p[0] * x + p[1]), 'g--')

ax.set_yscale('log')

plt.show()

enter image description here

Upvotes: 3

Related Questions