Source code for continuousflex.viewers.viewer_heteroflow

# **************************************************************************
# * Authors:    Mohamad Harastani            (mohamad.harastani@upmc.fr)
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
"""
This module implement the wrappers aroung Xmipp CL2D protocol
visualization program.
"""
from continuousflex.protocols.protocol_heteroflow import FlexProtHeteroFlow
from pwem.viewers import EmProtocolViewer
from pyworkflow.protocol.params import LabelParam, IntParam, EnumParam, StringParam
from pyworkflow.viewer import ProtocolViewer, DESKTOP_TKINTER, WEB_DJANGO
from pwem.viewers import ObjectView
import numpy as np
import matplotlib.pyplot as plt
from continuousflex.protocols.utilities.OF_plots import plot_quiver_3d, plot_quiver_2d
from continuousflex.protocols.utilities.spider_files3 import open_volume, open_image
import pyworkflow.protocol.params as params
from pyworkflow.utils.process import runJob
from pyworkflow.utils.path import makePath


XY = 0
XZ = 1
ZY = 2


[docs]class FlexHeteroFlowViewer(EmProtocolViewer): """ Visualization of results from the HeteroFlow protocol """ _label = 'viewer heteroflow' _targets = [FlexProtHeteroFlow] _environments = [DESKTOP_TKINTER, WEB_DJANGO] def __init__(self, **kwargs): ProtocolViewer.__init__(self, **kwargs) self._data = None def _defineParams(self, form): form.addSection(label='Visualization') group = form.addGroup('Optical flows') group.addParam('FlowNumber', IntParam, default=1, label='Optical flow for volume number') group.addParam('DownSample', IntParam, default=2, expertLevel=params.LEVEL_ADVANCED, label='Downsample the 3D quiver plot', help='Loading the 3D quiver plot can take time, we can downsample it by this number') group.addParam('displayFlow', LabelParam, label="Display 3D optical flow", help="Display the calculated optical flow of volume of specified number vs the reference") group.addParam('displayFlow2', EnumParam, choices=[' x-y plane: Euler( 0, 0, 0)', ' x-z plane: Euler( 0,90, 0)', ' z-y plane: Euler(90,90, 0)'], default=XY, display=params.EnumParam.DISPLAY_COMBO, label='Display projected 3D-to-2D optical flow', help="Display the projected 3D-2D optical flow on one of the planes") group.addParam('RotTiltPsi', StringParam, default=None, expertLevel=params.LEVEL_ADVANCED, label='rot tilt psi', help='Project the 3D optical flow using specific Euler angles') form.addParam('displayVolumes', LabelParam, label="Display warped volumes", help="Display the volumes that are generated by applying the calculated optical flow of each of " "the input volumes on the reference") form.addParam('displayHistCC', LabelParam, label="Histogram of normalized cross correlation", help="Histogram of the normalized cross correlation between the input volumes and " "the warped reference") form.addParam('displayHistmsd', LabelParam, label="Histogram of mean square distance", help="Histogram of the mean square distance between the input volumes and " "the warped reference") form.addParam('displayHistmad', LabelParam, label="Histogram of normalized mean absolute distance", help="Histogram of the mean absolute distance between the input volumes and " "the warped reference") def _getVisualizeDict(self): return {'displayVolumes': self._viewVolumes, 'displayHistCC': self._viewParam, 'displayHistmsd': self._viewParam, 'displayHistmad': self._viewParam, 'displayFlow' : self._viewFlow, 'displayFlow2' : self._viewFlow2, 'RotTiltPsi' : self._viewFlow2 } def _viewVolumes(self, paramName): volumes = self.protocol.WarpedRefByFlows return [ObjectView(self._project, volumes.strId(), volumes.getFileName())] def _viewParam(self, paramName): datamat_fn = self.protocol._getExtraPath('cc_msd_mad.txt') datamat = np.loadtxt(datamat_fn, delimiter=' ') if paramName == 'displayHistCC': plt.figure() plt.hist(datamat[:, 0]) # plt.xlim(0,1) plt.title('Histogram of normalized cross correlation between warped\n reference by ' 'optical flows (estimated volumes) and the input volumes') plt.xlabel('Cross correlation') plt.ylabel('Number of volumes') plt.show() elif paramName == 'displayHistmsd': plt.figure() plt.hist(datamat[:, 1]) plt.title('Histogram of normalized mean square distance between warped\n reference by ' 'optical flows (estimated volumes) and the input volumes') plt.xlabel('normalized mean square distance') plt.ylabel('Number of volumes') plt.show() elif paramName == 'displayHistmad': plt.figure() plt.hist(datamat[:, 2]) plt.title('Histogram of normalized mean absolute distance between warped\n reference by ' 'optical flows (estimated volumes) and the input volumes') plt.xlabel('normalized absolute square distance') plt.ylabel('Number of volumes') plt.show() pass def _viewFlow(self, paramName): number = str(self.FlowNumber).zfill(6) flow = self.read_optical_flow_by_number(number) title = '3D optical flow for input volume number %d' % self.FlowNumber plot_quiver_3d(flow, downsample=self.DownSample.get(), title=title) pass def _viewFlow2(self, paramName): number = str(self.FlowNumber).zfill(6) flow3D = self.read_optical_flow_by_number(number) op_path = self.protocol._getExtraPath() + '/optical_flows/' path_flowx = op_path + str(number).zfill(6) + '_opflowx.spi' path_flowy = op_path + str(number).zfill(6) + '_opflowy.spi' path_flowz = op_path + str(number).zfill(6) + '_opflowz.spi' makePath(self.protocol._getTmpPath()) proj_x = self.protocol._getTmpPath('proj_x.spi') proj_y = self.protocol._getTmpPath('proj_y.spi') proj_z = self.protocol._getTmpPath('proj_z.spi') rot = 0 tilt = 0 psi = 0 title = 'Projected optical flow of input volume number %d on XY plane' % self.FlowNumber if self.displayFlow2 == XZ: title = 'Projected optical flow of input volume number %d on XZ plane' % self.FlowNumber rot = 90 tilt = 90 elif self.displayFlow2 == ZY: title = 'Projected optical flow of input volume number %d on ZY plane' % self.FlowNumber tilt = 90 if paramName == 'RotTiltPsi': rot, tilt, psi = list(map(float, self.RotTiltPsi.get().split())) title = 'Projected optical flow of input volume number %d \n using Euler angles' \ ' (%.1f, %.1f, %.1f)' % (self.FlowNumber, rot, tilt, psi) #print(rot, tilt, psi) command_x = '-i ' + path_flowx + ' -o ' + proj_x + ' --angles ' + str(rot) + ' ' + str(tilt) + ' ' + str(psi) command_y = '-i ' + path_flowy + ' -o ' + proj_y + ' --angles ' + str(rot) + ' ' + str(tilt) + ' ' + str(psi) command_z = '-i ' + path_flowz + ' -o ' + proj_z + ' --angles ' + str(rot) + ' ' + str(tilt) + ' ' + str(psi) runJob(None, 'xmipp_phantom_project', command_x) runJob(None, 'xmipp_phantom_project', command_y) runJob(None, 'xmipp_phantom_project', command_z) px = open_image(proj_x) py = open_image(proj_y) pz = open_image(proj_z) p = np.zeros([3, np.shape(px)[0], np.shape(px)[1]]) p[0, :, :] = px p[1, :, :] = py p[2, :, :] = pz T = self.euler_matrix(rot, tilt, psi) p_reshaped = np.reshape(p, [3, np.shape(px)[0] * np.shape(px)[1]]) pn = np.reshape(np.matmul(T, p_reshaped), [3, np.shape(px)[0], np.shape(px)[1]]) flow2D = np.zeros([np.shape(px)[0], np.shape(px)[1], 2]) flow2D[:,:,0] = pn[0,:,:] flow2D[:,:,1] = pn[1,:,:] # We need to scale flow2D by the magnitude of flow3D mag_3D = np.sqrt(flow3D[0, :, :, :] * flow3D[0, :, :, :] + flow3D[1, :, :, :] * flow3D[1, :, :, :] + flow3D[2, :, :, :] * flow3D[2, :, :, :]) max_3D = np.max(mag_3D) mag_2D = np.sqrt(flow2D[:, :, 0] * flow2D[:, :, 0] + flow2D[:, :, 1] * flow2D[:, :, 1]) max_2D = np.max(mag_2D) flow2D = (max_3D/max_2D)*flow2D plot_quiver_2d(flow2D, title=title) pass
[docs] def read_optical_flow_by_number(self, num): op_path = self.protocol._getExtraPath() + '/optical_flows/' path_flowx = op_path + str(num).zfill(6) + '_opflowx.spi' path_flowy = op_path + str(num).zfill(6) + '_opflowy.spi' path_flowz = op_path + str(num).zfill(6) + '_opflowz.spi' flow = self.read_optical_flow(path_flowx, path_flowy, path_flowz) return flow
[docs] def read_optical_flow(self, path_flowx, path_flowy, path_flowz): x = open_volume(path_flowx) y = open_volume(path_flowy) z = open_volume(path_flowz) l = np.shape(x) # print(l) flow = np.zeros([3, l[0], l[1], l[2]]) flow[0, :, :, :] = x flow[1, :, :, :] = y flow[2, :, :, :] = z return flow
[docs] def euler_matrix(self,rot, tilt, psi): from math import sin, cos, radians t1 = -radians(psi) t2 = -radians(tilt) t3 = -radians(rot) a11 = cos(t1) * cos(t2) * cos(t3) - sin(t1) * sin(t3) a12 = -cos(t3) * sin(t1) - cos(t1) * cos(t2) * sin(t3) a13 = cos(t1) * sin(t2) a21 = cos(t1) * sin(t3) + cos(t2) * cos(t3) * sin(t1) a22 = cos(t1) * cos(t3) - cos(t2) * sin(t1) * sin(t3) a23 = sin(t1) * sin(t2) a31 = -cos(t3) * sin(t2) a32 = sin(t2) * sin(t3) a33 = cos(t2) T = np.array([[a11, a12, a13], [a21, a22, a23], [a31, a32, a33]]) return T