# **************************************************************************
# *
# * Authors:     J.M. De la Rosa Trevin (delarosatrevin@scilifelab.se)
# *
# * Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 3 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
from os.path import join
import tkinter as tk
from pwem.protocols import ProtUserSubSet
from pwem.objects import SetOfClasses2D
from pwem.viewers import EmPlotter, ClassesView
from pyworkflow.protocol.params import IntParam, FloatParam, LabelParam
from pyworkflow.protocol.constants import STATUS_FINISHED
from pyworkflow.utils.properties import Icon
from pyworkflow.viewer import ProtocolViewer, DESKTOP_TKINTER, WEB_DJANGO
from pyworkflow.utils.graph import Graph
from pyworkflow.utils.path import cleanPath
from pyworkflow.gui import Window
from pyworkflow.gui.widgets import HotButton
from pyworkflow.gui.graph import LevelTree
from pyworkflow.gui.canvas import Canvas, Item
from pyworkflow.gui.dialog import askString
from ..utils import SpiderDocFile
from ..protocols import SpiderProtClassifyWard, SpiderProtClassifyDiday
[docs]class SpiderViewerClassify(ProtocolViewer):
    """ Visualization of Spider classification results. """
    _environments = [DESKTOP_TKINTER, WEB_DJANGO]
    
    def _defineParams(self, form):
        form.addSection(label='Visualization')
        group1 = form.addGroup('Dendrogram')
        group1.addParam('doShowDendrogram', LabelParam,
                        label="Show dendrogram", default=True,
                        help='In a dendrogram larger vertical bars signify a '
                             'greater difference between classes. Many small '
                             'differences at the bottom can be eliminated with '
                             'an increase of the "Minimum height" setting.')
        group1.addParam('minHeight', FloatParam, default=0.5,
                        label='Minimum height',
                        help='The dendrogram will be cut at this level')
        self.groupClass = form.addGroup('Classes')
        self.groupClass.addParam('doShowClasses', LabelParam,
                                 label="Visualize class averages", default=True,
                                 help='Display class averages')
    def _getVisualizeDict(self):
        return {'doShowDendrogram': self._plotDendrogram,
                'doShowClasses': self.visualizeClasses
                }
    def _plotDendrogram(self, e=None):
        xplotter = EmPlotter()
        self.plt = xplotter.createSubPlot("Dendrogram", "", "")
        self.step = 0.25
        self.rightMost = 0.0  # Used to arrange leaf nodes at the bottom
        
        node = self.protocol.buildDendrogram()
        self.plotNode(node, self.minHeight.get())    
        self.plt.set_xlim(0., self.rightMost + self.step)
        self.plt.set_ylim(-10, 105)
        
        return [xplotter]
    
[docs]    def plotNode(self, node, minHeight=-1):
        childs = node.getChilds()
        h = node.height
        if h > minHeight and len(childs) > 1:
            x1, y1 = self.plotNode(childs[0], minHeight)
            x2, y2 = self.plotNode(childs[1], minHeight)
            xm = (x1 + x2)/2
            x = [x1, x1, x2, x2]
            y = [y1, h, h, y2]
            self.plt.plot(x, y, color='b')
            point = (xm, h)
        else:
            self.rightMost += self.step
            point = (self.rightMost, 0.)
        
        length = node.length
        index = node.index
        self.plt.annotate("%d(%d)" % (index, length), point, xytext=(0, -5),
                          textcoords='offset points', va='top', ha='center',
                          size='x-small')
        self.plt.plot(point[0], point[1], 'ro')
        
        return point 
    
    def _createNode(self, canvas, node, y):
        node.selected = False
        node.box = SpiderImageBox(canvas, node, y)
        return node.box 
            
            
[docs]class SpiderViewerWard(SpiderViewerClassify):
    """ Visualization of Spider - classify Ward protocol results. """
    
    _targets = [SpiderProtClassifyWard]
    _label = "viewer ward"
    
    def _defineParams(self, form):
        SpiderViewerClassify._defineParams(self, form)
        self.groupClass.addParam('maxLevel', IntParam, default=4,
                                 label='Maximum level',
                                 help='Maximum level of classes to show')
