# This is part of Qbics for fitting xtb parameters.
# Author: Jun Zhang

import sys,os,copy

class xtbFit:
    def __init__(self):
        self.Qbics = "/disk/zhangjun/soft-source-env/qbics-source/bin/qbics-linux-cpu"

    def readParams(self, fn):
        with open(fn, 'r') as f:
            params_lines = f.readlines()        
        return params_lines
    def writeParams(self, params_lines, fn):
        with open(fn, 'w') as f:
            f.writelines(params_lines)
    def updateParams(self, params_lines, variables):
        """Update parameters in the parameter file."""
        new_params_lines = copy.deepcopy(params_lines)
        for var in variables:
            idx = var[0]
            val = var[1]
            words = new_params_lines[idx[0]].split()
            words[idx[1]] = "%.6f" % (val)
            new_params_lines[idx[0]] = " ".join(words) + "\n"
        return new_params_lines
    def getOptimizableParams(self, params_lines, Zs):
        """Find optimizable parameters."""
        variables = []
        for Z in Zs:
            # Find the start and end of the Z block.
            start_idx = -1
            end_idx = 0
            for i, line in enumerate(params_lines):
                if line.startswith("$Z=%2d" % (Z)):
                    start_idx = i
                if line.startswith("$end") and start_idx >= 0:
                    end_idx = i
                    break
            # Build list.
            for i in range(start_idx+2, end_idx):
                words = params_lines[i].split()
                for j in range(1, len(words)):
                    variables.append([[i,j],float(words[j])])
        return variables

    def buildReferencePoints(self, traj_fn, grad_fn):
        """Read trajectory and gradient files."""
        with open(traj_fn, 'r') as f:
            traj_lines = f.readlines()
        with open(grad_fn, 'r') as f:
            grad_lines = f.readlines()
        if len(traj_lines) != len(grad_lines):
            print("Error: The numbers of lines for trajectory (%d) and gradient (%d) files are different." % (len(traj_lines), len(grad_lines)))
            sys.exit(1)
        num_atoms = int(traj_lines[0])
        num_refs = len(traj_lines)/(num_atoms+2)
        if num_refs != int(num_refs):
            print("Error: Trajectory file has incorrect number of reference data: %.5f" % (num_refs))
            sys.exit(1)
        num_refs = int(num_refs)        
        ref_points = []
        for i in range(num_refs):
            # Coordinates.
            coords = traj_lines[i*(num_atoms+2):(i+1)*(num_atoms+2)]            
            energy = float(coords[1].split()[9])
            grads = [[float(a) for a in line.split()[1:]] for line in grad_lines[i*(num_atoms+2)+2:(i+1)*(num_atoms+2)]]
            ref_points.append([coords, energy, grads])
        return ref_points
    def calcEnergyAndGrad(self, coords):
        cmd = self.Qbics+" temp.inp > temp.out"
        # Write coordinates.
        with open("temp.xyz", 'w') as f:
            f.writelines(coords)
        # Run.
        os.system(cmd)
        # Read output.
        with open("temp-grad.xyz", 'r') as f:
            grad_lines = f.readlines()
            energy = float(grad_lines[1].split()[2])
            grads = [[float(a) for a in line.split()[1:]] for line in grad_lines[2:]]        
        return [energy, grads]
    def calcLoss(self, ref_points, params_lines, variables, grad_weight, i):
        """Calculate loss function."""
        new_params_lines = self.updateParams(params_lines, variables)
        new_params_fn = "temp.txt"
        self.writeParams(new_params_lines, new_params_fn)        
        if i >= 0:
            os.system("cp %s params-%i.txt" % (new_params_fn, i))        
        # Do calculations.
        pred_points = []
        for ref_point in ref_points:
            pred_points.append(self.calcEnergyAndGrad(ref_point[0]))
        # Calculate errors.
        mean_ref_energy = 0.
        mean_pred_energy = 0.
        for j in range(len(ref_points)):
            ref_point = ref_points[j]
            pred_point = pred_points[j]
            ref_energy = ref_point[1]
            pred_energy = pred_point[0]
            mean_ref_energy += ref_energy
            mean_pred_energy += pred_energy
        mean_ref_energy /= len(ref_points)
        mean_pred_energy /= len(ref_points)
        energy_error = 0.
        for j in range(len(ref_points)):
            ref_point = ref_points[j]
            pred_point = pred_points[j]
            ref_energy = ref_point[1]
            pred_energy = pred_point[0]
            energy_error += ((ref_energy - mean_ref_energy) - (pred_energy - mean_pred_energy))**2                    
        # Gradient error.
        grad_error = 0.
        for j in range(len(ref_points)):
            ref_point = ref_points[j]
            pred_point = pred_points[j]
            ref_grads = ref_point[2]
            pred_grads = pred_point[1]
            for k in range(len(ref_grads)):
                ref_grad = ref_grads[k]
                pred_grad = pred_grads[k]
                grad_error += ((ref_grad[0] - pred_grad[0])**2 + (ref_grad[1] - pred_grad[1])**2 + (ref_grad[2] - pred_grad[2])**2)  
        # Combine them.          
        total_loss = energy_error + grad_weight * grad_error
        return energy_error, grad_error, total_loss
            
    def fitParams(self, params_fn, traj_fn, grad_fn, Zs, maxIters, grad_weight, conv_threshold):
        # Build reference data.
        ref_points = self.buildReferencePoints(traj_fn, grad_fn)        
        print("Number of reference data points:", len(ref_points))
        # Find optimizable parameters.
        params_lines = self.readParams(params_fn)
        variables = self.getOptimizableParams(params_lines, Zs)
        print("Number of optimizable parameters:", len(variables))
        # Optimize.
        learning_rate = 1.5
        for i in range(maxIters):
            print("Iteration %d:" % (i))
            print(" Calculating loss ...")
            nref = 10*(i+1)**2
            if nref > len(ref_points): nref = len(ref_points)
            print(" Use: %d points" % (nref))
            energy_loss, grad_loss, total_loss = self.calcLoss(ref_points[:nref], params_lines, variables, grad_weight, i)
            print(" Parameters saved to: params-%d.txt" % (i))
            print(" Energy loss:           %.8f Hartree" % (energy_loss)**0.5)
            print(" Energy loss (per mol): %.8f kcal/mol" % ((energy_loss/len(ref_points))**0.5*627.510))
            print(" Gradient loss:         %.8f Hartree/Bohr" % (grad_loss)**0.5)
            print(" Total loss:            %.8f " % total_loss**0.5, flush = True)
            if total_loss < conv_threshold:
                print("Converged.")
                break
            print(" Learning rate: %.8f" % (learning_rate))
            print(" Predicting variable movement:")
            print("  %3s: %15s %15s" % ("#", "P", "dL/dP"))
            dx = 0.001
            dp = [0. for _ in range(len(variables))]
            for j in range(len(variables)):
                variables0 = copy.deepcopy(variables)
                variables0[j][1] += dx
                energy_loss0, grad_loss0, total_loss0 = self.calcLoss(ref_points[:nref], params_lines, variables0, grad_weight, -1)
                dp[j] = (total_loss0 - total_loss) / dx
                print("  %3d: %15.8f %15.8f" % (j, variables[j][1], dp[j]), flush = True)
            # Update variables.
            for j in range(len(variables)):
                variables[j][1] -= learning_rate * dp[j]
        # Clean up.
        os.system("rm temp.xyz temp-grad.xyz temp.out temp.txt")

# Do NOT change the above function.

def WriteTemp():
    with open("temp.inp", 'w') as f:
        f.write("""xtb
    chrg   0
    uhf    0
    gfn    2
    vparam temp.txt
end
mol
    temp.xyz
end
task
    grad xtb
end""")

if __name__ == "__main__":
    params_fn = "params-init.txt"
    traj_fn = "traj.xyz"
    grad_fn = "grad.xyz"    
    Zs = [53]
    maxIters = 4
    grad_weight = 0.1
    conv_threshold = 1E-3
    WriteTemp()
    xf = xtbFit()
    xf.fitParams(params_fn, traj_fn, grad_fn, Zs, maxIters, grad_weight, conv_threshold)
    print("Done.")
