The two functions, load_state and load_from_mobilenet, are utility functions for loading pre-trained weights into a PyTorch model. They handle the process of matching the parameters in the checkpoint file with the model's parameters, ensuring compatibility and providing warnings for any mismatches.


1. load_state

def load_state(net, checkpoint):
    source_state = checkpoint['state_dict']
    target_state = net.state_dict()
    new_target_state = collections.OrderedDict()
    for target_key, target_value in target_state.items():
        if target_key in source_state and source_state[target_key].size() == target_state[target_key].size():
            new_target_state[target_key] = source_state[target_key]
        else:
            new_target_state[target_key] = target_state[target_key]
            print('[WARNING] Not found pre-trained parameters for {}'.format(target_key))

    net.load_state_dict(new_target_state)

Purpose:

How It Works:

  1. Extract State Dictionaries:
  2. Iterate Over Model Parameters:
  3. Load the Updated State Dictionary:

Use Case:


2. load_from_mobilenet

def load_from_mobilenet(net, checkpoint):
    source_state = checkpoint['state_dict']
    target_state = net.state_dict()
    new_target_state = collections.OrderedDict()
    for target_key, target_value in target_state.items():
        k = target_key
        if k.find('model') != -1:
            k = k.replace('model', 'module.model')
        if k in source_state and source_state[k].size() == target_state[target_key].size():
            new_target_state[target_key] = source_state[k]
        else:
            new_target_state[target_key] = target_state[target_key]
            print('[WARNING] Not found pre-trained parameters for {}'.format(target_key))

    net.load_state_dict(new_target_state)

Purpose: