trema/trema-edge

View on GitHub
src/lib/secure_channel.c

Summary

Maintainability
Test Coverage
/*
 * Author: Kazushi SUGYO, Yasunobu Chiba
 *
 * Copyright (C) 2008-2013 NEC Corporation
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License, version 2, as
 * published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along
 * with this program; if not, write to the Free Software Foundation, Inc.,
 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 */


#include <arpa/inet.h>
#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include "buffer.h"
#include "checks.h"
#include "log.h"
#include "message_queue.h"
#include "openflow.h"
#include "openflow_switch_interface.h"
#include "safe_event_handler.h"
#include "safe_timer.h"
#include "secure_channel.h"
#include "wrapper.h"


enum connection_state {
  INIT,
  CONNECTING,
  CONNECTED,
  DISCONNECTED,
};

static char connection_state_string[][ 13 ] = { "INIT", "CONNECTING", "CONNECTED", "DISCONNECTED" };


typedef struct {
  uint32_t ip;
  uint16_t port;
  int fd;
  int state;
  connected_handler connected_callback;
  disconnected_handler disconnected_callback;
} secure_channel_connection;


static secure_channel_connection connection = { 0, 0, -1, INIT, NULL, NULL };
static bool secure_channel_initialized = false;

static message_queue *send_queue = NULL;
static message_queue *recv_queue = NULL;

static const size_t RECEIVE_BUFFER_SIZE = UINT16_MAX + sizeof( struct ofp_packet_in ) - 2;
static buffer *fragment_buf = NULL;


static void
transit_state( int state ) {
  switch ( connection.state ) {
  case INIT:
  {
    if ( state != CONNECTING ) {
      goto invalid_transition;
    }
    connection.state = state;
  }
  break;

  case CONNECTING:
  {
    if ( state != CONNECTED && state != CONNECTING ) {
      goto invalid_transition;
    }
  }
  break;

  case CONNECTED:
  {
    if ( state != DISCONNECTED && state != INIT ) {
      goto invalid_transition;
    }
  }
  break;

  case DISCONNECTED:
  {
    if ( state != CONNECTED ) {
      goto invalid_transition;
    }
  }
  break;

  default:
  {
    error( "Invalid state ( %d ).", connection.state );
    return;
  }
  break;
  }

  debug( "State transition: %s -> %s.",
         connection_state_string[ connection.state ],
         connection_state_string[ state ] );

  connection.state = state;

  return;

invalid_transition:
  error( "Invalid state transition ( %d -> %d ).", connection.state, state );
}


static void
clear_connection() {
  if ( connection.fd >= 0 ) {
    close( connection.fd );
    set_readable_safe( connection.fd, false );
    set_writable_safe( connection.fd, false );
    delete_fd_handler_safe( connection.fd );
  }

  connection.fd = -1;
  connection.state = INIT;
}


static void reconnect( void *user_data );


static void
disconnected() {
  transit_state( DISCONNECTED );

  if ( connection.disconnected_callback != NULL ) {
    connection.disconnected_callback();
  }

  clear_connection();
}


static bool
recv_message_from_secure_channel() {
  assert( recv_queue != NULL );

  if ( recv_queue->length == 0 ) {
    return false;
  }

  buffer *message = dequeue_message( recv_queue );
  handle_secure_channel_message( message ); // FIXME: handle error properly
  free_buffer( message );

  return true;
}


static void
recv_from_secure_channel( int fd, void *user_data ) {
  UNUSED( fd );
  UNUSED( user_data );

  // all queued messages should be processed before receiving new messages from remote
  if ( recv_queue->length > 0 ) {
    return;
  }

  if ( fragment_buf == NULL ) {
    fragment_buf = alloc_buffer_with_length( RECEIVE_BUFFER_SIZE );
  }

  size_t remaining_length = RECEIVE_BUFFER_SIZE - fragment_buf->length;
  char *recv_buf = ( char * ) fragment_buf->data + fragment_buf->length;
  ssize_t recv_length = read( connection.fd, recv_buf, remaining_length );
  if ( recv_length < 0 ) {
    if ( errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK ) {
      return;
    }
    error( "Receive error ( errno = %s [%d] ).", strerror( errno ), errno );
    return;
  }
  if ( recv_length == 0 ) {
    debug( "Connection closed by peer." );
    disconnected();
    reconnect( NULL );
    return;
  }
  fragment_buf->length += ( size_t ) recv_length;

  size_t read_total = 0;
  while ( fragment_buf->length >= sizeof( struct ofp_header ) ) {
    struct ofp_header *header = fragment_buf->data;
    uint16_t message_length = ntohs( header->length );
    if ( message_length > fragment_buf->length ) {
      break;
    }
    buffer *message = alloc_buffer_with_length( message_length );
    char *p = append_back_buffer( message, message_length );
    memcpy( p, fragment_buf->data, message_length );
    remove_front_buffer( fragment_buf, message_length );
    enqueue_message( recv_queue, message );
    read_total += message_length;
  }

  // remove headroom manually for next call
  if ( read_total > 0 ) {
    memmove( ( char * ) fragment_buf->data - read_total, fragment_buf->data, fragment_buf->length );
    fragment_buf->data = ( char * ) fragment_buf->data - read_total;
  }

  while ( recv_message_from_secure_channel() == true );
}


static void
flush_send_queue( int fd, void *user_data ) {
  UNUSED( fd );
  UNUSED( user_data );

  assert( send_queue != NULL );
  assert( connection.fd >= 0 );

  debug( "Flushing send queue ( length = %d ).", send_queue->length );

  set_writable_safe( connection.fd, false );

  buffer *buf = NULL;
  while ( ( buf = peek_message( send_queue ) ) != NULL ) {
    ssize_t write_length = write( connection.fd, buf->data, buf->length );
    if ( write_length < 0 ) {
      if ( errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK ) {
        set_writable_safe( connection.fd, true );
        return;
      }
      error( "Failed to send a message to secure channel ( errno = %s [%d] ).",
             strerror( errno ), errno );
      return;
    }
    if ( ( size_t ) write_length < buf->length ) {
      remove_front_buffer( buf, ( size_t ) write_length );
      set_writable_safe( connection.fd, true );
      return;
    }

    buf = dequeue_message( send_queue );
    free_buffer( buf );
  }
}


static void
connected() {
  transit_state( CONNECTED );

  set_fd_handler_safe( connection.fd, recv_from_secure_channel, NULL, flush_send_queue, NULL );
  set_readable_safe( connection.fd, true );
  set_writable_safe( connection.fd, false );

  if ( connection.connected_callback != NULL ) {
    connection.connected_callback();
  }
}


static bool try_connect( void );

static void
reconnect( void *user_data ) {
  UNUSED( user_data );
  bool ret = try_connect();
  if ( ret == false ) {
    error( "Failed to reconnect." );
    clear_connection();
  }
}


static void
backoff() {
  if ( connection.fd >= 0 ) {
    close( connection.fd );
  }

  struct itimerspec spec = { { 0, 0 }, { 5, 0 } };
  add_timer_event_callback_safe( &spec, reconnect, NULL );
}


static void
check_connected( void *user_data ) {
  UNUSED( user_data );

  debug( "Checking a connection ( fd = %d ip = %#x, port = %u ).", connection.fd, connection.ip, connection.port );

  // assert( secure_channel_initialized );
  assert( connection.fd >= 0 );

  set_writable_safe( connection.fd, false );
  delete_fd_handler_safe( connection.fd );

  int err = 0;
  socklen_t length = sizeof( error );
  int ret = getsockopt( connection.fd, SOL_SOCKET, SO_ERROR, &err, &length );
  if ( ret < 0 ) {
    error( "Failed to retrieve error code ( fd = %d, ret = %d, errno = %s [%d] ).",
           connection.fd, ret, strerror( errno ), errno );
    return;
  }

  switch( err ) {
    case 0:
      connected();
      break;

    case EINTR:
    case EAGAIN:
    case ECONNREFUSED:
    case ENETUNREACH:
    case ETIMEDOUT:
      warn( "Failed to connect ( fd = %d, errno = %s [%d] ).", connection.fd, strerror( err ), err );
      backoff();
      return;

    case EINPROGRESS:
      set_fd_handler_safe( connection.fd, NULL ,NULL, ( event_fd_callback ) check_connected, NULL );
      set_writable_safe( connection.fd, true );
      break;

    default:
      error( "Failed to connect ( fd = %d, errno = %s [%d] ).", connection.fd, strerror( err ), err );
      clear_connection();
      return;
  }
}


static bool
try_connect() {
  assert( connection.state != CONNECTED );

  int fd = socket( PF_INET, SOCK_STREAM, 0 );
  if ( fd < 0 ) {
    error( "Failed to create a socket ( ret = %d, errno = %s [%d] ).",
           fd, strerror( errno ), errno );
    return false;
  }

  int flag = 1;
  int ret = setsockopt( fd, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof( flag ) );
  if ( ret < 0 ) {
    error( "Failed to set socket options ( fd = %d, ret = %d, errno = %s [%d] ).",
           fd, ret, strerror( errno ), errno );
    return false;
  }

  ret = fcntl( fd, F_SETFL, O_NONBLOCK );
  if ( ret < 0 ) {
    error( "Failed to enable non-blocking mode ( fd = %d, ret = %d, errno = %s [%d] ).",
           fd, ret, strerror( errno ), errno );

    close( fd );
    return false;
  }

  connection.fd = fd;

  struct sockaddr_in addr;
  memset( &addr, 0, sizeof( struct sockaddr_in ) );
  addr.sin_family = AF_INET;
  addr.sin_port = htons( connection.port );
  addr.sin_addr.s_addr = htonl( connection.ip );

  transit_state( CONNECTING );

  ret = connect( connection.fd, ( struct sockaddr * ) &addr, sizeof( struct sockaddr_in ) );
  if ( ret < 0 ) {
    switch( errno ) {
      case EINTR:
      case EAGAIN:
      case ECONNREFUSED:
      case ENETUNREACH:
      case ETIMEDOUT:
        warn( "Failed to connect ( fd = %d, ret = %d, errno = %s [%d] ).",
              connection.fd, ret, strerror( errno ), errno );
        backoff();
        return true;

      case EINPROGRESS:
        break;

      default:
        error( "Failed to connect ( fd = %d, ret = %d, errno = %s [%d] ).",
               connection.fd, ret, strerror( errno ), errno );
        clear_connection();
        return false;
    }
  }

  set_fd_handler_safe( connection.fd, NULL, NULL, ( event_fd_callback ) check_connected, NULL );
  set_writable_safe( connection.fd, true );

  return true;
}


bool
init_secure_channel( uint32_t ip, uint16_t port, connected_handler connected_callback, disconnected_handler disconnected_callback ) {
  assert( !secure_channel_initialized );

  connection.ip = ip;
  connection.port = port;
  connection.fd = -1;
  connection.connected_callback = connected_callback;
  connection.disconnected_callback = disconnected_callback;

  bool ret = try_connect();
  if ( ret == false ) {
    clear_connection();
    return false;
  }

  send_queue = create_message_queue();
  recv_queue = create_message_queue();

  secure_channel_initialized = true;

  return true;
}


bool
finalize_secure_channel() {
  assert( secure_channel_initialized );

  clear_connection();

  if ( send_queue != NULL ) {
    delete_message_queue( send_queue );
    send_queue = NULL;
  }
  if ( recv_queue != NULL ) {
    delete_message_queue( recv_queue );
    recv_queue = NULL;
  }

  secure_channel_initialized = false;

  return true;
}


bool
send_message_to_secure_channel( buffer *message ) {
  assert( send_queue != NULL );
  assert( message != NULL );
  assert( message->length > 0 );

  
  if ( connection.state != CONNECTED || connection.fd == -1 ) {
    return false;
  }
  debug( "Enqueuing a message to send queue ( queue length = %d, message length = %d ).",
        send_queue->length, message->length );

  if ( send_queue->length == 0 ) {
    set_writable_safe( connection.fd, true );
  }

  buffer *duplicated = duplicate_buffer( message );
  enqueue_message( send_queue, duplicated );

  return true;
}


/*
 * Local variables:
 * c-basic-offset: 2
 * indent-tabs-mode: nil
 * End:
 */