Reputation: 597
I want to extract all data to make the plot, not with tensorboard. My understanding is all log with loss and accuracy is stored in a defined directory since tensorboard draw the line graph.
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/
However, I wonder how all log can be extracted from the logger in pytorch lightning. The next is the code example in training part.
#model
ssl_classifier = SSLImageClassifier(lr=lr)
#train
logger = pl.loggers.TensorBoardLogger(name=f'ssl-{lr}-{num_epoch}', save_dir='lightning_logs')
trainer = pl.Trainer(progress_bar_refresh_rate=20,
gpus=1,
max_epochs = max_epoch,
logger = logger,
)
trainer.fit(ssl_classifier, train_loader, val_loader)
I had confirmed that trainer.logger.log_dir
returned directory which seems to save logs and trainer.logger.log_metrics
returned <bound method TensorBoardLogger.log_metrics of <pytorch_lightning.loggers.tensorboard.TensorBoardLogger object at 0x7efcb89a3e50>>
.
trainer.logged_metrics
returned only the log in the final epoch, like
{'epoch': 19,
'train_acc': tensor(1.),
'train_loss': tensor(0.1038),
'val_acc': 0.6499999761581421,
'val_loss': 1.2171183824539185}
Do you know how to solve the situation?
Upvotes: 9
Views: 22691
Reputation: 2268
Lightning do not store all logs by itself. All it does is streams them into the logger
instance and the logger decides what to do.
The best way to retrieve all logged metrics is by having a custom callback:
class MetricTracker(Callback):
def __init__(self):
self.collection = []
def on_validation_batch_end(trainer, module, outputs, ...):
vacc = outputs['val_acc'] # you can access them here
self.collection.append(vacc) # track them
def on_validation_epoch_end(trainer, module):
elogs = trainer.logged_metrics # access it here
self.collection.append(elogs)
# do whatever is needed
You can then access all logged stuff from the callback instance
cb = MetricTracker()
Trainer(callbacks=[cb])
cb.collection # do you plotting and stuff
Upvotes: 6
Reputation: 2279
The accepted answer is not fundamentally wrong but does not follow the official (current) guidelines by Pytorch-Lightning.
As suggested here: https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#make-a-custom-logger
It is suggested to write a class like:
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment
class MyLogger(LightningLoggerBase):
@property
def name(self):
return "MyLogger"
@property
@rank_zero_experiment
def experiment(self):
# Return the experiment object associated with this logger.
pass
@property
def version(self):
# Return the experiment version, int or str.
return "0.1"
@rank_zero_only
def log_hyperparams(self, params):
# params is an argparse.Namespace
# your code to record hyperparameters goes here
pass
@rank_zero_only
def log_metrics(self, metrics, step):
# metrics is a dictionary of metric names and values
# your code to record metrics goes here
pass
@rank_zero_only
def save(self):
# Optional. Any code necessary to save logger data goes here
# If you implement this, remember to call `super().save()`
# at the start of the method (important for aggregation of metrics)
super().save()
@rank_zero_only
def finalize(self, status):
# Optional. Any code that needs to be run after training
# finishes goes here
pass
By looking inside the class LightningLoggerBase
, one can see some suggestions of function that could be overriden.
Here is a minimalistic loggers of mine. It is highly not optimised, but would be a good first shot. I will edit if I improved it.
import collections
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
class History_dict(LightningLoggerBase):
def __init__(self):
super().__init__()
self.history = collections.defaultdict(list) # copy not necessary here
# The defaultdict in contrast will simply create any items that you try to access
@property
def name(self):
return "Logger_custom_plot"
@property
def version(self):
return "1.0"
@property
@rank_zero_experiment
def experiment(self):
# Return the experiment object associated with this logger.
pass
@rank_zero_only
def log_metrics(self, metrics, step):
# metrics is a dictionary of metric names and values
# your code to record metrics goes here
for metric_name, metric_value in metrics.items():
if metric_name != 'epoch':
self.history[metric_name].append(metric_value)
else: # case epoch. We want to avoid adding multiple times the same. It happens for multiple losses.
if (not len(self.history['epoch']) or # len == 0:
not self.history['epoch'][-1] == metric_value) : # the last values of epochs is not the one we are currently trying to add.
self.history['epoch'].append(metric_value)
else:
pass
return
def log_hyperparams(self, params):
pass
Upvotes: 3