Commit 9a23aaa5 authored by brianhhu's avatar brianhhu

Only load state dict if pretrained=True

parent 5cc4b714
......@@ -117,10 +117,10 @@ def Model(pretrained=False, **kwargs):
# Load state dict
state_dict = torch.load('./resnet18/converted_pytorch.pt')
# Replace with keys from Pytorch model
for i, (key, val) in enumerate(list(state_dict.items())):
state_dict[list(model.state_dict())[i]] = val
del state_dict[key]
# Replace with keys from Pytorch model
for i, (key, val) in enumerate(list(state_dict.items())):
state_dict[list(model.state_dict())[i]] = val
del state_dict[key]
model.load_state_dict(state_dict)
model.load_state_dict(state_dict)
return model
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment