Source code for xmipp3.viewers.viewer_consensus_classes3D

# **************************************************************************
# *
# * Authors:     J.M. De la Rosa Trevin (jmdelarosa@cnb.csic.es)
# *              Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu)
# *
# * Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************

"""
This module implement the wrappers aroung Xmipp CL2D protocol
visualization program.
"""
from pyworkflow.viewer import (ProtocolViewer, DESKTOP_TKINTER, WEB_DJANGO)
from pyworkflow.protocol.params import StringParam, LabelParam, IntParam
from pyworkflow.protocol.params import GE, GT, LE, LT

from pwem.viewers import TableView, ObjectView
from pwem.viewers.showj import *

from xmipp3.protocols.protocol_consensus_classes3D import XmippProtConsensusClasses3D

from scipy.cluster import hierarchy

import matplotlib.pyplot as plt
import numpy as np


[docs]class XmippConsensusClasses3DViewer(ProtocolViewer): """ Visualization of results from the consensus classes 3D protocol """ _label = 'viewer consensus classes 3D' _targets = [XmippProtConsensusClasses3D] _environments = [DESKTOP_TKINTER, WEB_DJANGO] def __init__(self, **kwargs): ProtocolViewer.__init__(self, **kwargs) def _defineParams(self, form): # Particles section form.addSection(label='Classes') form.addParam('displayClasses', IntParam, label='Custom number of classes', validators=[GE(1)], default=1, help='Open a GUI to visualize the class of the given iteration. Warning: It takes time to open') form.addParam('displayInitialClasses', LabelParam, label='Initial classes') form.addParam('displayManualClasses', LabelParam, label='Manual number of classes') form.addParam('displayOriginClasses', LabelParam, label='Number of classes by proximity to origin') form.addParam('displayAngleClasses', LabelParam, label='Number of classes by angle') form.addParam('displayPllClasses', LabelParam, label='Number of classes by profile likelihood') # Graph section form.addSection(label='Graphs') form.addParam('displayDendrogram', LabelParam, label='Dendrogram', help='Open a GUI to visualize the dendogram each number of clusters') form.addParam('displayDendrogramLog', LabelParam, label='Dendrogram (log)', help='Open a GUI to visualize the dendogram each number' ' of clusters with a logarithmic "y" axis') form.addParam('displayObjectiveFunction', LabelParam, label='Objective Function', help='Open a GUI to visualize the objective function for each number of clusters') form.addParam('displayObjectiveFunctionLog', LabelParam, label='Objective Function (log)', help='Open a GUI to visualize the objective function' ' for each number of clusters with a logarithmic "y" axis') # Reference classification section (only if exists) if self._readReferenceClassificationSizes() is not None: form.addSection(label='Reference classification consensus') form.addParam('displayReferenceClassificationSizePercentiles', LabelParam, label='Reference classification consensus size percentiles', help='Open a GUI to visualize a table with the most' ' common percentiles of the sizes of a random classification consensus') def _getVisualizeDict(self): return { 'displayClasses': self._visualizeClasses, 'displayInitialClasses': self._visualizeInitialClasses, 'displayManualClasses': self._visualizeManualClasses, 'displayOriginClasses': self._visualizeOriginClasses, 'displayAngleClasses': self._visualizeAngleClasses, 'displayPllClasses': self._visualizePllClasses, 'displayDendrogram': self._visualizeDendrogram, 'displayDendrogramLog': self._visualizeDendrogramLog, 'displayObjectiveFunction': self._visualizeObjectiveFunction, 'displayObjectiveFunctionLog': self._visualizeObjectiveFunctionLog, 'displayReferenceClassificationSizePercentiles': self._visualizeReferenceClassificationSizes, } # --------------------------- UTILS functions ------------------------------ def _getInputParticles(self): return self.protocol._getInputParticles() def _getInputClassifications(self): return self.protocol._getInputClassifications() def _readIntersections(self): return self.protocol._readIntersections() def _readClustering(self, iter): return self.protocol._readClustering(iter) def _readClusteringCount(self): return self.protocol._readClusteringCount() def _readAllClusterings(self): return self.protocol._readAllClusterings() def _readObjectiveValues(self): return self.protocol._readObjectiveValues() def _readElbows(self): return self.protocol._readElbows() def _readReferenceClassificationSizes(self): return self.protocol._readReferenceClassificationSizes() def _readReferenceClassificationRelativeSizes(self): return self.protocol._readReferenceClassificationRelativeSizes() def _visualizeClasses(self, e): # Read data count = self.displayClasses.get() particles = self._getInputParticles() clustering = self._readClustering(count) randomConsensusSizes = self._readReferenceClassificationSizes() randomConsensusRelativeSizes = self._readReferenceClassificationRelativeSizes() # Create the output classes and show them outputClasses = self.protocol._createOutputClasses3D( particles, clustering, 'viewer_'+str(count), randomConsensusSizes, randomConsensusRelativeSizes ) return self._showSetOfClasses3D(outputClasses) def _visualizeInitialClasses(self, e): return self._showSetOfClasses3D(self.protocol.outputClasses_initial) def _visualizeManualClasses(self, e): return self._showSetOfClasses3D(self.protocol.outputClasses_manual) def _visualizeOriginClasses(self, e): return self._showSetOfClasses3D(self.protocol.outputClasses_origin) def _visualizeAngleClasses(self, e): return self._showSetOfClasses3D(self.protocol.outputClasses_angle) def _visualizePllClasses(self, e): return self._showSetOfClasses3D(self.protocol.outputClasses_pll) def _visualizeDendrogram(self, e): clusterings = self._readAllClusterings() objectiveFunction = self._readObjectiveValues() fig, ax = plt.subplots() self._plotDendrogram(fig, ax, list(reversed(clusterings)), list(reversed(objectiveFunction))) return [fig] def _visualizeDendrogramLog(self, e): clusterings = self._readAllClusterings() objectiveFunction = self._readObjectiveValues() fig, ax = plt.subplots() self._plotDendrogram(fig, ax, list(reversed(clusterings)), list(reversed(objectiveFunction))) ax.set_yscale('log') return [fig] def _visualizeObjectiveFunction(self, e): objectiveFunction = self._readObjectiveValues() elbows = self._readElbows() fig, ax = plt.subplots() self._plotObjectiveFunctionAndElbows(fig, ax, objectiveFunction, elbows) return [fig] def _visualizeObjectiveFunctionLog(self, e): objectiveFunction = self._readObjectiveValues() elbows = self._readElbows() fig, ax = plt.subplots() self._plotObjectiveFunctionAndElbows(fig, ax, objectiveFunction[:-1], elbows) ax.set_yscale('log') return [fig] def _visualizeReferenceClassificationSizes(self, e): randomConsensusSizes = self._readReferenceClassificationSizes() randomConsensusRelativeSizes = self._readReferenceClassificationRelativeSizes() # Calculate the percentiles percentiles = [1, 5, 10, 25, 50, 75, 90, 95, 99] randomConsensusSizePercentiles = np.percentile(randomConsensusSizes, percentiles) randomConsensusRelativeSizePercentiles = np.percentile(randomConsensusRelativeSizes, percentiles) # Create a table with the data data = list(zip(percentiles, randomConsensusSizePercentiles, randomConsensusRelativeSizePercentiles)) header = ['Percentile (%)', 'Size percentile', 'Relative size percentile'] title = 'Reference classification consensus size percentiles' return [TableView(header, data, title=title)] def _showSetOfClasses3D(self, classes): labels = 'enabled id _size _representative._filename _xmipp_classIntersectionSizePValue _xmipp_classIntersectionRelativeSizePValue' labelRender = '_representative._filename' return [ObjectView( self._project, classes.strId(), classes.getFileName(), viewParams={ORDER: labels, VISIBLE: labels, RENDER: labelRender, SORT_BY: '_size desc', MODE: MODE_MD})] def _plotDendrogram(self, fig, ax, clusterings, obValues): """ Plot the dendrogram from the objective functions of the merge between the groups of images """ # Initialize required values linkageMatrix = np.zeros((len(clusterings)-1, 4)) clsIds = np.arange(len(clusterings)) nIntersections = len(clusterings[0]) # Loop over each iteration of clustering for i in range(len(clusterings)-1): # Find the sets that were merged clsMerged = [] assert(len(clusterings[i]) == len(clsIds)) for cls, clsId in zip(clusterings[i], clsIds): if cls not in clusterings[i+1]: clsMerged.append(clsId) assert(len(clsMerged) == 2) clsMerged = np.array(clsMerged) # Find original number of sets within new set nCls = list(map(lambda x : linkageMatrix[x,3] if (x >= 0) else 1, clsMerged - nIntersections)) # Create linkage matrix linkageMatrix[i, 0:2] = clsMerged # ids of merged sets linkageMatrix[i, 2] = obValues[i+1] # objective function as distance linkageMatrix[i, 3] = sum(nCls) # total number of original sets # Change set ids to reflect new set of clusters for id in clsMerged: clsIds = np.delete(clsIds, np.argwhere(clsIds == id)) clsIds = np.append(clsIds, len(clusterings)+i) # Plot resulting dendrogram hierarchy.dendrogram(linkageMatrix, ax=ax) ax.set_title('Dendrogram') ax.set_xlabel('sets ids') ax.set_ylabel('objective function') ax.set_ylim([np.min(linkageMatrix[:, 2]), np.max(linkageMatrix[:, 2])]) def _plotObjectiveFunctionAndElbows(self, fig, ax, obValues, elbows): """ Plots the objective function and elbows """ # Shorthands for some variables numClusters = np.arange(1, 1+len(obValues)) # Plot obValues vs numClusters ax.plot(numClusters, obValues) # Show elbows as scatter points for key, value in elbows.items(): value -= 1 # Value uses 1-based indexing label = key + ': ' + str(numClusters[value]) ax.scatter([numClusters[value]], [obValues[value]], label=label, color='green') # Configure the figure ax.legend() ax.set_xlabel('Number of clusters') ax.set_ylabel('Objective values') ax.set_title('Objective values for each number of clusters') ax.set_ylim([np.min(obValues), np.max(obValues)])