julia/feature-gradients.ipynb
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m\u001b[1mActivating\u001b[22m\u001b[39m environment at `~/Code/Research/nw2vec/Project.toml`\n"
]
}
],
"source": [
"using Pkg; Pkg.activate(\"..\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/Flux/QdkVy.ji for Flux [587475ba-b771-5e3f-ad9e-33799f191a9c]\n",
"└ @ Base loading.jl:1240\n",
"┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/LightGraphs/Xm08G.ji for LightGraphs [093fc24a-ae57-5d10-9952-331d41423f4d]\n",
"└ @ Base loading.jl:1240\n",
"┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/Distributions/xILW0.ji for Distributions [31c24e10-a181-5473-b8eb-7969acd0382f]\n",
"└ @ Base loading.jl:1240\n",
"┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/AbstractPlotting/6fydZ.ji for AbstractPlotting [537997a7-5e4e-5d89-9595-2241ea00577e]\n",
"└ @ Base loading.jl:1240\n",
"┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/Makie/iZ1Bl.ji for Makie [ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a]\n",
"└ @ Base loading.jl:1240\n",
"WARNING: using Makie.AbstractPlotting in module Main conflicts with an existing identifier.\n",
"┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/CairoMakie/9mSey.ji for CairoMakie [13f3f980-e62b-5c42-98c6-ff1f3baf88f0]\n",
"└ @ Base loading.jl:1240\n",
"┌ Info: Recompiling stale cache file /home/sl/.julia/compiled/v1.2/DataFrames/AR9oZ.ji for DataFrames [a93c6f00-e57d-5684-b7b6-d8193f3e46c0]\n",
"└ @ Base loading.jl:1240\n",
"┌ Info: Precompiling Combinatorics [861a8166-3701-5b0c-9a16-15d98fcdc6aa]\n",
"└ @ Base loading.jl:1242\n"
]
}
],
"source": [
"include(\"layers.jl\")\n",
"include(\"utils.jl\")\n",
"include(\"vae.jl\")\n",
"using .Utils, .Layers, .VAE\n",
"\n",
"using Flux\n",
"using LightGraphs\n",
"using Colors\n",
"using AbstractPlotting\n",
"using Makie\n",
"using CairoMakie\n",
"using BSON: @load\n",
"using NPZ\n",
"using Random\n",
"using StatsBase\n",
"using Distributions\n",
"using LinearAlgebra\n",
"using DataFrames\n",
"using Combinatorics"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading the dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dataset (generic function with 1 method)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function dataset(args)\n",
" data = npzread(args[\"dataset\"])\n",
"\n",
" features = convert(Array{Float32}, transpose(data[\"features\"]))\n",
" classes = transpose(data[\"labels\"])\n",
"\n",
" # Make sure we have a non-weighted graph\n",
" @assert Set(data[\"adjdata\"]) == Set([1])\n",
"\n",
" # Remove any diagonal elements in the matrix\n",
" rows = data[\"adjrow\"]\n",
" cols = data[\"adjcol\"]\n",
" nondiagindices = findall(rows .!= cols)\n",
" rows = rows[nondiagindices]\n",
" cols = cols[nondiagindices]\n",
" # Make sure indices start at 0\n",
" @assert minimum(rows) == minimum(cols) == 0\n",
"\n",
" # Construct the graph\n",
" edges = LightGraphs.SimpleEdge.(1 .+ rows, 1 .+ cols)\n",
" g = SimpleGraphFromIterator(edges)\n",
"\n",
" # Check sizes for sanity\n",
" @assert nv(g) == size(g, 1) == size(g, 2) == size(features, 2)\n",
"\n",
" # Randomize to the level requested\n",
" nnodes = nv(g)\n",
" correlation = args[\"forced-correlation\"]\n",
" nshuffle = Int(round((1 - correlation) * nnodes))\n",
" idx = StatsBase.sample(1:nnodes, nshuffle, replace = false)\n",
" shuffledidx = shuffle(idx)\n",
" features[:, idx] = features[:, shuffledidx]\n",
" classes[:, idx] = classes[:, shuffledidx]\n",
"\n",
" @assert eltype(features) == Float32\n",
" g, features, classes\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"make_dataset (generic function with 1 method)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function make_dataset()\n",
" g = SimpleGraph(9)\n",
" add_edge!(g, 1, 2)\n",
" add_edge!(g, 1, 3)\n",
" add_edge!(g, 1, 4)\n",
" add_edge!(g, 1, 5)\n",
" add_edge!(g, 5, 6)\n",
" add_edge!(g, 5, 7)\n",
" add_edge!(g, 5, 8)\n",
" add_edge!(g, 8, 9)\n",
" \n",
" features = vcat(ones(Float32, (1, 9)), Array(Diagonal(ones(Float32, 9))))\n",
" \n",
" g, features\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plotting model state"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"plotweights (generic function with 1 method)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"adims(a, dims) = [a[i, :] for i = dims]\n",
"\n",
"function plotstate(;enc, vae, x, refx, g, dims, colors)\n",
" @assert length(dims) in [2, 3]\n",
" embμ, emblogσ = enc(x)\n",
" logitÂ, unormF̂ = vae(x)\n",
" hbox(\n",
" vbox(\n",
" Scene(),\n",
" heatmap(σ.(logitÂ).data, colorrange = (0, 1)),\n",
" heatmap(1:size(x, 1), 1:size(x, 2), softmax(unormF̂).data, colorrange = (0, 1)),\n",
" sizes = [.45, .45, .1]\n",
" ),\n",
" vbox(\n",
" scatter(adims(embμ, dims)..., color = colors, markersize = Utils.markersize(embμ)),\n",
" heatmap(Array(adjacency_matrix(g)), colorrange = (0, 1)),\n",
" heatmap(1:size(x, 1), 1:size(x, 2), refx, colorrange = (0, 1)),\n",
" sizes = [.45, .45, .1]\n",
" ),\n",
" )\n",
"end\n",
"\n",
"function plotweights(layers...)\n",
" theme = Theme(align = (:left, :bottom), raw = true, camera = campixel!)\n",
" vbox([hbox(heatmap(l.W.data), text(theme, repr(l))) for l in layers]...)\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model parameters"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dict{String,Any} with 11 entries:\n",
" \"diml1dec\" => 32\n",
" \"sharedl1\" => false\n",
" \"label-distribution\" => Bernoulli\n",
" \"diml1enc\" => 32\n",
" \"bias\" => false\n",
" \"overlap\" => 8\n",
" \"initb\" => nobias\n",
" \"decadjdeep\" => true\n",
" \"dimxiadj\" => 16\n",
" \"forced-correlation\" => 1.0\n",
" \"dimxifeat\" => 16"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"args = Dict(\n",
" #\"dataset\" => \"../data/twitter/git=679a9eb593-csv_to_npz-mt=5-tmw=3-w2v_dim=50-w2v_iter=10-cho=True-nclusters=10,20,50,80,100/dataset=retweetsrange-nclusters=10.npz\",\n",
" \"forced-correlation\" => 1.0, # default\n",
" \"label-distribution\" => VAE.label_distributions[\"bernoulli\"],\n",
" \"diml1enc\" => 32,\n",
" \"diml1dec\" => 32,\n",
" \"dimxiadj\" => 16,\n",
" \"dimxifeat\" => 16,\n",
" \"overlap\" => 8,\n",
" \"bias\" => false,\n",
" \"sharedl1\" => false,\n",
" \"decadjdeep\" => true,\n",
" \"initb\" => VAE.Layers.nobias\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the model and plot its state"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"#g, _features, _ = dataset(args)\n",
"g, _features = make_dataset()\n",
"labels = _features\n",
"feature_size = size(_features, 1)\n",
"label_size = feature_size\n",
"fnormalise = Utils.normaliser(_features)\n",
"features = fnormalise(_features);"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Info: using unshared l1 encoder\n",
"Info: using deep adjacency decoder\n",
"Info: using boolean feature decoder\n"
]
}
],
"source": [
"@load \"../data/twitter/git=679a9eb593-an2vec-diml1enc=32-diml1dec=32-dimxiadj=16-dimxifeat=16-overlap=0,8,16-bias=false-sharedl1=false-decadjdeep=true-nepochs=200/dataset=retweetsrange-nclusters=10-ld=bernoulli-dimxi=24-weights.bson\" weights args\n",
"\n",
"enc, sampleξ, dec, paramsenc, paramsdec = VAE.make_vae(\n",
" g = g, feature_size = feature_size, label_size = label_size, args = args)\n",
"vae(x) = dec(sampleξ(enc(x)...))\n",
"\n",
"paramsvae = Tracker.Params()\n",
"push!(paramsvae, paramsenc..., paramsdec...)\n",
"loadparams!(paramsvae, weights)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"GLMakie.Screen(...)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"history = npzread(\"../data/twitter/git=679a9eb593-an2vec-diml1enc=32-diml1dec=32-dimxiadj=16-dimxifeat=16-overlap=0,8,16-bias=false-sharedl1=false-decadjdeep=true-nepochs=200/dataset=retweetsrange-nclusters=10-ld=bernoulli-dimxi=24-history.npz\")\n",
"\n",
"theme = Theme(align = (:left, :bottom), raw = true, camera = campixel!)\n",
"scene = vbox([hbox(lines(1:length(history[name]), history[name], color = color), text(theme, name))\n",
" for (name, color) in [\n",
" (\"total loss\", :blue),\n",
" (\"kl\", :red),\n",
" (\"reg\", :red),\n",
" (\"adj\", :green),\n",
" (\"feat\", :green),\n",
" #(\"ap\", :cyan),\n",
" #(\"auc\", :cyan)\n",
" ]]...)\n",
"#Makie.save(\"training-history.png\", scene)\n",
"display(scene)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"GLMakie.Screen(...)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#communities = [c for c in 1:args[\"l\"] for i in 1:args[\"k\"]]\n",
"#palette = distinguishable_colors(args[\"l\"])\n",
"#colors = map(i -> getindex(palette, i), communities)\n",
"\n",
"display(\n",
" plotstate(enc = enc, vae = vae, x = _features, refx = labels,\n",
" g = g, dims = 1:3, colors = \"black\")\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Distributions of gradients"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ae (generic function with 1 method)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ae(x) = dec(enc(x)[1])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"make_onehot (generic function with 1 method)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function make_onehot(coord, dims)\n",
" out = zeros(Float32, dims)\n",
" out[coord] = 1\n",
" out\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### w.r.t. features"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Float32[0.97156286 0.88115096 … 0.8797154 0.6768692; 0.9004984 0.7865321 … 0.73746455 0.5788719; … ; 0.9144334 0.7934744 … 0.9269219 0.8289111; 0.7496242 0.6469048 … 0.8417549 0.7637189] (tracked), getfield(Tracker, Symbol(\"##21#23\")){getfield(Tracker, Symbol(\"##18#19\")){Tracker.Params,TrackedArray{…,Array{Float32,2}}}}(Core.Box((Float32[1.8973668 1.8973668 … 1.8973664 1.8973664; 1.8973668 -0.4743417 … -0.4743416 -0.4743416; … ; -0.4743417 -0.4743417 … 1.8973664 -0.4743416; -0.4743417 -0.4743417 … -0.4743416 1.8973664] (tracked),)), getfield(Tracker, Symbol(\"##18#19\")){Tracker.Params,TrackedArray{…,Array{Float32,2}}}(Params([Float32[1.8973668 1.8973668 … 1.8973664 1.8973664; 1.8973668 -0.4743417 … -0.4743416 -0.4743416; … ; -0.4743417 -0.4743417 … 1.8973664 -0.4743416; -0.4743417 -0.4743417 … -0.4743416 1.8973664] (tracked)]), Float32[0.97156286 0.88115096 … 0.8797154 0.6768692; 0.9004984 0.7865321 … 0.73746455 0.5788719; … ; 0.9144334 0.7934744 … 0.9269219 0.8289111; 0.7496242 0.6469048 … 0.8417549 0.7637189] (tracked))))"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Apred, back = Tracker.forward(x -> σ.(ae(x)[1]), features)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pre-compute all shortest path distances"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9×9 Array{Int64,2}:\n",
" 0 1 1 1 1 2 2 2 3\n",
" 1 0 2 2 2 3 3 3 4\n",
" 1 2 0 2 2 3 3 3 4\n",
" 1 2 2 0 2 3 3 3 4\n",
" 1 2 2 2 0 1 1 1 2\n",
" 2 3 3 3 1 0 2 2 3\n",
" 2 3 3 3 1 2 0 2 3\n",
" 2 3 3 3 1 2 2 0 1\n",
" 3 4 4 4 2 3 3 1 0"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"distances = floyd_warshall_shortest_paths(g).dists"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get all gradients w.r.t. involved neighbours for each link prediction"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"nnodes = size(Apred)[1]\n",
"@assert nnodes == size(Apred)[2]\n",
"nfeatures = size(features)[1]\n",
"nhops = 2\n",
"onehot = zeros((nnodes, nnodes))\n",
"coord = CartesianIndex(1, 1)\n",
"\n",
"df = DataFrame(;((Symbol(\"grad$i\"), Float32[]) for i in 1:nfeatures)...,\n",
" u = Int64[], v = Int64[], n = Int64[], dist_uv = Int64[], dist_uorv = Int64[])\n",
"\n",
"@showprogress for (u, v) in combinations(1:nnodes, 2)\n",
" #print(\"$u - $v: \")\n",
" \n",
" # Get gradients for the u-v prediction\n",
" global coord\n",
" onehot[coord] = 0\n",
" coord = CartesianIndex(u, v)\n",
" onehot[coord] = 1\n",
" @assert onehot == make_onehot(CartesianIndex(u, v), (nnodes, nnodes))\n",
" grads = Tracker.data(back(onehot)[1])\n",
"\n",
" # Get the gradient for each 2-hop neigbour of u, v\n",
" dist_uv = distances[u, v]\n",
" neighbours_uv = collect(union(neighborhood(g, u, nhops), neighborhood(g, v, nhops)))\n",
" non_neighbours_uv = setdiff(1:nnodes, neighbours_uv)\n",
" \n",
" # All gradients for neighbours should be non-null\n",
" @assert all(sum(grads[:, neighbours_uv] .!= 0, dims = 1) .> 0)\n",
" # All gradients for non-neighbours should be null\n",
" @assert all(sum(grads[:, non_neighbours_uv] .== 0, dims = 1) .> 0)\n",
"\n",
" #println(\"$(length(neighbours_uv)) neighbours to both\")\n",
" for n in neighbours_uv\n",
" push!(df, (grads[:, n]..., u, v, n, dist_uv, min(distances[u, n], distances[v, n])))\n",
" end\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table class=\"data-frame\"><thead><tr><th></th><th>grad1</th><th>grad2</th><th>grad3</th><th>grad4</th><th>grad5</th><th>grad6</th><th>grad7</th><th>grad8</th><th>grad9</th><th>grad10</th><th>u</th><th>v</th><th>n</th><th>dist_uv</th><th>dist_uorv</th></tr><tr><th></th><th>Float32</th><th>Float32</th><th>Float32</th><th>Float32</th><th>Float32</th><th>Float32</th><th>Float32</th><th>Float32</th><th>Float32</th><th>Float32</th><th>Int64</th><th>Int64</th><th>Int64</th><th>Int64</th><th>Int64</th></tr></thead><tbody><p>276 rows × 15 columns</p><tr><th>1</th><td>0.0645079</td><td>-0.021225</td><td>0.00150802</td><td>0.00318302</td><td>-0.0319299</td><td>0.00660334</td><td>0.00870657</td><td>0.0306246</td><td>-0.00391502</td><td>-0.0416326</td><td>1</td><td>2</td><td>1</td><td>1</td><td>0</td></tr><tr><th>2</th><td>0.0653441</td><td>-0.0255881</td><td>-0.0114106</td><td>0.0108165</td><td>-0.0316851</td><td>-0.00136565</td><td>-0.00199008</td><td>0.0298564</td><td>0.00248509</td><td>-0.0377283</td><td>1</td><td>2</td><td>2</td><td>1</td><td>0</td></tr><tr><th>3</th><td>0.036182</td><td>-0.0138953</td><td>0.00824023</td><td>-0.00308798</td><td>-0.0130371</td><td>0.0059515</td><td>0.00703448</td><td>0.0184967</td><td>-0.00519035</td><td>-0.0281207</td><td>1</td><td>2</td><td>3</td><td>1</td><td>1</td></tr><tr><th>4</th><td>0.0292288</td><td>-0.0099957</td><td>0.00241025</td><td>0.00671407</td><td>-0.00923298</td><td>0.00159172</td><td>0.00227589</td><td>0.0133913</td><td>-0.00464004</td><td>-0.0223713</td><td>1</td><td>2</td><td>4</td><td>1</td><td>1</td></tr><tr><th>5</th><td>0.0139868</td><td>-0.00648846</td><td>-0.00108487</td><td>0.00196581</td><td>-0.00564524</td><td>0.00110447</td><td>0.000616112</td><td>0.00652828</td><td>-0.00118362</td><td>-0.0109085</td><td>1</td><td>2</td><td>5</td><td>1</td><td>1</td></tr><tr><th>6</th><td>0.00515707</td><td>-0.00153295</td><td>-9.54022e-5</td><td>-0.00106444</td><td>-0.00479411</td><td>0.00258529</td><td>0.00279811</td><td>0.00244054</td><td>-0.000862602</td><td>-0.00403407</td><td>1</td><td>2</td><td>6</td><td>1</td><td>2</td></tr><tr><th>7</th><td>0.00515707</td><td>-0.00153295</td><td>-9.54022e-5</td><td>-0.00106444</td><td>-0.00479411</td><td>0.00258529</td><td>0.00279811</td><td>0.00244054</td><td>-0.000862602</td><td>-0.00403407</td><td>1</td><td>2</td><td>7</td><td>1</td><td>2</td></tr><tr><th>8</th><td>0.00421073</td><td>-0.00125165</td><td>-7.78955e-5</td><td>-0.000869113</td><td>-0.00391437</td><td>0.00211088</td><td>0.00228465</td><td>0.00199269</td><td>-0.000704312</td><td>-0.0032938</td><td>1</td><td>2</td><td>8</td><td>1</td><td>2</td></tr><tr><th>9</th><td>0.0807522</td><td>-0.0227874</td><td>0.0254424</td><td>-0.0167195</td><td>-0.0220161</td><td>0.0101796</td><td>0.0140722</td><td>0.0501589</td><td>-0.00739399</td><td>-0.0722438</td><td>1</td><td>3</td><td>1</td><td>1</td><td>0</td></tr><tr><th>10</th><td>0.0340761</td><td>-0.0103437</td><td>-0.00257617</td><td>0.00722817</td><td>-0.0177132</td><td>0.00453164</td><td>0.00870965</td><td>0.00973111</td><td>-0.00574675</td><td>-0.0162742</td><td>1</td><td>3</td><td>2</td><td>1</td><td>1</td></tr><tr><th>11</th><td>0.0940134</td><td>-0.0318718</td><td>0.0344218</td><td>-0.0327469</td><td>-0.00542446</td><td>0.00659211</td><td>0.00370851</td><td>0.0694634</td><td>-0.00152773</td><td>-0.0996689</td><td>1</td><td>3</td><td>3</td><td>1</td><td>0</td></tr><tr><th>12</th><td>0.0324342</td><td>-0.00973636</td><td>0.00403564</td><td>0.00701846</td><td>-0.0169275</td><td>0.00535893</td><td>0.00938935</td><td>0.0100804</td><td>-0.00863605</td><td>-0.0212517</td><td>1</td><td>3</td><td>4</td><td>1</td><td>1</td></tr><tr><th>13</th><td>0.0163616</td><td>-0.00657916</td><td>-0.0015163</td><td>0.00433485</td><td>-0.00743952</td><td>0.00136546</td><td>0.003353</td><td>0.00500078</td><td>-0.00370737</td><td>-0.00774432</td><td>1</td><td>3</td><td>5</td><td>1</td><td>1</td></tr><tr><th>14</th><td>0.00629887</td><td>-0.00162777</td><td>-0.000149417</td><td>0.00192418</td><td>-0.0060904</td><td>0.00131024</td><td>0.00368191</td><td>0.001949</td><td>-0.00250137</td><td>-0.000507388</td><td>1</td><td>3</td><td>6</td><td>1</td><td>2</td></tr><tr><th>15</th><td>0.00629887</td><td>-0.00162777</td><td>-0.000149417</td><td>0.00192418</td><td>-0.0060904</td><td>0.00131024</td><td>0.00368191</td><td>0.001949</td><td>-0.00250137</td><td>-0.000507388</td><td>1</td><td>3</td><td>7</td><td>1</td><td>2</td></tr><tr><th>16</th><td>0.00514301</td><td>-0.00132907</td><td>-0.000121998</td><td>0.00157108</td><td>-0.00497279</td><td>0.0010698</td><td>0.00300627</td><td>0.00159136</td><td>-0.00204236</td><td>-0.000414281</td><td>1</td><td>3</td><td>8</td><td>1</td><td>2</td></tr><tr><th>17</th><td>0.0617292</td><td>-0.0381997</td><td>0.0128525</td><td>0.015317</td><td>-0.00830853</td><td>7.7243e-5</td><td>-0.0030023</td><td>0.0352501</td><td>0.0032523</td><td>-0.0626972</td><td>1</td><td>4</td><td>1</td><td>1</td><td>0</td></tr><tr><th>18</th><td>0.0315697</td><td>-0.0214036</td><td>0.000929661</td><td>0.0126375</td><td>-0.0146739</td><td>0.00134248</td><td>0.00210057</td><td>0.0117977</td><td>-0.00231886</td><td>-0.0223714</td><td>1</td><td>4</td><td>2</td><td>1</td><td>1</td></tr><tr><th>19</th><td>0.0374139</td><td>-0.0274874</td><td>0.014762</td><td>0.000410167</td><td>-0.00619973</td><td>0.00214736</td><td>-1.34762e-6</td><td>0.0223506</td><td>3.35925e-5</td><td>-0.0405679</td><td>1</td><td>4</td><td>3</td><td>1</td><td>1</td></tr><tr><th>20</th><td>0.0606246</td><td>-0.0319523</td><td>0.00271798</td><td>0.0253772</td><td>-0.00479889</td><td>-0.00373943</td><td>-0.0057935</td><td>0.0348819</td><td>0.00547458</td><td>-0.0587476</td><td>1</td><td>4</td><td>4</td><td>1</td><td>0</td></tr><tr><th>21</th><td>0.0144691</td><td>-0.0119399</td><td>0.000725399</td><td>0.00435129</td><td>-0.00615814</td><td>0.00146693</td><td>0.000685575</td><td>0.00692591</td><td>0.000697119</td><td>-0.0146377</td><td>1</td><td>4</td><td>5</td><td>1</td><td>1</td></tr><tr><th>22</th><td>0.00458313</td><td>-0.005771</td><td>0.00140194</td><td>-0.000148894</td><td>-0.00231272</td><td>0.00167019</td><td>0.000371735</td><td>0.00286897</td><td>0.00138584</td><td>-0.00791148</td><td>1</td><td>4</td><td>6</td><td>1</td><td>2</td></tr><tr><th>23</th><td>0.00458313</td><td>-0.005771</td><td>0.00140194</td><td>-0.000148894</td><td>-0.00231272</td><td>0.00167019</td><td>0.000371735</td><td>0.00286897</td><td>0.00138584</td><td>-0.00791148</td><td>1</td><td>4</td><td>7</td><td>1</td><td>2</td></tr><tr><th>24</th><td>0.00374211</td><td>-0.00471201</td><td>0.00114468</td><td>-0.000121571</td><td>-0.00188833</td><td>0.00136371</td><td>0.000303521</td><td>0.0023425</td><td>0.00113153</td><td>-0.0064597</td><td>1</td><td>4</td><td>8</td><td>1</td><td>2</td></tr><tr><th>25</th><td>0.0135714</td><td>0.000587793</td><td>-0.000455864</td><td>-0.00280042</td><td>-0.0118951</td><td>0.00429923</td><td>0.00695602</td><td>0.0092625</td><td>1.17823e-5</td><td>-0.00800769</td><td>1</td><td>5</td><td>1</td><td>1</td><td>0</td></tr><tr><th>26</th><td>0.00842064</td><td>0.000716397</td><td>-0.0014722</td><td>-0.000474771</td><td>-0.00648818</td><td>0.0022389</td><td>0.00288162</td><td>0.00454633</td><td>0.000500505</td><td>-0.00369009</td><td>1</td><td>5</td><td>2</td><td>1</td><td>1</td></tr><tr><th>27</th><td>0.00979808</td><td>-1.31039e-5</td><td>0.000951256</td><td>-0.00217937</td><td>-0.00753009</td><td>0.00323605</td><td>0.00495289</td><td>0.0065921</td><td>-0.00113181</td><td>-0.00520114</td><td>1</td><td>5</td><td>3</td><td>1</td><td>1</td></tr><tr><th>28</th><td>0.00692593</td><td>-0.000400284</td><td>0.000315747</td><td>0.000504271</td><td>-0.00339045</td><td>0.00123777</td><td>0.00237482</td><td>0.00346345</td><td>-0.000968839</td><td>-0.00436599</td><td>1</td><td>5</td><td>4</td><td>1</td><td>1</td></tr><tr><th>29</th><td>0.00958598</td><td>-0.00149535</td><td>-0.0023221</td><td>-0.00599477</td><td>-0.00405036</td><td>-0.00198833</td><td>-0.000829992</td><td>0.0112333</td><td>0.0046394</td><td>-0.0148051</td><td>1</td><td>5</td><td>5</td><td>1</td><td>0</td></tr><tr><th>30</th><td>0.0060058</td><td>0.00156492</td><td>-0.00212839</td><td>-0.00666747</td><td>-0.00397689</td><td>-4.11231e-5</td><td>0.00101687</td><td>0.00900213</td><td>0.00315526</td><td>-0.0106931</td><td>1</td><td>5</td><td>6</td><td>1</td><td>1</td></tr><tr><th>⋮</th><td>⋮</td><td>⋮</td><td>⋮</td><td>⋮</td><td>⋮</td><td>⋮</td><td>⋮</td><td>⋮</td><td>⋮</td><td>⋮</td><td>⋮</td><td>⋮</td><td>⋮</td><td>⋮</td><td>⋮</td></tr></tbody></table>"
],
"text/latex": [
"\\begin{tabular}{r|ccccccccccccccc}\n",
"\t& grad1 & grad2 & grad3 & grad4 & grad5 & grad6 & grad7 & grad8 & grad9 & grad10 & u & v & n & dist\\_uv & dist\\_uorv\\\\\n",
"\t\\hline\n",
"\t& Float32 & Float32 & Float32 & Float32 & Float32 & Float32 & Float32 & Float32 & Float32 & Float32 & Int64 & Int64 & Int64 & Int64 & Int64\\\\\n",
"\t\\hline\n",
"\t1 & 0.0645079 & -0.021225 & 0.00150802 & 0.00318302 & -0.0319299 & 0.00660334 & 0.00870657 & 0.0306246 & -0.00391502 & -0.0416326 & 1 & 2 & 1 & 1 & 0 \\\\\n",
"\t2 & 0.0653441 & -0.0255881 & -0.0114106 & 0.0108165 & -0.0316851 & -0.00136565 & -0.00199008 & 0.0298564 & 0.00248509 & -0.0377283 & 1 & 2 & 2 & 1 & 0 \\\\\n",
"\t3 & 0.036182 & -0.0138953 & 0.00824023 & -0.00308798 & -0.0130371 & 0.0059515 & 0.00703448 & 0.0184967 & -0.00519035 & -0.0281207 & 1 & 2 & 3 & 1 & 1 \\\\\n",
"\t4 & 0.0292288 & -0.0099957 & 0.00241025 & 0.00671407 & -0.00923298 & 0.00159172 & 0.00227589 & 0.0133913 & -0.00464004 & -0.0223713 & 1 & 2 & 4 & 1 & 1 \\\\\n",
"\t5 & 0.0139868 & -0.00648846 & -0.00108487 & 0.00196581 & -0.00564524 & 0.00110447 & 0.000616112 & 0.00652828 & -0.00118362 & -0.0109085 & 1 & 2 & 5 & 1 & 1 \\\\\n",
"\t6 & 0.00515707 & -0.00153295 & -9.54022e-5 & -0.00106444 & -0.00479411 & 0.00258529 & 0.00279811 & 0.00244054 & -0.000862602 & -0.00403407 & 1 & 2 & 6 & 1 & 2 \\\\\n",
"\t7 & 0.00515707 & -0.00153295 & -9.54022e-5 & -0.00106444 & -0.00479411 & 0.00258529 & 0.00279811 & 0.00244054 & -0.000862602 & -0.00403407 & 1 & 2 & 7 & 1 & 2 \\\\\n",
"\t8 & 0.00421073 & -0.00125165 & -7.78955e-5 & -0.000869113 & -0.00391437 & 0.00211088 & 0.00228465 & 0.00199269 & -0.000704312 & -0.0032938 & 1 & 2 & 8 & 1 & 2 \\\\\n",
"\t9 & 0.0807522 & -0.0227874 & 0.0254424 & -0.0167195 & -0.0220161 & 0.0101796 & 0.0140722 & 0.0501589 & -0.00739399 & -0.0722438 & 1 & 3 & 1 & 1 & 0 \\\\\n",
"\t10 & 0.0340761 & -0.0103437 & -0.00257617 & 0.00722817 & -0.0177132 & 0.00453164 & 0.00870965 & 0.00973111 & -0.00574675 & -0.0162742 & 1 & 3 & 2 & 1 & 1 \\\\\n",
"\t11 & 0.0940134 & -0.0318718 & 0.0344218 & -0.0327469 & -0.00542446 & 0.00659211 & 0.00370851 & 0.0694634 & -0.00152773 & -0.0996689 & 1 & 3 & 3 & 1 & 0 \\\\\n",
"\t12 & 0.0324342 & -0.00973636 & 0.00403564 & 0.00701846 & -0.0169275 & 0.00535893 & 0.00938935 & 0.0100804 & -0.00863605 & -0.0212517 & 1 & 3 & 4 & 1 & 1 \\\\\n",
"\t13 & 0.0163616 & -0.00657916 & -0.0015163 & 0.00433485 & -0.00743952 & 0.00136546 & 0.003353 & 0.00500078 & -0.00370737 & -0.00774432 & 1 & 3 & 5 & 1 & 1 \\\\\n",
"\t14 & 0.00629887 & -0.00162777 & -0.000149417 & 0.00192418 & -0.0060904 & 0.00131024 & 0.00368191 & 0.001949 & -0.00250137 & -0.000507388 & 1 & 3 & 6 & 1 & 2 \\\\\n",
"\t15 & 0.00629887 & -0.00162777 & -0.000149417 & 0.00192418 & -0.0060904 & 0.00131024 & 0.00368191 & 0.001949 & -0.00250137 & -0.000507388 & 1 & 3 & 7 & 1 & 2 \\\\\n",
"\t16 & 0.00514301 & -0.00132907 & -0.000121998 & 0.00157108 & -0.00497279 & 0.0010698 & 0.00300627 & 0.00159136 & -0.00204236 & -0.000414281 & 1 & 3 & 8 & 1 & 2 \\\\\n",
"\t17 & 0.0617292 & -0.0381997 & 0.0128525 & 0.015317 & -0.00830853 & 7.7243e-5 & -0.0030023 & 0.0352501 & 0.0032523 & -0.0626972 & 1 & 4 & 1 & 1 & 0 \\\\\n",
"\t18 & 0.0315697 & -0.0214036 & 0.000929661 & 0.0126375 & -0.0146739 & 0.00134248 & 0.00210057 & 0.0117977 & -0.00231886 & -0.0223714 & 1 & 4 & 2 & 1 & 1 \\\\\n",
"\t19 & 0.0374139 & -0.0274874 & 0.014762 & 0.000410167 & -0.00619973 & 0.00214736 & -1.34762e-6 & 0.0223506 & 3.35925e-5 & -0.0405679 & 1 & 4 & 3 & 1 & 1 \\\\\n",
"\t20 & 0.0606246 & -0.0319523 & 0.00271798 & 0.0253772 & -0.00479889 & -0.00373943 & -0.0057935 & 0.0348819 & 0.00547458 & -0.0587476 & 1 & 4 & 4 & 1 & 0 \\\\\n",
"\t21 & 0.0144691 & -0.0119399 & 0.000725399 & 0.00435129 & -0.00615814 & 0.00146693 & 0.000685575 & 0.00692591 & 0.000697119 & -0.0146377 & 1 & 4 & 5 & 1 & 1 \\\\\n",
"\t22 & 0.00458313 & -0.005771 & 0.00140194 & -0.000148894 & -0.00231272 & 0.00167019 & 0.000371735 & 0.00286897 & 0.00138584 & -0.00791148 & 1 & 4 & 6 & 1 & 2 \\\\\n",
"\t23 & 0.00458313 & -0.005771 & 0.00140194 & -0.000148894 & -0.00231272 & 0.00167019 & 0.000371735 & 0.00286897 & 0.00138584 & -0.00791148 & 1 & 4 & 7 & 1 & 2 \\\\\n",
"\t24 & 0.00374211 & -0.00471201 & 0.00114468 & -0.000121571 & -0.00188833 & 0.00136371 & 0.000303521 & 0.0023425 & 0.00113153 & -0.0064597 & 1 & 4 & 8 & 1 & 2 \\\\\n",
"\t25 & 0.0135714 & 0.000587793 & -0.000455864 & -0.00280042 & -0.0118951 & 0.00429923 & 0.00695602 & 0.0092625 & 1.17823e-5 & -0.00800769 & 1 & 5 & 1 & 1 & 0 \\\\\n",
"\t26 & 0.00842064 & 0.000716397 & -0.0014722 & -0.000474771 & -0.00648818 & 0.0022389 & 0.00288162 & 0.00454633 & 0.000500505 & -0.00369009 & 1 & 5 & 2 & 1 & 1 \\\\\n",
"\t27 & 0.00979808 & -1.31039e-5 & 0.000951256 & -0.00217937 & -0.00753009 & 0.00323605 & 0.00495289 & 0.0065921 & -0.00113181 & -0.00520114 & 1 & 5 & 3 & 1 & 1 \\\\\n",
"\t28 & 0.00692593 & -0.000400284 & 0.000315747 & 0.000504271 & -0.00339045 & 0.00123777 & 0.00237482 & 0.00346345 & -0.000968839 & -0.00436599 & 1 & 5 & 4 & 1 & 1 \\\\\n",
"\t29 & 0.00958598 & -0.00149535 & -0.0023221 & -0.00599477 & -0.00405036 & -0.00198833 & -0.000829992 & 0.0112333 & 0.0046394 & -0.0148051 & 1 & 5 & 5 & 1 & 0 \\\\\n",
"\t30 & 0.0060058 & 0.00156492 & -0.00212839 & -0.00666747 & -0.00397689 & -4.11231e-5 & 0.00101687 & 0.00900213 & 0.00315526 & -0.0106931 & 1 & 5 & 6 & 1 & 1 \\\\\n",
"\t$\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ \\\\\n",
"\\end{tabular}\n"
],
"text/plain": [
"276×15 DataFrame\n",
"│ Row │ grad1 │ grad2 │ grad3 │ grad4 │ grad5 │ grad6 │ grad7 │ grad8 │ grad9 │ grad10 │ u │ v │ n │ dist_uv │ dist_uorv │\n",
"│ │ \u001b[90mFloat32\u001b[39m │ \u001b[90mFloat32\u001b[39m │ \u001b[90mFloat32\u001b[39m │ \u001b[90mFloat32\u001b[39m │ \u001b[90mFloat32\u001b[39m │ \u001b[90mFloat32\u001b[39m │ \u001b[90mFloat32\u001b[39m │ \u001b[90mFloat32\u001b[39m │ \u001b[90mFloat32\u001b[39m │ \u001b[90mFloat32\u001b[39m │ \u001b[90mInt64\u001b[39m │ \u001b[90mInt64\u001b[39m │ \u001b[90mInt64\u001b[39m │ \u001b[90mInt64\u001b[39m │ \u001b[90mInt64\u001b[39m │\n",
"├─────┼────────────┼─────────────┼─────────────┼──────────────┼─────────────┼─────────────┼─────────────┼─────────────┼──────────────┼─────────────┼───────┼───────┼───────┼─────────┼───────────┤\n",
"│ 1 │ 0.0645079 │ -0.021225 │ 0.00150802 │ 0.00318302 │ -0.0319299 │ 0.00660334 │ 0.00870657 │ 0.0306246 │ -0.00391502 │ -0.0416326 │ 1 │ 2 │ 1 │ 1 │ 0 │\n",
"│ 2 │ 0.0653441 │ -0.0255881 │ -0.0114106 │ 0.0108165 │ -0.0316851 │ -0.00136565 │ -0.00199008 │ 0.0298564 │ 0.00248509 │ -0.0377283 │ 1 │ 2 │ 2 │ 1 │ 0 │\n",
"│ 3 │ 0.036182 │ -0.0138953 │ 0.00824023 │ -0.00308798 │ -0.0130371 │ 0.0059515 │ 0.00703448 │ 0.0184967 │ -0.00519035 │ -0.0281207 │ 1 │ 2 │ 3 │ 1 │ 1 │\n",
"│ 4 │ 0.0292288 │ -0.0099957 │ 0.00241025 │ 0.00671407 │ -0.00923298 │ 0.00159172 │ 0.00227589 │ 0.0133913 │ -0.00464004 │ -0.0223713 │ 1 │ 2 │ 4 │ 1 │ 1 │\n",
"│ 5 │ 0.0139868 │ -0.00648846 │ -0.00108487 │ 0.00196581 │ -0.00564524 │ 0.00110447 │ 0.000616112 │ 0.00652828 │ -0.00118362 │ -0.0109085 │ 1 │ 2 │ 5 │ 1 │ 1 │\n",
"│ 6 │ 0.00515707 │ -0.00153295 │ -9.54022e-5 │ -0.00106444 │ -0.00479411 │ 0.00258529 │ 0.00279811 │ 0.00244054 │ -0.000862602 │ -0.00403407 │ 1 │ 2 │ 6 │ 1 │ 2 │\n",
"│ 7 │ 0.00515707 │ -0.00153295 │ -9.54022e-5 │ -0.00106444 │ -0.00479411 │ 0.00258529 │ 0.00279811 │ 0.00244054 │ -0.000862602 │ -0.00403407 │ 1 │ 2 │ 7 │ 1 │ 2 │\n",
"│ 8 │ 0.00421073 │ -0.00125165 │ -7.78955e-5 │ -0.000869113 │ -0.00391437 │ 0.00211088 │ 0.00228465 │ 0.00199269 │ -0.000704312 │ -0.0032938 │ 1 │ 2 │ 8 │ 1 │ 2 │\n",
"│ 9 │ 0.0807522 │ -0.0227874 │ 0.0254424 │ -0.0167195 │ -0.0220161 │ 0.0101796 │ 0.0140722 │ 0.0501589 │ -0.00739399 │ -0.0722438 │ 1 │ 3 │ 1 │ 1 │ 0 │\n",
"│ 10 │ 0.0340761 │ -0.0103437 │ -0.00257617 │ 0.00722817 │ -0.0177132 │ 0.00453164 │ 0.00870965 │ 0.00973111 │ -0.00574675 │ -0.0162742 │ 1 │ 3 │ 2 │ 1 │ 1 │\n",
"⋮\n",
"│ 266 │ 0.0412439 │ 0.0124566 │ -0.0120742 │ 0.00936623 │ -0.0590286 │ 0.012509 │ 0.0284646 │ 0.0304896 │ -0.00870682 │ 0.00460341 │ 7 │ 9 │ 5 │ 3 │ 1 │\n",
"│ 267 │ 0.00561691 │ 0.00853805 │ -0.00792131 │ 0.00419491 │ -0.0152246 │ -0.00259941 │ 0.00539298 │ 0.00800614 │ -0.003767 │ 0.0103741 │ 7 │ 9 │ 1 │ 3 │ 2 │\n",
"│ 268 │ 0.00888112 │ 0.0134998 │ -0.0125247 │ 0.00663273 │ -0.0240722 │ -0.00411003 │ 0.00852706 │ 0.0126588 │ -0.00595615 │ 0.0164029 │ 7 │ 9 │ 6 │ 3 │ 2 │\n",
"│ 269 │ 0.10261 │ -0.0519794 │ 0.0833485 │ -0.0359104 │ 0.0279448 │ 0.00737936 │ -0.00128965 │ -0.00660026 │ 0.0381341 │ -0.14625 │ 7 │ 9 │ 8 │ 3 │ 1 │\n",
"│ 270 │ 0.11679 │ -0.0771613 │ 0.114605 │ -0.0506138 │ 0.0582975 │ 0.0131479 │ -0.0101065 │ -0.0207424 │ 0.0526607 │ -0.195522 │ 7 │ 9 │ 9 │ 3 │ 0 │\n",
"│ 271 │ 0.0911527 │ -0.0259464 │ 0.0306541 │ -0.0086948 │ -0.140151 │ 0.0501676 │ 0.0964996 │ 0.026796 │ 0.015733 │ -0.0377764 │ 8 │ 9 │ 8 │ 1 │ 0 │\n",
"│ 272 │ 0.0248515 │ -0.00836338 │ 0.00185364 │ -0.00484979 │ -0.054738 │ 0.0192284 │ 0.0367432 │ 0.00956858 │ 0.00852271 │ -0.00498861 │ 8 │ 9 │ 5 │ 1 │ 1 │\n",
"│ 273 │ 0.105916 │ -0.0374046 │ 0.0401043 │ -0.0102666 │ -0.154377 │ 0.0593004 │ 0.109111 │ 0.0258074 │ 0.018702 │ -0.0492926 │ 8 │ 9 │ 9 │ 1 │ 0 │\n",
"│ 274 │ 0.0036196 │ 0.00355871 │ -0.00161957 │ -0.000241777 │ -0.0109235 │ 0.00135475 │ 0.00574067 │ 0.0044341 │ 0.000358557 │ 0.00191391 │ 8 │ 9 │ 1 │ 1 │ 2 │\n",
"│ 275 │ 0.0057231 │ 0.00562682 │ -0.00256077 │ -0.000382283 │ -0.0172716 │ 0.00214204 │ 0.00907679 │ 0.00701092 │ 0.000566929 │ 0.00302616 │ 8 │ 9 │ 6 │ 1 │ 2 │\n",
"│ 276 │ 0.0057231 │ 0.00562682 │ -0.00256077 │ -0.000382283 │ -0.0172716 │ 0.00214204 │ 0.00907679 │ 0.00701092 │ 0.000566929 │ 0.00302616 │ 8 │ 9 │ 7 │ 1 │ 2 │"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ENV[\"COLUMNS\"] = 300\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**TODO**\n",
"\n",
"Plot facets of that data."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### w.r.t. embeddings"
]
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Float32[0.97156286 0.88115096 … 0.8797154 0.6768692; 0.9004984 0.78653204 … 0.73746455 0.57887185; … ; 0.9144334 0.7934744 … 0.9269219 0.8289111; 0.7496242 0.6469048 … 0.8417549 0.76371896] (tracked), getfield(Tracker, Symbol(\"##21#23\")){getfield(Tracker, Symbol(\"##18#19\")){Tracker.Params,TrackedArray{…,Array{Float32,2}}}}(Core.Box(([1.1418794037884052 0.8853957612198733 … 0.4948227551326396 -0.021783616590958732; 0.9882851674121734 0.7086697417288617 … 0.2995249605165839 -0.14853460307849584; … ; 1.6789900392252963 0.8560976209638216 … -0.26693793304649993 -0.6054415825060152; -0.191996930972175 -0.3271772580544453 … -0.5457029432684755 -0.8966404562162545] (tracked),)), getfield(Tracker, Symbol(\"##18#19\")){Tracker.Params,TrackedArray{…,Array{Float32,2}}}(Params([[1.1418794037884052 0.8853957612198733 … 0.4948227551326396 -0.021783616590958732; 0.9882851674121734 0.7086697417288617 … 0.2995249605165839 -0.14853460307849584; … ; 1.6789900392252963 0.8560976209638216 … -0.26693793304649993 -0.6054415825060152; -0.191996930972175 -0.3271772580544453 … -0.5457029432684755 -0.8966404562162545] (tracked)]), Float32[0.97156286 0.88115096 … 0.8797154 0.6768692; 0.9004984 0.78653204 … 0.73746455 0.57887185; … ; 0.9144334 0.7934744 … 0.9269219 0.8289111; 0.7496242 0.6469048 … 0.8417549 0.76371896] (tracked))))"
]
},
"execution_count": 111,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Apred_emb, back_emb = Tracker.forward(x -> σ.(dec(x)[1]), enc(features)[1])"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tracked 24×9 Array{Float64,2}:\n",
" 0.0506319 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -0.10493 \n",
" -0.0837624 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.234714 \n",
" 0.0126053 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -0.028999 \n",
" -0.0347072 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.223553 \n",
" 0.111203 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -0.0901013 \n",
" 0.0937184 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -0.0224072 \n",
" -0.024323 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -0.119372 \n",
" 0.0506258 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -0.102253 \n",
" -0.00647627 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -0.0671987 \n",
" 0.0500753 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -0.00589014\n",
" -0.00719932 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -0.0790304 \n",
" 0.00181941 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -0.0755756 \n",
" -0.0445624 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0485049 \n",
" 0.00678587 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.15691 \n",
" 0.0739295 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0631873 \n",
" -0.0275066 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0550068 \n",
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
]
},
"execution_count": 113,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"back_emb(onehot(CartesianIndex(1, 9), (9, 9)))[1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**TODO**\n",
"\n",
"Precompute all mutual distances.\n",
"\n",
"Then:\n",
"- for each couple of nodes u,v (linked or not):\n",
" - get the gradients w.r.t. embeddings\n",
" - for both u and v, store the gradient in pandas data along with distance between u and v\n",
"\n",
"Then look at the average gradient, faceted by distance between u and v"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Saving the plots"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"embμ, emblogσ = enc(features)\n",
"logitÂ, unormF̂ = vae(features);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Adjacency reconstruction\n",
"scene = Scene(resolution = (15000, 15000))\n",
"heatmap!(scene, σ.(logitÂ).data, colorrange = (0, 1))\n",
"Makie.save(\"Apred.png\", scene);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Adjacency reference\n",
"scene = Scene(resolution = (15000, 15000))\n",
"heatmap!(scene, Array(adjacency_matrix(g)), colorrange = (0, 1))\n",
"Makie.save(\"Aref.png\", scene);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Feature reconstruction\n",
"scene = Scene(resolution = (150, 15000))\n",
"heatmap!(scene, 1:size(_features, 1), 1:size(_features, 2), softmax(unormF̂).data, colorrange = (0, 1))\n",
"Makie.save(\"Fpred.png\", scene);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Feature reference\n",
"scene = Scene(resolution = (150, 15000))\n",
"heatmap(1:size(_features, 1), 1:size(_features, 2), _features, colorrange = (0, 1))\n",
"Makie.save(\"Fref.png\", scene);"
]
}
],
"metadata": {
"@webio": {
"lastCommId": "5ba504aae1ef4fd98baab9111205e028",
"lastKernelId": "a7b6f62a-f991-4853-8a1a-62f6b154f88e"
},
"kernelspec": {
"display_name": "Julia 1.2.0",
"language": "julia",
"name": "julia-1.2"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.2.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}