Trainer
Main Trainer Module
- class gretel_trainer.trainer.Trainer(project_name: str = 'trainer', model_type: _BaseConfig | None = None, cache_file: str | None = None, overwrite: bool = True, session: ClientConfig | None = None)
Automated model training and synthetic data generation tool
- Parameters:
project_name (str, optional) – Gretel project name. Defaults to “trainer”.
model_type (_BaseConfig, optional) – Options include GretelLSTM(), GretelACTGAN(). If unspecified, the best option will be chosen at train time based on the training dataset.
cache_file (str, optional) – Select a path to save or load the cache file. Default is [project_name]-runner.json.
overwrite (bool, optional) – Overwrite previous progress. Defaults to True.
- generate(num_records: int = 500, seed_df: DataFrame | None = None) DataFrame
Generate synthetic data
- Parameters:
num_records (int, optional) – Number of records to generate from model. Defaults to 500.
seed_df (pd.DataFrame, optional) – Pandas DataFrame of values to seed the model with. Defaults to None.
- Returns:
Synthetic data.
- Return type:
pd.DataFrame
- get_sqs_score() int
Return the average SQS synthetic data quality score.
Requires the model has been trained.
- classmethod load(cache_file: str = 'trainer-runner.json', project_name: str = 'trainer', session: ClientConfig | None = None) Trainer
Load an existing project from a cache.
- Parameters:
cache_file (str, optional) – Valid file path to load the cache file from. Defaults to [project-name]-runner.json
project_name (str, optional) – Gretel project name. This should match the original project. Defaults to “trainer”
- Returns:
returns a Trainer instance with an initialized StrategyRunner class.
- Return type:
- train(dataset_path: str, delimiter: str = ',', round_decimals: int = 4, seed_fields: list | None = None)
Train a model on the dataset
- Parameters:
dataset_path (str) – Path or URL to CSV
delimiter (str, optional) – Delimiter to use when reading the dataset. Defaults to comma (“,”).
round_decimals (int, optional) – Round decimals in CSV as preprocessing step. Defaults to 4.
seed_fields (list, optional) – List fields that can be used for conditional generation.