Reputation: 189
I'm using a combination of PyTorch Forecasting and PyTorch Lightning, and running into an odd error. Some code below.
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=8)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=8)
.
.
.
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=0.05,
hidden_size=16, # biggest influence network size
attention_head_size=1,
dropout=0.1,
hidden_continuous_size=8,
output_size=7, # QuantileLoss has 7 quantiles by default
loss=QuantileLoss(),
log_interval=10, # log example every 10 batches
reduce_on_plateau_patience=4, # reduce learning automatically
)
trainer.fit(
tft,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader
)
However, I then run into this error and I can't figure out why. Can anyone help me figure out what to do with the below error? I tried playing around with changing the syntax for the val_dataloader, but couldn't get anything to work.
Traceback (most recent call last):
File "/model.py", line 136, in <module>
val_dataloaders=val_dataloader,
File "C:\...\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 553, in fit
self._run(model)
File "C:\...\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 912, in _run
self._pre_dispatch()
File "C:\...\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 941, in _pre_dispatch
self._log_hyperparams()
File "C:\...\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 970, in _log_hyperparams
self.logger.save()
File "C:\...\venv\lib\site-packages\pytorch_lightning\utilities\distributed.py", line 48, in wrapped_fn
return fn(*args, **kwargs)
File "C:\...\venv\lib\site-packages\pytorch_lightning\loggers\tensorboard.py", line 249, in save
save_hparams_to_yaml(hparams_file, self.hparams)
File "C:\...\venv\lib\site-packages\pytorch_lightning\core\saving.py", line 405, in save_hparams_to_yaml
yaml.dump(v)
File "C:\...\venv\lib\site-packages\yaml\__init__.py", line 290, in dump
return dump_all([data], stream, Dumper=Dumper, **kwds)
File "C:\...\venv\lib\site-packages\yaml\__init__.py", line 278, in dump_all
dumper.represent(data)
File "C:\...\lib\site-packages\yaml\representer.py", line 27, in represent
node = self.represent_data(data)
File "C:\...\venv\lib\site-packages\yaml\representer.py", line 52, in represent_data
node = self.yaml_multi_representers[data_type](self, data)
File "C:\...\venv\lib\site-packages\yaml\representer.py", line 343, in represent_object
'tag:yaml.org,2002:python/object:'+function_name, state)
File "C:\...\venv\lib\site-packages\yaml\representer.py", line 118, in represent_mapping
node_value = self.represent_data(item_value)
File "C:\...\venv\lib\site-packages\yaml\representer.py", line 52, in represent_data
node = self.yaml_multi_representers[data_type](self, data)
File "C:\...\venv\lib\site-packages\yaml\representer.py", line 343, in represent_object
'tag:yaml.org,2002:python/object:'+function_name, state)
File "C:\...\venv\lib\site-packages\yaml\representer.py", line 118, in represent_mapping
node_value = self.represent_data(item_value)
File "C:\...\venv\lib\site-packages\yaml\representer.py", line 52, in represent_data
node = self.yaml_multi_representers[data_type](self, data)
File "C:\...\venv\lib\site-packages\yaml\representer.py", line 346, in represent_object
return self.represent_sequence(tag+function_name, args)
File "C:\...\venv\lib\site-packages\yaml\representer.py", line 92, in represent_sequence
node_item = self.represent_data(item)
File "C:\...\venv\lib\site-packages\yaml\representer.py", line 48, in represent_data
node = self.yaml_representers[data_types[0]](self, data)
File "C:\...\venv\lib\site-packages\yaml\representer.py", line 286, in represent_tuple
return self.represent_sequence('tag:yaml.org,2002:python/tuple', data)
File "C:\...\venv\lib\site-packages\yaml\representer.py", line 92, in represent_sequence
node_item = self.represent_data(item)
File "C:\...\venv\lib\site-packages\yaml\representer.py", line 52, in represent_data
node = self.yaml_multi_representers[data_type](self, data)
File "C:\...\venv\lib\site-packages\yaml\representer.py", line 331, in represent_object
if function.__name__ == '__newobj__':
AttributeError: 'functools.partial' object has no attribute '__name__'
Process finished with exit code 1
Upvotes: 0
Views: 861
Reputation: 189
This ended up being caused by an issue with a recent pandas upgrade. Rolling back to 1.2.5 resolved the issue.
pip install --upgrade pandas==1.2.5
Details on the problem in the link below.
https://github.com/pandas-dev/pandas/issues/42748
Upvotes: 3