ixxi-dante/nw2vec

View on GitHub
projects/correctness/embeddings-generate.ipynb

Summary

Maintainability
Test Coverage
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Args for training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dict{String,Any} with 11 entries:\n",
       "  \"feature-distribution\" => Bernoulli\n",
       "  \"diml1\"                => 32\n",
       "  \"nepochs\"              => 200\n",
       "  \"overlap\"              => 16\n",
       "  \"blurring\"             => 0.0\n",
       "  \"dimxiadj\"             => 16\n",
       "  \"dataset\"              => \"../../datasets/gae-benchmarks/cora.npz\"\n",
       "  \"profile\"              => nothing\n",
       "  \"saveweights\"          => nothing\n",
       "  \"savehistory\"          => \"/dev/null\"\n",
       "  \"dimxifeat\"            => 16"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args = Dict(\n",
    "    \"dataset\" => \"../../datasets/gae-benchmarks/cora.npz\",\n",
    "    \"blurring\" => 0.0,\n",
    "    \"feature-distribution\" => Bernoulli,\n",
    "    \"diml1\" => 32,\n",
    "    \"dimxiadj\" => 16,\n",
    "    \"dimxifeat\" => 16,\n",
    "    \"overlap\" => 16,\n",
    "    \"nepochs\" => 200,\n",
    "    \"savehistory\" => \"/dev/null\",\n",
    "    \"saveweights\" => nothing,\n",
    "    \"profile\" => nothing\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Implementation of the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: replacing module Utils.\n",
      "WARNING: replacing module Layers.\n",
      "WARNING: replacing module Generative.\n",
      "WARNING: replacing module Dataset.\n",
      "WARNING: using Utils.scale_center in module Main conflicts with an existing identifier.\n",
      "WARNING: using Utils.randn_like in module Main conflicts with an existing identifier.\n",
      "WARNING: using Utils.threadedlogitbinarycrossentropy in module Main conflicts with an existing identifier.\n",
      "WARNING: using Utils.adjacency_matrix_diag in module Main conflicts with an existing identifier.\n",
      "WARNING: redefining constant supported_feature_distributions\n",
      "WARNING: redefining constant feature_distributions_dict\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "train_vae!"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "include(\"../../julia/utils.jl\")\n",
    "include(\"../../julia/layers.jl\")\n",
    "include(\"../../julia/generative.jl\")\n",
    "include(\"../../julia/dataset.jl\")\n",
    "using .Utils\n",
    "using .Layers\n",
    "using .Generative\n",
    "using .Dataset\n",
    "\n",
    "using Flux\n",
    "using LightGraphs\n",
    "using ProgressMeter\n",
    "using Statistics\n",
    "using Distributions\n",
    "using Random\n",
    "import BSON\n",
    "using ArgParse\n",
    "using Profile\n",
    "import JLD\n",
    "using NPZ\n",
    "using PyCall\n",
    "\n",
    "\n",
    "# Parameters\n",
    "const klscale = 1f-3\n",
    "const regscale = 1f-3\n",
    "const profile_losses_filename = \"an2vec-losses.jlprof\"\n",
    "const supported_feature_distributions = [Bernoulli, Categorical, Normal]\n",
    "const feature_distributions_dict = Dict(lowercase(repr(d)) => d for d in supported_feature_distributions)\n",
    "\n",
    "\n",
    "\"\"\"Load the adjacency matrix and features.\"\"\"\n",
    "function dataset(args)\n",
    "    adjfeatures = npzread(args[\"dataset\"])\n",
    "\n",
    "    features = transpose(adjfeatures[\"features\"])\n",
    "\n",
    "    # Make sure we have a non-weighted graph\n",
    "    @assert Set(adjfeatures[\"adjdata\"]) == Set([1])\n",
    "\n",
    "    # Remove any diagonal elements in the matrix\n",
    "    rows = adjfeatures[\"adjrow\"]\n",
    "    cols = adjfeatures[\"adjcol\"]\n",
    "    nondiagindices = findall(rows .!= cols)\n",
    "    rows = rows[nondiagindices]\n",
    "    cols = cols[nondiagindices]\n",
    "    # Make sure indices start at 0\n",
    "    @assert minimum(rows) == minimum(cols) == 0\n",
    "\n",
    "    # Construct the graph\n",
    "    edges = LightGraphs.SimpleEdge.(1 .+ rows, 1 .+ cols)\n",
    "    g = SimpleGraphFromIterator(edges)\n",
    "\n",
    "    # Check sizes for sanity\n",
    "    @assert size(g, 1) == size(g, 2) == size(features, 2)\n",
    "    g, convert(Array{Float32}, features), convert(Array{Float32}, scale_center(features))\n",
    "end\n",
    "\n",
    "\n",
    "\"\"\"Make the model.\"\"\"\n",
    "function make_vae(;g, feature_size, args)\n",
    "    diml1, dimξadj, dimξfeat, overlap = args[\"diml1\"], args[\"dimxiadj\"], args[\"dimxifeat\"], args[\"overlap\"]\n",
    "\n",
    "    # Encoder\n",
    "    l1 = Layers.GC(g, feature_size, diml1, Flux.relu, initb = Layers.nobias)\n",
    "    lμ = Layers.Apply(Layers.VOverlap(overlap),\n",
    "        Layers.GC(g, diml1, dimξadj, initb = Layers.nobias),\n",
    "        Layers.GC(g, diml1, dimξfeat, initb = Layers.nobias))\n",
    "    llogσ = Layers.Apply(Layers.VOverlap(overlap),\n",
    "        Layers.GC(g, diml1, dimξadj, initb = Layers.nobias),\n",
    "        Layers.GC(g, diml1, dimξfeat, initb = Layers.nobias))\n",
    "    enc(x) = (h = l1(x); (lμ(h), llogσ(h)))\n",
    "    encparams = Flux.params(l1, lμ, llogσ)\n",
    "\n",
    "    # Sampler\n",
    "    sampleξ(μ, logσ) = μ .+ exp.(logσ) .* randn_like(μ)\n",
    "\n",
    "    # Decoder\n",
    "    decadj = Layers.Bilin()\n",
    "    decfeat, decparams = if args[\"feature-distribution\"] == Normal\n",
    "        println(\"Info: using Gaussian feature decoder\")\n",
    "        decfeatl1 = Dense(dimξfeat, diml1, Flux.relu, initb = Layers.nobias)\n",
    "        decfeatlμ = Dense(diml1, feature_size, initb = Layers.nobias)\n",
    "        decfeatllogσ = Dense(diml1, feature_size, initb = Layers.nobias)\n",
    "        decfeat(ξ) = (h = decfeatl1(ξ); (decfeatlμ(h), decfeatllogσ(h)))\n",
    "        decfeat, Flux.params(decadj, decfeatl1, decfeatlμ, decfeatllogσ)\n",
    "    else\n",
    "        println(\"Info: using non-Gaussian feature decoder\")\n",
    "        decfeat = Chain(\n",
    "            Dense(dimξfeat, diml1, Flux.relu, initb = Layers.nobias),\n",
    "            Dense(diml1, feature_size, initb = Layers.nobias),\n",
    "        )\n",
    "        decfeat, Flux.params(decadj, decfeat)\n",
    "    end\n",
    "    dec(ξ) = (decadj(ξ[1:dimξadj, :]), decfeat(ξ[end-dimξfeat+1:end, :]))\n",
    "\n",
    "    enc, sampleξ, dec, encparams, decparams\n",
    "end\n",
    "\n",
    "\n",
    "\"\"\"Define the function compting AUC and AP scores for model predictions (adjacency only)\"\"\"\n",
    "function make_perf_scorer(;enc, sampleξ, dec, greal::SimpleGraph, test_true_edges, test_false_edges)\n",
    "    # Convert test edge arrays to indices\n",
    "    test_true_indices = CartesianIndex.(test_true_edges[:, 1], test_true_edges[:, 2])\n",
    "    test_false_indices = CartesianIndex.(test_false_edges[:, 1], test_false_edges[:, 2])\n",
    "\n",
    "    # Prepare ground truth values for test edges\n",
    "    Areal = Array(adjacency_matrix(greal))\n",
    "    real_true = Areal[test_true_indices]\n",
    "    @assert real_true == ones(length(test_true_indices))\n",
    "    real_false = Areal[test_false_indices]\n",
    "    @assert real_false == zeros(length(test_false_indices))\n",
    "    real_all = vcat(real_true, real_false)\n",
    "\n",
    "    metrics = pyimport(\"sklearn.metrics\")\n",
    "\n",
    "    function perf(x)\n",
    "        μ = enc(x)[1]\n",
    "        Alogitpred = dec(μ)[1].data\n",
    "        pred_true = Utils.threadedσ(Alogitpred[test_true_indices])\n",
    "        pred_false = Utils.threadedσ(Alogitpred[test_false_indices])\n",
    "        pred_all = vcat(pred_true, pred_false)\n",
    "\n",
    "        metrics[:roc_auc_score](real_all, pred_all), metrics[:average_precision_score](real_all, pred_all)\n",
    "    end\n",
    "\n",
    "    perf\n",
    "end\n",
    "\n",
    "\n",
    "\"\"\"Define the model losses.\"\"\"\n",
    "function make_losses(;g, labels, feature_size, args, enc, sampleξ, dec, paramsenc, paramsdec)\n",
    "    feature_distribution = args[\"feature-distribution\"]\n",
    "    dimξadj, dimξfeat, overlap = args[\"dimxiadj\"], args[\"dimxifeat\"], args[\"overlap\"]\n",
    "    Adiag = Array{Float32}(adjacency_matrix_diag(g))\n",
    "    densityA = Float32(mean(adjacency_matrix(g)))\n",
    "    densitylabels = Float32(mean(labels))\n",
    "\n",
    "    # TODO check normalisation constants\n",
    "\n",
    "    # Kullback-Leibler divergence\n",
    "    Lkl(μ, logσ) = sum(Utils.threadedklnormal(μ, logσ))\n",
    "    κkl = Float32(size(g, 1) * (dimξadj - overlap + dimξfeat))\n",
    "\n",
    "    # Adjacency loss\n",
    "    Ladj(logitApred) = (\n",
    "        sum(threadedlogitbinarycrossentropy(logitApred, Adiag, pos_weight = (1f0 / densityA) - 1))\n",
    "        / (2 * (1 - densityA))\n",
    "    )\n",
    "    κadj = Float32(size(g, 1)^2 * log(2))\n",
    "\n",
    "    # Features loss\n",
    "    Lfeat(logitFpred, ::Type{Bernoulli}) = (\n",
    "        sum(threadedlogitbinarycrossentropy(logitFpred, labels, pos_weight = (1f0 / densitylabels) - 1))\n",
    "        / (1 - densitylabels)\n",
    "    )\n",
    "    κfeat_bernoulli = Float32(prod(size(labels)) * log(2))\n",
    "    κfeat(::Type{Bernoulli}) = κfeat_bernoulli\n",
    "\n",
    "    Lfeat(unormFpred, ::Type{Categorical}) = - softmaxcategoricallogprob(unormFpred, labels)\n",
    "    κfeat_categorical = Float32(size(g, 1) * log(feature_size))\n",
    "    κfeat(::Type{Categorical}) = κfeat_categorical\n",
    "\n",
    "    Lfeat(Fpreds, ::Type{Normal}) = ((μ, logσ) = Fpreds; sum(Utils.threadednormallogprobloss(μ, logσ, labels)))\n",
    "    κfeat_normal = Float32(prod(size(labels)) * (log(2π) + mean(labels.^2)) / 2)\n",
    "    κfeat(::Type{Normal}) = κfeat_normal\n",
    "\n",
    "    # Total loss\n",
    "    function losses(x)\n",
    "        μ, logσ = enc(x)\n",
    "        logitApred, unormfeatpred = dec(sampleξ(μ, logσ))\n",
    "        Dict(\"kl\" => klscale * Lkl(μ, logσ) / κkl,\n",
    "            \"adj\" => Ladj(logitApred) / κadj,\n",
    "            \"feat\" => Lfeat(unormfeatpred, feature_distribution) / κfeat(feature_distribution),\n",
    "            \"reg\" => regscale * Utils.regularizer(paramsdec))\n",
    "    end\n",
    "\n",
    "    function loss(x)\n",
    "        sum(values(losses(x)))\n",
    "    end\n",
    "\n",
    "    losses, loss\n",
    "end\n",
    "\n",
    "\n",
    "\"\"\"Profile runs of a function\"\"\"\n",
    "function profile_fn(n::Int64, fn, args...)\n",
    "    for i = 1:n\n",
    "        fn(args...)\n",
    "    end\n",
    "end\n",
    "\n",
    "\n",
    "\"\"\"Train a VAE.\"\"\"\n",
    "function train_vae!(;args, features, losses, loss, perf, paramsvae)\n",
    "    nepochs = args[\"nepochs\"]\n",
    "\n",
    "    history = Dict(name => zeros(nepochs) for name in keys(losses(features)))\n",
    "    history[\"total loss\"] = zeros(nepochs)\n",
    "    #history[\"auc\"] = zeros(nepochs)\n",
    "    #history[\"ap\"] = zeros(nepochs)\n",
    "\n",
    "    opt = ADAM(0.01)\n",
    "    @showprogress for i = 1:nepochs\n",
    "        Flux.train!(loss, paramsvae, [(features,)], opt)\n",
    "\n",
    "        lossparts = losses(features)\n",
    "        for (name, value) in lossparts\n",
    "            history[name][i] = value.data\n",
    "        end\n",
    "        history[\"total loss\"][i] = sum(values(lossparts)).data\n",
    "        #history[\"auc\"][i], history[\"ap\"][i] = perf(features)\n",
    "    end\n",
    "\n",
    "    history\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the dataset\n",
      "Making the model\n",
      "Info: using non-Gaussian feature decoder\n",
      "Training...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32mProgress: 100%|████████████████████████████████████████▊|  ETA: 0:00:02\u001b[39m"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final losses and performance metrics:\n",
      "  adj = 0.6321921944618225\n",
      "  feat = 0.8452295064926147\n",
      "  kl = 0.004350213799625635\n",
      "  reg = 0.061194006353616714\n",
      "  total loss = 1.5429660081863403\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:05:11\u001b[39m\n"
     ]
    }
   ],
   "source": [
    "println(\"Loading the dataset\")\n",
    "g, labels, features = dataset(args)\n",
    "gtrain, test_true_edges, test_false_edges = Dataset.make_blurred_test_set(g, args[\"blurring\"])\n",
    "feature_size = size(features, 1)\n",
    "\n",
    "println(\"Making the model\")\n",
    "enc, sampleξ, dec, paramsenc, paramsdec = make_vae(\n",
    "    g = gtrain, feature_size = feature_size, args = args)\n",
    "losses, loss = make_losses(\n",
    "    g = gtrain, labels = labels, feature_size = feature_size, args = args,\n",
    "    enc = enc, sampleξ = sampleξ, dec = dec,\n",
    "    paramsenc = paramsenc, paramsdec = paramsdec)\n",
    "perf = make_perf_scorer(\n",
    "    enc = enc, sampleξ = sampleξ, dec = dec,\n",
    "    greal = g, test_true_edges = test_true_edges, test_false_edges = test_false_edges)\n",
    "\n",
    "println(\"Training...\")\n",
    "paramsvae = Flux.Params()\n",
    "push!(paramsvae, paramsenc..., paramsdec...)\n",
    "history = train_vae!(\n",
    "    args = args, features = features,\n",
    "    losses = losses, loss = loss, perf = nothing,\n",
    "    paramsvae = paramsvae)\n",
    "\n",
    "println(\"Final losses and performance metrics:\")\n",
    "for (name, values) in history\n",
    "    println(\"  $name = $(values[end])\")\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save the embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "embμ = enc(features)[1].data\n",
    "npzwrite(\"cora.emb.npz\", embμ')"
   ]
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": "cb42669768154a9699807ff8a30ce041",
   "lastKernelId": "61ed97eb-7117-498d-994b-c53a9c4f1c68"
  },
  "kernelspec": {
   "display_name": "Julia 1.1.0",
   "language": "julia",
   "name": "julia-1.1"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.1.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}