import os, sys
import libtbx.load_env
import iotbx.pdb
from scitbx.array_family import flex
from scitbx.math import superpose, matrix
from cctbx import maptbx
from cctbx.maptbx import superpose_maps

def ccp4_map(cg, file_name, map_data):
  from iotbx import mrcfile
  mrcfile.write_ccp4_map(
      file_name=file_name,
      unit_cell=cg.unit_cell(),
      space_group=cg.space_group(),
      map_data=map_data,
      labels=flex.std_string([""]))
  
def exercise(args):
  # Load firts model, generate its map, and write map out
  xrsA = iotbx.pdb.input(file_name = args[0]).xray_structure_simple()
  fft_map_A = xrsA.structure_factors(d_min=1.5).f_calc().fft_map(
    resolution_factor = 1./4)
  fft_map_A.apply_sigma_scaling()
  map_data_A = fft_map_A.real_map_unpadded()
  ccp4_map(cg=xrsA.crystal_symmetry(), file_name="A.ccp4", map_data=map_data_A)
  # Load second model, generate its map, and write map out
  pdb_inpB = iotbx.pdb.input(file_name = args[1])
  xrsB = pdb_inpB.xray_structure_simple()
  fft_map_B = xrsB.structure_factors(d_min=1.5).f_calc().fft_map(
    resolution_factor = 1./4)
  fft_map_B.apply_sigma_scaling()
  map_data_B = fft_map_B.real_map_unpadded()
  ccp4_map(cg=xrsB.crystal_symmetry(), file_name="B.ccp4", map_data=map_data_B)
  # find superposition operator
  fit = superpose.least_squares_fit(
    reference_sites = xrsA.sites_cart(),
    other_sites     = xrsB.sites_cart())
  # move model and output it
  sites_cart_in_b_superposed_on_a = fit.rt() * xrsB.sites_cart()
  hb = pdb_inpB.construct_hierarchy()
  hb.atoms().set_xyz(sites_cart_in_b_superposed_on_a)
  hb.write_pdb_file(file_name="moved.pdb")
  # move map and output it
  target_map = superpose_maps(
      unit_cell_1        = xrsA.unit_cell(),
      unit_cell_2        = xrsB.unit_cell(),
      map_data_1         = map_data_A,
      n_real_2           = map_data_B.focus(),
      rotation_matrix    = fit.r.elems,
      translation_vector = fit.t.elems,
      wrapping=True)
  ccp4_map(cg=xrsB.crystal_symmetry(), file_name="AB.ccp4", map_data=target_map)
  
if (__name__ == "__main__"):
  exercise(args=sys.argv[1:])
