import phaser

from phenix.substructure.hyss import structure_from_clusters

from cctbx import miller
from cctbx import xray
from cctbx.array_family import flex
from cctbx import maptbx
from scitbx import matrix
from libtbx import cluster
from libtbx.utils import Sorry

import math
import sys

class sad_data_adaptor(object):
    """
    Convert cctbx.miller_array to phaser format
    """

    def __init__(self, miller_array):
    
        indices = miller_array.indices()
        
        assert miller_array.anomalous_flag()
        assert miller_array.map_to_asu().indices().all_eq( indices )
        assert miller_array.sigmas() is not None
    
        matches = miller.match_bijvoet_mates(
            miller_array.space_group_info().type(),
            miller_array.indices()
            )
        sel_pp = matches.pairs_hemisphere_selection("+")
        sel_pm = matches.pairs_hemisphere_selection("-")
        sel_sp = matches.singles("+")
        sel_sm = matches.singles("-")
        
        self.miller = flex.miller_index()
        self.miller.reserve( sel_pp.size() + sel_sp.size() + sel_sm.size() )
        self.miller.extend( indices.select( sel_pp ) )
        self.miller.extend( indices.select( sel_sp ) )
        self.miller.extend( -indices.select( sel_sm ) )
        self.fplus, self.fminus = self.split_plus_minus(
            data = miller_array.data(),
            sel_pp = sel_pp,
            sel_pm = sel_pm,
            sel_sp = sel_sp,
            sel_sm = sel_sm
            )
        self.sigfplus, self.sigfminus = self.split_plus_minus(
            data = miller_array.sigmas(),
            sel_pp = sel_pp,
            sel_pm = sel_pm,
            sel_sp = sel_sp,
            sel_sm = sel_sm
            )
        
        n_indices = self.miller.size()
        n_pairs = sel_pp.size()
        i_singles_plus_minus_boundary = n_pairs + sel_sp.size()
        self.pplus = flex.bool()
        self.pplus.reserve( n_indices )
        self.pplus.resize( i_singles_plus_minus_boundary, True )
        self.pplus.resize( n_indices, False )
        self.unflag_non_positive_amplitudes(
            amplitudes = self.fplus,
            flags = self.pplus
            )
        self.pminus = flex.bool()
        self.pminus.reserve( n_indices )
        self.pminus.resize( n_pairs, True )
        self.pminus.resize( i_singles_plus_minus_boundary, False )
        self.pminus.resize( n_indices, True )
    
        self.unflag_non_positive_amplitudes(
            amplitudes = self.fminus,
            flags = self.pminus
            )
        present = self.find_not_double_missing_reflection_indices()
        assert len( present ) == n_indices
        
        if present.count( True ) < n_indices:
            self.miller = self.miller.select( present )
            self.fplus = self.fplus.select( present ) 
            self.sigfplus = self.sigfplus.select( present )
            self.pplus = self.pplus.select( present )
            self.fminus = self.fminus.select( present ) 
            self.sigfminus = self.sigfminus.select( present )
            self.pminus = self.pminus.select( present )
        
        self.unit_cell = miller_array.unit_cell().parameters()
        self.space_group_hall = miller_array.space_group_info().type().hall_symbol()
    

    def split_plus_minus(self, data, sel_pp, sel_pm, sel_sp, sel_sm):
      
        p = data.select( sel_pp )
        m = data.select( sel_pm )
        p.extend( data.select( sel_sp ) )
        m.resize( p.size() )
        m.extend( data.select( sel_sm ) )
        p.resize( m.size() )
        return ( p, m )
    
    
    def unflag_non_positive_amplitudes(self, amplitudes, flags):
        
        assert len( amplitudes ) == len( flags )
        
        for ( index, f ) in enumerate( amplitudes ):
            if f <= 0:
                flags[ index ] = False
                
                
    def find_not_double_missing_reflection_indices(self):
        
        assert len( self.pplus ) == len( self.pminus )
        return flex.bool(
            [ pplus or pminus for ( pplus, pminus ) in zip( self.pplus, self.pminus ) ]
            )
        
        
    def data(self):
        
        return zip(
            self.miller,
            self.fplus,
            self.fminus,
            self.sigfplus,
            self.sigfminus,
            self.pplus,
            self.pminus,
            )
        
        
    def __str__(self):
        
        lines = [ "Index Fplus Fminus SIGFplus SIGFminus Pplus Pminus" ]
        lines.extend(
            [ "%12s %10.2f %10.2f %10.2f %10.2f %4s %4s" % t for t in self.data() ]
            )
        return "\n".join( lines ) 


