julia/layers.jl
module Layers
include("utils.jl")
using Flux, LightGraphs, LinearAlgebra, Memoize, .Utils
#
# Helpers
#
"""Helper to have no bias in Dense and GC layers."""
nobias(out::Integer) = fill(nothing, out)
Flux.param(n::AbstractArray{Nothing}) = fill(0f0, size(n))
#
# Helper layers
#
struct VOverlap{I<:Integer,F}
overlap::I
reducer::F
end
Flux.@treelike VOverlap
VOverlap(overlap) = VOverlap(overlap, mean)
function Base.show(io::IO, o::VOverlap)
print(io, "VOverlap(", o.overlap)
o.reducer == mean || print(io, ", ", o.reducer)
print(io, ")")
end
@views function (o::VOverlap)(x1, x2)
vcat(
x1[1:end-o.overlap, :],
o.reducer(x1[end-o.overlap+1:end, :], x2[1:o.overlap, :]),
x2[1+o.overlap:end, :]
)
end
struct Apply{V,T<:NTuple}
f::V
args::T
Apply(f, args...) = new{typeof(f), typeof(args)}(f, args)
end
Flux.@treelike Apply
children(a::Apply) = (a.f, a.args...)
mapchildren(f, a::Apply) = Apply(f(a.f), f.(a.args)...)
function Base.show(io::IO, a::Apply)
print(io, "Apply(", a.f, ", ")
join(io, a.args, ", ")
print(io, ")")
end
function (a::Apply)(x)
f, args = a.f, a.args
f(map(l -> l(x), args)...)
end
#
# Graph-convolutional layer
#
@memoize function adjacency_matrix_diag_norm(g::SimpleGraph)
Adiag = adjacency_matrix_diag(g)
sumin = sum(Adiag, dims = 1)
sumout = sum(Adiag, dims = 2)
Anorm = @. (Adiag / sqrt(sumout)) / sqrt(sumin)
Array{Float32}(Anorm)
end
struct GC{S<:AbstractArray,T,U,F}
Anorm::S
W::T
b::U
σ::F
function GC(g::SimpleGraph, W::T, b::U, σ::F) where {T,U,F}
Anorm = adjacency_matrix_diag_norm(g)
new{typeof(Anorm),T,U,F}(Anorm, W, b, σ)
end
end
GC(g, W, b) = GC(g, W, b, identity)
Flux.@treelike GC
function GC(g::SimpleGraph, in::Integer, out::Integer, σ = identity;
initW = Flux.glorot_uniform, initb = zeros)
return GC(g, param(initW(out, in)), param(initb(out)), σ)
end
function Base.show(io::IO, l::GC)
print(io, "GC(g ~ ", size(l.Anorm, 1), ", W ~ ", (size(l.W, 2), size(l.W, 1)), ", b ~ ")
isa(l.b, TrackedArray) ? print(io, size(l.b, 1)) : print(io, "nothing")
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
function (a::GC)(x::AbstractArray)
W, Anorm, b, σ = a.W, a.Anorm, a.b, a.σ
σ.(W * x * Anorm .+ b)
end
#
# Bilinear layer
#
struct Bilin{T,F}
W::T
σ::F
end
Flux.@treelike Bilin
Bilin(σ = identity) = Bilin(I, σ)
function Bilin(in::Integer, σ = identity; initW = Flux.glorot_uniform)
return Bilin(param(initW(in, in)), σ)
end
function (a::Bilin)(x::AbstractArray)
W, σ = a.W, a.σ
σ.(transpose(x) * W * x)
end
end