[docs]    def visualizeClasses(self, e=None):
        classTemplate = "class_%03d"
        averages = '%03d@' + self.protocol._getFileName('averages')
        
        def getInfo2(level, classNo):
            return classTemplate % classNo, averages % classNo
        
        node = self.protocol.buildDendrogram(writeAverages=False)
        
        g = Graph(root=node)
        self.graph = g
               
        self.win = Window("Select classes", self.formWindow, minsize=(1000, 600))
        root = self.win.root
        canvas = Canvas(root)
        canvas.grid(row=0, column=0, sticky='nsew')
        root.grid_columnconfigure(0, weight=1)
        root.grid_rowconfigure(0, weight=1) 
        
        self.buttonframe = tk.Frame(root)
        self.buttonframe.grid(row=2, column=0, columnspan=2)  
        self.win.createCloseButton(self.buttonframe).grid(row=0, column=0,
                                                          sticky='n',
                                                          padx=5, pady=5)
        saveparticlesbtn = HotButton(self.buttonframe, "Particles",
                                     Icon.PLUS_CIRCLE,
                                     command=self._askCreateParticles)
        saveparticlesbtn.grid(row=0, column=1, sticky='n', padx=5, pady=5)  
        btn = HotButton(self.buttonframe, "Classes", Icon.PLUS_CIRCLE,
                        command=self._askCreateClasses)
        btn.grid(row=0, column=2, sticky='n', padx=5, pady=5)
            
        lt = LevelTree(g)
        lt.DY = 135  # TODO: change in percent of the image size
        lt.setCanvas(canvas)
        lt.paint(self._createNode, maxLevel=self.maxLevel.get()-1)
        canvas.updateScrollRegion()
        
        return [self.win] 
    def _askCreateParticles(self):
        self._askCreateSubset('Particles', self.getSelectedNodesCount(2))
    def _askCreateClasses(self):
        self._askCreateSubset('Classes', self.getSelectedNodesCount(1))
    def _askCreateSubset(self, output, size):
        if self._selectionOverlap():
            self.win.showError("Classes could not overlap in the tree.")
            return
        s = '' if size == 1 else 's'
        headerLabel = 'Are you sure you want to create a new set of ' \
                      
' %s with %s element%s?' % (output, size, s)
        runname = askString('Question', 'Run name:', self.win.getRoot(), 30,
                            defaultValue='ProtUserSubSet',
                            headerLabel=headerLabel)
        if runname:
            createFunc = getattr(self, 'save' + output)
            createFunc(runname)
    def _createSubsetProtocol(self, createOutputFunc, label=None):
        """ Create a subset of classes or particles. """
        try:
            project = self.getProject()
            prot = project.newProtocol(ProtUserSubSet)
            prot.setObjLabel(label)
            prot.inputObject.set(self.protocol)
            project._setupProtocol(prot)
            prot.makePathsAndClean()
            createOutputFunc(prot)
            prot.setStatus(STATUS_FINISHED)
            project._storeProtocol(prot)
            # self.project.launchProtocol(prot, wait=True)
        except Exception as ex:
            self.win.showError(str(ex))
        
[docs]    def getSelectedNodesCount(self, depth):
        if depth == 1:
            return len([node for node in self.graph.getNodes() if node.selected])
        else:
            count = 0
            for node in self.graph.getNodes():
                if node.selected:
                    count += len(node.imageList)
            return count 
    def _selectionOverlap(self):
        """ Check if selected classes do overlap. """
        allImages = set()
        for n in self._selectedNodes():
            for i in n.imageList:
                if i in allImages:
                    return True
                allImages.add(i)
        return False
    def _selectedNodes(self):
        return [node for node in self.graph.getNodes() if node.selected]
[docs]    def saveClasses(self, runname=None):
        """ Store selected classes. """
        def createClasses(prot):
            classes = prot._createSetOfClasses2D(self.protocol.inputParticles.get(),
                                                 suffix='Selection')
            self.protocol._fillClassesFromNodes(classes, self._selectedNodes())
            prot._defineOutputs(outputClasses=classes)
        self._createSubsetProtocol(createClasses, runname) 
            