def default_setter(input):

    pass


class value_setter(object):
  
    def __init__(self, setter, setting):

        self.setter = setter
        self.setting = setting

    
    def __call__(self, input):

        self.setter( input, self.setting )


def phaser_setter(method, value):

    if value is None:
        return default_setter

    else:
        return value_setter(
            setter = getattr( phaser.InputEP_SAD, method ),
            setting = value,
            )

      
class phaser_llg_and_correlation_score(object):
    """
    A score that contains the LLG and the correlation coefficient
    """
  
    def __init__(self, cc, llg):
    
        self._cc = cc
        self._llg = llg
    
    
    def correlation(self):
    
        return self._cc
  
  
    def llg(self):
    
        return self._llg
  
  
    def score(self):
    
        return self.llg()
  
  
    def __mul__(self, other):
    
        return self.__class__( llg = self.llg() * other, cc = self.correlation() )
  
  
    def __rmul__(self, other):
    
        return self.__mul__( other = other )
    
    
    def __div__(self, other):
    
        return self.__class__( llg = self.llg() / other, cc = self.correlation() )
    
    
    def __truediv__(self, other):
    
        return self.__class__( llg = self.llg() / other, cc = self.correlation() )


    def __add__(self, other):
        
        return self.__class__(
            llg = self.llg() + other.llg(),
            cc = self.correlation() + other.correlation()
            )
        
    
    def __sub__(self, other):
        
        return self.__class__(
            llg = self.llg() - other.llg(),
            cc = self.correlation() - other.correlation()
            )
    
  
    def __cmp__(self, other):
    
        return cmp( self.score(), other.score() )
  
  
    def __str__(self):
    
        return "( %.3f (llg) %.3f (cc) )" % ( self.llg(), self.correlation() )
    
    
class phaser_llg_score(object):
    """
    A score that contains the LLG
    """
  
    def __init__(self, llg):
    
        self._llg = llg
  
  
    def llg(self):
    
        return self._llg
  
  
    def score(self):
    
        return self.llg()
    
  
    def __cmp__(self, other):
    
        return cmp( self.score(), other.score() )
  
  
    def __str__(self):
    
        return "%.3f (llg)" % self.llg()
    
    
