Reputation: 3
` import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset import numpy as np from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split
# Define the deep neural network model
class DNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(DNN, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.sigmoid(out)
return out
# Load the breast cancer dataset
data = load_breast_cancer()
X = data.data
y = data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Define the number of training rounds and the number of clients
num_rounds = 100
num_clients = 2
batch_size = 10
# Split the data into equal chunks for each client
X_splits = np.array_split(X_train, num_clients)
y_splits = np.array_split(y_train, num_clients)
# Define the loss function and optimizer
criterion = nn.BCELoss()
# Perform federated learning
global_model = DNN(X_train.shape[1], 16, 1)
optimizer = optim.SGD(global_model.parameters(), lr=.01)
for i in range(num_rounds):
local_models = []
for j in range(num_clients):
# Create a local model by copying the current global model
local_model = DNN(X_train.shape[1], 16, 1)
local_model.load_state_dict(global_model.state_dict())
# Create a dataloader for the local client's data
local_X = torch.tensor(X_splits[j], dtype=torch.float32)
local_y = torch.tensor(y_splits[j], dtype=torch.float32)
local_dataset = torch.utils.data.TensorDataset(local_X, local_y)
local_dataloader = DataLoader(local_dataset, batch_size=batch_size, shuffle=True)
# Train the local model
local_optimizer = optim.SGD(local_model.parameters(), lr=0.1)
for inputs, labels in local_dataloader:
local_optimizer.zero_grad()
outputs = local_model(inputs)
loss = criterion(outputs, labels.view(-1, 1))
loss.backward()
local_optimizer.step()
# Add the trained local model to the list of local models
local_models.append(local_model)
# Aggregate the local models to create a global model
with torch.no_grad():
for global_param, local_params in zip(global_model.parameters(), zip(*[local_model.parameters() for local_model in local_models])):
global_param.data += torch.stack(local_params).sum(0) / num_clients
# Evaluate the global model on the train dataset
global_model.eval()
with torch.no_grad():
global_outputs = global_model(torch.tensor(X_train, dtype=torch.float32))
global_loss = criterion(global_outputs, torch.tensor(y_train, dtype=torch.float32).view(-1, 1))
global_pred = (global_outputs > 0.5).int().numpy().flatten()
accuracy = np.mean(global_pred == y_train)
print(f"Round {i}, train accuracy:{accuracy}")
`
the code works perfectly upto num_rounds=96 but when the numround is given greater then or equal to 97, it shows an error:
` RuntimeError Traceback (most recent call last)
in <cell line: 47>() 79 with torch.no_grad(): 80 global_outputs = global_model(torch.tensor(X_train, dtype=torch.float32)) ---> 81 global_loss = criterion(global_outputs, torch.tensor(y_train, dtype=torch.float32).view(-1, 1)) 82 global_pred = (global_outputs > 0.5).int().numpy().flatten() 83 accuracy = np.mean(global_pred == y_train)
2 frames
/usr/local/lib/python3.9/dist-packages/torch/nn/functional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction) 3093 weight = weight.expand(new_size) 3094 -> 3095 return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) 3096 3097
RuntimeError: all elements of input should be between 0 and 1 `
Upvotes: 0
Views: 198
Reputation: 3958
Seems like your dataloader is not returning labels that fall in the desired range in all cases (as it is safe to assume that the outputs of the sigmoid activation function do fall in this range though of course you could double-check that as well. I recommend checking conformance with an assertion:
for inputs,labels in local_dataloader:
...
assert labels.max() <= 1 and labels.min() >= 0, "Labels violate assumed range"
assert outputs.max() < 1 and outputs.max() > 0, "Inputs violate assumed range"
loss = criterion(outputs, labels.view(-1, 1))
...
Upvotes: 0