# -*- coding: utf-8 -*-
# -------------------------------------------------------------
# Bulk Configuration
# -------------------------------------------------------------

# Set up lattice
vector_a = [10.559292205578911, 0.0, 0.0]*Angstrom
vector_b = [0.0, 10.559292205578911, 0.0]*Angstrom
vector_c = [0.0, 0.0, 10.559292205578911]*Angstrom
lattice = UnitCell(vector_a, vector_b, vector_c)

# Define elements
elements = [Hafnium, Hafnium, Hafnium, Hafnium, Hafnium, Hafnium, Hafnium,
            Hafnium, Hafnium, Hafnium, Hafnium, Hafnium, Hafnium, Hafnium,
            Hafnium, Hafnium, Hafnium, Hafnium, Hafnium, Hafnium, Hafnium,
            Hafnium, Hafnium, Hafnium, Hafnium, Hafnium, Hafnium, Hafnium,
            Hafnium, Hafnium, Hafnium, Hafnium, Oxygen, Oxygen, Oxygen, Oxygen,
            Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen,
            Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen,
            Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen,
            Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen,
            Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen,
            Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen,
            Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen,
            Oxygen, Oxygen, Oxygen, Oxygen]

# Define coordinates
fractional_coordinates = [[ 0.24941383618 ,  0.588908558001,  0.721482382728],
                          [ 0.902855608094,  0.719311096057,  0.503324733774],
                          [ 0.667773850045,  0.680021940826,  0.315112130795],
                          [ 0.366051408441,  0.397870673559,  0.419761956664],
                          [ 0.402185685125,  0.937865521348,  0.738330840107],
                          [ 0.721933221283,  0.651477971505,  0.011405373011],
                          [ 0.802870790841,  0.390726574845,  0.139927625652],
                          [ 0.320785363317,  0.164137591684,  0.033560443696],
                          [ 0.375926436776,  0.719763134578,  0.333932254362],
                          [ 0.00348302074 ,  0.086709693209,  0.911120104968],
                          [ 0.20771836228 ,  0.83015833482 ,  0.991715244064],
                          [ 0.362695814842,  0.096344858485,  0.431126156569],
                          [ 0.476717433039,  0.876712805957,  0.068670223918],
                          [ 0.351358665305,  0.519692883241,  0.094315300364],
                          [ 0.525217288218,  0.437052464567,  0.879633151035],
                          [ 0.705196705606,  0.855428643414,  0.791629278211],
                          [ 0.878184581399,  0.927916545828,  0.217583615261],
                          [ 0.23207171114 ,  0.217455358051,  0.735154359269],
                          [ 0.671706736126,  0.117231777413,  0.963643600579],
                          [ 0.817282805511,  0.36340170513 ,  0.804152093636],
                          [ 0.594174398458,  0.620649184791,  0.636499414818],
                          [ 0.903516534077,  0.096908590632,  0.561614408785],
                          [ 0.185317578193,  0.882930422275,  0.509346537725],
                          [ 0.762405231515,  0.38849340748 ,  0.492784744308],
                          [ 0.544162451925,  0.25195397921 ,  0.653169201868],
                          [ 0.608950103883,  0.93025028558 ,  0.50702464697 ],
                          [ 0.066129448571,  0.200104512632,  0.304150389092],
                          [ 0.608394526732,  0.176018727799,  0.23702966881 ],
                          [ 0.066496464619,  0.442434451326,  0.536037242962],
                          [ 0.05433759723 ,  0.641201369156,  0.25423989688 ],
                          [ 0.98891732283 ,  0.795803846329,  0.772234112962],
                          [ 0.057808815739,  0.561230452414,  0.983634394272],
                          [ 0.279068727518,  0.537498897632,  0.525608287916],
                          [ 0.44008248514 ,  0.564742799056,  0.737874139022],
                          [ 0.170521217557,  0.055030482874,  0.401934159979],
                          [ 0.686165221727,  0.534594390296,  0.201618176546],
                          [ 0.440718428366,  0.241177468471,  0.319172982909],
                          [ 0.523289806344,  0.680287129434,  0.457252989328],
                          [ 0.657745322496,  0.443610277121,  0.702753349166],
                          [ 0.73874647351 ,  0.043944835248,  0.144572597241],
                          [ 0.733330682334,  0.462139189608,  0.954685122268],
                          [ 0.404919684943,  0.320456174065,  0.770299948873],
                          [ 0.588738214224,  0.004274694557,  0.829722659715],
                          [ 0.893013107766,  0.431415314967,  0.634955021664],
                          [ 0.113237652051,  0.932540729764,  0.858198025755],
                          [ 0.740479066476,  0.583727295913,  0.479122990855],
                          [ 0.344110892442,  0.249432914223,  0.550136813505],
                          [ 0.517935583258,  0.580214580029,  0.015414590368],
                          [ 0.758162359153,  0.981278439109,  0.618883784773],
                          [ 0.966623275059,  0.795937585934,  0.336476453353],
                          [ 0.382304240773,  0.833884248634,  0.901043024078],
                          [ 0.909482297748,  0.054799823918,  0.367091498701],
                          [ 0.843201367225,  0.190372412267,  0.900254184443],
                          [ 0.159699260491,  0.200203661295,  0.922240974345],
                          [ 0.053968934598,  0.645430123217,  0.622085493361],
                          [ 0.925826052659,  0.540892818164,  0.134289561292],
                          [ 0.165416319115,  0.398670669672,  0.691544111348],
                          [ 0.665733443228,  0.27030104043 ,  0.085098136345],
                          [ 0.282483316799,  0.786253470408,  0.653037521779],
                          [ 0.173209147742,  0.35429440692 ,  0.393576993577],
                          [ 0.962210703148,  0.27100432793 ,  0.47105163203 ],
                          [ 0.65491740313 ,  0.666378789789,  0.819675632631],
                          [ 0.974112975927,  0.421409643317,  0.894372042151],
                          [ 0.259824979584,  0.030232543926,  0.626385150949],
                          [ 0.820749958436,  0.746258146649,  0.671512093507],
                          [ 0.896015209182,  0.685048874803,  0.913407674178],
                          [ 0.000802912835,  0.006003393166,  0.090941440308],
                          [ 0.298147168479,  0.968372352246,  0.089306954055],
                          [ 0.729665916716,  0.307145499105,  0.322405853303],
                          [ 0.076474450163,  0.740809301522,  0.092488519106],
                          [ 0.37222270705 ,  0.899283890626,  0.426985393272],
                          [ 0.784014922167,  0.759411172403,  0.174067084567],
                          [ 0.168403482866,  0.51991729442 ,  0.155216845056],
                          [ 0.554826279748,  0.82363245738 ,  0.657957284668],
                          [ 0.3287726963  ,  0.725156788691,  0.128929595389],
                          [ 0.738256801819,  0.225122614513,  0.595458330136],
                          [ 0.646163421945,  0.836535442404,  0.986390847321],
                          [ 0.134918140314,  0.69592600534 ,  0.852968772801],
                          [ 0.498251781913,  0.067868400654,  0.594930954447],
                          [ 0.025292079595,  0.148412684824,  0.706236792884],
                          [ 0.624139981164,  0.259315995186,  0.833783994032],
                          [ 0.569547861564,  0.04540227083 ,  0.365772506996],
                          [ 0.411388952542,  0.343261752668,  0.032043272541],
                          [ 0.220572417983,  0.19714194811 ,  0.190840977687],
                          [ 0.947974714731,  0.276665306637,  0.177529342417],
                          [ 0.869013638066,  0.939693437027,  0.851562567373],
                          [ 0.352304382485,  0.091298512383,  0.848633137133],
                          [ 0.507068863785,  0.078739681492,  0.081120908203],
                          [ 0.191515178593,  0.734685592687,  0.371755703642],
                          [ 0.40055984606 ,  0.530829320339,  0.282988602457],
                          [ 0.549769330001,  0.388670168995,  0.502944191824],
                          [ 0.008816974987,  0.889659572675,  0.58704419235 ],
                          [ 0.972656965386,  0.552146944663,  0.404502024642],
                          [ 0.531051809837,  0.794173156049,  0.236180870299],
                          [ 0.740293266149,  0.822730667041,  0.421446693542],
                          [ 0.256830517134,  0.523327690969,  0.92422328655 ]]

