Reputation: 37
I implemented an algorithm in a Decentralized Federated Learning (DFL) environment. When I experimented with MNIST and Fashion-MNIST, I achieved an accuracy of 80–90%. However, when testing with CIFAR-10, the accuracy dropped to ~60%, whereas I expect 70–80%. I am evaluating using the entire test set of the dataset.
I have carefully prepared the data and cannot identify any issues. Is there anything specific I should consider when working with CIFAR-10 in a federated learning setting?
def get_data(args):
if args.dataset == 'mnist' or args.dataset == 'fashion-mnist':
data_file = f"{args.data_path}/{args.dataset}.npz"
dataset = np.load(data_file) #데이터 불러오기
train_X, train_y = dataset['x_train'], dataset['y_train'].astype(np.int64)
test_X, test_y = dataset['x_test'], dataset['y_test'].astype(np.int64)
if args.dataset == 'fashion-mnist':
train_X = np.reshape(train_X, (-1, 1, 28, 28))
test_X = np.reshape(test_X, (-1, 1, 28, 28))
else:
train_X = np.expand_dims(train_X, 1)
test_X = np.expand_dims(test_X, 1)
elif args.dataset == 'cifar10':
# Only load data, transformation done later
trainset = torchvision.datasets.CIFAR10(root=f"{args.data_path}/{args.dataset}/",
train=True)
# download = True,
train_X = trainset.data.transpose([0, 3, 1, 2])
train_y = np.array(trainset.targets)
testset = torchvision.datasets.CIFAR10(root=f"{args.data_path}/{args.dataset}/",
train=False)
test_X = testset.data.transpose([0, 3, 1, 2])
test_y = np.array(testset.targets)
else:
raise ValueError("Unknown dataset")
return train_X, train_y, test_X, test_y
def data_loader(dataset, inputs, targets, batch_size, is_train=True):
def cifar10_norm(x):
x -= CIFAR10_TRAIN_MEAN
x /= CIFAR10_TRAIN_STD
return x
def no_norm(x):
return x
if dataset == 'cifar10':
norm_func = cifar10_norm
else:
norm_func = no_norm
assert inputs.shape[0] == targets.shape[0]
n_examples = inputs.shape[0]
sample_rate = batch_size / n_examples
num_blocks = int(n_examples / batch_size)
if is_train:
for i in range(num_blocks):
mask = np.random.rand(n_examples) < sample_rate
if np.sum(mask) != 0:
yield (norm_func(inputs[mask].astype(np.float32) / 255.),
targets[mask]) # 픽셀값을 0 ~ 1로 정규화
else:
for i in range(num_blocks):
yield (norm_func(inputs[i * batch_size: (i+1) * batch_size].astype(np.float32) / 255.),
targets[i * batch_size: (i+1) * batch_size])
if num_blocks * batch_size != n_examples:
yield (norm_func(inputs[num_blocks * batch_size:].astype(np.float32) / 255.),
targets[num_blocks * batch_size:])
Upvotes: 0
Views: 29