Source code for continuousflex.viewers.viewer_image_synthesize

# **************************************************************************
# * Authors:  Mohamad Harastani          (mohamad.harastani@upmc.fr)
# *           Rémi Vuillemot             (remi.vuillemot@upmc.fr)
# * IMPMC, UPMC Sorbonne University
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# **************************************************************************


from os.path import basename

from pwem.emlib import MetaData, MDL_ORDER
from pyworkflow.protocol.params import StringParam, LabelParam
from pyworkflow.viewer import (ProtocolViewer, DESKTOP_TKINTER, WEB_DJANGO)
from pyworkflow.utils import replaceBaseExt, replaceExt

from continuousflex.protocols.data import Point, Data
from continuousflex.viewers.nma_plotter import FlexNmaPlotter
from continuousflex.protocols import FlexProtSynthesizeImages
import xmipp3
import pwem.emlib.metadata as md
from pwem.viewers import ObjectView
from continuousflex.protocols.protocol_image_synthesize import NMA_YES
import matplotlib.pyplot as plt


[docs]class FlexProtSynthesizeImageViewer(ProtocolViewer): """ Visualization of results from synthesized images """ _label = 'viewer synthetic images' _targets = [FlexProtSynthesizeImages] _environments = [DESKTOP_TKINTER, WEB_DJANGO] def __init__(self, **kwargs): ProtocolViewer.__init__(self, **kwargs) self._data = None
[docs] def getData(self): if self._data is None: self._data = self.loadData() return self._data
def _defineParams(self, form): form.addSection(label='Visualization') form.addParam('displayRawDeformation', StringParam, default='7 8', condition=self.protocol.confVar.get()==NMA_YES, label='Display the computed normal-mode amplitudes', help='Type 7 to see the histogram of amplitudes along mode 7; \n' 'type 8 to see the histogram of amplitudes along mode 8, etc.\n' 'Type 7 8 to see the 2D plot of amplitudes along modes 7 and 8.\n' 'Type 7 8 9 to see the 3D plot of amplitudes along modes 7, 8 and 9; etc.' ) form.addParam('displayHists', LabelParam, label="Display shift and angle histograms", help="Display shift and angle histograms") form.addParam('displayVolumes', LabelParam, label="Display images", help="Display the volumes that are generated") def _getVisualizeDict(self): return {'displayRawDeformation': self._viewRawDeformation, 'displayHists': self._viewHists, 'displayVolumes': self._viewVolumes, } def _viewVolumes(self, paramName): volumes = self.protocol.outputImages return [ObjectView(self._project, volumes.strId(), volumes.getFileName())] def _viewHists(self, paramName): mdVolumes = md.MetaData(self.protocol._getExtraPath('GroundTruth.xmd')) X = [] Y = [] Rot = [] Tilt= [] Psi = [] for objId in mdVolumes: X.append(mdVolumes.getValue(md.MDL_SHIFT_X,objId)) Y.append(mdVolumes.getValue(md.MDL_SHIFT_Y, objId)) Rot.append(mdVolumes.getValue(md.MDL_ANGLE_ROT, objId)) Tilt.append(mdVolumes.getValue(md.MDL_ANGLE_TILT, objId)) Psi.append(mdVolumes.getValue(md.MDL_ANGLE_PSI, objId)) fig, ax = plt.subplots(2, 3) fig.suptitle('Histogram of generated rigid-body paramerers') ax[0,0].hist(X, bins=25) ax[0,0].set_title('Shift X') ax[0,1].hist(Y, bins=25) ax[0,1].set_title('Shift Y') ax[1,0].hist(Rot, bins=25) ax[1,0].set_title('Rot') ax[1,1].hist(Tilt, bins=25) ax[1,1].set_title('Tilt') ax[1,2].hist(Psi, bins=25) ax[1,2].set_title('Psi') plt.show() def _viewRawDeformation(self, paramName): components = self.displayRawDeformation.get() return self._doViewRawDeformation(components) def _doViewRawDeformation(self, components): components = list(map(int, components.split())) dim = len(components) views = [] if dim > 0: modeList = [] modeNameList = [] missingList = [] for modeNumber in components: found = False modes_md = MetaData(replaceExt(self.protocol.inputModes.get().getFileName(),'xmd')) for i, objId in enumerate(modes_md): modeId = modes_md.getValue(MDL_ORDER, objId) if modeNumber == modeId: modeNameList.append('Mode %d' % modeNumber) modeList.append(i) found = True break if not found: missingList.append(str(modeNumber)) if missingList: return [self.errorMessage("Invalid mode(s) *%s*\n." % (', '.join(missingList)), title="Invalid input")] # Actually plot plotter = FlexNmaPlotter(data=self.getData()) baseList = [basename(n) for n in modeNameList] self.getData().XIND = modeList[0] if dim == 1: plotter.plotArray1D("Histogram of normal-mode amplitudes: %s" % baseList[0], "Amplitude", "Number of images") else: self.getData().YIND = modeList[1] if dim == 2: plotter.plotArray2D_xy("Normal-mode amplitudes: %s vs %s" % tuple(baseList), *baseList) elif dim == 3: self.getData().ZIND = modeList[2] plotter.plotArray3D_xyz("Normal-mode amplitudes: %s %s %s" % tuple(baseList), *baseList) views.append(plotter) return views
[docs] def loadData(self): mdVolumes = md.MetaData(self.protocol._getExtraPath('GroundTruth.xmd')) data = Data() for objId in mdVolumes: # pointData = list(map(float, particle._xmipp_nmaDisplacements)) pointData = list(mdVolumes.getValue(md.MDL_NMA,objId)) # inserting 6 zeros for the first 6 never used modes for j in range(6): pointData.insert(0, 0) # print(pointData) data.addPoint(Point(pointId=objId, data=pointData, weight=0.0)) return data