DeepDFTModel¶
Included in QATK.MLDFT
- class DeepDFTModel(model_dir, model_type=None, cutoff=None, num_interactions=None, node_size=None, num_components=None, scaling_factor=None, range_type=None, gaussian_sigma=None)¶
Class for a model for predicting electron densities that holds its specific hyperparameters. The model directory will contain the model parameters obtained from the training, the state of the training optimizer, and metadata.
- Parameters:
model_dir (str) – Absolute or relative path to the directory containing the model files.
model_type (DeepDFTModelType) – The type of the model (SchNet/PaiNN).
cutoff (PhysicalQuantity of type length.) – The cutoff radius used for determining bonds between atoms. Default:
4.0 * Angstromnum_interactions (int) – The number of propagation steps “T” in message passing. Default:
3node_size (int) – The size of the embedded feature vector for each node. Default:
128num_components (int) – The number of grid values components (e.g. spin, for polarized calc.) in the model. Default:
1scaling_factor (float) – The output grid values are divided (scaled down) by this number. Default:
1.0range_type (str or None) – The range type for range-separated models (“short-range” or “long-range”). Default:
Nonegaussian_sigma (PhysicalQuantity of type length.) – The sigma of the Gaussian (in Bohr units) for the convolution with the density used when making range-separated density models. Default:
3.0 * Bohr
- absoluteModelDir()¶
- Returns:
The absolute path to the model directory.
- Return type:
str
- absoluteModelDirs()¶
- Returns:
A list of absolute paths to all model directories in the parent directory.
- Return type:
list[str]
- absoluteModelFile()¶
- Returns:
The absolute path to the model file.
- Return type:
str
- atomsToGraphDict(configuration)¶
Create a dict for the graph representation for the atom nodes. The output is a dictionary of tensors that is used as input to the DeepDFT model.
- Parameters:
configuration – The bulk configuration to create the graph from.
- Returns:
The dictionary of torch.tensor for atom nodes, edges, etc.
- copyHyperparameters(model)¶
Copy the hyperparameters from another model.
- Parameters:
model – The model to copy the hyperparameters from.
- cutoff()¶
Return the cutoff radius.
- gaussianSigma()¶
Return the sigma of the Gaussian for convolution with density (for range separation).
- classmethod initializeFromGridValuesModel(model)¶
Initialize the DeepDFTModel from a GridValuesModel, using its model directory and hyperparameters within.
- Parameters:
model – The GridValuesModel to initialize from.
- Returns:
A DeepDFTModel initialized from the given GridValuesModel.
- Return type:
- initializeModel()¶
Initialize an internal DeepDFT model class for training.
- modelDir()¶
Get the directory containing the model files.
- Returns:
The model directory.
- Return type:
str
- modelType()¶
Return the type of the model (SchNet/PaiNN).
- nlinfo()¶
- Returns:
The nlinfo.
- Return type:
dict
- nodeSize()¶
Return the node size the model was trained with.
- numComponents()¶
Return the number of (spin) components the model was trained with.
- numInteractions()¶
Return the number of interactions the model was trained with.
- predictGridValuesWithModel(configuration, grid_descriptor, probe_count, gpu_acceleration)¶
Use the defined model to predict the grid property of a configuration that the model was trained for.
- Parameters:
configuration – The configuration for which to predict the grid values.
grid_descriptor – The grid descriptor to use for the prediction.
probe_count – Number of probe graph nodes to use.
gpu_acceleration – Flag to indicate whether to use GPU acceleration.
- Returns:
The predicted grid values.
- rangeType()¶
Return the range type for range-separated models (“short-range” or “long-range”, or None).
- saveHyperparameters()¶
Save the hyperparameters of the model to the json file in the model directory
- scalingFactor()¶
Return the scaling factor used for the targets during the training of the model.
- static torchDevice(gpu_acceleration)¶
Resolve the flag for GPU acceleration and make a torch device.
:rtype : torch.device
- uniqueString()¶
Return a unique string representing the state of the object.