# Set up configuration
bulk_configuration = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates
    )
nlsave('active-learning.hdf5', bulk_configuration)

#-------------------------------------------------------------
# Reference Calculator
# -------------------------------------------------------------
k_point_sampling = KpointDensity(
    density_a=4.0*Angstrom,
    )
numerical_accuracy_parameters = NumericalAccuracyParameters(
    density_mesh_cutoff=125.0*Hartree,
    k_point_sampling=k_point_sampling,
    )

calculator = LCAOCalculator(
    numerical_accuracy_parameters=numerical_accuracy_parameters,
    )

# compile initial training data

#---------------------------------------------
#initial data be crystalTrainingRandomDisplacements object
#---------------------------------------------
training_set=nlread('initial_training_sets.hdf5')
if(len(training_set)>1):
    training_set=[item for i in training_set for item in i]

#---------------------------------------------
#uncomment only if data is a Trajectory or MDTrajectory object
#Reference configurations loaded from a trajectory object from file,
#trajectory may contain energy, force and stress data as well
#extract 100 snapshots (uniformly distributed) from the trajectory and recompute using the calculator set above    
#for multiple trajetories create training_set_1, training_set_2, etc
#https://docs.quantumatk.com/manual/Types/TrainingSet/TrainingSet.html
#---------------------------------------------
#trajectory=nlread('initial_training_sets.hdf5',Trajectory)[-1]
#trajectory=nlread('initial_training_sets.hdf5',MDTrajectory)[-1]
#training_set = TrainingSet(trajectory, 
#                        #sample_size=100, 
#                        calculator=calculator,
#                        recalculate_training_data=True)            