class phaser_hyss_component(object):
    """
    Common functionality for phaser-based objects
    """
    
    AMPLITUDE_TYPES = (
        xray.observation_types.amplitude,
        xray.observation_types.reconstructed_amplitude,
        )
    INTENSITY_TYPES= (
        xray.observation_types.intensity,
        )
    TYPES = AMPLITUDE_TYPES + INTENSITY_TYPES
    
    def __init__(self, hyss_search,out = sys.stdout):
        
        if ( not hyss_search.f_original
            or not isinstance( hyss_search.f_original.observation_type(), self.TYPES )
            ):
          raise ValueError, "Missing data: F+ and F- observations"
        
        if not hyss_search.scattering_type:
          raise ValueError, "Missing data: scattering type"
        
        if not hyss_search.wavelength:
          raise ValueError, "Missing data: wavelength"
    
        self._hyss_search = hyss_search
        self.out = out
        self.preprocess_and_store_data()
    
    
    def phaser_data_key(self):
        
        return "PHASER_DATA"
    
    
    def base_score_key(self):
        
        return "BASE_SCORE"
        
        
    def preprocess_and_store_data(self): 
        
        data = self._hyss_search.extra
        
        if self.phaser_data_key() not in data:
            print >>self.out,"Phaser scoring: Converting data to phaser format"
            assert self._hyss_search.f_original
            assert self._hyss_search.f_original.anomalous_flag()
            
            if isinstance( self._hyss_search.f_original.observation_type(), self.INTENSITY_TYPES ):
                amplitudes = self._hyss_search.f_original.f_sq_as_f()
                
            else:
                amplitudes = self._hyss_search.f_original
                
            assert isinstance( amplitudes.observation_type(), self.AMPLITUDE_TYPES )
            niggli = (
                amplitudes.unique_under_symmetry() \
                    .change_basis( self._hyss_search.cb_op_niggli ) \
                    .map_to_asu()
                )
            data[ self.phaser_data_key() ] = sad_data_adaptor( niggli )
        
        if self.base_score_key() not in data:
            print >>self.out,"Phaser scoring: Calculating base score for LLG calculation"
            data[ self.base_score_key() ] = self.calculate_base_score()
            print >>self.out,"Phaser scoring: Base score is %.0f" % data[ self.base_score_key() ]
            
            
    def calculate_base_score(self):

        inp = self.get_phaser_base_input()
        inp.addATOM_FULL(
            "xtal",
            self._hyss_search.scattering_type,
            True,
            ( 0, 0, 0 ),
            0,
            True,
            0,
            False,
            matrix.rec( [ 1, 1, 1, 0, 0, 0 ], [ 3, 2 ] ),
            True,
            True,
            True,
            True,
            "S1"
            )
        inp.setOUTL_REJE( False )
        inp.setLLGC_COMP( False )
        inp.setMACA_PROT( "DEFAULT" )
        inp.setMACS_PROT( "CUSTOM" )
        inp.addMACS(
            ref_k = False,
            ref_b = False,
            ref_sigma = False,
            ref_xyz = False,
            ref_occ = False,
            ref_bfac = False,
            ref_fdp = False,
            ref_sa = False,
            ref_sb = False,
            ref_sp = True,
            ref_sd = True,
            ref_pk = False,
            ref_pb = False,
            ncyc = 50,
            target = "NOT_ANOM_ONLY",
            minimizer = "BFGS",
            )
        
        result = self.run_phaser( input = inp )
        
        return -result.getLogLikelihood()
    
    
    def run_phaser(self, input):
    
        result = phaser.runEP_SAD( input )

        if result.Failed():
            print >>self.out, result.logfile()
            raise Sorry, "Calculation failed, see error message above"
        
        return result
        
        
    def get_phaser_base_input(self):
        
        assert self.phaser_data_key() in self._hyss_search.extra
        data = self._hyss_search.extra[ self.phaser_data_key() ]
        
        input = phaser.InputEP_SAD()
        
        input.setCELL6( data.unit_cell )
        input.setSPAC_HALL( data.space_group_hall )
        input.setCRYS_MILLER( data.miller )
        input.addCRYS_ANOM_DATA(
            "xtal",
            "wave",
            data.fplus,
            data.sigfplus,
            data.pplus,
            data.fminus,
            data.sigfminus,
            data.pminus,
            )
        
        input.setCOMP_AVER()
        input.setWAVE( self._hyss_search.wavelength )
        input.setATOM_CHAN_BFAC_WILS( True )
        input.setHKLO( False )
        input.setXYZO( False )
        input.setMUTE( True )
        input.setHAND( "OFF" )
        input.addLLGC_SCAT( self._hyss_search.scattering_type )
        input.setMACA_PROT( "OFF" )
            
        return input
    
    
    def get_phaser_base_llg(self):
        
        assert self.base_score_key() in self._hyss_search.extra
        return self._hyss_search.extra[ self.base_score_key() ]
    

