from __future__ import division
import iotbx.pdb
from cctbx.array_family import flex
import time
import sys
import mmtbx.f_model
from cctbx import maptbx
from iotbx import reflection_file_utils
from libtbx.utils import null_out
import mmtbx.utils
from mmtbx import map_tools
from mmtbx import masks
from cStringIO import StringIO

def prepare_f_obs_and_flags(f_obs, r_free_flags):
  sel = f_obs.data()>0
  f_obs = f_obs.select(sel)
  if (r_free_flags is not None):
    r_free_flags = r_free_flags.select(sel)

  merged = f_obs.as_non_anomalous_array().merge_equivalents()
  f_obs = merged.array().set_observation_type(f_obs)
  if (r_free_flags is not None):
    merged = r_free_flags.as_non_anomalous_array().merge_equivalents()
    r_free_flags = merged.array().set_observation_type(r_free_flags)

    f_obs, r_free_flags = f_obs.common_sets(r_free_flags)
  return f_obs, r_free_flags

def get_fmodel(pdb_file, hkl_file):
  print pdb_file, hkl_file
  pdb_inp = iotbx.pdb.input(file_name=pdb_file)
  xrs = pdb_inp.xray_structure_simple()
  #
  rfs = reflection_file_utils.reflection_file_server(
    crystal_symmetry = xrs.crystal_symmetry(),
    force_symmetry   = True,
    reflection_files = hkl_file,
    err              = StringIO())
  determine_data_and_flags_result = mmtbx.utils.determine_data_and_flags(
    reflection_file_server  = rfs,
    keep_going              = True,
    log                     = StringIO())
  f_obs = determine_data_and_flags_result.f_obs
  r_free_flags = determine_data_and_flags_result.r_free_flags
  print "f-obs labels:", f_obs.info().labels
  try: print "r-free-flags labels:", r_free_flags.info().labels
  except: pass
  #
  fo, fl = prepare_f_obs_and_flags(
    f_obs=f_obs, r_free_flags=r_free_flags)
  return xrs, fo, fl

def exercise(args):
  processed_args = mmtbx.utils.process_command_line_args(
    args=args, log=null_out())
  xrs, f_obs, flags = get_fmodel(pdb_file=processed_args.pdb_file_names[0],
    hkl_file=processed_args.reflection_files)
  fmodel = mmtbx.f_model.manager(
    xray_structure = xrs,
    f_obs          = f_obs,
    r_free_flags   = flags)
  fmodel.update_all_scales(update_f_part1=False)
  print fmodel.r_work(), fmodel.r_free()
  fmodel.show(show_header=False, show_approx=False)
  #
  # compute map before truncation
  # here XXX
  #
  mc_diff_mask_allreso = map_tools.electron_density_map(
    fmodel = fmodel).map_coefficients(
      map_type         = "mFo-DFc",
      isotropize       = True,
      fill_missing     = False)
  fmodel = fmodel.resolution_filter(d_max=6) 
  zero = fmodel.f_calc().array(
    data = flex.complex_double(fmodel.f_calc().data().size(),0))
  fmodel.update(f_mask = zero)
  fmodel.update_all_scales(update_f_part1=False)
  print fmodel.r_work(), fmodel.r_free()
  fmodel.show(show_header=False, show_approx=False)
  #
  # compute map after truncation
  # here XXX
  #
  mc_diff_low_res = map_tools.electron_density_map(
    fmodel = fmodel).map_coefficients(
      map_type         = "mFo-DFc",
      isotropize       = True,
      fill_missing     = False)

  mtz_dataset = mc_diff_mask_allreso.as_mtz_dataset(
    column_root_label = "mFo-DFc_mask_allreso")
  mtz_dataset.add_miller_array(
    miller_array      = mc_diff_low_res,
    column_root_label = "mFo-DFc_low_res") 
  mtz_object = mtz_dataset.mtz_object()
  mtz_object.write(file_name = 'low_res_map_coeffs.mtz')


if (__name__ == "__main__"):
  t0 = time.time()
  exercise(sys.argv[1:])
  print "Total time: %-8.4f"%(time.time()-t0)
  print "OK"
