
import pylab
import scipy

# Read the data that is stored in the HDF5 file from the simulation
data_file = 'md_methanol_data.hdf5'
time = nlread(data_file, object_id="Time_8")[-1]
pressure_tensor = nlread(data_file, object_id="Pressure_8")[-1]
volume_avg = nlread(data_file, object_id="Volume_Average_8")[-1]
temperature = nlread(data_file, object_id="Temperature_8")[-1]
time_step = nlread(data_file, object_id="Time_Step_8")[-1]

# Set up the calculation to create 100 time based estimates of the viscosity
N = pressure_tensor.shape[0]
N_steps = 101
skip = int(N/100)
time_skip = time[::skip]

# Calculate the off-diagonal elements of the pressure tensor
P_shear = numpy.zeros((5,N), dtype=float) * Joule / Meter**3
P_shear[0] = pressure_tensor[:,0,1]
P_shear[1] = pressure_tensor[:,0,2]
P_shear[2] = pressure_tensor[:,1,2]
P_shear[3] = (pressure_tensor[:,0,0] - pressure_tensor[:,1,1]) / 2
P_shear[4] = (pressure_tensor[:,1,1] - pressure_tensor[:,2,2]) / 2

# At increasing time lengths, calculate the viscosity based on that part of the simulatino
pressure_integral = numpy.zeros(N_steps, dtype=numpy.float) * (Second * Joule / Meter**3)**2
for t in range(1,N_steps):
    total_step = t*skip

    for i in range(5):
        integral = scipy.integrate.trapz(
            y = P_shear[i][:total_step].inUnitsOf(Joule/Meter**3),
            dx=time_step.inUnitsOf(Second)
        )
        integral *= Second * Joule / Meter**3
        pressure_integral[t] += integral**2 / 5

# Finally calculate the overall viscosity
# Note that here the first step is skipped to avoid divide by zero issues
kbT = boltzmann_constant * temperature
viscosity = pressure_integral[1:] * volume_avg / (2*kbT*time_skip[1:])

# Print the final viscosity
print("Viscosity is {} cP".format(viscosity[-1].inUnitsOf(millisecond*Pa)))

# Display the evolution of the viscosity in time
pylab.figure()
pylab.plot(time_skip[1:].inUnitsOf(ps), viscosity.inUnitsOf(millisecond*Pa), label='Viscosity')
pylab.xlabel('Time (ps)')
pylab.ylabel('Viscosity (cP)')
pylab.legend()
pylab.show()
