ixxi-dante/nw2vec

View on GitHub
julia/an2vec-generative.jl

Summary

Maintainability
Test Coverage
include("utils.jl")
include("generative.jl")
include("vae.jl")
using .Utils
using .Generative
using .VAE

using LinearAlgebra
using Flux
import BSON
using ArgParse
using Profile
import JLD
using NPZ
using StatsBase


# Parameters
const profile_losses_filename = "an2vec-losses.jlprof"

"""Parse CLI arguments."""
function parse_cliargs()
    parse_settings = ArgParseSettings()
    @add_arg_table parse_settings begin
        "-l"
            help = "number of structural communities"
            arg_type = Int
            # default = 10
            required = true
        "-k"
            help = "size of each structural community"
            arg_type = Int
            # default = 10
            required = true
        "--p_in"
            help = "structural intra-community connection probability"
            arg_type = Float64
            default = 0.4
        "--p_out"
            help = "structural extra-community connection probability"
            arg_type = Float64
            default = 0.01
        "--correlation"
            help = "correlation between features and structural communities"
            arg_type = Float64
            required = true
        "--featuretype"
            help = "generative model for the features; one of \"colors\", \"sbm\", \"lowrank\""
            arg_type = String
            required = true
        "--fsbm_l"
            help = "number of sbmfeatures communities (ignore when using colors)"
            arg_type = Int
        "--fsbm_k"
            help = "size of each sbmfeatures community (ignore when using colors)"
            arg_type = Int
        "--fsbm_p_in"
            help = "sbmfeatures intra-community connection probability (ignore when using colors)"
            arg_type = Float64
        "--fsbm_p_out"
            help = "sbmfeatures extra-community connection probability (ignore when using colors)"
            arg_type = Float64
        "--fsbm_seed"
            help = "sbmfeatures graph seed"
            arg_type = Int
            default = -1
        "--flowrank_dim"
            help = "Dimension for low rank clustered hyperplane features"
            arg_type = Int
        "--flowrank_rank"
            help = "Rank for low rank clustered hyperplane features"
            arg_type = Int
        "--flowrank_noisescale"
            help = "Scale of noise around each community centroid"
            arg_type = Float64
        "--flowrank_catdiag"
            help = """
                concatenate I_N to features and possibly labels too;
                if provided, must be either "input" for catdiag on encoder input and not on decoder output, or "both" for input and output"""
            arg_type = String
        "--gseed"
            help = "seed for generation of the graph"
            arg_type = Int
            default = -1
        "--diml1enc"
            help = "dimension of intermediary encoder layer"
            arg_type = Int
            # default = 10
            required = true
        "--diml1dec"
            help = "dimension of intermediary decoder layer"
            arg_type = Int
            # default = 10
            required = true
        "--dimxiadj"
            help = "embedding dimensions for adjacency"
            arg_type = Int
            # default = 2
            required = true
        "--dimxifeat"
            help = "embedding dimensions for features"
            arg_type = Int
            # default = 2
            required = true
        "--overlap"
            help = "overlap of adjacency and feature embeddings"
            arg_type = Int
            # default = 1
            required = true
        "--bias"
            help = "activate/deactivate bias in the VAE"
            arg_type = Bool
            required = true
        "--sharedl1"
            help = "share/unshare encoder first layer across features and adjacency"
            arg_type = Bool
            required = true
        "--decadjdeep"
            help = "deep/shallow adjacency decoder"
            arg_type = Bool
            required = true
        "--nepochs"
            help = "number of epochs to train for"
            arg_type = Int
            default = 1000
        "--savehistory"
            help = "file to save the training history (as npz)"
            arg_type = String
            required = true
        "--saveweights"
            help = "file to save the final model weights and creation parameters (as Bson)"
            arg_type = String
        "--savedataset"
            help = "file to save the training dataset (as Bson)"
            arg_type = String
        "--profile"
            help = """
                profile n loss runs instead of training the model;
                overrides nepochs and save* options; results are saved to "$(profile_losses_filename)"."""
            arg_type = Int
    end

    parsed = parse_args(ARGS, parse_settings)

    @assert 0 <= parsed["correlation"] <= 1
    parsed["initb"] = parsed["bias"] ? (s) -> zeros(Float32, s) : VAE.Layers.nobias
    parsed["label-distribution"], dimfeat, dimlabels = if parsed["featuretype" ] == "colors"
        (VAE.Categorical, parsed["l"], parsed["l"])
    elseif parsed["featuretype" ] == "sbm"
        @assert all((k) -> parsed[k] != nothing, ["fsbm_l", "fsbm_k", "fsbm_p_in", "fsbm_p_out"])
        @assert parsed["l"] * parsed["k"] == parsed["fsbm_l"] * parsed["fsbm_k"]
        (VAE.Bernoulli, parsed["l"] * parsed["k"], parsed["l"] * parsed["k"])
    else
        @assert parsed["featuretype"] == "lowrank"
        @assert parsed["flowrank_rank"] != nothing
        @assert parsed["flowrank_dim"] != nothing
        @assert parsed["flowrank_noisescale"] != nothing
        if parsed["flowrank_catdiag"] == nothing
            _dimfeat = parsed["flowrank_dim"]
            _dimlabels = _dimfeat
        elseif parsed["flowrank_catdiag"] == "input"
            _dimfeat = parsed["flowrank_dim"] + parsed["l"] * parsed["k"]
            _dimlabels = parsed["flowrank_dim"]
        else
            @assert parsed["flowrank_catdiag"] == "both"
            _dimfeat = parsed["flowrank_dim"] + parsed["l"] * parsed["k"]
            _dimlabels = _dimfeat
        end
        (VAE.Normal, _dimfeat, _dimlabels)
    end
    parsed
