Reputation: 5267
After noticing that my custom implementation of first order MAML might be wrong I decided to google how the official way to do first order MAML is. I found a useful gitissue that suggests to stop tracking the higher order gradients. Which makes complete sense to me. No more derivatives over the derivatives. But when I tried setting it to false (so that no higher derivatives are tracked) I got that there was no more training of my models and the .grad
fiedl was None
. Which is obviously wrong.
Is this a bug in higher or what is going on?
To reproduce run the official MAML example higher has but slightly modified here. The main code is this though:
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
for few-shot Omniglot classification.
For more details see the original MAML paper:
https://arxiv.org/abs/1703.03400
This code has been modified from Jackie Loong's PyTorch MAML implementation:
https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
Our MAML++ fork and experiments are available at:
https://github.com/bamos/HowToTrainYourMAMLPytorch
"""
import argparse
import time
import typing
import pandas as pd
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('bmh')
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import higher
from support.omniglot_loaders import OmniglotNShot
def main():
argparser = argparse.ArgumentParser()
argparser.add_argument('--n_way', type=int, help='n way', default=5)
argparser.add_argument(
'--k_spt', type=int, help='k shot for support set', default=5)
argparser.add_argument(
'--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument(
'--task_num',
type=int,
help='meta batch size, namely task num',
default=32)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
args = argparser.parse_args()
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
# Set up the Omniglot loader.
# device = torch.device('cuda')
# from uutils.torch_uu import get_device
# device = get_device()
device = torch.device(f"cuda:{gpu_idx}" if torch.cuda.is_available() else "cpu")
db = OmniglotNShot(
'/tmp/omniglot-data',
batchsz=args.task_num,
n_way=args.n_way,
k_shot=args.k_spt,
k_query=args.k_qry,
imgsz=28,
device=device,
)
# Create a vanilla PyTorch neural network that will be
# automatically monkey-patched by higher later.
# Before higher, models could *not* be created like this
# and the parameters needed to be manually updated and copied
# for the updates.
net = nn.Sequential(
nn.Conv2d(1, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Linear(64, args.n_way)).to(device)
# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.
meta_opt = optim.Adam(net.parameters(), lr=1e-3)
log = []
for epoch in range(100):
train(db, net, device, meta_opt, epoch, log)
test(db, net, device, epoch, log)
# plot(log)
def train(db, net, device, meta_opt, epoch, log):
net.train()
n_train_iter = db.x_train.shape[0] // db.batchsz
for batch_idx in range(n_train_iter):
start_time = time.time()
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()
task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
# Initialize the inner optimizer to adapt the parameters to
# the support set.
n_inner_iter = 5
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
qry_losses = []
qry_accs = []
meta_opt.zero_grad()
for i in range(task_num):
with higher.innerloop_ctx(
net, inner_opt, copy_initial_weights=False,
# track_higher_grads=True,
track_higher_grads=False,
) as (fnet, diffopt):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
# higher is able to automatically keep copies of
# your network's parameters as they are being updated.
for _ in range(n_inner_iter):
spt_logits = fnet(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
diffopt.step(spt_loss)
# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
# These will be used to update the model's meta-parameters.
qry_logits = fnet(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_losses.append(qry_loss.detach())
qry_acc = (qry_logits.argmax(
dim=1) == y_qry[i]).sum().item() / querysz
qry_accs.append(qry_acc)
# Update the model's meta-parameters to optimize the query
# losses across all of the tasks sampled in this batch.
# This unrolls through the gradient steps.
qry_loss.backward()
assert meta_opt.param_groups[0]['params'][0].grad is not None
meta_opt.step()
qry_losses = sum(qry_losses) / task_num
qry_accs = 100. * sum(qry_accs) / task_num
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
if batch_idx % 4 == 0:
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)
log.append({
'epoch': i,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
})
def test(db, net, device, epoch, log):
# Crucially in our testing procedure here, we do *not* fine-tune
# the model during testing for simplicity.
# Most research papers using MAML for this task do an extra
# stage of fine-tuning here that should be added if you are
# adapting this code for research.
net.train()
n_test_iter = db.x_test.shape[0] // db.batchsz
qry_losses = []
qry_accs = []
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)
# doesn't have to be duplicated between `train` and `test`?
n_inner_iter = 5
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
for i in range(task_num):
with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
for _ in range(n_inner_iter):
spt_logits = fnet(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
diffopt.step(spt_loss)
# The query loss and acc induced by these parameters.
qry_logits = fnet(x_qry[i]).detach()
qry_loss = F.cross_entropy(
qry_logits, y_qry[i], reduction='none')
qry_losses.append(qry_loss.detach())
qry_accs.append(
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
print(
f'[Epoch {epoch + 1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
)
log.append({
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
})
def plot(log):
# Generally you should pull your plotting code out of your training
# script but we are doing it here for brevity.
df = pd.DataFrame(log)
fig, ax = plt.subplots(figsize=(6, 4))
train_df = df[df['mode'] == 'train']
test_df = df[df['mode'] == 'test']
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_ylim(70, 100)
fig.legend(ncol=2, loc='lower right')
fig.tight_layout()
fname = 'maml-accs.png'
print(f'--- Plotting accuracy to {fname}')
fig.savefig(fname)
plt.close(fig)
# Won't need this after this PR is merged in:
# https://github.com/pytorch/pytorch/pull/22245
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
if __name__ == '__main__':
main()
Note:
I asked a similar question here Would making the gradient "data" by detaching them implement first order MAML using PyTorch's higher library? but that one is slightly different. It is asking about a custom implementation that detaches the gradients directly to make them "data". This one is asking why the setting track_higher_grads=False
screws up the population of gradients -- which as I understand should not.
related:
Explain the reasoning of why the solution here works i.e. why
track_higher_grads = True
...
diffopt.step(inner_loss, grad_callback=lambda grads: [g.detach() for g in grads])
computed FO maml but:
new_params = params[:]
for group, mapping in zip(self.param_groups, self._group_to_param_list):
for p, index in zip(group['params'], mapping):
if self._track_higher_grads:
new_params[index] = p
else:
new_params[index] = p.detach().requires_grad_() # LIKELY THIS LINE!!!
does not allow FO to work properly and sets .grads to None (not populate the grad field). The assignment with p.detach().requires_grad_()
honestly looks the same to me. This .requires_grad_()
evens seems extra "safe".
Upvotes: 1
Views: 1352
Reputation: 8855
The reason why track_higher_grads=False
doesn't actually work is that it detaches the gradients of the post-adaptation parameters rather than just the gradients (see here). So you get no gradient at all from your outer loop loss. What you really want is just to detach the gradients on just the inner loop-computed gradients, but leave the (otherwise trivial) computation graph between model initialization and adapted parameters intact.
Upvotes: 1
Reputation: 5267
I think I found the solution, though it's hard to confirm with 100% confidence it's correct since I don't fully understand it but I've done multiple sanity checks and it does change the behavior of higher and speed of code -- I am assuming this does make FO work:
track_higher_grads = True
diffopt.step(inner_loss, grad_callback=lambda grads: [g.detach() for g in grads])
Sanity checks:
do track_higher_order_grads = True but without Eric's grads_callback trick:
1.111317753791809
with deterministic code. So if I run it again it should print the same number.
1.1113194227218628
close enough!🙂 . Now let's change the seed (from 0 to 42, 142, 1142), the grad norm value should change:
1.5447670221328735
1.1538511514663696
1.8301351070404053
now returning to zero:
1.1113179922103882
close enough again!🙂
Now if eric's trick works (passing a grads callback), then the gradient value should change since it's now using FO and no higher order info. So will change my code in steps. First I will leave the track track_higher_order_grads = True and use the call back. This gives this gradient:
0.09500227868556976
Running it again I get (to confirm determinism of code):
0.09500067681074142
confirming that this combination does something different (i.e. his grads_callback changes the behaviour).
Now what if I use Eric's call back but use track_higher_order_grads=False:
AttributeError: 'NoneType' object has no attribute 'norm'
gives a bug. So setting track_higher_order_grads is always wrong it seems.
This makes me feel your solution at least changes the behaviour though I don't know why it works or why the original code by higher doesn't work.
--
Now I will check how fast the code runs by reading the output of tdqm. If it's truly doing FO (and not using higher grads), then there should be some speed up. Running this in my m1 laptop. The combination for the following run is track_higher_grads = True and diffopt.step(inner_loss, grad_callback=lambda grads: [g.detach() for g in grads]) so this should be FO (the faster one). So it should end quicker than the next run with higher grads/hessians:
0.03890747204422951
100%|██████████| 100/100 [06:32<00:00, 3.92s/it, accuracy=0.5092]
Now with track_higher_grads = True and diffopt.step(inner_loss) , which is with higher grads (hessian):
0.08946451544761658
100%|██████████| 100/100 [09:59<00:00, 6.00s/it, accuracy=0.9175]
since it's taking much longer I will conclude this indeed uses hessians & it's NOT the fo maml. I assume the difference would be more noticeable if the networks was larger (due to ~ quadratic size of Hessien).
reproducible code:
"""
For correctness see details here:
- SO: https://stackoverflow.com/questions/70961541/what-is-the-official-implementation-of-first-order-maml-using-the-higher-pytorch/74270560#74270560
- gitissue: https://github.com/facebookresearch/higher/issues/102
"""
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import logging
from collections import OrderedDict
import higher # tested with higher v0.2
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader
logger = logging.getLogger(__name__)
def conv3x3(in_channels, out_channels, **kwargs):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs),
nn.BatchNorm2d(out_channels, momentum=1., track_running_stats=False),
nn.ReLU(),
nn.MaxPool2d(2)
)
class ConvolutionalNeuralNetwork(nn.Module):
def __init__(self, in_channels, out_features, hidden_size=64):
super(ConvolutionalNeuralNetwork, self).__init__()
self.in_channels = in_channels
self.out_features = out_features
self.hidden_size = hidden_size
self.features = nn.Sequential(
conv3x3(in_channels, hidden_size),
conv3x3(hidden_size, hidden_size),
conv3x3(hidden_size, hidden_size),
conv3x3(hidden_size, hidden_size)
)
self.classifier = nn.Linear(hidden_size, out_features)
def forward(self, inputs, params=None):
features = self.features(inputs)
features = features.view((features.size(0), -1))
logits = self.classifier(features)
return logits
def get_accuracy(logits, targets):
"""Compute the accuracy (after adaptation) of MAML on the test/query points
Parameters
----------
logits : `torch.FloatTensor` instance
Outputs/logits of the model on the query points. This tensor has shape
`(num_examples, num_classes)`.
targets : `torch.LongTensor` instance
A tensor containing the targets of the query points. This tensor has
shape `(num_examples,)`.
Returns
-------
accuracy : `torch.FloatTensor` instance
Mean accuracy on the query points
"""
_, predictions = torch.max(logits, dim=-1)
return torch.mean(predictions.eq(targets).float())
def train(args):
logger.warning('This script is an example to showcase the data-loading '
'features of Torchmeta in conjunction with using higher to '
'make models "unrollable" and optimizers differentiable, '
'and as such has been very lightly tested.')
dataset = omniglot(args.folder,
shots=args.num_shots,
ways=args.num_ways,
shuffle=True,
test_shots=15,
meta_train=True,
download=args.download,
)
dataloader = BatchMetaDataLoader(dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers)
model = ConvolutionalNeuralNetwork(1,
args.num_ways,
hidden_size=args.hidden_size)
model.to(device=args.device)
model.train()
inner_optimiser = torch.optim.SGD(model.parameters(), lr=args.step_size)
meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training loop
# understanding ETA: https://github.com/tqdm/tqdm/issues/40, 00:05<00:45 means 5 seconds have elapsed and a further (estimated) 45 remain. < is used as an ASCII arrow really rather than a less than sign.
with tqdm(dataloader, total=args.num_batches) as pbar:
for batch_idx, batch in enumerate(pbar):
model.zero_grad()
train_inputs, train_targets = batch['train']
train_inputs = train_inputs.to(device=args.device)
train_targets = train_targets.to(device=args.device)
test_inputs, test_targets = batch['test']
test_inputs = test_inputs.to(device=args.device)
test_targets = test_targets.to(device=args.device)
outer_loss = torch.tensor(0., device=args.device)
accuracy = torch.tensor(0., device=args.device)
for task_idx, (train_input, train_target, test_input,
test_target) in enumerate(zip(train_inputs, train_targets,
test_inputs, test_targets)):
track_higher_grads = True
# track_higher_grads = False
with higher.innerloop_ctx(model, inner_optimiser, track_higher_grads=track_higher_grads, copy_initial_weights=False) as (fmodel, diffopt):
train_logit = fmodel(train_input)
inner_loss = F.cross_entropy(train_logit, train_target)
diffopt.step(inner_loss)
# diffopt.step(inner_loss, grad_callback=lambda grads: [g.detach() for g in grads])
test_logit = fmodel(test_input)
outer_loss += F.cross_entropy(test_logit, test_target)
# inspired by https://github.com/facebookresearch/higher/blob/15a247ac06cac0d22601322677daff0dcfff062e/examples/maml-omniglot.py#L165
# outer_loss = F.cross_entropy(test_logit, test_target)
# outer_loss.backward()
with torch.no_grad():
accuracy += get_accuracy(test_logit, test_target)
outer_loss.div_(args.batch_size)
accuracy.div_(args.batch_size)
outer_loss.backward()
# print(list(model.parameters()))
# print(f"{meta_optimizer.param_groups[0]['params'] is list(model.parameters())}")
# print(f"{meta_optimizer.param_groups[0]['params'][0].grad is not None=}")
# print(f"{meta_optimizer.param_groups[0]['params'][0].grad=}")
print(f"{meta_optimizer.param_groups[0]['params'][0].grad.norm()}")
assert meta_optimizer.param_groups[0]['params'][0].grad is not None
meta_optimizer.step()
pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
if batch_idx >= args.num_batches:
break
# Save model
if args.output_folder is not None:
filename = os.path.join(args.output_folder, 'maml_omniglot_'
'{0}shot_{1}way.th'.format(args.num_shots, args.num_ways))
with open(filename, 'wb') as f:
state_dict = model.state_dict()
torch.save(state_dict, f)
if __name__ == '__main__':
seed = 0
import random
import numpy as np
import torch
import os
os.environ["PYTHONHASHSEED"] = str(seed)
# - make pytorch determinsitc
# makes all ops determinsitic no matter what. Note this throws an errors if you code has an op that doesn't have determinsitic implementation
torch.manual_seed(seed)
# if always_use_deterministic_algorithms:
torch.use_deterministic_algorithms(True)
# makes convs deterministic
torch.backends.cudnn.deterministic = True
# doesn't allow benchmarking to select fastest algorithms for specific ops
torch.backends.cudnn.benchmark = False
# - make python determinsitic
np.random.seed(seed)
random.seed(seed)
import argparse
parser = argparse.ArgumentParser('Model-Agnostic Meta-Learning (MAML)')
parser.add_argument('--folder', type=str, default=Path('~/data/torchmeta_data').expanduser(),
help='Path to the folder the data is downloaded to.')
parser.add_argument('--num-shots', type=int, default=5,
help='Number of examples per class (k in "k-shot", default: 5).')
parser.add_argument('--num-ways', type=int, default=5,
help='Number of classes per task (N in "N-way", default: 5).')
parser.add_argument('--step-size', type=float, default=0.4,
help='Step-size for the gradient step for adaptation (default: 0.4).')
parser.add_argument('--hidden-size', type=int, default=64,
help='Number of channels for each convolutional layer (default: 64).')
parser.add_argument('--output-folder', type=str, default=None,
help='Path to the output folder for saving the model (optional).')
parser.add_argument('--batch-size', type=int, default=16,
help='Number of tasks in a mini-batch of tasks (default: 16).')
parser.add_argument('--num-batches', type=int, default=100,
help='Number of batches the model is trained over (default: 100).')
parser.add_argument('--num-workers', type=int, default=1,
help='Number of workers for data loading (default: 1).')
parser.add_argument('--download', action='store_false',
help='Do not Download the Omniglot dataset in the data folder.')
parser.add_argument('--use-cuda', action='store_true',
help='Use CUDA if available.')
args = parser.parse_args()
args.device = torch.device('cuda' if args.use_cuda
and torch.cuda.is_available() else 'cpu')
print(f'{args.device=}')
train(args)
First order MAML in my real script:
-> it=0: train_loss=4.249784290790558, train_acc=0.24000000208616257
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=0: val_loss=3.680968999862671, val_acc=0.2666666731238365
0% (0 of 70000) | | Elapsed Time: 0:00:00 | ETA: --:--:-- | 0.0 s/itmeta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=1: train_loss=4.253764450550079, train_acc=2.712197299694683
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=1: val_loss=3.5652921199798584, val_acc=0.36666667461395264
0% (1 of 70000) || Elapsed Time: 0:00:08 | ETA: 6 days, 18:55:28 | 0.1 it/smeta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
0% (2 of 70000) || Elapsed Time: 0:00:16 | ETA: 6 days, 18:56:48 | 0.1 it/ssys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=2: train_loss=4.480343401432037, train_acc=3.732449478260403
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=2: val_loss=3.6090375185012817, val_acc=0.19999999552965164
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
0% (3 of 70000) || Elapsed Time: 0:00:25 | ETA: 6 days, 18:46:19 | 0.1 it/ssys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=3: train_loss=2.822919726371765, train_acc=0.3426572134620805
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=3: val_loss=4.102218151092529, val_acc=0.30666667222976685
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f8b2a97b280>
0% (4 of 70000) || Elapsed Time: 0:00:33 | ETA: 6 days, 18:47:29 | 0.1 it/ssys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
now not FO maml:
-> it=0: train_loss=4.590916454792023, train_acc=0.23333333432674408
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=0: val_loss=3.6842236518859863, val_acc=0.2666666731238365
0% (0 of 70000) | | Elapsed Time: 0:00:00 | ETA: --:--:-- | 0.0 s/itmeta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=1: train_loss=4.803018927574158, train_acc=2.596685569748149
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=1: val_loss=3.0977725982666016, val_acc=0.3199999928474426
0% (1 of 70000) || Elapsed Time: 0:00:16 | ETA: 13 days, 1:18:19 | 16.1 s/itmeta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
0% (2 of 70000) || Elapsed Time: 0:00:32 | ETA: 13 days, 1:09:53 | 16.1 s/itsys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=2: train_loss=4.257768213748932, train_acc=2.2006314379501504
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=2: val_loss=7.144366264343262, val_acc=0.30666665732860565
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
0% (3 of 70000) || Elapsed Time: 0:00:48 | ETA: 13 days, 1:00:01 | 16.1 s/itsys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=3: train_loss=4.1194663643836975, train_acc=1.929317718150093
sys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
-> it=3: val_loss=3.4890414476394653, val_acc=0.35333333164453506
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
meta_learner_forward_adapt_batch_of_tasks=<function meta_learner_forward_adapt_batch_of_tasks at 0x7f9fd80db280>
0% (4 of 70000) || Elapsed Time: 0:01:04 | ETA: 13 days, 0:46:34 | 16.1 s/itsys.stdout=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>
os.path.realpath(sys.stdout.name)='/Users/brandomiranda/diversity-for-predictive-success-of-meta-learning/div_src/diversity_src/experiment_mains/<stdout>'
FO is 6 days while higher order one is 13, so it's likely correct!
Upvotes: 0