session_manager.go
package nimona
import (
"context"
"fmt"
"time"
"github.com/hashicorp/golang-lru/v2/simplelru"
"nimona.io/internal/xsync"
)
// SessionManager manages the dialing and accepting of connections.
// It maintains a cache of the last 100 connections.
type SessionManager struct {
connCache *simplelru.LRU[connCacheKey, *Session]
dialer Dialer
listener Listener
handlers map[string]RequestHandlerFunc
publicKey PublicKey
privateKey PrivateKey
resolver Resolver
aliases xsync.Map[IdentityAlias, *IdentityInfo]
providers xsync.Map[IdentityAlias, *IdentityInfo]
identities xsync.Map[KeygraphID, *IdentityInfo]
}
type RequestHandlerFunc func(context.Context, *Request) error
type connCacheKey struct {
publicKeyInHex string
}
func NewSessionManager(
dialer Dialer,
listener Listener,
publicKey PublicKey,
privateKey PrivateKey,
) (*SessionManager, error) {
connCache, err := simplelru.NewLRU(100, func(_ connCacheKey, ses *Session) {
err := ses.Close()
if err != nil {
// TODO: log error
fmt.Println("error closing connection on eviction:", err)
return
}
})
if err != nil {
return nil, fmt.Errorf("error creating connection cache: %w", err)
}
c := &SessionManager{
connCache: connCache,
dialer: dialer,
listener: listener,
publicKey: publicKey,
privateKey: privateKey,
handlers: map[string]RequestHandlerFunc{},
resolver: &ResolverHTTP{},
}
if listener != nil {
go func() {
// nolint:errcheck // TODO: handle error
c.handleConnections()
}()
}
return c, nil
}
// Dial dials the given address and returns a connection if successful. If the
// address is already in the cache, the cached connection is returned.
func (cm *SessionManager) Dial(
ctx context.Context,
addr PeerAddr,
) (*Session, error) {
// check the cache
existingConn, ok := cm.connCache.Get(cm.connCacheKey(addr.PublicKey))
if ok {
return existingConn, nil
}
// dial the address if it is not in the cache.
conn, err := cm.dialer.Dial(ctx, addr)
if err != nil {
return nil, fmt.Errorf("error dialing %s: %w", addr, err)
}
// wrap the connection in a chunked connection
ses := NewSession(conn, &addr)
err = ses.DoServer(cm.publicKey, cm.privateKey)
if err != nil {
return nil, fmt.Errorf("error performing handshake: %w", err)
}
// if we have been given a public key for the remote peer, check it
if addr.PublicKey != nil {
if !ses.PublicKey().Equal(addr.PublicKey) {
return nil, fmt.Errorf("public key mismatch")
}
}
// start handling messages
go func() {
cm.handleSession(ses)
}()
// add ses to cache
cm.connCache.Add(cm.connCacheKey(ses.PublicKey()), ses)
return ses, nil
}
func (cm *SessionManager) connCacheKey(k PublicKey) connCacheKey {
return connCacheKey{
publicKeyInHex: fmt.Sprintf("%x", k),
}
}
type RequestRecipientFn func(*requestRecipient)
type requestRecipient struct {
Alias *IdentityAlias
KeygraphID KeygraphID
PeerAddr *PeerAddr
}
func FromAlias(alias IdentityAlias) RequestRecipientFn {
return func(r *requestRecipient) {
r.Alias = &alias
}
}
func FromIdentity(id KeygraphID) RequestRecipientFn {
return func(r *requestRecipient) {
r.KeygraphID = id
}
}
func FromPeerAddr(peerAddr PeerAddr) RequestRecipientFn {
return func(r *requestRecipient) {
r.PeerAddr = &peerAddr
}
}
func (cm *SessionManager) Request(
ctx context.Context,
req *Document,
rfn RequestRecipientFn,
) (*Response, error) {
rec := &requestRecipient{}
rfn(rec)
switch {
case rec.Alias != nil:
return cm.requestFromAlias(ctx, *rec.Alias, req)
case !rec.KeygraphID.IsEmpty():
return cm.requestFromIdentity(ctx, rec.KeygraphID, req)
case rec.PeerAddr != nil:
return cm.requestFromPeerAddr(ctx, *rec.PeerAddr, req)
default:
return nil, fmt.Errorf("no recipient specified")
}
}
func (cm *SessionManager) requestFromAlias(
ctx context.Context,
alias IdentityAlias,
req *Document,
) (*Response, error) {
// resolve the alias
info, err := cm.LookupAlias(alias)
if err != nil {
return nil, fmt.Errorf("error looking up alias %s: %w", alias, err)
}
return cm.requestFromIdentity(ctx, info.KeygraphID, req)
}
func (cm *SessionManager) requestFromIdentity(
ctx context.Context,
id KeygraphID,
req *Document,
) (*Response, error) {
// resolve the identity
info, err := cm.LookupIdentity(id)
if err != nil {
return nil, fmt.Errorf("error looking up identity %s: %w", id, err)
}
return cm.requestFromPeerAddr(ctx, info.PeerAddresses[0], req)
}
func (cm *SessionManager) requestFromPeerAddr(
ctx context.Context,
addr PeerAddr,
req *Document,
) (*Response, error) {
ses, err := cm.Dial(ctx, addr)
if err != nil {
return nil, fmt.Errorf("error dialing %s: %w", addr, err)
}
return ses.Request(ctx, req)
}
func (cm *SessionManager) handleConnections() error {
errCh := make(chan error)
// accept inbound connections.
// if a connection with the same remote address already exists in the cache,
// it is closed and removed before the new connection is added.
go func() {
for {
conn, err := cm.listener.Accept()
if err != nil {
errCh <- fmt.Errorf("error accepting connection: %w", err)
return
}
// start a new session, and perform the server side of the handshake
// this will also perform the key exchange so after this we should
// know the public key of the remote peer
ses := NewSession(conn, nil)
err = ses.DoServer(cm.publicKey, cm.privateKey)
if err != nil {
// TODO: log error
continue
}
// check if a connection with the same remote address already exists
// in the cache.
connCacheKey := cm.connCacheKey(ses.PublicKey())
_, connectionExists := cm.connCache.Get(connCacheKey)
if connectionExists {
// remove the existing connection from the cache; this will
// trigger the eviction callback which will close the connection
cm.connCache.Remove(connCacheKey)
}
// start handling messages
go func() {
cm.handleSession(ses)
}()
// add ses to cache
cm.connCache.Add(connCacheKey, ses)
}
}()
return <-errCh
}
func (cm *SessionManager) handleSession(ses *Session) {
for {
req, err := ses.Read()
if err != nil {
// TODO log error
fmt.Println("error reading message:", err)
ses.Close() // TODO handle error
return
}
// get the handler for the message type
handler, ok := cm.handlers[req.Type]
if !ok {
// TODO log error
fmt.Println("no handler for message type:", req.Type)
continue
}
// handle the message
err = handler(context.Background(), req)
if err != nil {
// TODO log error
fmt.Println("error handling message:", err)
continue
}
}
}
func (cm *SessionManager) RegisterHandler(
msgType string,
handler RequestHandlerFunc,
) {
cm.handlers[msgType] = handler
}
func (cm *SessionManager) PeerAddr() PeerAddr {
return PeerAddr{
Transport: cm.listener.PeerAddr().Transport,
Address: cm.listener.PeerAddr().Address,
PublicKey: cm.publicKey,
}
}
// Close closes all connections in the connection cache.
func (cm *SessionManager) Close() error {
// purge will close all connections in the cache
cm.connCache.Purge()
return nil
}
func (cm *SessionManager) LookupAlias(alias IdentityAlias) (*IdentityInfo, error) {
if info, ok := cm.aliases.Load(alias); ok {
return info, nil
}
identityInfo, err := cm.resolver.ResolveIdentityAlias(alias)
if err != nil {
return nil, fmt.Errorf("unable to resolve provider alias: %w", err)
}
cm.aliases.Store(alias, identityInfo)
cm.identities.Store(identityInfo.KeygraphID, identityInfo)
// TODO add recursive lookup for user identities
// TODO(geoah): fix "use"
// if identityInfo.KeygraphID.Use == "provider" {
// cm.providers.Store(alias, identityInfo)
// }
return identityInfo, nil
}
func (cm *SessionManager) LookupIdentity(id KeygraphID) (*IdentityInfo, error) {
if info, ok := cm.identities.Load(id); ok {
return info, nil
}
var identityInfo *IdentityInfo
cm.providers.Range(func(key IdentityAlias, providerInfo *IdentityInfo) bool {
ctx, cf := context.WithTimeout(context.Background(), time.Second)
defer cf()
for _, addr := range providerInfo.PeerAddresses {
rctx := &RequestContext{}
doc, err := RequestDocument(ctx, rctx, cm, id.DocumentID(), FromPeerAddr(addr))
if err != nil {
continue
}
err = identityInfo.FromDocument(doc)
if err != nil {
continue
}
return false
}
return true
})
if identityInfo == nil {
return nil, fmt.Errorf("unable to resolve identity %s", id)
}
cm.identities.Store(id, identityInfo)
// TODO verify the alias is indeed correct before storing or returning it
if identityInfo.Alias.Hostname != "" {
cm.aliases.Store(identityInfo.Alias, identityInfo)
}
// TODO(geoah): fix "use"
// if identityInfo.KeygraphID.Use == "provider" {
// cm.providers.Store(identityInfo.Alias, identityInfo)
// }
return identityInfo, nil
}