Péter Leéh
Péter Leéh

Reputation: 2119

Is there a way to write a math formula on matplotlib plot dynamically?

I'm doing a template for my lab work in Python. To summarize the purpose of it, it's to plot data points and fit a pre-defined model with scipy curve_fit. Usually I fit polynomials or exponential curves. I managed to print the fitting params dynamically on the plot, but I have to manually type in the relevant equation every time. I'm wondering, is there an elegant way to do this dynamically? I've read about sympy, but for the time being I couldn't make it.

Here's the code:

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from datetime import datetime

#two example functions
def f(x, p0, p1):
    return p0 * x + p1

def g(x, p0, p1):
    return p0 * np.exp(x * p1)

#example data
xval = np.array([0,1,2,3,4,5,6])
yval = np.array([0, 2,3.95,5.8,8.1, 10.2, 12.4])

#curve fitting
popt, pcov = curve_fit(f, xval, yval)

plt.rcParams.update({'font.size': 12})
plt.figure(figsize=(9,7))
plt.plot(xval, yval,'ko', label = 'Data points', markersize = 7)
plt.title('TITLE', fontsize = 15)
plt.grid()
plt.plot(xval, f(xval, *popt),'r-', label = 'Fit')
#printing the params on plot
for idx in range(len(popt)):
    plt.text(0.8,0.05+0.05*(idx+1), 'p'+str(idx)+' = {0:.5f}'.format(popt[idx]), transform=plt.gca().transAxes)

#manually writing the equation, that's what I want to print dynamically
plt.text(0.8, 0.05, '$y = p0 \cdot x + p1 $' , transform=plt.gca().transAxes)

plt.text(0.86, 1.01, datetime.today().strftime('%Y.%m.%d.'), transform=plt.gca().transAxes)
plt.text(0 ,1.01, 'NAME', transform=plt.gca().transAxes)
plt.ylabel('Y axis title')
plt.xlabel('X axis title')
plt.legend()
plt.show()

The expected result is:

if I use a function for fitting - let's say g(x, p0, p1) which returns p0 * np.exp(x * p1) then the returned formula itself should be printed on the plot, just like the other one in the example code :

plt.text(0.8, 0.05, '$y = p0 \cdot x + p1 $' , transform=plt.gca().transAxes) 

except it's a manual solution.

I really appreciate any suggestions.

Upvotes: 1

Views: 1850

Answers (2)

Péter Leéh
Péter Leéh

Reputation: 2119

I actually managed to make a solution (without sympy though), and I have to type in manually the formulas, but they are selected automatically. I use dictionary for that.

Here's the code:

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from datetime import datetime

fun_dict = {}

#three example functions
def f(x, p0, p1):
    return p0 * x + p1

def g(x, p0, p1):
    return p0 * np.exp(x * p1)

def h(x, p0, p1, p2):
    return p0 * x ** 2 + p1 * x + p2

f_string = '$y = p0 \cdot x + p1 $'
fun_dict['f'] = f_string
g_string = '$y = p0 \cdot e^{p1 \cdot x} $'
fun_dict['g'] = g_string
h_string = '$y = p0 \cdot x^2 + p1 \cdot x + p2$'
fun_dict['h'] = h_string
#example data
xval = np.array([0,1,2,3,4,5,6])
yval = np.array([0, 2,3.95,5.8,8.1, 10.2, 12.4])


def get_fun(func):
    popt, _ = curve_fit(func, xval, yval)
    return popt, fun_dict[str(func.__name__)], func

popt, str_name, func = get_fun(h)



plt.rcParams.update({'font.size': 12})
plt.figure(figsize=(9,7))
plt.plot(xval, yval,'ko', label = 'Data points', markersize = 7)
plt.title('TITLE', fontsize = 15)
plt.grid()
plt.plot(xval, func(xval, *popt),'r-', label = 'Fit')
for idx in range(len(popt)):
    plt.text(0.8,0.05+0.05*(idx+1), 'p'+str(idx)+' = {0:.5f}'.format(popt[idx]), transform=plt.gca().transAxes)

plt.text(0.7, 0.05, str_name, transform=plt.gca().transAxes)

plt.text(0.86, 1.01, datetime.today().strftime('%Y.%m.%d.'), transform=plt.gca().transAxes)
plt.text(0 ,1.01, 'NAME', transform=plt.gca().transAxes)
plt.ylabel('Y axis title')
plt.xlabel('X axis title')
plt.legend()
plt.show()

Upvotes: 0

Grzegorz Bokota
Grzegorz Bokota

Reputation: 1804

I think that you may use sympy package. It allows to define custom variables, create, expressions and then evaluate it. I'm not sure what is impact on performance

Here is your code with changes:

import numpy as np
import sympy
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from datetime import datetime


#two example functions
x, p0, p1 = sympy.var("x p0 p1")

f = p0 * x + p1 

g = p0 * sympy.exp(x*p1)

def partial_fun(sympy_expr):

    def res_fun(X, P0, P1):
        return np.array([sympy_expr.evalf(subs={x: x_, p0: P0, p1: P1}) for x_ in X], dtype=np.float)

    return res_fun

#example data
xval = np.array([0,1,2,3,4,5,6])
yval = np.array([0, 2,3.95,5.8,8.1, 10.2, 12.4])

#curve fitting
popt, pcov = curve_fit(partial_fun(f), xval, yval)

plt.rcParams.update({'font.size': 12})
plt.figure(figsize=(9,7))
plt.plot(xval, yval,'ko', label = 'Data points', markersize = 7)
plt.title('TITLE', fontsize = 15)
plt.grid()
plt.plot(xval, partial_fun(f)(xval, *popt),'r-', label = 'Fit')
#printing the params on plot
for idx in range(len(popt)):
    plt.text(0.8,0.05+0.05*(idx+1), 'p'+str(idx)+' = {0:.5f}'.format(popt[idx]), transform=plt.gca().transAxes)

#manually writing the equation, that's what I want to print dynamically
plt.text(0.8, 0.05, f'$y = {f} $' , transform=plt.gca().transAxes)

plt.text(0.86, 1.01, datetime.today().strftime('%Y.%m.%d.'), transform=plt.gca().transAxes)
plt.text(0 ,1.01, 'NAME', transform=plt.gca().transAxes)
plt.ylabel('Y axis title')
plt.xlabel('X axis title')
plt.legend()
plt.show()

Upvotes: 1

Related Questions