# **************************************************************************
# *
# * Authors: J.M. De la Rosa Trevin (delarosatrevin@scilifelab.se)
# *
# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 3 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307 USA
# *
# * All comments concerning this program package may be sent to the
# * e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
from os.path import basename
from pwem.protocols import ProtClassify2D
from pwem.emlib.image import ImageHandler
from pyworkflow.protocol.params import PointerParam, IntParam
from pyworkflow.utils import removeExt, copyFile
import pyworkflow.utils.graph as graph
from ..utils import SpiderDocFile
from .protocol_base import SpiderProtocol
[docs]class SpiderProtClassify(ProtClassify2D, SpiderProtocol):
""" Base protocol for SPIDER classifications. """
_label = None
def __init__(self, script, classDir, **kwargs):
ProtClassify2D.__init__(self, **kwargs)
SpiderProtocol.__init__(self, **kwargs)
# To avoid showing MPI box due to duplicated init
self.allowMpi = False
self._script = script
self._classDir = classDir
self._params = {'ext': 'stk',
'[class_dir]': self._classDir,
'particles': 'input_particles',
'particlesSel': 'input_particles_sel',
'dendroPs': 'dendrogram',
'dendroDoc': '%s/docdendro' % self._classDir,
'averages': 'averages'
}
[docs] def getClassDir(self):
return self._classDir
[docs] def getNumberOfClasses(self):
return None
# --------------------------- DEFINE param functions ----------------------
def _defineParams(self, form):
self._defineBasicParams(form)
form.addParallelSection(threads=4, mpi=0)
def _defineBasicParams(self, form):
form.addSection(label='Input')
form.addParam('inputParticles', PointerParam,
label="Input particles", important=True,
pointerClass='SetOfParticles',
help='Input images to perform PCA')
form.addParam('pcaFile', PointerParam, pointerClass='PcaFile',
label="IMC/SEQ file",
help='The IMC file contains the coordinates of each image '
'in the reduced-dimension space. '
'The SEQ file contains, for all images, '
'the pixel values under the mask. ')
form.addParam('numberOfFactors', IntParam, default=10,
label='Number of factors',
help='After running, examine the eigenimages and decide '
'which ones to use. \n'
'Typically all but the first few are noisy.')
# --------------------------- INSERT steps functions ----------------------
def _insertAllSteps(self):
pcaFile = self.pcaFile.get().filename.get()
self._insertFunctionStep('convertInput', 'inputParticles',
self._getFileName('particles'),
self._getFileName('particlesSel'))
self._insertFunctionStep('classifyStep', pcaFile,
self.numberOfFactors.get(),
self.getNumberOfClasses())
self._insertFunctionStep('createOutputStep')
# --------------------------- STEPS functions -----------------------------
def _updateParams(self):
pass
[docs] def classifyStep(self, imcFile, numberOfFactors, numberOfClasses):
""" Apply the selected filter to particles.
Create the set of particles.
"""
# Copy file to working directory, it could be also a link
imcLocalFile = basename(imcFile)
copyFile(imcFile, self._getPath(imcLocalFile))
self.info("Copied file '%s' to '%s' " % (imcFile, imcLocalFile))
# Spider automatically add _IMC to the ca-pca result file
# JMRT: I have modify the kmeans.msa script to not add _IMC
# automatically, it can be also used with _SEQ suffix,
# so we will pass the whole cas_file
imcBase = removeExt(imcLocalFile) # .replace('_IMC', '')
imcPrefix = imcBase.replace('_IMC', '').replace('_SEQ', '')
self._params.update({'x27': numberOfFactors,
'x30': self.numberOfThreads.get(),
'[cas_prefix]': imcPrefix,
'[cas_file]': imcBase,
})
self._updateParams()
self.runTemplate(self.getScript(), self.getExt(), self._params)
[docs]class SpiderProtClassifyCluster(SpiderProtClassify):
""" Base for Clustering Spider classification protocols.
"""
def __init__(self, script, classDir, **kwargs):
SpiderProtClassify.__init__(self, script, classDir, **kwargs)
# --------------------------- STEPS functions -----------------------------
[docs] def createOutputStep(self):
self.buildDendrogram(True)
# --------------------------- UTILS functions -----------------------------
def _fillClassesFromNodes(self, classes2D, nodeList):
""" Create the SetOfClasses2D from the images of each node
in the dendrogram.
"""
particles = classes2D.getImages()
sampling = classes2D.getSamplingRate()
# We need to first create a map between the particles index and
# the assigned class number
classDict = {}
nodeDict = {}
classCount = 0
for node in nodeList:
if node.path:
classCount += 1
node.classId = classCount
nodeDict[classCount] = node
for i in node.imageList:
classDict[int(i)] = classCount
def updateItem(p, item):
classId = classDict.get(item, None)
if classId is None:
p._appendItem = False
else:
p.setClassId(classId)
def updateClass(cls):
node = nodeDict[cls.getObjId()]
rep = cls.getRepresentative()
rep.setSamplingRate(sampling)
rep.setLocation(node.avgCount, self.dendroAverages)
particlesRange = range(1, particles.getSize()+1)
classes2D.classifyItems(updateItemCallback=updateItem,
updateClassCallback=updateClass,
itemDataIterator=iter(particlesRange))
def _fillParticlesFromNodes(self, inputParts, outputParts, nodeList):
""" Create the SetOfClasses2D from the images of each node
in the dendrogram.
"""
allImages = set()
for node in nodeList:
if node.path:
for i in node.imageList:
allImages.add(i)
def updateItem(item, index):
item._appendItem = index in allImages
particlesRange = range(1, inputParts.getSize()+1)
outputParts.copyItems(inputParts,
updateItemCallback=updateItem,
itemDataIterator=iter(particlesRange))
[docs] def buildDendrogram(self, writeAverages=False):
""" Parse Spider docfile with the information to build the dendrogram.
Params:
writeAverages: whether to write class averages or not.
"""
dendroFile = self._getFileName('dendroDoc')
# Dendrofile is a docfile with at least 3 data columns (class, height, id)
doc = SpiderDocFile(dendroFile)
values = []
indexes = []
for _, h, i in doc.iterValues():
indexes.append(i)
values.append(h)
doc.close()
self.dendroValues = values
self.dendroIndexes = indexes
self.dendroImages = self._getFileName('particles')
self.dendroAverages = self._getFileName('averages')
self.dendroAverageCount = 0 # Write only the number of needed averages
self.dendroMaxLevel = 10 # FIXME: remove hard coding if working the levels
self.ih = ImageHandler()
return self._buildDendrogram(0, len(values)-1, 1, writeAverages)
[docs] def getImage(self, particleNumber):
return self.ih.read((int(particleNumber), self.dendroImages))
[docs] def addChildNode(self, node, leftIndex, rightIndex, index,
writeAverages, level, searchStop):
child = self._buildDendrogram(leftIndex, rightIndex, index,
writeAverages, level+1, searchStop)
node.addChild(child)
node.extendImageList(child.imageList)
if writeAverages:
node.addImage(child.image)
del child.image # Allow to free child image memory
def _buildDendrogram(self, leftIndex, rightIndex, index,
writeAverages=False, level=0, searchStop=0):
""" This function is recursively called to create the dendrogram
graph (binary tree) and also to write the average image files.
Params:
leftIndex, rightIndex: the indexes within the list where to search.
index: the index of the class average.
writeAverages: flag to select when to write averages
searchStop: this could be 1, means that we will search until the
last element (used for right childs of the dendrogram or,
can be 0, meaning that the last element was already the max
(used for left childs )
From self:
self.dendroValues: the list with the heights of each node
self.dendroImages: image stack filename to read particles
self.dendroAverages: stack name where to write averages
It will search for the max in values list (between minIndex and maxIndex).
Nodes to the left of the max are left childs and the other right childs.
"""
if level < self.dendroMaxLevel:
avgCount = self.dendroAverageCount + 1
self.dendroAverageCount += 1
if rightIndex == leftIndex: # Just only one element
height = self.dendroValues[leftIndex]
node = DendroNode(index, height)
node.extendImageList([self.dendroIndexes[leftIndex]])
node.addImage(self.getImage(node.imageList[0]))
elif rightIndex == leftIndex + 1: # Two elements
height = max(self.dendroValues[leftIndex],
self.dendroValues[rightIndex])
node = DendroNode(index, height)
node.extendImageList([self.dendroIndexes[leftIndex],
self.dendroIndexes[rightIndex]])
node.addImage(self.getImage(node.imageList[0]),
self.getImage(node.imageList[1]))
else: # 3 or more elements
# Find the max value (or height) of the elements
maxValue = self.dendroValues[leftIndex]
maxIndex = 0
# searchStop could be 0 (do not consider last element, coming from
# left child, or 1 (consider also the last one, coming from right)
values = self.dendroValues[leftIndex+1:rightIndex+searchStop]
for i, v in enumerate(values):
if v > maxValue:
maxValue = v
maxIndex = i+1
m = maxIndex + leftIndex
node = DendroNode(index, maxValue)
hasRightChild = m < rightIndex
if maxValue > 0:
nextIndex = 2 * index if hasRightChild else index
self.addChildNode(node, leftIndex, m, nextIndex,
writeAverages, level, 0)
if hasRightChild:
self.addChildNode(node, m+1, rightIndex, 2 * index + 1,
writeAverages, level, 1)
else:
# If the node has a single child, we will remove a node
# just to advance in the level of the tree to get more
# different class averages
if node.getChilds():
child = node.getChilds()[0]
child.image = node.image
child.parents = []
node = child
else:
node.extendImageList(self.dendroIndexes[leftIndex:rightIndex+1])
node.addImage(*[self.getImage(img) for img in node.imageList])
if level < self.dendroMaxLevel:
node.avgCount = avgCount
node.path = '%d@%s' % (node.avgCount, self.dendroAverages)
if writeAverages:
# normalize the sum of images depending on the number of particles
# assigned to this classes
# avgImage = node.image / float(node.getSize())
node.image.inplaceDivide(float(node.getSize()))
self.ih.write(node.image, (node.avgCount, self.dendroAverages))
fn = self._getTmpPath('doc_class%03d.stk' % index)
doc = SpiderDocFile(fn, 'w+')
for i in node.imageList:
doc.writeValues(i)
doc.close()
return node
[docs]class DendroNode(graph.Node):
""" Special type of Node to store dendrogram values. """
def __init__(self, index, height):
graph.Node.__init__(self, 'class_%03d' % index)
self.index = index
self.height = height
self.length = 0
self.path = None
self.selected = False
self.imageList = []
self.image = None
[docs] def getChilds(self):
return [c for c in self._childs if c.path]
[docs] def getSize(self):
""" Return the number of images assigned to this class. """
return len(self.imageList)
[docs] def addImage(self, *images):
""" Add some images to this node. """
for img in images:
if self.image is None:
self.image = img
else:
self.image.inplaceAdd(img)
[docs] def extendImageList(self, imageList):
self.imageList.extend(imageList)
self.length = self.getSize()