The two functions, load_state
and load_from_mobilenet
, are utility functions for loading pre-trained weights into a PyTorch model. They handle the process of matching the parameters in the checkpoint file with the model's parameters, ensuring compatibility and providing warnings for any mismatches.
load_state
def load_state(net, checkpoint):
source_state = checkpoint['state_dict']
target_state = net.state_dict()
new_target_state = collections.OrderedDict()
for target_key, target_value in target_state.items():
if target_key in source_state and source_state[target_key].size() == target_state[target_key].size():
new_target_state[target_key] = source_state[target_key]
else:
new_target_state[target_key] = target_state[target_key]
print('[WARNING] Not found pre-trained parameters for {}'.format(target_key))
net.load_state_dict(new_target_state)
state_dict
) from a checkpoint into the model (net
).source_state
: The state dictionary from the checkpoint.target_state
: The state dictionary of the model (net
).target_key
):
source_state
) and its size matches, it is loaded from the checkpoint.new_target_state
) is loaded into the model using net.load_state_dict
.load_from_mobilenet
def load_from_mobilenet(net, checkpoint):
source_state = checkpoint['state_dict']
target_state = net.state_dict()
new_target_state = collections.OrderedDict()
for target_key, target_value in target_state.items():
k = target_key
if k.find('model') != -1:
k = k.replace('model', 'module.model')
if k in source_state and source_state[k].size() == target_state[target_key].size():
new_target_state[target_key] = source_state[k]
else:
new_target_state[target_key] = target_state[target_key]
print('[WARNING] Not found pre-trained parameters for {}'.format(target_key))
net.load_state_dict(new_target_state)
DataParallel
).