ML Training Blueprints

Note

Tensorflow and LightGBM Blueprints are developed and can be accessed in the corresponding branches. They will not be mereged into the main branch until the the release of FastIOT 2.0.

What are ML Training Blueprints?

ML Training Blueprints are templates for training a ML model in a microservice architecture via the FastIoT framework. Each blueprint is designed to be an example how to train a specific model type and framework (like for example a regression model in pytorch or a classifier in LightGBM). Most frameworks have some peculiarities when it comes loading data. ML Training Blueprints are designed to showcase how one can implement data loading, preprocessing in a microservice architecture. It also showcases how to train a model and store it in a model repository.

Pytorch Regression Blueprint (WandB)

Note

WandB support might be removed in the future in favor of MLflow.

Pytorch Regression WandB
  1import asyncio
  2import logging
  3import uuid
  4
  5import wandb
  6
  7import torch
  8
  9import numpy as np
 10import pandas as pd
 11
 12from fastiot.core import FastIoTService, loop
 13from rich.progress import Progress
 14from torch.utils.data import Dataset
 15
 16from blueprint_dev_v2.ml_lifecycle_utils.ml_lifecycle_broker_facade import (
 17    request_get_processed_data_points_count,
 18    request_get_processed_data_points_page
 19)
 20from src.blueprint_dev_v2.logger.logger import log
 21
 22from torch import nn, optim
 23
 24
 25class DemonstratorNeuralNet(nn.Module):
 26    """
 27    A simple neural network for demonstration purposes.
 28
 29    Attributes
 30    ----------
 31    layer_1 : torch.nn.Linear
 32        The first linear layer.
 33    layer_2 : torch.nn.Linear
 34        The second linear layer.
 35    layer_3 : torch.nn.Linear
 36        The third linear layer.
 37
 38    Methods
 39    -------
 40    forward(x)
 41        Forward pass through the network.
 42    """
 43    def __init__(self, input_dim, hidden_dim, output_dim, *args, **kwargs):
 44        """
 45        Initialize the network.
 46
 47        Parameters
 48        ----------
 49        input_dim
 50        hidden_dim
 51        output_dim
 52        args
 53        kwargs
 54        """
 55        super().__init__(*args, **kwargs)
 56        self.layer_1 = nn.Linear(input_dim, hidden_dim)
 57        self.layer_2 = nn.Linear(hidden_dim, hidden_dim)
 58        self.layer_3 = nn.Linear(hidden_dim, output_dim)
 59
 60    def forward(self, x):
 61        """
 62        Forward pass through the network.
 63
 64        Parameters
 65        ----------
 66        x
 67            The input to the network.
 68
 69        Returns
 70        -------
 71        torch.Tensor
 72            The output of the network.
 73        """
 74        x = torch.relu(self.layer_1(x))
 75        x = torch.relu(self.layer_2(x))
 76        x = self.layer_3(x)
 77        return x
 78
 79
 80class PageDataset(Dataset):
 81    """
 82    A custom dataset for the pytorch regression service.
 83
 84    Attributes
 85    ----------
 86    _page_size : int
 87        The size of a page.
 88    _total_pages : int
 89        The total number of pages.
 90    _num_entries_in_db : int
 91        The total number of entries in the database.
 92    _current_page : int
 93        The current page.
 94    _fast_iot_service : FastIoTService
 95        The fast iot service.
 96    _broker_timeout : float
 97        The broker timeout.
 98    _page_df : pd.DataFrame
 99        The page dataframe.
100
101    Methods
102    -------
103    __len__()
104        Return the length of the dataset.
105    _init_total_pages()
106        Initialize the total number of pages.
107    _get_page_df(page)
108        Get the dataframe for a page.
109    init_dataset()
110        Initialize the dataset.
111    has_next_page()
112        Check if there is a next page.
113    load_next_page()
114        Load the next page.
115    __getitem__(idx)
116        Get an item from the dataset.
117    """
118    _page_size: int
119    _total_pages: int
120    _num_entries_in_db: int
121    _current_page: int
122
123    _fast_iot_service: FastIoTService
124    _broker_timeout: float
125
126    _page_df: pd.DataFrame
127
128    def __init__(self, fast_iot_service: FastIoTService, page_size: int, broker_timeout=10):
129        """
130        Initialize the dataset.
131
132        Parameters
133        ----------
134        fast_iot_service
135        page_size
136        broker_timeout
137        """
138        self._fast_iot_service = fast_iot_service
139        self._broker_timeout = broker_timeout
140
141        self._page_size = page_size
142
143    def __len__(self):
144        """
145        Return the length of the dataset.
146
147        Returns
148        -------
149        int
150            The length of the dataset.
151        """
152        return len(self._page_df)
153
154    async def _init_total_pages(self):
155        """
156        Initialize the total number of pages.
157
158        Returns
159        -------
160        int
161            The total number of pages.
162        """
163        # count
164        count: int = await request_get_processed_data_points_count(fiot_service=self._fast_iot_service)
165        self._num_entries_in_db = count
166        self._total_pages = int(np.ceil(self._num_entries_in_db / self._page_size))
167
168    async def _get_page_df(self, page: int) -> pd.DataFrame:
169        """
170        Get the dataframe for a page.
171
172        Parameters
173        ----------
174        page
175            The page. (A slice of the data present in the database.)
176
177        Returns
178        -------
179        pd.DataFrame
180            The dataframe for the page.
181        """
182        # query the db_service for the number of raw data points
183        page: list[dict] = await request_get_processed_data_points_page(
184            fiot_service=self._fast_iot_service,
185            page=page,
186            page_size=self._page_size
187        )
188        return pd.DataFrame(page)
189
190    async def init_dataset(self):
191        """
192        Initialize the dataset.
193
194        Returns
195        -------
196        None
197        """
198        # init total number of pages
199        await self._init_total_pages()
200        df = await self._get_page_df(page=0)
201        self._page_df = df
202        self._current_page = 0
203
204    def has_next_page(self):
205        """
206        Check if there is a next page.
207
208        Returns
209        -------
210        bool
211            True if there is a next page, False otherwise.
212        """
213        return self._current_page < self._total_pages
214
215    @property
216    def num_pages(self):
217        """
218        The total number of pages.
219
220        Returns
221        -------
222        """
223        if self._total_pages is None:
224            log.warn("total pages not initialized. init_page() needs to called and awaited first.")
225        return self._total_pages
226
227    async def load_next_page(self):
228        """
229        Load the next page.
230
231        Returns
232        -------
233        None
234        """
235        if self._current_page is None:
236            log.error("page not initialized. init_page() needs to called and awaited first.")
237            raise ValueError("page not initialized. init_page() needs to called and awaited first.")
238
239        if self._current_page >= self._total_pages:
240            log.error("no more pages available")
241            raise ValueError("no more pages available")
242
243        self._current_page += 1
244        df = await self._get_page_df(page=self._current_page)
245        self._page_df = df
246
247    def __getitem__(self, idx):  # idx means index of the chunk.
248        """
249        Get an item from the dataset.
250
251        Parameters
252        ----------
253        idx
254
255        Returns
256        -------
257        tuple
258            The input and output data.
259        """
260        # drop index column
261        temp = self._page_df
262        temp = temp.iloc[idx]
263
264        y_data = np.array([temp.pop("aufbereiteter_wert")])
265        x_data = temp.to_numpy()
266
267        # The following condition is actually needed in Pytorch. Otherwise, for our particular example,
268        # the iterator will be an infinite loop.
269        # Readers can verify this by removing this condition.
270        if idx == self.__len__():
271            raise IndexError
272
273        return x_data, y_data
274
275
276class MlPytorchRegressionService(FastIoTService):
277
278    async def _start(self):
279        log.info("MlPytorchRegressionService started")
280
281        # the following requests are needed for the custom dataset
282        # you can comment the m in to see if they are working
283        # count: int = await request_get_processed_data_points_count(fiot_service=self)
284        # page = await request_get_processed_data_points_page(fiot_service=self, page=0, page_size=10)
285        # pageDataset = PageDataset(fast_iot_service=self, page_size=10)
286        # await pageDataset.init_dataset()
287        # await pageDataset.load_next_page()
288
289    async def _stop(self):
290        log.info("MlPytorchRegressionService stopped")
291
292    def get_model(self) -> DemonstratorNeuralNet:
293        return DemonstratorNeuralNet(
294            input_dim=15,
295            hidden_dim=10,
296            output_dim=1
297        )
298
299    @loop
300    async def training_loop(self):
301        model = self.get_model()
302        loss_fn = nn.MSELoss()
303        optimizer = optim.Adam(model.parameters(), lr=0.001)
304        dataset = PageDataset(fast_iot_service=self, page_size=10)
305
306        # await self.train_model_without_experiment_tracking(dataset, model, loss_fn, optimizer)
307        await self.train_model_with_wandb_tracking(dataset, model, loss_fn, optimizer)
308
309        return asyncio.sleep(24 * 60 * 60)
310
311    async def train_model_without_experiment_tracking(self, dataset: PageDataset, model: DemonstratorNeuralNet,
312                                                      loss_fn: nn.MSELoss,
313                                                      optimizer: optim.Adam, epochs: int = 5, batch_size: int = 5,
314                                                      shuffle: bool = True):
315        log.info("Starting training loop without experiment tracking.")
316        await dataset.init_dataset()
317        progress = Progress()
318        total_steps = dataset.num_pages * epochs
319        task_id = progress.add_task("[cyan]Training...", total=total_steps)
320
321        with progress:
322            for page in range(dataset.num_pages):
323                # define pytorch data loader
324                data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
325
326                # define pytorch training loop
327                for epoch in range(epochs):
328                    for batch_idx, (x, y) in enumerate(data_loader):
329                        optimizer.zero_grad()
330                        y_pred = model(x.to(torch.float32)).to(torch.float32)
331                        # loss = loss_fn(y_pred, y.to(torch.float32))
332                        # loss.backward()
333                        # optimizer.step()
334                        # log.info(f"page: {page}, epoch: {epoch}, batch_idx: {batch_idx}, loss: {loss.item()}")
335                    progress.update(task_id, advance=1)
336
337                await dataset.load_next_page()
338
339        log.info("Training loop without experiment tracking completed.")
340        # save model
341        # here you can implement a saving mechanism for the model
342
343    async def train_model_with_wandb_tracking(self, dataset: PageDataset, model: DemonstratorNeuralNet,
344                                              loss_fn: nn.MSELoss, optimizer: optim.Adam, epochs: int = 5,
345                                              batch_size: int = 5, shuffle: bool = True):
346        log.info("Starting training loop with wandb tracking.")
347        await dataset.init_dataset()
348        progress = Progress()
349        total_steps = dataset.num_pages * epochs
350        task_id = progress.add_task("[cyan]Training", total=total_steps)
351
352        # Initialize a new wandb run
353        config_dict = {
354            "epochs": epochs,
355            "batch_size": batch_size,
356            "shuffle": shuffle,
357            "optimizer": str(optimizer),
358            "loss_function": str(loss_fn)
359        }
360        run_id = uuid.uuid4()
361        wandb_run = wandb.init(
362            project="KIOptipack-dev",
363            config=config_dict,
364            group="MVDP-pytorch-regression",
365            name=f"run_{run_id}",
366        )
367
368        # Log gradients and model parameters
369        wandb.watch(model)
370
371        with progress:
372            optimizer_step = 0
373            for page in range(dataset.num_pages):
374                # define pytorch data loader
375                data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
376
377                # define pytorch training loop
378                for epoch in range(epochs):
379                    for batch_idx, (x, y) in enumerate(data_loader):
380                        optimizer.zero_grad()
381                        y_pred = model(x.to(torch.float32)).to(torch.float32)
382                        loss = loss_fn(y_pred, y.to(torch.float32))
383                        loss.backward()
384                        optimizer.step()
385                        # Log metrics with wandb
386                        wandb.log({
387                            "loss": loss.item(),
388                            "epoch": epoch,
389                            "page": page,
390                            "optimizer_step": optimizer_step
391                        })
392                        log.debug(f"page: {page}, epoch: {epoch}, batch_idx: {batch_idx}, loss: {loss.item()}")
393
394                        optimizer_step += 1
395                    progress.update(task_id, advance=1, )
396
397                await dataset.load_next_page()
398
399        log.info("Training loop with wandb tracking completed.")
400        # save model
401        model_name = f"model_{run_id}.pth"
402        torch.save(model.state_dict(), model_name)
403        artifact = wandb.Artifact(
404            f'DemonstratorNeuralNet',
405            incremental=True,
406            type='pytorch-regression-model',
407            description="Pytorch regression model, saved using the state_dict method.",
408            metadata={
409                "run_id": run_id,
410                "model_name": model_name,
411                "class": model.__class__.__name__,
412                "input_dim": model.layer_1.in_features,
413                "output_dim": model.layer_3.out_features,
414            }
415        )
416        artifact.add_file(model_name)
417        wandb_run.log_artifact(artifact)
418        # Finish the wandb run
419        wandb_run.finish()
420
421        # delete to local model file
422        try:
423            import os
424            os.remove(model_name)
425        except Exception as e:
426            log.warning(f"Error while deleting local model file: {e}")
427
428
429if __name__ == '__main__':
430    # Change this to reduce verbosity or remove completely to use `FASTIOT_LOG_LEVEL` environment variable to configure
431    # logging.
432    logging.basicConfig(level=logging.DEBUG)
433    MlPytorchRegressionService.main()

