#! /usr/bin/env python

import os
import sys
import glob
import shutil

FIX = "from __future__ import division"

def needs_fix(p):
  needed = True
  with open(p) as f:
    for aline in f:
      spaced = " ".join(aline.split())
      if spaced == FIX:
        needed = False
        break
  return needed


def fixit(p):
  print "fixing:", p
  new = []
  with open(p) as f:
    fixed = False
    ilines = enumerate(f)
    for i, aline in ilines:
      if aline.startswith("#!"):
        new.append(aline)
        continue
      strp = aline.strip()
      if strp:
        if strp.startswith('"""'):
          new.append(aline)
          if (len(strp) < 6) or (not strp.endswith('"""')):
            for i, aline in ilines:
              new.append(aline)
              if aline.strip().endswith('"""'):
                break
        elif strp.startswith("'''"):
          new.append(aline)
          if (len(strp) < 6) or (not strp.endswith("'''")):
            for i, aline in ilines:
              new.append(aline)
              if aline.strip().endswith("'''"):
                break
        elif ((strp.startswith('"') and strp.endswith('"')) or
              (strp.startswith("'") and strp.endswith("'"))):
          new.append(aline)
        else:
           new.extend(["\n", FIX, "\n\n"])
           new.append(aline)
           for i, aline in ilines:
             new.append(aline)
        
      else:
        new.append(aline)
  new.append("")
  shutil.copy(p, p + ".unfixed")
  with open(p, "w") as f:
    f.write("".join(new))

def main():
  if len(sys.argv) == 1:
    top = "."
  else:
    top = sys.argv[1]

  for pth, dirs, files in os.walk(top):
    pyfiles = glob.glob(os.path.join(pth, "*.py"))
    for p in pyfiles:
      if needs_fix(p):
        fixit(p)

if __name__ == "__main__":
  main()
