Source code for modelforge.train.tuning

"""
This module contains functions and classes for hyperparameter tuning and distributed training using Ray Tune.
"""

import torch

from ray import air, tune


from typing import Type
from ray.tune.schedulers import ASHAScheduler


[docs] def tune_model( model: torch.nn.Module, dataset: torch.utils.data.Dataset, num_samples: int = 100, name: str = "tune", ): """A function to tune a model. Parameters ---------- model : torch.nn.Module The model to tune. dataset : torch.utils.data.Dataset The dataset to use for tuning. num_samples : int, optional The number of samples to use for tuning. Default is 100. """ # access the model's configuration prior # TODO: not finalized yet config_prior = model._config_prior() def objective(): raise NotImplementedError scheduler = ASHAScheduler( time_attr="training_iteration", metric="rmse_vl", mode="max", max_t=100, grace_period=10, reduction_factor=3, brackets=1, ) tune_config = tune.TuneConfig( scheduler=scheduler, num_samples=1000, ) run_config = air.RunConfig( name=name, verbose=1, ) tuner = tune.Tuner( tune.with_resources(objective, {"cpu": 1, "gpu": 1}), param_space=config_prior, tune_config=tune_config, run_config=run_config, ) results = tuner.fit()
[docs] class RayTuner: def __init__(self, model: Type[torch.nn.Module]) -> None: """ Initializes the RayTuner with the given model. Parameters ---------- model : torch.nn.Module The model to be tuned and trained using Ray. """ self.model = model
[docs] def train_func(self): """ Defines the training function to be used with Ray for distributed training. This function configures a PyTorch Lightning trainer with the Ray Distributed Data Parallel (DDP) strategy for efficient distributed training. The training process utilizes a custom training loop and environment setup provided by Ray. Note: This function should be passed to a Ray Trainer or directly used with Ray tasks. """ from ray.train.lightning import ( RayDDPStrategy, RayLightningEnvironment, RayTrainReportCallback, prepare_trainer, ) import lightning as pl # Configure PyTorch Lightning trainer with Ray DDP strategy trainer = pl.Trainer( devices="auto", accelerator="auto", strategy=RayDDPStrategy(find_unused_parameters=True), callbacks=[RayTrainReportCallback()], plugins=[RayLightningEnvironment()], enable_progress_bar=False, ) trainer = prepare_trainer(trainer) # Fit the model using the trainer trainer.fit(self.model, self.train_dataloader, self.val_dataloader)
[docs] def get_ray_trainer(self, number_of_workers: int = 2, use_gpu: bool = False): """ Initializes and returns a Ray Trainer for distributed training. Configures a Ray Trainer with a specified number of workers and GPU usage settings. This trainer is prepared for distributed training using Ray, with support for checkpointing. Parameters ---------- number_of_workers : int, optional The number of distributed workers to use, by default 2. use_gpu : bool, optional Specifies whether to use GPUs for training, by default False. Returns ------- TorchTrainer The configured Ray Trainer for distributed training. """ from ray.train import CheckpointConfig, RunConfig, ScalingConfig from ray.train.torch import TorchTrainer # Configure scaling for Ray Trainer scaling_config = ScalingConfig( num_workers=number_of_workers, use_gpu=use_gpu, resources_per_worker={"CPU": 1, "GPU": 1} if use_gpu else {"CPU": 1}, ) # Configure run settings for Ray Trainer run_config = RunConfig( checkpoint_config=CheckpointConfig( num_to_keep=2, checkpoint_score_attribute="val/energy/rmse", checkpoint_score_order="min", ), ) # Define and return the TorchTrainer ray_trainer = TorchTrainer( self.train_func, scaling_config=scaling_config, run_config=run_config, ) return ray_trainer
[docs] def tune_with_ray( self, train_dataloader, val_dataloader, number_of_epochs: int = 5, number_of_samples: int = 10, number_of_ray_workers: int = 2, train_on_gpu: bool = False, metric: str = "val/per_system_energy/rmse", ): """ Performs hyperparameter tuning using Ray Tune. This method sets up and starts a Ray Tune hyperparameter tuning session, utilizing the ASHA scheduler for efficient trial scheduling and early stopping. Parameters ---------- train_dataloader : DataLoader The DataLoader for training data. val_dataloader : DataLoader The DataLoader for validation data. number_of_epochs : int, optional The maximum number of epochs for training, by default 5. number_of_samples : int, optional The number of samples (trial runs) to perform, by default 10. number_of_ray_workers : int, optional The number of Ray workers to use for distributed training, by default 2. train_on_gpu : bool, optional Whether to use GPUs for training, by default False. metric : str, optional The metric to use for evaluation and early stopping, by default "val/per_system_energy/rmse Returns ------- ExperimentAnalysis The result of the hyperparameter tuning session, containing performance metrics and the best hyperparameters found. """ from ray import tune from ray.tune.schedulers import ASHAScheduler self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader # Initialize Ray Trainer ray_trainer = self.get_ray_trainer( number_of_workers=number_of_ray_workers, use_gpu=train_on_gpu ) # Configure ASHA scheduler for early stopping scheduler = ASHAScheduler( max_t=number_of_epochs, grace_period=1, reduction_factor=2 ) # Define tuning configuration tune_config = tune.TuneConfig( metric=metric, mode="min", scheduler=scheduler, num_samples=number_of_samples, ) # Initialize and run the tuner tuner = tune.Tuner( ray_trainer, param_space={"train_loop_config": self.model.config_prior()}, tune_config=tune_config, ) return tuner.fit()