nigroup/nideep

View on GitHub
nideep/nets/net_merge.py

Summary

Maintainability
D
1 day
Test Coverage
'''
Created on Dec 18, 2015

@author: kashefy
'''
from sets import Set
import numpy as np
from google.protobuf import text_format
from nideep.proto.proto_utils import Parser

def is_singular_layer_type(layer_type_name):
    return 'data' in layer_type_name.lower() \
        and 'output' not in layer_type_name.lower()

def suffix_fmt_idx(idx):
    return '_nidx_%02d' % (idx,)

def merge_indep_net_spec(net_specs, suffix_fmt=suffix_fmt_idx):
    """
    list of network specification instances
    suffix formatter given index
    """
    data_tops = [l.top for n in net_specs for l in n.layer if is_singular_layer_type(l.type)]
    data_tops = Set([item for sublist in data_tops for item in sublist])

    for idx, n in enumerate(net_specs):

        suffix = suffix_fmt(idx)
        throw_away = []
        for l in n.layer:
            if not is_singular_layer_type(l.type):
                l.name += suffix
                if np.prod([p.lr_mult for p in l.param]) == 0:
                    print "LAYER WITH FIXED WEIGHTS. MAKE SHARED?"

                for b in list(l.bottom):
                    l.bottom.remove(b)
                    if b not in data_tops:
                        l.bottom.append(unicode(b + suffix))
                    else:
                        l.bottom.append(unicode(b))  # preserve order of layer bottoms, label as bottom has to come last

                for b in list(l.top):
                    l.top.remove(b)
                    if b not in data_tops:
                        l.top.append(unicode(b + suffix))
                    else:
                        l.top.append(unicode(b))  # preserve order of layer tops

            elif idx > 0:
                throw_away.append(l)

        for l in throw_away:
            n.layer.remove(l)  # Data layers of first net only

    proto_str = ''
    for idx, n in enumerate(net_specs):

        s = text_format.MessageToString(n)
        if idx > 0 and s.startswith("name:"):
            _, s = s.split('\n', 1)
        proto_str += s

    return proto_str

if __name__ == '__main__':

    fpath_net_1 = '/home/kashefy/models/dark/mnist/tx/tx1.prototxt'
    fpath_net_2 = '/home/kashefy/models/dark/mnist/tx/tx2.prototxt'

    n1 = Parser().from_net_params_file(fpath_net_1)
    n2 = Parser().from_net_params_file(fpath_net_2)
    n_str = merge_indep_net_spec([n1, n2])

    fpath_dst = "/home/kashefy/models/dark/mnist/tx/m.prototxt"

    with open(fpath_dst, 'w') as f:
        f.write(n_str)

    import caffe
    n = caffe.Net("/home/kashefy/models/dark/mnist/tx/m.prototxt", caffe.TRAIN)
    n = caffe.Net("/home/kashefy/models/dark/mnist/tx/m.prototxt", caffe.TEST)