end


"""Define the graph and features."""
function dataset(args)
    l, k, p_in, p_out, gseed, correlation = args["l"], args["k"], args["p_in"], args["p_out"], args["gseed"], args["correlation"]

    g, communities = Generative.make_sbm(l, k, p_in, p_out, gseed = gseed)

    features, labels = if args["featuretype"] == "colors"
        colors = Generative.make_colors(communities, correlation)
        features = Array{Float32}(Utils.onehotmaxbatch(colors))
        features, features
    elseif args["featuretype"] == "sbm"
        fsbm_l, fsbm_k, fsbm_p_in, fsbm_p_out, fsbm_seed = args["fsbm_l"], args["fsbm_k"], args["fsbm_p_in"], args["fsbm_p_out"], args["fsbm_seed"]
        sbmfeatures, _ = Generative.make_sbmfeatures(fsbm_l, fsbm_k, fsbm_p_in, fsbm_p_out, correlation; gseed = fsbm_seed)
        features = Array{Float32}(sbmfeatures)
        features, features
    else
        @assert args["featuretype"] == "lowrank"
        flowrank_dim, flowrank_rank, flowrank_noisescale = args["flowrank_dim"], args["flowrank_rank"], args["flowrank_noisescale"]
        features = Generative.make_clusteredhyperplane(Float32, (flowrank_dim, l * k), flowrank_rank, flowrank_noisescale, communities, correlation)
        @assert rank(features) == flowrank_rank
        if args["flowrank_catdiag"] == nothing
            labels = features
        elseif args["flowrank_catdiag"] == "input"
            labels = features
            features = vcat(features, Array(Diagonal(ones(Float32, l * k))))
        else
            @assert args["flowrank_catdiag"] == "both"
            features = vcat(features, Array(Diagonal(ones(Float32, l * k))))
            labels = features
        end

        features, labels
    end

    @assert eltype(features) == Float32
    @assert eltype(labels) == Float32
    g, features, labels
end


function main()
    args = parse_cliargs()
    savehistory, saveweights, savedataset, profilen = args["savehistory"], args["saveweights"], args["savedataset"], args["profile"]
    if profilen != nothing
        println("Profiling $profilen loss runs. Ignoring any \"save*\" arguments.")
    else
        saveweights == nothing && println("Warning: will not save model weights after training")
        savedataset == nothing && println("Warning: will not save the training dataset after training")
    end

    println("Making the dataset")
    g, _features, labels = dataset(args)
    feature_size = size(_features, 1)
    label_size = size(labels, 1)
    features = normaliser(_features)(_features)

    println("Making the model")
    enc, sampleξ, dec, paramsenc, paramsdec = VAE.make_vae(
        g = g, feature_size = feature_size, label_size = label_size, args = args)
    losses, loss = VAE.make_losses(
        g = g, labels = labels, args = args,
        enc = enc, sampleξ = sampleξ, dec = dec,
        paramsenc = paramsenc, paramsdec = paramsdec)

    if profilen != nothing
        println("Profiling loss runs...")
        Utils.repeat_fn(1, loss, features)  # Trigger compilation
        Profile.clear()
        Profile.init(n = 10000000)
        @profile Utils.repeat_fn(profilen, loss, features)
        li, lidict = Profile.retrieve()
        println("Saving profile results to \"$(profile_losses_filename)\"")
        JLD.@save profile_losses_filename li lidict

        return
    end

    println("Training...")
    paramsvae = Flux.Params()
    push!(paramsvae, paramsenc..., paramsdec...)
    history = VAE.train_vae!(
        args = args, features = features,
        losses = losses, loss = loss,
        paramsvae = paramsvae)
    println("Final losses:")
    for (name, values) in history
        println("  $name = $(values[end])")
    end

    # Save results
    println("Saving training history to \"$savehistory\"")
    npzwrite(savehistory, Dict{String, Any}(history))
    if saveweights != nothing
        println("Saving final model weights and creation parameters to \"$saveweights\"")
        weights = Tracker.data.(paramsvae)
        BSON.@save saveweights weights args
    else
        println("Not saving model weights or creation parameters")
    end
    if savedataset != nothing
        println("Saving training dataset to \"$savedataset\"")
        BSON.@save savedataset g labels
    else
        println("Not saving training dataset")
    end
end

main()