class phaser_composite_rescoring(phaser_hyss_component):
    """
    Phaser adaptor for use in solution rescoring with Hyss
    """
    
    def __init__(
        self,
        hyss_search,
        complete,
        llgc_sigma = None,
        min_clustering_count = 4,
        stddev_multiplier = 3.0,
        out=sys.stdout,
        ):
        super( phaser_composite_rescoring, self ).__init__( hyss_search = hyss_search, out = out)
        self._sp = self._hyss_search.search_params
        self._max_spread = self._sp.max_relative_spread_top_correlations
        # Parameters
        self._extra_params = [
            phaser_setter(
                method = "setLLGC_COMP",
                value = complete,
                ),
            phaser_setter(
                method = "setLLGC_SIGM",
                value = llgc_sigma,
                ),
            ]
        self._min_clustering_count = min_clustering_count
        self._stddev_multiplier = stddev_multiplier
        self.preprocess_and_store_data()
    
    
    def terminate(self, top_scores, valid_scores):
    
        min_top = top_scores[-1].correlation()
        min_valid = max(
          self._sp.min_low_correlation,
          min( valid_scores, key = lambda s: s.score() ).correlation()
          )
        
        # Ralf's termination criteria on CCs
        if ( self._sp.compare_sites_if_correlation_is_over <= min_top
          or self._sp.high_low_correlation_factor * min_valid <= min_top ):
            print >>self.out, "Phaser scoring: identified solutions from correlation coefficient"
            return True
        
        # Do cluster analysis if there are enough points
        if len( valid_scores ) < self._min_clustering_count:
            return False
        
        ( lower, upper ) = self.llg_cluster_analysis( scores = valid_scores ) 
        
        lstats = flex.mean_and_variance( flex.double( [ e.llg() for e in lower ] ) )
        ustats = flex.mean_and_variance( flex.double( [ e.llg() for e in upper ] ) )
        lower_mean =  lstats.mean()
        upper_mean =  ustats.mean()
        lower_count = len( lower )
        upper_count = len( upper )
        
        if lower_count == 1:
            assert 1 < upper_count
            pooled = ustats.unweighted_sample_standard_deviation()
            
        elif upper_count == 1:
            assert 1 < lower_count
            pooled = lstats.unweighted_sample_standard_deviation()
            
        else:
            nom = ( ( lower_count - 1 ) * lstats.unweighted_sample_variance()
                + ( upper_count - 1 ) * ustats.unweighted_sample_variance() ) 
            pooled = math.sqrt( nom / ( lower_count + upper_count - 2 ) )
            
        if self._stddev_multiplier * pooled < ( upper_mean - lower_mean ):
            print >>self.out, "Phaser scoring: identified solutions from bimodal LLG distribution"
            return True
        
        else:
            return False
    
    
    def llg_cluster_analysis(self, scores):
        
        kcl = cluster.KMeansClustering(
            [ ( s, ) for s in scores ],
            distance = lambda x,y: abs( x[0].llg() - y[0].llg() )
            )
        ( cl_lower, cl_upper ) = kcl.getclusters( 2 )
        lower = [ e[0] for e in cl_lower ]
        upper = [ e[0] for e in cl_upper ]
        
        key = lambda s: s.llg()
        
        if max( lower, key = key ) <= min( upper, key = key ):
            return ( lower, upper )
        
        else:
            assert max( upper, key = key ) <= min( lower, key = key ) 
            return ( upper, lower )
    

    def minimize(self, structure):

        inp = self.get_phaser_input()
        real_scat_type = self._hyss_search.scattering_type
        
        for s in structure.scatterers():
            ( x, y, z ) = s.site
            inp.addATOM( "xtal", real_scat_type, x, y, z, s.occupancy )
            
        result = self.run_phaser( input = inp )
        structure.erase_scatterers()
        hyss_scat_type = self._hyss_search.xray_scatterer.scattering_type
        structure.add_scatterers(
            flex.xray_scatterer(
                [ scat.customized_copy( scattering_type = hyss_scat_type )
                    for scat in result.getAtoms() ]
                )
            )
        
        correlation = self._hyss_search.correlation_calculation(
          f_obs = self._hyss_search.structure_factors.q_all,
          structure = structure
          )
        
        return phaser_llg_and_correlation_score(
            llg = -result.getLogLikelihood() - self.get_phaser_base_llg(),
            cc = correlation
            )
    
    
    def top_group_threshold(self, scores):
      
        assert scores
        llgs = [ s.llg() for s in scores ]
        best = max( llgs )
        threshold = best - abs( self._max_spread * best )
        
        if self._min_clustering_count <= len( scores ):
            stats = flex.mean_and_variance( flex.double( llgs ) )
            stddev = stats.unweighted_sample_standard_deviation()
            hcl = cluster.HierarchicalClustering(
                data = llgs,
                distance_function = lambda x, y: float( abs( x - y ) )
                )
            clust = hcl.getlevel( 2.0 * stddev )
            
            if 1 < len( clust ):
                threshold = min( threshold, max( [ min( c ) for c in clust ] ) )
            
        return phaser_llg_and_correlation_score( llg = threshold, cc = 0 )
    
    
    def get_phaser_input(self):
    
        input = self.get_phaser_base_input()

        for param in self._extra_params:
            param( input = input )

        return input
    
    
