ML Training Blueprints¶
Note
Tensorflow and LightGBM Blueprints are developed and can be accessed in the corresponding branches. They will not be mereged into the main branch until the the release of FastIOT 2.0.
What are ML Training Blueprints?¶
ML Training Blueprints are templates for training a ML model in a microservice architecture via the FastIoT framework. Each blueprint is designed to be an example how to train a specific model type and framework (like for example a regression model in pytorch or a classifier in LightGBM). Most frameworks have some peculiarities when it comes loading data. ML Training Blueprints are designed to showcase how one can implement data loading, preprocessing in a microservice architecture. It also showcases how to train a model and store it in a model repository.
Pytorch Regression Blueprint (WandB)¶
Note
WandB support might be removed in the future in favor of MLflow.
1import asyncio
2import logging
3import uuid
4
5import wandb
6
7import torch
8
9import numpy as np
10import pandas as pd
11
12from fastiot.core import FastIoTService, loop
13from rich.progress import Progress
14from torch.utils.data import Dataset
15
16from blueprint_dev_v2.ml_lifecycle_utils.ml_lifecycle_broker_facade import (
17 request_get_processed_data_points_count,
18 request_get_processed_data_points_page
19)
20from src.blueprint_dev_v2.logger.logger import log
21
22from torch import nn, optim
23
24
25class DemonstratorNeuralNet(nn.Module):
26 """
27 A simple neural network for demonstration purposes.
28
29 Attributes
30 ----------
31 layer_1 : torch.nn.Linear
32 The first linear layer.
33 layer_2 : torch.nn.Linear
34 The second linear layer.
35 layer_3 : torch.nn.Linear
36 The third linear layer.
37
38 Methods
39 -------
40 forward(x)
41 Forward pass through the network.
42 """
43 def __init__(self, input_dim, hidden_dim, output_dim, *args, **kwargs):
44 """
45 Initialize the network.
46
47 Parameters
48 ----------
49 input_dim
50 hidden_dim
51 output_dim
52 args
53 kwargs
54 """
55 super().__init__(*args, **kwargs)
56 self.layer_1 = nn.Linear(input_dim, hidden_dim)
57 self.layer_2 = nn.Linear(hidden_dim, hidden_dim)
58 self.layer_3 = nn.Linear(hidden_dim, output_dim)
59
60 def forward(self, x):
61 """
62 Forward pass through the network.
63
64 Parameters
65 ----------
66 x
67 The input to the network.
68
69 Returns
70 -------
71 torch.Tensor
72 The output of the network.
73 """
74 x = torch.relu(self.layer_1(x))
75 x = torch.relu(self.layer_2(x))
76 x = self.layer_3(x)
77 return x
78
79
80class PageDataset(Dataset):
81 """
82 A custom dataset for the pytorch regression service.
83
84 Attributes
85 ----------
86 _page_size : int
87 The size of a page.
88 _total_pages : int
89 The total number of pages.
90 _num_entries_in_db : int
91 The total number of entries in the database.
92 _current_page : int
93 The current page.
94 _fast_iot_service : FastIoTService
95 The fast iot service.
96 _broker_timeout : float
97 The broker timeout.
98 _page_df : pd.DataFrame
99 The page dataframe.
100
101 Methods
102 -------
103 __len__()
104 Return the length of the dataset.
105 _init_total_pages()
106 Initialize the total number of pages.
107 _get_page_df(page)
108 Get the dataframe for a page.
109 init_dataset()
110 Initialize the dataset.
111 has_next_page()
112 Check if there is a next page.
113 load_next_page()
114 Load the next page.
115 __getitem__(idx)
116 Get an item from the dataset.
117 """
118 _page_size: int
119 _total_pages: int
120 _num_entries_in_db: int
121 _current_page: int
122
123 _fast_iot_service: FastIoTService
124 _broker_timeout: float
125
126 _page_df: pd.DataFrame
127
128 def __init__(self, fast_iot_service: FastIoTService, page_size: int, broker_timeout=10):
129 """
130 Initialize the dataset.
131
132 Parameters
133 ----------
134 fast_iot_service
135 page_size
136 broker_timeout
137 """
138 self._fast_iot_service = fast_iot_service
139 self._broker_timeout = broker_timeout
140
141 self._page_size = page_size
142
143 def __len__(self):
144 """
145 Return the length of the dataset.
146
147 Returns
148 -------
149 int
150 The length of the dataset.
151 """
152 return len(self._page_df)
153
154 async def _init_total_pages(self):
155 """
156 Initialize the total number of pages.
157
158 Returns
159 -------
160 int
161 The total number of pages.
162 """
163 # count
164 count: int = await request_get_processed_data_points_count(fiot_service=self._fast_iot_service)
165 self._num_entries_in_db = count
166 self._total_pages = int(np.ceil(self._num_entries_in_db / self._page_size))
167
168 async def _get_page_df(self, page: int) -> pd.DataFrame:
169 """
170 Get the dataframe for a page.
171
172 Parameters
173 ----------
174 page
175 The page. (A slice of the data present in the database.)
176
177 Returns
178 -------
179 pd.DataFrame
180 The dataframe for the page.
181 """
182 # query the db_service for the number of raw data points
183 page: list[dict] = await request_get_processed_data_points_page(
184 fiot_service=self._fast_iot_service,
185 page=page,
186 page_size=self._page_size
187 )
188 return pd.DataFrame(page)
189
190 async def init_dataset(self):
191 """
192 Initialize the dataset.
193
194 Returns
195 -------
196 None
197 """
198 # init total number of pages
199 await self._init_total_pages()
200 df = await self._get_page_df(page=0)
201 self._page_df = df
202 self._current_page = 0
203
204 def has_next_page(self):
205 """
206 Check if there is a next page.
207
208 Returns
209 -------
210 bool
211 True if there is a next page, False otherwise.
212 """
213 return self._current_page < self._total_pages
214
215 @property
216 def num_pages(self):
217 """
218 The total number of pages.
219
220 Returns
221 -------
222 """
223 if self._total_pages is None:
224 log.warn("total pages not initialized. init_page() needs to called and awaited first.")
225 return self._total_pages
226
227 async def load_next_page(self):
228 """
229 Load the next page.
230
231 Returns
232 -------
233 None
234 """
235 if self._current_page is None:
236 log.error("page not initialized. init_page() needs to called and awaited first.")
237 raise ValueError("page not initialized. init_page() needs to called and awaited first.")
238
239 if self._current_page >= self._total_pages:
240 log.error("no more pages available")
241 raise ValueError("no more pages available")
242
243 self._current_page += 1
244 df = await self._get_page_df(page=self._current_page)
245 self._page_df = df
246
247 def __getitem__(self, idx): # idx means index of the chunk.
248 """
249 Get an item from the dataset.
250
251 Parameters
252 ----------
253 idx
254
255 Returns
256 -------
257 tuple
258 The input and output data.
259 """
260 # drop index column
261 temp = self._page_df
262 temp = temp.iloc[idx]
263
264 y_data = np.array([temp.pop("aufbereiteter_wert")])
265 x_data = temp.to_numpy()
266
267 # The following condition is actually needed in Pytorch. Otherwise, for our particular example,
268 # the iterator will be an infinite loop.
269 # Readers can verify this by removing this condition.
270 if idx == self.__len__():
271 raise IndexError
272
273 return x_data, y_data
274
275
276class MlPytorchRegressionService(FastIoTService):
277
278 async def _start(self):
279 log.info("MlPytorchRegressionService started")
280
281 # the following requests are needed for the custom dataset
282 # you can comment the m in to see if they are working
283 # count: int = await request_get_processed_data_points_count(fiot_service=self)
284 # page = await request_get_processed_data_points_page(fiot_service=self, page=0, page_size=10)
285 # pageDataset = PageDataset(fast_iot_service=self, page_size=10)
286 # await pageDataset.init_dataset()
287 # await pageDataset.load_next_page()
288
289 async def _stop(self):
290 log.info("MlPytorchRegressionService stopped")
291
292 def get_model(self) -> DemonstratorNeuralNet:
293 return DemonstratorNeuralNet(
294 input_dim=15,
295 hidden_dim=10,
296 output_dim=1
297 )
298
299 @loop
300 async def training_loop(self):
301 model = self.get_model()
302 loss_fn = nn.MSELoss()
303 optimizer = optim.Adam(model.parameters(), lr=0.001)
304 dataset = PageDataset(fast_iot_service=self, page_size=10)
305
306 # await self.train_model_without_experiment_tracking(dataset, model, loss_fn, optimizer)
307 await self.train_model_with_wandb_tracking(dataset, model, loss_fn, optimizer)
308
309 return asyncio.sleep(24 * 60 * 60)
310
311 async def train_model_without_experiment_tracking(self, dataset: PageDataset, model: DemonstratorNeuralNet,
312 loss_fn: nn.MSELoss,
313 optimizer: optim.Adam, epochs: int = 5, batch_size: int = 5,
314 shuffle: bool = True):
315 log.info("Starting training loop without experiment tracking.")
316 await dataset.init_dataset()
317 progress = Progress()
318 total_steps = dataset.num_pages * epochs
319 task_id = progress.add_task("[cyan]Training...", total=total_steps)
320
321 with progress:
322 for page in range(dataset.num_pages):
323 # define pytorch data loader
324 data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
325
326 # define pytorch training loop
327 for epoch in range(epochs):
328 for batch_idx, (x, y) in enumerate(data_loader):
329 optimizer.zero_grad()
330 y_pred = model(x.to(torch.float32)).to(torch.float32)
331 # loss = loss_fn(y_pred, y.to(torch.float32))
332 # loss.backward()
333 # optimizer.step()
334 # log.info(f"page: {page}, epoch: {epoch}, batch_idx: {batch_idx}, loss: {loss.item()}")
335 progress.update(task_id, advance=1)
336
337 await dataset.load_next_page()
338
339 log.info("Training loop without experiment tracking completed.")
340 # save model
341 # here you can implement a saving mechanism for the model
342
343 async def train_model_with_wandb_tracking(self, dataset: PageDataset, model: DemonstratorNeuralNet,
344 loss_fn: nn.MSELoss, optimizer: optim.Adam, epochs: int = 5,
345 batch_size: int = 5, shuffle: bool = True):
346 log.info("Starting training loop with wandb tracking.")
347 await dataset.init_dataset()
348 progress = Progress()
349 total_steps = dataset.num_pages * epochs
350 task_id = progress.add_task("[cyan]Training", total=total_steps)
351
352 # Initialize a new wandb run
353 config_dict = {
354 "epochs": epochs,
355 "batch_size": batch_size,
356 "shuffle": shuffle,
357 "optimizer": str(optimizer),
358 "loss_function": str(loss_fn)
359 }
360 run_id = uuid.uuid4()
361 wandb_run = wandb.init(
362 project="KIOptipack-dev",
363 config=config_dict,
364 group="MVDP-pytorch-regression",
365 name=f"run_{run_id}",
366 )
367
368 # Log gradients and model parameters
369 wandb.watch(model)
370
371 with progress:
372 optimizer_step = 0
373 for page in range(dataset.num_pages):
374 # define pytorch data loader
375 data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
376
377 # define pytorch training loop
378 for epoch in range(epochs):
379 for batch_idx, (x, y) in enumerate(data_loader):
380 optimizer.zero_grad()
381 y_pred = model(x.to(torch.float32)).to(torch.float32)
382 loss = loss_fn(y_pred, y.to(torch.float32))
383 loss.backward()
384 optimizer.step()
385 # Log metrics with wandb
386 wandb.log({
387 "loss": loss.item(),
388 "epoch": epoch,
389 "page": page,
390 "optimizer_step": optimizer_step
391 })
392 log.debug(f"page: {page}, epoch: {epoch}, batch_idx: {batch_idx}, loss: {loss.item()}")
393
394 optimizer_step += 1
395 progress.update(task_id, advance=1, )
396
397 await dataset.load_next_page()
398
399 log.info("Training loop with wandb tracking completed.")
400 # save model
401 model_name = f"model_{run_id}.pth"
402 torch.save(model.state_dict(), model_name)
403 artifact = wandb.Artifact(
404 f'DemonstratorNeuralNet',
405 incremental=True,
406 type='pytorch-regression-model',
407 description="Pytorch regression model, saved using the state_dict method.",
408 metadata={
409 "run_id": run_id,
410 "model_name": model_name,
411 "class": model.__class__.__name__,
412 "input_dim": model.layer_1.in_features,
413 "output_dim": model.layer_3.out_features,
414 }
415 )
416 artifact.add_file(model_name)
417 wandb_run.log_artifact(artifact)
418 # Finish the wandb run
419 wandb_run.finish()
420
421 # delete to local model file
422 try:
423 import os
424 os.remove(model_name)
425 except Exception as e:
426 log.warning(f"Error while deleting local model file: {e}")
427
428
429if __name__ == '__main__':
430 # Change this to reduce verbosity or remove completely to use `FASTIOT_LOG_LEVEL` environment variable to configure
431 # logging.
432 logging.basicConfig(level=logging.DEBUG)
433 MlPytorchRegressionService.main()
Pytorch Regression Blueprint (MLflow)¶
Note
Make sure to have the MLflow server running before starting the service.
1"""
2mlflow server --host 127.0.0.1 --port 8080
3"""
4
5import asyncio
6import logging
7import pprint
8import uuid
9
10import mlflow
11import random
12
13import torch
14
15import numpy as np
16import pandas as pd
17
18from fastiot.core import FastIoTService, Subject, subscribe, loop
19from fastiot.core.core_uuid import get_uuid
20from fastiot.core.time import get_time_now
21from fastiot.msg.thing import Thing
22from rich.progress import Progress
23from torch.utils.data import Dataset
24
25from blueprint_dev_v2.ml_lifecycle_utils.ml_lifecycle_broker_facade import request_get_processed_data_points_count, \
26 request_get_all_raw_data_points, request_get_processed_data_points_page
27from src.blueprint_dev_v2.logger.logger import log
28
29from torch import nn, optim
30
31
32class DemonstratorNeuralNet(nn.Module):
33 """
34 A simple neural network for demonstration purposes.
35
36 Attributes
37 ----------
38 layer_1 : torch.nn.Linear
39 The first linear layer.
40 layer_2 : torch.nn.Linear
41 The second linear layer.
42 layer_3 : torch.nn.Linear
43 The third linear layer.
44
45 Methods
46 -------
47 forward(x)
48 Forward pass through the network.
49 """
50
51 def __init__(self, input_dim, hidden_dim, output_dim, *args, **kwargs):
52 """
53 Initialize the network.
54
55 Parameters
56 ----------
57 input_dim
58 hidden_dim
59 output_dim
60 args
61 kwargs
62 """
63 super().__init__(*args, **kwargs)
64 self.layer_1 = nn.Linear(input_dim, hidden_dim)
65 self.layer_2 = nn.Linear(hidden_dim, hidden_dim)
66 self.layer_3 = nn.Linear(hidden_dim, output_dim)
67
68 def forward(self, x):
69 """
70 Forward pass through the network.
71
72 Parameters
73 ----------
74 x
75 The input to the network.
76
77 Returns
78 -------
79 torch.Tensor
80 The output of the network.
81 """
82 x = torch.relu(self.layer_1(x))
83 x = torch.relu(self.layer_2(x))
84 x = self.layer_3(x)
85 return x
86
87
88class PageDataset(Dataset):
89 """
90 A custom dataset for the pytorch regression service.
91
92 Attributes
93 ----------
94 _page_size : int
95 The size of a page.
96 _total_pages : int
97 The total number of pages.
98 _num_entries_in_db : int
99 The total number of entries in the database.
100 _current_page : int
101 The current page.
102 _fast_iot_service : FastIoTService
103 The fast iot service.
104 _broker_timeout : float
105 The broker timeout.
106 _page_df : pd.DataFrame
107 The page dataframe.
108
109 Methods
110 -------
111 __len__()
112 Return the length of the dataset.
113 _init_total_pages()
114 Initialize the total number of pages.
115 _get_page_df(page)
116 Get the dataframe for a page.
117 init_dataset()
118 Initialize the dataset.
119 has_next_page()
120 Check if there is a next page.
121 load_next_page()
122 Load the next page.
123 __getitem__(idx)
124 Get an item from the dataset.
125 """
126 _page_size: int
127 _total_pages: int
128 _num_entries_in_db: int
129 _current_page: int
130
131 _fast_iot_service: FastIoTService
132 _broker_timeout: float
133
134 _page_df: pd.DataFrame
135
136 def __init__(self, fast_iot_service: FastIoTService, page_size: int, broker_timeout=10):
137 """
138 Initialize the dataset.
139
140 Parameters
141 ----------
142 fast_iot_service
143 page_size
144 broker_timeout
145 """
146 self._fast_iot_service = fast_iot_service
147 self._broker_timeout = broker_timeout
148
149 self._page_size = page_size
150
151 def __len__(self):
152 """
153 Return the length of the dataset.
154
155 Returns
156 -------
157 int
158 The length of the dataset.
159 """
160 return len(self._page_df)
161
162 async def _init_total_pages(self):
163 """
164 Initialize the total number of pages.
165
166 Returns
167 -------
168 int
169 The total number of pages.
170 """
171 # count
172 count: int = await request_get_processed_data_points_count(fiot_service=self._fast_iot_service)
173 self._num_entries_in_db = count
174 self._total_pages = int(np.ceil(self._num_entries_in_db / self._page_size))
175
176 async def _get_page_df(self, page: int) -> pd.DataFrame:
177 """
178 Get the dataframe for a page.
179
180 Parameters
181 ----------
182 page
183 The page. (A slice of the data present in the database.)
184
185 Returns
186 -------
187 pd.DataFrame
188 The dataframe for the page.
189 """
190 # query the db_service for the number of raw data points
191 page: list[dict] = await request_get_processed_data_points_page(
192 fiot_service=self._fast_iot_service,
193 page=page,
194 page_size=self._page_size
195 )
196 return pd.DataFrame(page)
197
198 async def init_dataset(self):
199 """
200 Initialize the dataset.
201
202 Returns
203 -------
204 None
205 """
206 # init total number of pages
207 await self._init_total_pages()
208 df = await self._get_page_df(page=0)
209 self._page_df = df
210 self._current_page = 0
211
212 def has_next_page(self):
213 """
214 Check if there is a next page.
215
216 Returns
217 -------
218 bool
219 True if there is a next page, False otherwise.
220 """
221 return self._current_page < self._total_pages
222
223 @property
224 def num_pages(self):
225 """
226 The total number of pages.
227
228 Returns
229 -------
230 """
231 if self._total_pages is None:
232 log.warn("total pages not initialized. init_page() needs to called and awaited first.")
233 return self._total_pages
234
235 async def load_next_page(self):
236 """
237 Load the next page.
238
239 Returns
240 -------
241 None
242 """
243 if self._current_page is None:
244 log.error("page not initialized. init_page() needs to called and awaited first.")
245 raise ValueError("page not initialized. init_page() needs to called and awaited first.")
246
247 if self._current_page >= self._total_pages:
248 log.error("no more pages available")
249 raise ValueError("no more pages available")
250
251 self._current_page += 1
252 df = await self._get_page_df(page=self._current_page)
253 self._page_df = df
254
255 def __getitem__(self, idx): # idx means index of the chunk.
256 """
257 Get an item from the dataset.
258
259 Parameters
260 ----------
261 idx
262
263 Returns
264 -------
265 tuple
266 The input and output data.
267 """
268 # drop index column
269 temp = self._page_df
270 temp = temp.iloc[idx]
271
272 y_data = np.array([temp.pop("aufbereiteter_wert")])
273 x_data = temp.to_numpy()
274
275 # The following condition is actually needed in Pytorch. Otherwise, for our particular example,
276 # the iterator will be an infinite loop.
277 # Readers can verify this by removing this condition.
278 if idx == self.__len__():
279 raise IndexError
280
281 return x_data, y_data
282
283
284class MlPytorchRegressionMlflowService(FastIoTService):
285 """
286 A service for training a pytorch model with mlflow experiment tracking.
287
288 Attributes
289 ----------
290 MLFLOW_TRACKING_URI : str
291 The mlflow tracking uri.
292
293 Methods
294 -------
295 _start()
296 Start the service.
297 _stop()
298 Stop the service.
299 get_model()
300 Get the model.
301 training_loop()
302 The training loop.
303 train_model_without_experiment_tracking(dataset, model, loss_fn, optimizer, epochs, batch_size, shuffle)
304 Train the model without experiment tracking.
305 train_model_with_wandb_tracking(dataset, model, loss_fn, optimizer, epochs, batch_size, shuffle)
306 Train the model with wandb tracking.
307 """
308 MLFLOW_TRACKING_URI = "http://127.0.0.1:8080"
309
310 async def _start(self):
311 """
312 Runs when the service starts.
313 """
314 log.info("MlPytorchRegressionService started")
315 log.info(f"Setting MLFlow tracking uri to {self.MLFLOW_TRACKING_URI}")
316 mlflow.set_tracking_uri(self.MLFLOW_TRACKING_URI)
317
318 async def _stop(self):
319 """
320 Runs when the service stops.
321 """
322 log.info("MlPytorchRegressionService stopped")
323
324 def get_model(self) -> DemonstratorNeuralNet:
325 """
326 create a model instance.
327
328 Returns
329 -------
330 DemonstratorNeuralNet
331 A model instance.
332 """
333 return DemonstratorNeuralNet(
334 input_dim=15,
335 hidden_dim=10,
336 output_dim=1
337 )
338
339 @loop
340 async def training_loop(self):
341 """
342 The training loop.
343 Returns
344 -------
345
346 """
347 model = self.get_model()
348 loss_fn = nn.MSELoss()
349 optimizer = optim.Adam(model.parameters(), lr=0.001)
350 dataset = PageDataset(fast_iot_service=self, page_size=10)
351
352 # await self.train_model_without_experiment_tracking(dataset, model, loss_fn, optimizer)
353 await self.train_model_with_mlflow_tracking(dataset, model, loss_fn, optimizer)
354
355 return asyncio.sleep(24 * 60 * 60)
356
357 async def train_model_without_experiment_tracking(self, dataset: PageDataset, model: DemonstratorNeuralNet,
358 loss_fn: nn.MSELoss,
359 optimizer: optim.Adam, epochs: int = 5, batch_size: int = 5,
360 shuffle: bool = True):
361 """
362 Train the model without experiment tracking.
363
364 Parameters
365 ----------
366 dataset
367 model
368 loss_fn
369 optimizer
370 epochs
371 batch_size
372 shuffle
373
374 Returns
375 -------
376 """
377 log.info("Starting training loop without experiment tracking.")
378 await dataset.init_dataset()
379 progress = Progress()
380 total_steps = dataset.num_pages * epochs
381 task_id = progress.add_task("[cyan]Training...", total=total_steps)
382
383 with progress:
384 for page in range(dataset.num_pages):
385 # define pytorch data loader
386 data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
387
388 # define pytorch training loop
389 for epoch in range(epochs):
390 for batch_idx, (x, y) in enumerate(data_loader):
391 optimizer.zero_grad()
392 y_pred = model(x.to(torch.float32)).to(torch.float32)
393 loss = loss_fn(y_pred, y.to(torch.float32))
394 loss.backward()
395 optimizer.step()
396 # log.info(f"page: {page}, epoch: {epoch}, batch_idx: {batch_idx}, loss: {loss.item()}")
397 progress.update(task_id, advance=1)
398
399 await dataset.load_next_page()
400
401 log.info("Training loop without experiment tracking completed.")
402 # save model
403 # here you can implement a saving mechanism for the model
404
405 async def train_model_with_mlflow_tracking(self, dataset: PageDataset, model: DemonstratorNeuralNet,
406 loss_fn: nn.MSELoss, optimizer: optim.Adam, epochs: int = 5,
407 batch_size: int = 5, shuffle: bool = True):
408 """
409 Train the model with mlflow tracking.
410
411 Parameters
412 ----------
413 dataset
414 model
415 loss_fn
416 optimizer
417 epochs
418 batch_size
419 shuffle
420
421 Returns
422 -------
423 """
424 log.info("Starting training loop with mlfow tracking.")
425 await dataset.init_dataset()
426 progress = Progress()
427 total_steps = dataset.num_pages * epochs
428 task_id = progress.add_task("[cyan]Training", total=total_steps)
429
430 with mlflow.start_run() as run:
431
432 with progress:
433 optimizer_step = 0
434 for page in range(dataset.num_pages):
435 # define pytorch data loader
436 data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
437
438 # define pytorch training loop
439 for epoch in range(epochs):
440 for batch_idx, (x, y) in enumerate(data_loader):
441 optimizer.zero_grad()
442 y_pred = model(x.to(torch.float32)).to(torch.float32)
443 loss = loss_fn(y_pred, y.to(torch.float32))
444 loss.backward()
445 optimizer.step()
446 # Log metrics with wandb
447 metrics = {
448 "loss": loss.item(),
449 "epoch": epoch,
450 "page": page,
451 "optimizer_step": optimizer_step
452 }
453 mlflow.log_metrics(metrics, step=optimizer_step)
454 log.debug(f"page: {page}, epoch: {epoch}, batch_idx: {batch_idx}, loss: {loss.item()}")
455
456 optimizer_step += 1
457 progress.update(task_id, advance=1, )
458
459 await dataset.load_next_page()
460
461 log.info("Training loop with mlflow tracking completed.")
462
463 mlflow.pytorch.log_model(pytorch_model=model, artifact_path="model")
464
465 model_uri = f"runs:/{run.info.run_id}/model"
466 model_details = mlflow.register_model(model_uri=model_uri, name="MyModel")
467 log.info(f"registered model in mlfow model regestry. Details: \n {pprint.pformat(dict(model_details))}")
468
469
470if __name__ == '__main__':
471 logging.basicConfig(level=logging.DEBUG)
472 MlPytorchRegressionMlflowService.main()