Pytorch Regression Blueprint (MLflow)

Note

Make sure to have the MLflow server running before starting the service.

Pytorch Regression MLflow
  1"""
  2mlflow server --host 127.0.0.1 --port 8080
  3"""
  4
  5import asyncio
  6import logging
  7import pprint
  8import uuid
  9
 10import mlflow
 11import random
 12
 13import torch
 14
 15import numpy as np
 16import pandas as pd
 17
 18from fastiot.core import FastIoTService, Subject, subscribe, loop
 19from fastiot.core.core_uuid import get_uuid
 20from fastiot.core.time import get_time_now
 21from fastiot.msg.thing import Thing
 22from rich.progress import Progress
 23from torch.utils.data import Dataset
 24
 25from blueprint_dev_v2.ml_lifecycle_utils.ml_lifecycle_broker_facade import request_get_processed_data_points_count, \
 26    request_get_all_raw_data_points, request_get_processed_data_points_page
 27from src.blueprint_dev_v2.logger.logger import log
 28
 29from torch import nn, optim
 30
 31
 32class DemonstratorNeuralNet(nn.Module):
 33    """
 34    A simple neural network for demonstration purposes.
 35
 36    Attributes
 37    ----------
 38    layer_1 : torch.nn.Linear
 39        The first linear layer.
 40    layer_2 : torch.nn.Linear
 41        The second linear layer.
 42    layer_3 : torch.nn.Linear
 43        The third linear layer.
 44
 45    Methods
 46    -------
 47    forward(x)
 48        Forward pass through the network.
 49    """
 50
 51    def __init__(self, input_dim, hidden_dim, output_dim, *args, **kwargs):
 52        """
 53        Initialize the network.
 54
 55        Parameters
 56        ----------
 57        input_dim
 58        hidden_dim
 59        output_dim
 60        args
 61        kwargs
 62        """
 63        super().__init__(*args, **kwargs)
 64        self.layer_1 = nn.Linear(input_dim, hidden_dim)
 65        self.layer_2 = nn.Linear(hidden_dim, hidden_dim)
 66        self.layer_3 = nn.Linear(hidden_dim, output_dim)
 67
 68    def forward(self, x):
 69        """
 70        Forward pass through the network.
 71
 72        Parameters
 73        ----------
 74        x
 75            The input to the network.
 76
 77        Returns
 78        -------
 79        torch.Tensor
 80            The output of the network.
 81        """
 82        x = torch.relu(self.layer_1(x))
 83        x = torch.relu(self.layer_2(x))
 84        x = self.layer_3(x)
 85        return x
 86
 87
 88class PageDataset(Dataset):
 89    """
 90    A custom dataset for the pytorch regression service.
 91
 92    Attributes
 93    ----------
 94    _page_size : int
 95        The size of a page.
 96    _total_pages : int
 97        The total number of pages.
 98    _num_entries_in_db : int
 99        The total number of entries in the database.
100    _current_page : int
101        The current page.
102    _fast_iot_service : FastIoTService
103        The fast iot service.
104    _broker_timeout : float
105        The broker timeout.
106    _page_df : pd.DataFrame
107        The page dataframe.
108
109    Methods
110    -------
111    __len__()
112        Return the length of the dataset.
113    _init_total_pages()
114        Initialize the total number of pages.
115    _get_page_df(page)
116        Get the dataframe for a page.
117    init_dataset()
118        Initialize the dataset.
119    has_next_page()
120        Check if there is a next page.
121    load_next_page()
122        Load the next page.
123    __getitem__(idx)
124        Get an item from the dataset.
125    """
126    _page_size: int
127    _total_pages: int
128    _num_entries_in_db: int
129    _current_page: int
130
131    _fast_iot_service: FastIoTService
132    _broker_timeout: float
133
134    _page_df: pd.DataFrame
135
136    def __init__(self, fast_iot_service: FastIoTService, page_size: int, broker_timeout=10):
137        """
138        Initialize the dataset.
139
140        Parameters
141        ----------
142        fast_iot_service
143        page_size
144        broker_timeout
145        """
146        self._fast_iot_service = fast_iot_service
147        self._broker_timeout = broker_timeout
148
149        self._page_size = page_size
150
151    def __len__(self):
152        """
153        Return the length of the dataset.
154
155        Returns
156        -------
157        int
158            The length of the dataset.
159        """
160        return len(self._page_df)
161
162    async def _init_total_pages(self):
163        """
164        Initialize the total number of pages.
165
166        Returns
167        -------
168        int
169            The total number of pages.
170        """
171        # count
172        count: int = await request_get_processed_data_points_count(fiot_service=self._fast_iot_service)
173        self._num_entries_in_db = count
174        self._total_pages = int(np.ceil(self._num_entries_in_db / self._page_size))
175
176    async def _get_page_df(self, page: int) -> pd.DataFrame:
177        """
178        Get the dataframe for a page.
179
180        Parameters
181        ----------
182        page
183            The page. (A slice of the data present in the database.)
184
185        Returns
186        -------
187        pd.DataFrame
188            The dataframe for the page.
189        """
190        # query the db_service for the number of raw data points
191        page: list[dict] = await request_get_processed_data_points_page(
192            fiot_service=self._fast_iot_service,
193            page=page,
194            page_size=self._page_size
195        )
196        return pd.DataFrame(page)
197
198    async def init_dataset(self):
199        """
200        Initialize the dataset.
201
202        Returns
203        -------
204        None
205        """
206        # init total number of pages
207        await self._init_total_pages()
208        df = await self._get_page_df(page=0)
209        self._page_df = df
210        self._current_page = 0
211
212    def has_next_page(self):
213        """
214        Check if there is a next page.
215
216        Returns
217        -------
218        bool
219            True if there is a next page, False otherwise.
220        """
221        return self._current_page < self._total_pages
222
223    @property
224    def num_pages(self):
225        """
226        The total number of pages.
227
228        Returns
229        -------
230        """
231        if self._total_pages is None:
232            log.warn("total pages not initialized. init_page() needs to called and awaited first.")
233        return self._total_pages
234
235    async def load_next_page(self):
236        """
237        Load the next page.
238
239        Returns
240        -------
241        None
242        """
243        if self._current_page is None:
244            log.error("page not initialized. init_page() needs to called and awaited first.")
245            raise ValueError("page not initialized. init_page() needs to called and awaited first.")
246
247        if self._current_page >= self._total_pages:
248            log.error("no more pages available")
249            raise ValueError("no more pages available")
250
251        self._current_page += 1
252        df = await self._get_page_df(page=self._current_page)
253        self._page_df = df
254
255    def __getitem__(self, idx):  # idx means index of the chunk.
256        """
257        Get an item from the dataset.
258
259        Parameters
260        ----------
261        idx
262
263        Returns
264        -------
265        tuple
266            The input and output data.
267        """
268        # drop index column
269        temp = self._page_df
270        temp = temp.iloc[idx]
271
272        y_data = np.array([temp.pop("aufbereiteter_wert")])
273        x_data = temp.to_numpy()
274
275        # The following condition is actually needed in Pytorch. Otherwise, for our particular example,
276        # the iterator will be an infinite loop.
277        # Readers can verify this by removing this condition.
278        if idx == self.__len__():
279            raise IndexError
280
281        return x_data, y_data
282
283
284class MlPytorchRegressionMlflowService(FastIoTService):
285    """
286    A service for training a pytorch model with mlflow experiment tracking.
287
288    Attributes
289    ----------
290    MLFLOW_TRACKING_URI : str
291        The mlflow tracking uri.
292
293    Methods
294    -------
295    _start()
296        Start the service.
297    _stop()
298        Stop the service.
299    get_model()
300        Get the model.
301    training_loop()
302        The training loop.
303    train_model_without_experiment_tracking(dataset, model, loss_fn, optimizer, epochs, batch_size, shuffle)
304        Train the model without experiment tracking.
305    train_model_with_wandb_tracking(dataset, model, loss_fn, optimizer, epochs, batch_size, shuffle)
306        Train the model with wandb tracking.
307    """
308    MLFLOW_TRACKING_URI = "http://127.0.0.1:8080"
309
310    async def _start(self):
311        """
312        Runs when the service starts.
313        """
314        log.info("MlPytorchRegressionService started")
315        log.info(f"Setting MLFlow tracking uri to {self.MLFLOW_TRACKING_URI}")
316        mlflow.set_tracking_uri(self.MLFLOW_TRACKING_URI)
317
318    async def _stop(self):
319        """
320        Runs when the service stops.
321        """
322        log.info("MlPytorchRegressionService stopped")
323
324    def get_model(self) -> DemonstratorNeuralNet:
325        """
326        create a model instance.
327
328        Returns
329        -------
330        DemonstratorNeuralNet
331            A model instance.
332        """
333        return DemonstratorNeuralNet(
334            input_dim=15,
335            hidden_dim=10,
336            output_dim=1
337        )
338
339    @loop
340    async def training_loop(self):
341        """
342        The training loop.
343        Returns
344        -------
345
346        """
347        model = self.get_model()
348        loss_fn = nn.MSELoss()
349        optimizer = optim.Adam(model.parameters(), lr=0.001)
350        dataset = PageDataset(fast_iot_service=self, page_size=10)
351
352        # await self.train_model_without_experiment_tracking(dataset, model, loss_fn, optimizer)
353        await self.train_model_with_mlflow_tracking(dataset, model, loss_fn, optimizer)
354
355        return asyncio.sleep(24 * 60 * 60)
356
357    async def train_model_without_experiment_tracking(self, dataset: PageDataset, model: DemonstratorNeuralNet,
358                                                      loss_fn: nn.MSELoss,
359                                                      optimizer: optim.Adam, epochs: int = 5, batch_size: int = 5,
360                                                      shuffle: bool = True):
361        """
362        Train the model without experiment tracking.
363
364        Parameters
365        ----------
366        dataset
367        model
368        loss_fn
369        optimizer
370        epochs
371        batch_size
372        shuffle
373
374        Returns
375        -------
376        """
377        log.info("Starting training loop without experiment tracking.")
378        await dataset.init_dataset()
379        progress = Progress()
380        total_steps = dataset.num_pages * epochs
381        task_id = progress.add_task("[cyan]Training...", total=total_steps)
382
383        with progress:
384            for page in range(dataset.num_pages):
385                # define pytorch data loader
386                data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
387
388                # define pytorch training loop
389                for epoch in range(epochs):
390                    for batch_idx, (x, y) in enumerate(data_loader):
391                        optimizer.zero_grad()
392                        y_pred = model(x.to(torch.float32)).to(torch.float32)
393                        loss = loss_fn(y_pred, y.to(torch.float32))
394                        loss.backward()
395                        optimizer.step()
396                        # log.info(f"page: {page}, epoch: {epoch}, batch_idx: {batch_idx}, loss: {loss.item()}")
397                    progress.update(task_id, advance=1)
398
399                await dataset.load_next_page()
400
401        log.info("Training loop without experiment tracking completed.")
402        # save model
403        # here you can implement a saving mechanism for the model
404
405    async def train_model_with_mlflow_tracking(self, dataset: PageDataset, model: DemonstratorNeuralNet,
406                                               loss_fn: nn.MSELoss, optimizer: optim.Adam, epochs: int = 5,
407                                               batch_size: int = 5, shuffle: bool = True):
408        """
409        Train the model with mlflow tracking.
410
411        Parameters
412        ----------
413        dataset
414        model
415        loss_fn
416        optimizer
417        epochs
418        batch_size
419        shuffle
420
421        Returns
422        -------
423        """
424        log.info("Starting training loop with mlfow tracking.")
425        await dataset.init_dataset()
426        progress = Progress()
427        total_steps = dataset.num_pages * epochs
428        task_id = progress.add_task("[cyan]Training", total=total_steps)
429
430        with mlflow.start_run() as run:
431
432            with progress:
433                optimizer_step = 0
434                for page in range(dataset.num_pages):
435                    # define pytorch data loader
436                    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
437
438                    # define pytorch training loop
439                    for epoch in range(epochs):
440                        for batch_idx, (x, y) in enumerate(data_loader):
441                            optimizer.zero_grad()
442                            y_pred = model(x.to(torch.float32)).to(torch.float32)
443                            loss = loss_fn(y_pred, y.to(torch.float32))
444                            loss.backward()
445                            optimizer.step()
446                            # Log metrics with wandb
447                            metrics = {
448                                "loss": loss.item(),
449                                "epoch": epoch,
450                                "page": page,
451                                "optimizer_step": optimizer_step
452                            }
453                            mlflow.log_metrics(metrics, step=optimizer_step)
454                            log.debug(f"page: {page}, epoch: {epoch}, batch_idx: {batch_idx}, loss: {loss.item()}")
455
456                            optimizer_step += 1
457                        progress.update(task_id, advance=1, )
458
459                    await dataset.load_next_page()
460
461            log.info("Training loop with mlflow tracking completed.")
462
463            mlflow.pytorch.log_model(pytorch_model=model, artifact_path="model")
464
465            model_uri = f"runs:/{run.info.run_id}/model"
466            model_details = mlflow.register_model(model_uri=model_uri, name="MyModel")
467            log.info(f"registered model in mlfow model regestry. Details: \n {pprint.pformat(dict(model_details))}")
468
469
470if __name__ == '__main__':
471    logging.basicConfig(level=logging.DEBUG)
472    MlPytorchRegressionMlflowService.main()