class phaser_extrapolation(phaser_hyss_component):
    """
    Extrapolation using phaser completion
    """ 
        
    def get_extrapolation_structure_and_score(self, fragment, n_sites):
        
        inp = self.get_phaser_input()
        real_scat_type = self._hyss_search.scattering_type
        
        for s in fragment.scatterers():
            ( x, y, z ) = s.site
            inp.addATOM( "xtal", real_scat_type, x, y, z, s.occupancy )
            
        result = self.run_phaser( input = inp )
        hyss_scat_type = self._hyss_search.xray_scatterer.scattering_type
        extrapolation_scan_structure = xray.structure(
            scatterers = flex.xray_scatterer(
                [ scat.customized_copy( scattering_type = hyss_scat_type )
                    for scat in result.getAtoms()[:n_sites] ]
                ),
            crystal_symmetry = self._hyss_search.f_obs_niggli
            )
        
        return (
            extrapolation_scan_structure,
            phaser_llg_score(
                llg = -result.getLogLikelihood() - self.get_phaser_base_llg()
                ),
            )
        
    
    def get_phaser_input(self):
    
        input = self.get_phaser_base_input()
        input.setLLGC_COMP( True )
        input.setLLGC_NCYC( 1 )
        input.setLLGC_SIGM( 3.0 )
        return input
    
    
class phaser_map_based_extrapolation(phaser_hyss_component):
    """
    Extrapolation using phaser completion
    """
    
    def get_fft_map(self, fragment):
        
        inp = self.get_phaser_input()
        real_scat_type = self._hyss_search.scattering_type
        
        for s in fragment.scatterers():
            ( x, y, z ) = s.site
            inp.addATOM( "xtal", real_scat_type, x, y, z, s.occupancy )
            
        result = self.run_phaser( input = inp )
        
        # Calculate LLG map
        miller_array = self._hyss_search.f_obs_niggli.customized_copy(
            indices = result.getMiller(),
            anomalous_flag = False,
            data = flex.polar(
                result.getLLG_F( real_scat_type.upper() ),
                result.getLLG_PHI( real_scat_type.upper() ),
                True
                )
            )
        
        return (
            -result.getLogLikelihood(),
            miller_array.fft_map(
                symmetry_flags = maptbx.use_space_group_symmetry
                ),
            )
        
        
    def get_extrapolation_structure_and_score(self, fragment, n_sites):
        
        ( llg, fft_map ) = self.get_fft_map( fragment = fragment )
        
        # Find peaks
        cluster_analysis = fft_map.peak_search(
            parameters = self._hyss_search.peak_search_parameters
            )
        
        # Append  original fragment atoms
        height = cluster_analysis.max_grid_height()
        
        if height is None:
            height = 1
        
        for scatterer in fragment.scatterers():
            cluster_analysis.append_fixed_site(
                site = scatterer.site,
                height = height
                )
          
        extrapolation_scan_structure = structure_from_clusters.build(
          cluster_analysis = cluster_analysis,
          xray_scatterer = self._hyss_search.xray_scatterer,
          n_sites = n_sites
          )

        return (
            extrapolation_scan_structure,
            phaser_llg_score(
                llg = llg - self.get_phaser_base_llg()
                ),
            )
        
    
    def get_phaser_input(self):

        input = self.get_phaser_base_input()
        input.setLLGC_COMP( False )
        # XXX did setLLGM replace setLLGC_MAPS?
        input.setLLGM(True)
        return input