[docs]    def saveParticles(self, runname=None):
        """ Store particles from selected classes. """
        def createParticles(prot):
            inputParticles = self.protocol.inputParticles.get()
            particles = prot._createSetOfParticles(suffix='Selection')
            particles.copyInfo(inputParticles)
            self.protocol._fillParticlesFromNodes(inputParticles,
                                                  particles,
                                                  self._selectedNodes())
            prot._defineOutputs(outputParticles=particles)
        self._createSubsetProtocol(createParticles, runname)  
[docs]class ImageBox(Item):
    # copied from pw/gui/canvas.py, still depends on xmippLib
    def __init__(self, canvas, imgPath, x=0, y=0, text=None):
        Item.__init__(self, canvas, x, y)
        # Create the image
        from pyworkflow.gui import getImage
        from pwem.viewers.filehandlers import getImageFromPath
        if imgPath is None:
            self.image = getImage('no-image.gif')
        else:
            self.image = getImageFromPath(imgPath)
        if text is not None:
            self.label = tk.Label(canvas, image=self.image, text=text,
                                  compound=tk.TOP, bg='gray')
            self.id = self.canvas.create_window(x, y, window=self.label)
            self.label.bind('<Button-1>', self._onClick)
        else:
            self.id = self.canvas.create_image(x, y, image=self.image)
[docs]    def setSelected(self, value):  # Ignore selection highlight
        pass 
    def _onClick(self, e=None):
        pass 
[docs]class SpiderImageBox(ImageBox):
    def __init__(self, canvas, node, y):
        ImageBox.__init__(self, canvas, node.path, text=node.getName(), y=y)
        
    def _onClick(self, e=None):
        if self.node.path is None:
            return
        # On click change the selection state
        self.node.selected = not self.node.selected
        if self.node.selected:
            self.label.config(bd=2, bg='green')
        else:
            self.label.config(bd=0, bg='grey') 
            
            
[docs]class SpiderViewerDiday(SpiderViewerClassify):
    """ Visualization of Spider - classify Diday protocol results. """
    
    _targets = [SpiderProtClassifyDiday]
    _label = "viewer diday"
    
    def _defineParams(self, form):
        SpiderViewerClassify._defineParams(self, form)
        self.groupClass.addParam('numberOfClasses', IntParam, default=4,
                                 label='Number of classes',
                                 help='Desired number of classes.')
[docs]    def visualizeClasses(self, e=None):
        prot = self.protocol
        classDir = prot.getClassDir()
        classAvg = 'classavg'
        classVar = 'classvar'
        classDoc = 'docclass'
        
        params = {'[class_dir]': classDir,
                  '[desired-classes]': self.numberOfClasses.get(),
                  '[particles]': prot._params['particles'] + '@******',
                  '[class_doc]': join(classDir, classDoc + '***'), 
                  '[class_avg]': join(classDir, classAvg + '***'),
                  '[class_var]': join(classDir, classVar + '***'),        
                  }
        
        prot.runTemplate('mda/classavg.msa', prot.getExt(), params)
        particles = prot.inputParticles.get()
        particles.load()
        sampling = particles.getSamplingRate()
        
        setFn = prot._getTmpPath('classes2D.sqlite')
        cleanPath(setFn)
        classes2D = SetOfClasses2D(filename=setFn)
        classes2D.setImages(particles)
        # We need to first create a map between the particles index and
        # the assigned class number
        classDict = {}
        for classId in range(1, self.numberOfClasses.get()+1):
            docClass = prot._getPath(classDir, classDoc + '%03d.stk' % classId)
            doc = SpiderDocFile(docClass)
            for values in doc.iterValues():
                imgIndex = int(values[0])
                classDict[imgIndex] = classId
            doc.close()
        updateItem = lambda p, i: p.setClassId(classDict[i])
        def updateClass(cls):
            rep = cls.getRepresentative()
            rep.setSamplingRate(particles.getSamplingRate())
            avgFn = prot._getPath(classDir,
                                  classAvg + '%03d.stk' % cls.getObjId())
            rep.setLocation(1, avgFn)
        particlesRange = range(1, particles.getSize()+1)
        classes2D.classifyItems(updateItemCallback=updateItem,
                                updateClassCallback=updateClass,
                                itemDataIterator=iter(particlesRange))
        classes2D.write()
        classes2D.close()
        return [ClassesView(self.getProject(), prot.strId(),
                            classes2D.getFileName(), particles.strId())]