Skip to content
Snippets Groups Projects
Commit e825d842 authored by hsfzxjy's avatar hsfzxjy
Browse files

Fix bug

parent d7f6282c
No related branches found
No related tags found
No related merge requests found
......@@ -19,7 +19,6 @@ from torch.nn import functional as F
from utils.utils import AverageMeter
from utils.utils import get_confusion_matrix
from utils.utils import adjust_learning_rate
from utils.utils import freeze_layers, open_all_layers
import utils.distributed as dist
......@@ -43,10 +42,6 @@ def train(config, epoch, num_epoch, epoch_iters, base_lr,
# Training
model.train()
if epoch <= config.TRAIN.FREEZE_EPOCHS:
freeze_layers(model, config.TRAIN.FREEZE_LAYERS)
else:
open_all_layers(model)
batch_time = AverageMeter()
ave_loss = AverageMeter()
tic = time.time()
......
......@@ -164,6 +164,9 @@ def main():
pin_memory=True,
drop_last=True,
sampler=extra_train_sampler)
extra_epoch_iters = np.int(extra_train_dataset.__len__() /
config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))
test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
test_dataset = eval('datasets.'+config.DATASET.DATASET)(
......@@ -250,7 +253,9 @@ def main():
checkpoint = torch.load(model_state_file, map_location={'cuda:0': 'cpu'})
best_mIoU = checkpoint['best_mIoU']
last_epoch = checkpoint['epoch']
model.module.load_state_dict(checkpoint['state_dict'])
dct = checkpoint['state_dict']
model.module.model.load_state_dict({k.replace('model.', ''): v for k, v in checkpoint['state_dict'].items() if k.startswith('model.')})
optimizer.load_state_dict(checkpoint['optimizer'])
logger.info("=> loaded checkpoint (epoch {})"
.format(checkpoint['epoch']))
......@@ -260,7 +265,7 @@ def main():
start = timeit.default_timer()
end_epoch = config.TRAIN.END_EPOCH + config.TRAIN.EXTRA_EPOCH
num_iters = config.TRAIN.END_EPOCH * epoch_iters
extra_iters = config.TRAIN.EXTRA_EPOCH * epoch_iters
extra_iters = config.TRAIN.EXTRA_EPOCH * extra_epoch_iters
for epoch in range(last_epoch, end_epoch):
......@@ -273,7 +278,7 @@ def main():
if epoch >= config.TRAIN.END_EPOCH:
train(config, epoch-config.TRAIN.END_EPOCH,
config.TRAIN.EXTRA_EPOCH, epoch_iters,
config.TRAIN.EXTRA_EPOCH, extra_epoch_iters,
config.TRAIN.EXTRA_LR, extra_iters,
extra_trainloader, optimizer, model, writer_dict)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment