GridValuesDataset¶
Included in QATK.MLDFT
- class GridValuesDataset(dataset_dirs, probe_count_train=None, probe_count_val=None, validation_ratio=None, seed=None, grid_sampler=None, **kwargs)¶
- Parameters:
dataset_dirs – List of directories containing the training data (hdf5 files). There could be multiple files containing multiple ElectronDifferenceDensity objects.
probe_count_train – The number of probe nodes used in each training step. Default:
1000.probe_count_val – The number of probe nodes used for validation. Default:
5000.validation_ratio – The ratio of the dataset used for validation. Default:
0.05.seed – The seed used to randomly split the data to training and validation. Default:
91824789.grid_sampler – The grid sampling strategy to use (uses a different seed by default). Default:
UniformRandomGridSampler().
- cutoffAngstrom()¶
Return the edge cutoff radius used in the model.
- static electronegativity(atomic_number)¶
Return the electronegativity value for a given atom number.
- getMemberIndex(file, object_id)¶
Return the index in self.member_list of the given dataset item given by file path and object_id.
- Parameters:
file – The file path to the dataset file.
object_id – The object ID to search for.
- Returns:
The index of the member in the list.
- Raises:
NLValueError – If the member is not found in the list.
- gridSampler()¶
Return the grid sampling strategy.
- indicesTraining()¶
Return the list of indices used for training.
- indicesValidation()¶
Return the list of indices used for validation.
- memberList()¶
Return the list of (file, object_id) tuples for all dataset members.
- numComponents()¶
Return the number of components used in the model.
- probeCounts()¶
Return the array of probe counts for each member.
- scalingFactor()¶
Return the scaling factor used in the model.
- seed()¶
Return the seed used for random train/validation split.
- setModelParameters(cutoff, num_components, scaling_factor)¶
Use the model’s parameters to set them on the dataset. This must be done before the dataset is used for training or validation, since the model parameters are needed for processing the grid values and constructing graphs.
- trainingDataLoader(batch_size)¶
Return a DataLoader for the training data.
- validationDataLoader(batch_size)¶
Return a DataLoader for the validation data.