# **************************************************************************
# *
# * Authors:     Daniel Marchán Torres (da.marchan@cnb.csic.es)
# *
# * Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
from pyworkflow.viewer import Viewer, DESKTOP_TKINTER, WEB_DJANGO
from pyworkflow.protocol.params import LabelParam
from pwem.viewers import EmProtocolViewer, ObjectView, ClassesView
from xmipp3.protocols.protocol_cl2d_clustering import XmippProtCL2DClustering
import matplotlib.pyplot as plt
import os
[docs]class XmippCL2DClusteringViewer(EmProtocolViewer):
    """ This viewer is intended to visualize the selection made by the Xmipp - clustering 2d classes protocol.
    """
    _label = 'viewer Clustering 2D Classes'
    _environments = [DESKTOP_TKINTER, WEB_DJANGO]
    _targets = [XmippProtCL2DClustering]
    def _defineParams(self, form):
        form.addSection(label='Visualization')
        form.addParam('visualizeOutput', LabelParam,
                      label="Visualize output",
                      help="Visualize the aggregated 2D classes, 2D averages or both.")
        form.addParam('visualizeCluster', LabelParam,
                      label="Visualize a 2D representation of the clustering",
                      help="This will show a 2D representation of the clustering operation. "
                              "The clustering operation is done on a multidimensional space, meaning that this"
                              "2D representation (using TSNE) might not capture the real difference between some clusters.")
        form.addParam('visualizeClusterImages', LabelParam,
                      label="Visualize the clusters distribution",
                      help="Visualize the clusters images.")
    def _getVisualizeDict(self):
        return {
                 'visualizeOutput': self._visualizeOutputs,
                 'visualizeCluster': self._visualizeCluster,
                 'visualizeClusterImages': self._visualizeClusterImages
                }
    def _visualizeOutputs(self, e=None):
        outputList = []
        for objName in ["outputClasses", "outputAverages"]:
            if self.protocol.hasAttribute(objName):
                outputList.append(objName)
        return self._visualizeMultipleOutputs(outputList)
    def _visualizeCluster(self, e=None):
        if os.path.exists(self.protocol.getClusterPlot()):# Load the image
            image = plt.imread(self.protocol.getClusterPlot())
            # Get the image dimensions (height, width)
            height, width, _ = image.shape
            # Convert pixels to inches for the figure size (assuming 100 DPI)
            dpi = 100
            figsize = (width / dpi, height / dpi)
            # Create the figure with the calculated size
            plt.figure(figsize=figsize)
            # Display the image without axes
            fig = plt.imshow(image)
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            # Show the image
            plt.show()
    def _visualizeClusterImages(self, e=None):
        if os.path.exists(self.protocol.getClusterImagesPlot()):
            # Load the image
            image = plt.imread(self.protocol.getClusterImagesPlot())
            # Get the image dimensions (height, width)
            height, width, _ = image.shape
            # Convert pixels to inches for the figure size (assuming 100 DPI)
            dpi = 100
            figsize = (width / dpi, height / dpi)
            # Create the figure with the calculated size
            plt.figure(figsize=figsize)
            # Display the image without axes
            fig = plt.imshow(image)
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            # Show the image
            plt.show()
    def _visualizeMultipleOutputs(self, objList):
        views = []
        classView = "outputClasses"
        if objList:
            for objName in objList:
                if self.protocol.hasAttribute(objName):
                    outputSet = getattr(self.protocol, objName)
                    outputId = outputSet.strId()
                    outputFn = outputSet.getFileName()
                    if objName == classView:
                        views.append(ClassesView(self._project, outputId, outputFn))
                    else:
                        views.append(ObjectView(self._project, outputId, outputFn))
        else:
            self.infoMessage('%s does not have output %s'
                             % (self.protocol.getObjLabel(),
                                getStringIfActive(self.protocol)),
                             title='Info message').show()
        return views 
[docs]def getStringIfActive(prot):
    return 'yet.' if prot.isActive() else '.'