# **************************************************************************
# *
# * Authors: J.M. De la Rosa Trevin (delarosatrevin@scilifelab.se)
# *
# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 3 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307 USA
# *
# * All comments concerning this program package may be sent to the
# * e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
from os.path import join
import tkinter as tk
from pwem.protocols import ProtUserSubSet
from pwem.objects import SetOfClasses2D
from pwem.viewers import EmPlotter, ClassesView
from pyworkflow.protocol.params import IntParam, FloatParam, LabelParam
from pyworkflow.protocol.constants import STATUS_FINISHED
from pyworkflow.utils.properties import Icon
from pyworkflow.viewer import ProtocolViewer, DESKTOP_TKINTER, WEB_DJANGO
from pyworkflow.utils.graph import Graph
from pyworkflow.utils.path import cleanPath
from pyworkflow.gui import Window
from pyworkflow.gui.widgets import HotButton
from pyworkflow.gui.graph import LevelTree
from pyworkflow.gui.canvas import Canvas, Item
from pyworkflow.gui.dialog import askString
from ..utils import SpiderDocFile
from ..protocols import SpiderProtClassifyWard, SpiderProtClassifyDiday
[docs]class SpiderViewerClassify(ProtocolViewer):
""" Visualization of Spider classification results. """
_environments = [DESKTOP_TKINTER, WEB_DJANGO]
def _defineParams(self, form):
form.addSection(label='Visualization')
group1 = form.addGroup('Dendrogram')
group1.addParam('doShowDendrogram', LabelParam,
label="Show dendrogram", default=True,
help='In a dendrogram larger vertical bars signify a '
'greater difference between classes. Many small '
'differences at the bottom can be eliminated with '
'an increase of the "Minimum height" setting.')
group1.addParam('minHeight', FloatParam, default=0.5,
label='Minimum height',
help='The dendrogram will be cut at this level')
self.groupClass = form.addGroup('Classes')
self.groupClass.addParam('doShowClasses', LabelParam,
label="Visualize class averages", default=True,
help='Display class averages')
def _getVisualizeDict(self):
return {'doShowDendrogram': self._plotDendrogram,
'doShowClasses': self.visualizeClasses
}
def _plotDendrogram(self, e=None):
xplotter = EmPlotter()
self.plt = xplotter.createSubPlot("Dendrogram", "", "")
self.step = 0.25
self.rightMost = 0.0 # Used to arrange leaf nodes at the bottom
node = self.protocol.buildDendrogram()
self.plotNode(node, self.minHeight.get())
self.plt.set_xlim(0., self.rightMost + self.step)
self.plt.set_ylim(-10, 105)
return [xplotter]
[docs] def plotNode(self, node, minHeight=-1):
childs = node.getChilds()
h = node.height
if h > minHeight and len(childs) > 1:
x1, y1 = self.plotNode(childs[0], minHeight)
x2, y2 = self.plotNode(childs[1], minHeight)
xm = (x1 + x2)/2
x = [x1, x1, x2, x2]
y = [y1, h, h, y2]
self.plt.plot(x, y, color='b')
point = (xm, h)
else:
self.rightMost += self.step
point = (self.rightMost, 0.)
length = node.length
index = node.index
self.plt.annotate("%d(%d)" % (index, length), point, xytext=(0, -5),
textcoords='offset points', va='top', ha='center',
size='x-small')
self.plt.plot(point[0], point[1], 'ro')
return point
def _createNode(self, canvas, node, y):
node.selected = False
node.box = SpiderImageBox(canvas, node, y)
return node.box
[docs]class SpiderViewerWard(SpiderViewerClassify):
""" Visualization of Spider - classify Ward protocol results. """
_targets = [SpiderProtClassifyWard]
_label = "viewer ward"
def _defineParams(self, form):
SpiderViewerClassify._defineParams(self, form)
self.groupClass.addParam('maxLevel', IntParam, default=4,
label='Maximum level',
help='Maximum level of classes to show')
[docs] def visualizeClasses(self, e=None):
classTemplate = "class_%03d"
averages = '%03d@' + self.protocol._getFileName('averages')
def getInfo2(level, classNo):
return classTemplate % classNo, averages % classNo
node = self.protocol.buildDendrogram(writeAverages=False)
g = Graph(root=node)
self.graph = g
self.win = Window("Select classes", self.formWindow, minsize=(1000, 600))
root = self.win.root
canvas = Canvas(root)
canvas.grid(row=0, column=0, sticky='nsew')
root.grid_columnconfigure(0, weight=1)
root.grid_rowconfigure(0, weight=1)
self.buttonframe = tk.Frame(root)
self.buttonframe.grid(row=2, column=0, columnspan=2)
self.win.createCloseButton(self.buttonframe).grid(row=0, column=0,
sticky='n',
padx=5, pady=5)
saveparticlesbtn = HotButton(self.buttonframe, "Particles",
Icon.PLUS_CIRCLE,
command=self._askCreateParticles)
saveparticlesbtn.grid(row=0, column=1, sticky='n', padx=5, pady=5)
btn = HotButton(self.buttonframe, "Classes", Icon.PLUS_CIRCLE,
command=self._askCreateClasses)
btn.grid(row=0, column=2, sticky='n', padx=5, pady=5)
lt = LevelTree(g)
lt.DY = 135 # TODO: change in percent of the image size
lt.setCanvas(canvas)
lt.paint(self._createNode, maxLevel=self.maxLevel.get()-1)
canvas.updateScrollRegion()
return [self.win]
def _askCreateParticles(self):
self._askCreateSubset('Particles', self.getSelectedNodesCount(2))
def _askCreateClasses(self):
self._askCreateSubset('Classes', self.getSelectedNodesCount(1))
def _askCreateSubset(self, output, size):
if self._selectionOverlap():
self.win.showError("Classes could not overlap in the tree.")
return
s = '' if size == 1 else 's'
headerLabel = 'Are you sure you want to create a new set of ' \
' %s with %s element%s?' % (output, size, s)
runname = askString('Question', 'Run name:', self.win.getRoot(), 30,
defaultValue='ProtUserSubSet',
headerLabel=headerLabel)
if runname:
createFunc = getattr(self, 'save' + output)
createFunc(runname)
def _createSubsetProtocol(self, createOutputFunc, label=None):
""" Create a subset of classes or particles. """
try:
project = self.getProject()
prot = project.newProtocol(ProtUserSubSet)
prot.setObjLabel(label)
prot.inputObject.set(self.protocol)
project._setupProtocol(prot)
prot.makePathsAndClean()
createOutputFunc(prot)
prot.setStatus(STATUS_FINISHED)
project._storeProtocol(prot)
# self.project.launchProtocol(prot, wait=True)
except Exception as ex:
self.win.showError(str(ex))
[docs] def getSelectedNodesCount(self, depth):
if depth == 1:
return len([node for node in self.graph.getNodes() if node.selected])
else:
count = 0
for node in self.graph.getNodes():
if node.selected:
count += len(node.imageList)
return count
def _selectionOverlap(self):
""" Check if selected classes do overlap. """
allImages = set()
for n in self._selectedNodes():
for i in n.imageList:
if i in allImages:
return True
allImages.add(i)
return False
def _selectedNodes(self):
return [node for node in self.graph.getNodes() if node.selected]
[docs] def saveClasses(self, runname=None):
""" Store selected classes. """
def createClasses(prot):
classes = prot._createSetOfClasses2D(self.protocol.inputParticles.get(),
suffix='Selection')
self.protocol._fillClassesFromNodes(classes, self._selectedNodes())
prot._defineOutputs(outputClasses=classes)
self._createSubsetProtocol(createClasses, runname)
[docs] def saveParticles(self, runname=None):
""" Store particles from selected classes. """
def createParticles(prot):
inputParticles = self.protocol.inputParticles.get()
particles = prot._createSetOfParticles(suffix='Selection')
particles.copyInfo(inputParticles)
self.protocol._fillParticlesFromNodes(inputParticles,
particles,
self._selectedNodes())
prot._defineOutputs(outputParticles=particles)
self._createSubsetProtocol(createParticles, runname)
[docs]class ImageBox(Item):
# copied from pw/gui/canvas.py, still depends on xmippLib
def __init__(self, canvas, imgPath, x=0, y=0, text=None):
Item.__init__(self, canvas, x, y)
# Create the image
from pyworkflow.gui import getImage
from pwem.viewers.filehandlers import getImageFromPath
if imgPath is None:
self.image = getImage('no-image.gif')
else:
self.image = getImageFromPath(imgPath)
if text is not None:
self.label = tk.Label(canvas, image=self.image, text=text,
compound=tk.TOP, bg='gray')
self.id = self.canvas.create_window(x, y, window=self.label)
self.label.bind('<Button-1>', self._onClick)
else:
self.id = self.canvas.create_image(x, y, image=self.image)
[docs] def setSelected(self, value): # Ignore selection highlight
pass
def _onClick(self, e=None):
pass
[docs]class SpiderImageBox(ImageBox):
def __init__(self, canvas, node, y):
ImageBox.__init__(self, canvas, node.path, text=node.getName(), y=y)
def _onClick(self, e=None):
if self.node.path is None:
return
# On click change the selection state
self.node.selected = not self.node.selected
if self.node.selected:
self.label.config(bd=2, bg='green')
else:
self.label.config(bd=0, bg='grey')
[docs]class SpiderViewerDiday(SpiderViewerClassify):
""" Visualization of Spider - classify Diday protocol results. """
_targets = [SpiderProtClassifyDiday]
_label = "viewer diday"
def _defineParams(self, form):
SpiderViewerClassify._defineParams(self, form)
self.groupClass.addParam('numberOfClasses', IntParam, default=4,
label='Number of classes',
help='Desired number of classes.')
[docs] def visualizeClasses(self, e=None):
prot = self.protocol
classDir = prot.getClassDir()
classAvg = 'classavg'
classVar = 'classvar'
classDoc = 'docclass'
params = {'[class_dir]': classDir,
'[desired-classes]': self.numberOfClasses.get(),
'[particles]': prot._params['particles'] + '@******',
'[class_doc]': join(classDir, classDoc + '***'),
'[class_avg]': join(classDir, classAvg + '***'),
'[class_var]': join(classDir, classVar + '***'),
}
prot.runTemplate('mda/classavg.msa', prot.getExt(), params)
particles = prot.inputParticles.get()
particles.load()
sampling = particles.getSamplingRate()
setFn = prot._getTmpPath('classes2D.sqlite')
cleanPath(setFn)
classes2D = SetOfClasses2D(filename=setFn)
classes2D.setImages(particles)
# We need to first create a map between the particles index and
# the assigned class number
classDict = {}
for classId in range(1, self.numberOfClasses.get()+1):
docClass = prot._getPath(classDir, classDoc + '%03d.stk' % classId)
doc = SpiderDocFile(docClass)
for values in doc.iterValues():
imgIndex = int(values[0])
classDict[imgIndex] = classId
doc.close()
updateItem = lambda p, i: p.setClassId(classDict[i])
def updateClass(cls):
rep = cls.getRepresentative()
rep.setSamplingRate(particles.getSamplingRate())
avgFn = prot._getPath(classDir,
classAvg + '%03d.stk' % cls.getObjId())
rep.setLocation(1, avgFn)
particlesRange = range(1, particles.getSize()+1)
classes2D.classifyItems(updateItemCallback=updateItem,
updateClassCallback=updateClass,
itemDataIterator=iter(particlesRange))
classes2D.write()
classes2D.close()
return [ClassesView(self.getProject(), prot.strId(),
classes2D.getFileName(), particles.strId())]