# compute reference data using the above loaded calculator

#---------------------------------------------
#set up MTP training
#for multiple training_sets, use [training_set_1, training_set_2]
#---------------------------------------------
moment_tensor_potential_training = MomentTensorPotentialTraining(
    filename="reference-data.hdf5",
    object_id='reference calculation',
    training_sets=training_set,
    calculator=calculator,
    calculate_stress=True,
    random_seed=13345,
    number_of_processes_per_task=8,
    log_filename_prefix='mtp_'
)

#---------------------------------------------
#fit MTP
#---------------------------------------------
moment_tensor_potential_training.update()

initial_training_data=[moment_tensor_potential_training]

# Set MTP training parameters

#---------------------------------------------
# Set up non-linear coefficients with optimization.
#https://docs.quantumatk.com/manual/Types/NonLinearCoefficientsParameters/NonLinearCoefficientsParameters.html
#---------------------------------------------
non_linear_coefficients_parameters = NonLinearCoefficientsParameters(
   perform_optimization=True,
   random_seed=500,
   max_steps=1,
   energy_only=True,
)

#---------------------------------------------
#use default MTP fitting parameters
#https://docs.quantumatk.com/manual/Types/MomentTensorPotentialFittingParameters/MomentTensorPotentialFittingParameters.html
#---------------------------------------------
fitting_parameters = MomentTensorPotentialFittingParameters(
    basis_size=300,
    inner_cutoff_radii=1.0 * Angstrom,
    tapering_cutoff_radii=1.1 * Angstrom,
    outer_cutoff_radii= 5.0 * Angstrom,
    mtp_filename='mtp_potential.mtp',
    forces_cap= 100.0 * eV / Angstrom,
    non_linear_coefficients_parameters=non_linear_coefficients_parameters,
)


# Set up active learning MD

#---------------------------------------------
# Set up active learning
#---------------------------------------------
active_learning = ActiveLearningSimulation(
    fitting_parameters=fitting_parameters,
    initial_training_data=initial_training_data,
    mtp_study_filename='HfO2-final-dataset',
    mtp_study_object_id='HfO2 amorphous',
    reference_calculator=calculator,
    candidate_threshold=1.0,
    retrain_threshold=3.0,
    check_interval=20,
    max_forces_check=10.0*eV/Ang,
    candidate_trajectory_filename='active_learning_candidates.hdf5',
)

#---------------------------------------------
# Set up a high-temperature MD at 3000 K.
# uncomment below the MD ensemble of interest
#---------------------------------------------
initial_velocity = MaxwellBoltzmannDistribution(
    temperature=3000.0*Kelvin,
    remove_center_of_mass_momentum=True,
    random_seed=None,
    enforce_temperature=True,
)

#-------------------
#NPT
#-------------------
method = NPTMartynaTobiasKlein(
    time_step=1*femtoSecond,
    reservoir_temperature=3000*Kelvin,
    reservoir_pressure=1*bar,
    thermostat_timescale=100*femtoSecond,
    barostat_timescale=500*femtoSecond,
    initial_velocity=initial_velocity,
    heating_rate=0*Kelvin/picoSecond,
    compression_rate=0*bar/femtoSecond,
    chain_length=3,
)

#------------------------
#Langevin
#------------------------
#method = Langevin(
#    time_step=1*femtoSecond,
#    reservoir_temperature=6000*Kelvin,
#    friction=0.05*femtoSecond**-1,
#    initial_velocity=initial_velocity,
#)


constraints = [FixCenterOfMass()]

#---------------------------------------------
# Run the MD simulation using bulk_configuration through the active learning object.
#---------------------------------------------

# Scale the volume to facilitate melting.
bulk_configuration.scaleVolume(1.1)

md_trajectory = active_learning.runMolecularDynamics(
    bulk_configuration,
    constraints=constraints,
    trajectory_filename='active_learning_3000K.hdf5',
    steps=200000,
    log_interval=500,
    method=method,
    domain_decomposition_pattern=[1, 1, 1],
)
