import matplotlib.pyplot as plt

# Increase the number of MD steps to converge the histograms.
STEPS = 5000  #500000

# The pressure, volume and temperature distributions are plotted on separate axes.
fig, axes_ = plt.subplots(nrows=1, ncols=3)

axes = {'pressure': axes_[0], 'volume': axes_[1], 'temperature': axes_[2]}

axes['pressure'].set_xlabel("Pressure / GPa")
axes['pressure'].set_ylabel("Probability")

axes['volume'].set_xlabel(r"Volume / nm$^3$")
axes['volume'].set_ylabel("Probability")

axes['temperature'].set_xlabel("Temperature / Kelvin")
axes['temperature'].set_ylabel("Probability")

# Set up configuration
bulk_configuration = BulkConfiguration(
    bravais_lattice=FaceCenteredCubic(5.4306*Angstrom),
    elements=[Silicon, Silicon],
    fractional_coordinates=[[0.0, 0.0, 0.0], [0.25, 0.25, 0.25]]
    )

# Set calculator
calculator = TremoloXCalculator(parameters=Tersoff_Si_1988b())
bulk_configuration.setCalculator(calculator)

# Run MD simulations with each barostat.
for barostat_class in NPTBerendsen, NPTBernettiBussi, NPTMartynaTobiasKlein:
    # Set up MD method
    method = barostat_class(
        time_step=1*femtoSecond,
        initial_velocity=MaxwellBoltzmannDistribution(temperature=300*Kelvin),
        # thermostat settings
        reservoir_temperature=300*Kelvin,
        thermostat_timescale=100*femtoSecond,
        heating_rate=0*Kelvin/picoSecond,
        # barostat settings
        reservoir_pressure=1.0*bar,
        barostat_timescale=1000.0*femtoSecond,
    )

    # Run long MD simulation.
    md_trajectory = MolecularDynamics(
        bulk_configuration,
        steps=STEPS,
        log_interval=50,
        method=method
    )

    # Choose color and label for barostat
    color = next(plt.gca()._get_lines.prop_cycler)["color"]
    label = barostat_class.__name__

    #nlsave(label+'.hdf5', md_trajectory, object_id='NPT')

    # Make histograms for pressure, temperature and volume along the trajectory.
    pressures = md_trajectory.pressures()
    axes['pressure'].hist(pressures.inUnitsOf(GPa),
                          bins='auto', histtype='step',
                          alpha=0.5, color=color,
                          stacked=True, density=True,
                          label=f'{label}')
    
    temperatures = md_trajectory.temperatures()
    axes['temperature'].hist(temperatures.inUnitsOf(Kelvin),
                             bins='auto', histtype='step',
                             alpha=0.5, color=color,
                             stacked=True, density=True,
                             label=f'{label}')
    
    volumes = md_trajectory.volumes()
    axes['volume'].hist(volumes.inUnitsOf(nanoMeter**3), 
                        bins='auto', histtype='step',
                        alpha=0.5, color=color,
                        stacked=True, density=True,
                        label=f'{label}')

#
# For the temperature we know the exact distribution in the canonical ensemble.
#

# Find number of degrees of freedom
Nf = 3 * bulk_configuration.numberOfAtoms()

# To avoid overflows reorganize the expression for P(K),
# so the logs of large quantities cancel before exponentiating.
def logGamma(x):
    """ Expansion of log(Gamma(x)) for large x """
    lG = (x-0.5) * numpy.log(x) - x + 0.5*numpy.log(2.0*numpy.pi)
    return lG


temperatures.sort()
reservoir_temperature = 300.0 * Kelvin
# ratio of current temperature to set temperature
r = temperatures/reservoir_temperature

# To avoid overflows we have to reorganize the expression for P(T),
# so the logs of large quantities cancel before exponentiating.
probability_canonical = numpy.exp(
    -Nf/2.0 * (r - numpy.log(r) - numpy.log(Nf/2)) - logGamma(Nf/2.0)
) / temperatures

# Plot canonical temperature distribution in black
axes['temperature'].plot(
    temperatures, probability_canonical,
    color='black', lw=2, label=f'Canonical'
)


axes['temperature'].legend()
plt.tight_layout()

plt.show()
