from __future__ import division
import iotbx.pdb
from cctbx.array_family import flex
import time
import sys
import mmtbx.f_model
from cctbx import maptbx
from iotbx import reflection_file_utils
from libtbx.utils import null_out
import mmtbx.utils
from mmtbx import map_tools
from mmtbx import masks

def prepare_f_obs_and_flags(f_obs, r_free_flags):
  sel = f_obs.data()>0
  f_obs = f_obs.select(sel)
  if (r_free_flags is not None):
   r_free_flags = r_free_flags.select(sel)

  merged = f_obs.as_non_anomalous_array().merge_equivalents()
  f_obs = merged.array().set_observation_type(f_obs)

  if (r_free_flags is not None):
    merged = r_free_flags.as_non_anomalous_array().merge_equivalents()
    r_free_flags = merged.array().set_observation_type(r_free_flags)

    f_obs, r_free_flags = f_obs.common_sets(r_free_flags)
  return f_obs, r_free_flags

def get_fmodel(pdb_file, hkl_file):
  pdb_inp = iotbx.pdb.input(file_name=pdb_file)
  xrs = pdb_inp.xray_structure_simple()
  #
  rfs = reflection_file_utils.reflection_file_server(
    crystal_symmetry = xrs.crystal_symmetry(),
    force_symmetry   = True,
    reflection_files = hkl_file,
    err              = null_out())
  determine_data_and_flags_result = mmtbx.utils.determine_data_and_flags(
    reflection_file_server  = rfs,
    keep_going              = True,
    log                     = null_out())
  f_obs = determine_data_and_flags_result.f_obs
  r_free_flags = determine_data_and_flags_result.r_free_flags
  print "f-obs labels:", f_obs.info().labels
  try: print "r-free-flags labels", r_free_flags.info().labels
  except: pass
  #
  fo, fl = prepare_f_obs_and_flags(
    f_obs=f_obs, r_free_flags=r_free_flags)
  return xrs, fo, fl

def exercise(args):
  processed_args = mmtbx.utils.process_command_line_args(
    args=args, log=null_out())
  xrs, f_obs, flags = get_fmodel(pdb_file=processed_args.pdb_file_names[0],
    hkl_file=processed_args.reflection_files)
  #
  # FLAT
  #
  fmodel_f = mmtbx.f_model.manager(
    xray_structure = xrs,
    f_obs          = f_obs,
    r_free_flags   = flags)
  fmodel_f.update_all_scales(update_f_part1=False, remove_outliers=False)
  fmodel_f.show(show_header=False, show_approx=False)
  print fmodel_f.r_work(), fmodel_f.r_free()
  #mc_diff_mask_allreso = map_tools.electron_density_map(
  #  fmodel = fmodel_f).map_coefficients(
  #    map_type         = "mFo-DFc",
  #    isotropize       = True,
  #    fill_missing     = False)
  #
  # BABINET (mimiking it with available tools).
  #
  # Do we want to bother implementing it properly? Don't know. My pedantic
  # side says YES... but I have absolutely NO time to do this now or in 
  # foreseeable future. It is a full day of work at most.
  #
  f_calc = fmodel_f.f_calc()
  xrs = fmodel_f.xray_structure.set_b_iso(value=50)
  fmodel_f.update_xray_structure(xray_structure = xrs, update_f_calc=True)
  f_mask = fmodel_f.f_calc()
  f_mask = f_mask.customized_copy(data = f_mask.data()*(-1))
  fmodel_b = mmtbx.f_model.manager(
    f_obs          = fmodel_f.f_obs(),
    r_free_flags   = fmodel_f.r_free_flags(),
    f_mask         = f_mask,
    f_calc         = f_calc)
  fmodel_b.update_all_scales(update_f_part1=False, remove_outliers=False, 
    fast=True)
  fmodel_b.show(show_header=False, show_approx=False)
  print fmodel_b.r_work(), fmodel_b.r_free()

  mc_diff_babinet = map_tools.electron_density_map(
    fmodel = fmodel_b).map_coefficients(
      map_type         = "mFo-DFc",
      isotropize       = True,
      fill_missing     = False)

  mtz_dataset = mc_diff_babinet.as_mtz_dataset(
    column_root_label = "mFo-DFc_babinet")
  #mtz_dataset.add_miller_array(
  #  miller_array      = mc_diff_babinet,
  #  column_root_label = "mFo-DFc_babinet") 
  mtz_object = mtz_dataset.mtz_object()
  mtz_object.write(file_name = 'babinet_map_coeffs.mtz')

if (__name__ == "__main__"):
  t0 = time.time()
  exercise(sys.argv[1:])
  print "Total time: %-8.4f"%(time.time()-t0)
  print "OK"
