Amanda.py
Amanda.py

Reputation: 113

plotting 2D xarray in pyqt window

I am trying to use the FigureCanvas class to embed a matplotlib plot in a pyqt window. Since the data is already stored in an xarray, I want to use the xarray.plot() function to generate the plot and then add it to the figure. However, when I do this, a histogram is plotted instead of a pcolormesh like I expect from the documentation. Here is the class:


class PlotCanvas(FigureCanvas):

    def __init__(self, parent=None, width=5, height=4, dpi=100):
        self.fig = Figure(figsize=(width, height), dpi=dpi)
        super(PlotCanvas, self).__init__(self.fig)
        self.setParent(parent)
        FigureCanvas.setSizePolicy(self,
                                   QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
        self.data = xr.DataArray()
        self.axes = None

    def update_xyt(self, x, y, t):
        self.axes.clear()
        self.x_label = x
        self.y_label = y
        self.title = t
        self.axes.set_title(self.title)
        self.axes.set_xlabel(self.x_label)
        self.axes.set_ylabel(self.y_label)
        self.plot(self.data)

    def plot(self, data):
        self.data = data
        self.axes = self.fig.add_subplot(111)
        self.data.plot(ax=self.axes)
        self.draw()

I override the plot() function to pass in an xarray to try to plot onto the axes of the figure. Here is the creation of an xarray and it being plotted:

        x = np.linspace(-1, 1, 51)
        y = np.linspace(-1, 1, 51)
        z = np.linspace(-1, 1, 51)
        xyz = np.meshgrid(x, y, z, indexing='ij')
        d = np.sin(np.pi * np.exp(-1 * (xyz[0]**2 + xyz[1]**2 + xyz[2]**2))) * np.cos(np.pi / 2 * xyz[1])
        obj.xar = xr.DataArray(d, coords={"slit": x, 'perp': y, "energy": z}, dims=["slit", "perp", "energy"])
        obj.cut = obj.xar.sel({"perp": 0}, method='nearest')
        obj.fm_pyqtplot = PGImageTool(obj.data, layout=1)
        obj.fm_pyplot = PlotCanvas()
        obj.fm_pyplot.plot(obj.cut)

Upvotes: 0

Views: 187

Answers (1)

Amanda.py
Amanda.py

Reputation: 113

While I was trying to create a runnable version of this issue having its own window, I found that it actually works. My issue was that I was resetting the canvas when the widget was initialized in my larger application and when plot was called it was passing in the entire xarray instead of just a 2D cut of it. This is why it was giving me a histogram instead of a pcolormesh. I figured I would finish this post so that I could post the solution.

import sys

from PyQt5.Qt import Qt
from PyQt5.QtWidgets import QSizePolicy
from PyQt5.QtWidgets import QApplication, QMainWindow, QHBoxLayout
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
import xarray as xr
import numpy as np


class PlotCanvas(FigureCanvas):

    def __init__(self, parent=None, width=5, height=4, dpi=100):
        self.fig = Figure(figsize=(width, height), dpi=dpi)
        super(PlotCanvas, self).__init__(self.fig)
        self.setParent(parent)
        FigureCanvas.setSizePolicy(self,
                                   QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
        self.data = xr.DataArray()
        self.axes = None

    def update_xyt(self, x, y, t):
        self.axes.clear()
        self.x_label = x
        self.y_label = y
        self.title = t
        self.axes.set_title(self.title)
        self.axes.set_xlabel(self.x_label)
        self.axes.set_ylabel(self.y_label)
        self.plot(self.data)

    def plot(self, data):
        self.data = data
        self.axes = self.fig.add_subplot(111)
        self.data.plot(ax=self.axes)
        self.draw()


class MainWindow(QMainWindow):
    def __init__(self):
        super(MainWindow, self).__init__()\
        x = np.linspace(-1, 1, 51)
        y = np.linspace(-1, 1, 51)
        z = np.linspace(-1, 1, 51)
        xyz = np.meshgrid(x, y, z, indexing='ij')
        d = np.sin(np.pi * np.exp(-1 * (xyz[0] ** 2 + xyz[1] ** 2 + xyz[2] ** 2))) * np.cos(np.pi / 2 * xyz[1])
        self.xar = xr.DataArray(d, coords={"slit": x, 'perp': y, "energy": z}, dims=["slit", "perp", "energy"])
        self.cut = self.xar.sel({"perp": 0}, method='nearest')
        self.fm_pyplot = PlotCanvas()
        self.fm_pyplot.plot(self.cut)
        self.layout().addWidget(self.fm_pyplot)


class App(QApplication):
    def __init__(self, sys_argv):
        super(App, self).__init__(sys_argv)
        self.setAttribute(Qt.AA_EnableHighDpiScaling)
        self.mainWindow = MainWindow()
        self.mainWindow.setWindowTitle("Main Window")
        self.mainWindow.show()

def main():

    app = App(sys.argv)
    sys.exit(app.exec_())


if __name__ == "__main__":
    main()

Upvotes: 1

Related Questions