
import pylab
import scipy

data_file = 'md_methanol_data.hdf5'
time = nlread(data_file, object_id="Time")[-1]
pressure_tensor = nlread(data_file, object_id="Pressure")[-1]
volume_avg = nlread(data_file, object_id="Volume_Average")[-1]
temperature = nlread(data_file, object_id="Temperature")[-1]
time_step = nlread(data_file, object_id="Time_Step")[-1]

# Reduce down the data by the factor given
# 1 uses all of the gathered data, 10 uses one 10th
# This can be used to speed getting the autocorrelation function
factor = 1
pressure_tensor = pressure_tensor[::factor,:,:]
time = time[::factor]

# Calculate the shear components of the pressure tensor and place them in a convenient array
N = pressure_tensor.shape[0]
P_shear = numpy.zeros((5,N), dtype=float) * Pa
P_shear[0] = pressure_tensor[:N,0,1]
P_shear[1] = pressure_tensor[:N,0,2]
P_shear[2] = pressure_tensor[:N,1,2]
P_shear[3] = (pressure_tensor[:N,0,0] - pressure_tensor[:N,1,1]) / 2
P_shear[4] = (pressure_tensor[:N,1,1] - pressure_tensor[:N,2,2]) / 2

# Calculate the 5 ACF together then average them
# Note that they are only N/2 the length of the data
# because beyond that there are not enough available time windows
# to be accurate
size = int(N/2)
ACF = numpy.zeros((5,size), dtype=float) * Pa**2
for t in range(size):
    ACF[:,t] = numpy.mean(P_shear[:,:N-t] * P_shear[:,t:])
ACF_avg = ACF.mean(axis=0)

# Plotting the ACF
pylab.figure()
pylab.plot(time[:ACF_avg.shape[0]].inUnitsOf(ps), ACF_avg.inUnitsOf(GPa**2), label='ACF')
pylab.xlabel('Time (ps)')
pylab.ylabel('ACF (GPa)')
pylab.legend()
pylab.show()

# Integrate the ACF to get the viscosity
# Note that units need to be put back in because the
# scipy function call strips them
kbT = boltzmann_constant * temperature
ACF_avg *= volume_avg / (kbT)
gk_raw = scipy.integrate.cumtrapz(
    y=ACF_avg.inUnitsOf(Pa),
    dx=time_step.inUnitsOf(Second)
)
viscosity = gk_raw * Pa * Second

# Plotting the time evolution of the viscosity estimate
pylab.figure()
pylab.plot(time[:viscosity.shape[0]].inUnitsOf(ps), viscosity.inUnitsOf(Pa*millisecond), label='Viscosity')
pylab.xlabel('Time (ps)')
pylab.ylabel('Viscosity (cP)')
pylab.legend()
pylab.show()


