Trainer

Main Trainer Module

class gretel_trainer.trainer.Trainer(project_name: str = 'trainer', model_type: _BaseConfig | None = None, cache_file: str | None = None, overwrite: bool = True, session: ClientConfig | None = None)

Automated model training and synthetic data generation tool

Parameters:
  • project_name (str, optional) – Gretel project name. Defaults to “trainer”.

  • model_type (_BaseConfig, optional) – Options include GretelLSTM(), GretelACTGAN(). If unspecified, the best option will be chosen at train time based on the training dataset.

  • cache_file (str, optional) – Select a path to save or load the cache file. Default is [project_name]-runner.json.

  • overwrite (bool, optional) – Overwrite previous progress. Defaults to True.

generate(num_records: int = 500, seed_df: DataFrame | None = None) DataFrame

Generate synthetic data

Parameters:
  • num_records (int, optional) – Number of records to generate from model. Defaults to 500.

  • seed_df (pd.DataFrame, optional) – Pandas DataFrame of values to seed the model with. Defaults to None.

Returns:

Synthetic data.

Return type:

pd.DataFrame

get_sqs_score() int

Return the average SQS synthetic data quality score.

Requires the model has been trained.

classmethod load(cache_file: str = 'trainer-runner.json', project_name: str = 'trainer', session: ClientConfig | None = None) Trainer

Load an existing project from a cache.

Parameters:
  • cache_file (str, optional) – Valid file path to load the cache file from. Defaults to [project-name]-runner.json

  • project_name (str, optional) – Gretel project name. This should match the original project. Defaults to “trainer”

Returns:

returns a Trainer instance with an initialized StrategyRunner class.

Return type:

Trainer

train(dataset_path: str, delimiter: str = ',', round_decimals: int = 4, seed_fields: list | None = None)

Train a model on the dataset

Parameters:
  • dataset_path (str) – Path or URL to CSV

  • delimiter (str, optional) – Delimiter to use when reading the dataset. Defaults to comma (“,”).

  • round_decimals (int, optional) – Round decimals in CSV as preprocessing step. Defaults to 4.

  • seed_fields (list, optional) – List fields that can be used for conditional generation.