Source code for tomo.protocols.protocol_import_tomomasks

# *
# * Authors:     Scipion Team
# *
# * Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion-users@lists.sourceforge.net'
# *
# **************************************************************************
from pwem.emlib.image import ImageHandler
from pyworkflow import BETA
from pyworkflow.protocol import PointerParam
from pyworkflow.utils import yellowStr
from pyworkflow.utils.path import removeBaseExt
from .protocol_base import ProtTomoImportFiles
from ..constants import ERR_NO_TOMOMASKS_GEN, ERR_NON_MATCHING_TOMOS
from ..objects import TomoMask, SetOfTomoMasks


[docs]class ProtImportTomomasks(ProtTomoImportFiles): """Protocol to import a set of tomomasks (segmentations) to the project""" _outputClassName = 'SetOfTomoMasks' _label = 'import tomomasks (segmentations)' _devStatus = BETA def __init__(self, **kwargs): super().__init__(**kwargs) self.warnMsg = '' self.matchingTomoMaskDict = None # keys = filenames of matching tomomasks, values = matching Tomogram def _defineParams(self, form): ProtTomoImportFiles()._defineImportParams(form) form.addParam('inputTomos', PointerParam, pointerClass='SetOfTomograms', label='Tomograms', help='Select the tomograms to be assigned to the input tomo masks.') # --------------------------- STEPS functions ----------------------------- def _insertAllSteps(self): self._insertFunctionStep(self.importStep)
[docs] def importStep(self): self._checkMatchingFiles() tomoMaskSet = self._genOutputSetOfTomoMasks() if self.warnMsg: print(yellowStr('WARNING!') + '\n' + self.warnMsg) if not tomoMaskSet: raise Exception(ERR_NO_TOMOMASKS_GEN) self._defineOutputs(outputTomoMasks=tomoMaskSet)
# --------------------------- INFO functions ------------------------------ def _getTomMessage(self): return "Tomomasks %s" % self.getObjectTag('outputTomoMasks') def _summary(self): try: summary = [] if self.isFinished(): summary.append("%s imported from:\n%s" % (self._getTomMessage(), self.getPattern())) if self.warnMsg: summary.append('Some tomograms or tomomasks were excluded. Check the log for more details.') return summary except Exception as e: print(e) def _validate(self): errors = [] try: next(self.iterFiles()) except StopIteration: errors.append('No files matching the pattern %s were found.' % self.getPattern()) return errors # --------------------------- UTILS functions ------------------------------ def _checkMatchingFiles(self): nonMatchingTomoMaskNames = [] matchingTomoMaskDict = {} msgNonMatchingTomos = '' msgNonMatchingTomoMasks = '' inTomoSet = self.inputTomos.get() tomoIds, tomoBaseNames = zip(*[(tomo.getTsId(), removeBaseExt(tomo.getFileName())) for tomo in inTomoSet]) def isMember(x, y): # x will be a list of ids and y the basename of the current tomomask. Thus, all the ids will be mapped to # check if one of them is contained in the current basename return x in y if self._tomoHasValidTsId(inTomoSet[1]): # Look for the tsId of each tomogram to be contained in the tomoFileNames list2check = tomoIds else: # The same, but considering the tomo basename instead of the tsId list2check = tomoBaseNames for file, _ in self.iterFiles(): tomoMaskName = self.getTomoMaskName(file) matches = list(map(isMember, list2check, [tomoMaskName])) if any(matches): matchingTomoMaskDict[file] = inTomoSet[matches.index(True) + 1] else: nonMatchingTomoMaskNames.append(tomoMaskName) # Check if there are non-matching tomograms matchingIds = matchingTomoMaskDict.keys() pattern = '\t- {}\n' if matchingIds: uniqueMatchingIds = set([self.getTomoMaskName(matchingId) for matchingId in matchingIds]) set2check = set(list2check) nonMatchingTomogramIds = set2check ^ uniqueMatchingIds # ^ is (a | b) - (a & b) inverse intersection if nonMatchingTomogramIds: nOfNonMatchingTomos = len(nonMatchingTomogramIds) headMsg = yellowStr('[%i] tomograms did not match to any of the tomomasks introduced:' % nOfNonMatchingTomos) msgNonMatchingTomos = ('%s\n%s' % (headMsg, pattern * nOfNonMatchingTomos)).format(*nonMatchingTomogramIds) self.warnMsg += msgNonMatchingTomos + '\n\n' else: raise Exception(ERR_NON_MATCHING_TOMOS) # The same for the non-matching tomomasks if nonMatchingTomoMaskNames: nOfNonMatchingTomomasks = len(nonMatchingTomoMaskNames) headMsg = yellowStr('[%i] tomomasks did not match to any of the tomograms introduced:' % nOfNonMatchingTomomasks) msgNonMatchingTomoMasks = ('%s\n%s' % (headMsg, pattern * nOfNonMatchingTomomasks)).format(*nonMatchingTomoMaskNames) self.warnMsg += msgNonMatchingTomoMasks + '\n\n' self.matchingTomoMaskDict = matchingTomoMaskDict @staticmethod def _tomoHasValidTsId(tomo): return True if getattr(tomo, '_tsId', None) else False def _genOutputSetOfTomoMasks(self): ih = ImageHandler() tomoMasksNonMatchingDims = [] nonMatchingDimsList = [] tomoMaskSet = SetOfTomoMasks.create(self._getPath(), template='tomomasks%s.sqlite', suffix='annotated') inTomoSet = self.inputTomos.get() xt, yt, zt = inTomoSet.getDimensions() sRate = inTomoSet.getSamplingRate() tomoMaskSet.setSamplingRate(sRate) counter = 1 for tomomaskFile, tomoObj in self.matchingTomoMaskDict.items(): x, y, z, _ = ih.getDimensions(tomomaskFile) if (xt, yt, zt) == (x, y, z): tomoMask = TomoMask() tomoMask.setSamplingRate(sRate) tomoMask.setLocation(counter, tomomaskFile) tomoMask.setVolName(tomoObj.getFileName()) if self._tomoHasValidTsId(tomoObj): tomoMask.setTsId(tomoObj.getTsId()) tomoMaskSet.append(tomoMask) counter += 1 else: tomoMasksNonMatchingDims.append(self.getTomoMaskName(tomomaskFile)) nonMatchingDimsList.append((x, y, z)) if tomoMasksNonMatchingDims: nNonMatchingDimsMasks = len(tomoMasksNonMatchingDims) msgNonMatchingDimsMasks = yellowStr('[%i] tomomasks have different dimensions than the ones from the ' 'introduced set of tomograms (x, y, z) = (%i, %i, %i):' % (nNonMatchingDimsMasks, xt, yt, zt)) for nonMatchingDimsMask, nonMatchingDims in zip(tomoMasksNonMatchingDims, nonMatchingDimsList): msgNonMatchingDimsMasks += '\n\t- %s (x, y, z) = (%i, %i, %i)\n' % \ (nonMatchingDimsMask, nonMatchingDims[0], nonMatchingDims[1], nonMatchingDims[2]) self.warnMsg += msgNonMatchingDimsMasks + '\n\n' return tomoMaskSet
[docs] @staticmethod def getTomoMaskName(maskFileName): return removeBaseExt(maskFileName).replace('_materials', '') # This suffix is added by the memb. annotator