from __future__ import division
from libtbx import easy_run
import time
from libtbx.test_utils import approx_equal
import iotbx.pdb
from iotbx import reflection_file_reader
from cctbx import miller
from cctbx import maptbx
from scitbx.array_family import flex




def run_polder(pdb_code, selection):
  cmd = " ".join([
    "phenix.polder",
    "%s.pdb" % pdb_code,
    "%s.mtz" % pdb_code,
    "sphere_radius=5",
    'output_file_name_prefix="%s" ' %pdb_code,
    'solvent_exclusion_mask_selection="%s" ' % selection,
    "> %s_polder.log" % pdb_code
  ])
  print cmd
  easy_run.call(cmd)


def get_map(cg, mc):
  fft_map = miller.fft_map(
    crystal_gridding     = cg,
    fourier_coefficients = mc)
  fft_map.apply_sigma_scaling()
  return fft_map.real_map_unpadded()

def get_map_stats(map, sites_frac):
  map_values = flex.double()
  for sf in sites_frac:
    map_values.append(map.eight_point_interpolation(sf))
  return map_values


def exercise(pdb_code, selection):
  print 'hi'
  # run polder
  run_polder(
	pdb_code = pdb_code,
	selection = selection)
  file_name = pdb_code+'_polder_map_coeffs.mtz'
  # open polder map file
  miller_arrays = reflection_file_reader.any_reflection_file(file_name =
    file_name).as_miller_arrays()
  mc_polder, mc_omit = [None,]*2
  for ma in miller_arrays:
    lbl = ma.info().label_string()
    if(lbl == "mFo-DFc_polder,PHImFo-DFc_polder"):
      mc_polder = ma.deep_copy()
    if(lbl == "mFo-DFc_omit,PHImFo-DFc_omit"):
      mc_omit = ma.deep_copy()
  assert [mc_polder, mc_omit].count(None)==0
  cg = maptbx.crystal_gridding(
    unit_cell         = mc_polder.unit_cell(),
    d_min             = mc_polder.d_min(),
    resolution_factor = 0.25,
    space_group_info  = mc_polder.space_group_info())
  map_polder   = get_map(cg=cg, mc=mc_polder)
  map_omit     = get_map(cg=cg, mc=mc_omit)
  pdb_file_name = pdb_code+'.pdb'
  pdb_hierarchy = iotbx.pdb.input(
    file_name = pdb_file_name).construct_hierarchy()
  sel = pdb_hierarchy.atom_selection_cache().selection(string = selection)
  sites_cart_lig = pdb_hierarchy.atoms().extract_xyz().select(sel)
  sites_frac_lig = mc_polder.unit_cell().fractionalize(sites_cart_lig)
  mp  = get_map_stats(map=map_polder,   sites_frac=sites_frac_lig)
  mo  = get_map_stats(map=map_omit,     sites_frac=sites_frac_lig)
  mmm_mp = mp.min_max_mean().as_tuple()
  mmm_o = mo.min_max_mean().as_tuple()
  print "%s: map min/max/mean" % pdb_code
  print "Polder map : %7.3f %7.3f %7.3f"%mmm_mp
  print "Omit       : %7.3f %7.3f %7.3f"%mmm_o


exercise(
	pdb_code = '1aba',
	selection = 'chain A and resseq 88')