akissa/sachannelupdate

View on GitHub
sachannelupdate/transports.py

Summary

Maintainability
C
1 day
Test Coverage
# -*- coding: utf-8 -*-
# vim: ai ts=4 sts=4 et sw=4
# sachannelupdate - Utility for pushing updates to Spamassassin update channels
# Copyright (C) 2015  Andrew Colin Kissa <andrew@topdog.za.net>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""
sachannelupdate: Transports
"""
import os

from Queue import Queue
from pwd import getpwnam
from getpass import getuser
from urlparse import urlparse

from paramiko.util import load_host_keys
from paramiko import Transport, SFTPClient, PKey, PasswordRequiredException, \
    SSHException

from sachannelupdate.exceptions import SaChannelUpdateTransportError


def get_key_files(kfiles, dirname, names):
    """Return key files"""
    for name in names:
        fullname = os.path.join(dirname, name)
        if os.path.isfile(fullname) and \
            fullname.endswith('_rsa') or \
                fullname.endswith('_dsa'):
            kfiles.put(fullname)


def get_ssh_keys(sshdir):
    """Get SSH keys"""
    keys = Queue()
    for root, _, files in os.walk(os.path.abspath(sshdir)):
        if not files:
            continue
        for filename in files:
            fullname = os.path.join(root, filename)
            if (os.path.isfile(fullname) and fullname.endswith('_rsa') or
                    fullname.endswith('_dsa')):
                keys.put(fullname)
    return keys


def get_remote_path(remote_location):
    """Get the remote path from the remote location"""
    parts = urlparse(remote_location)
    return parts.path


def get_ssh_dir(config, username):
    """Get the users ssh dir"""
    sshdir = config.get('ssh_config_dir')
    if not sshdir:
        sshdir = os.path.expanduser('~/.ssh')
        if not os.path.isdir(sshdir):
            pwentry = getpwnam(username)
            sshdir = os.path.join(pwentry.pw_dir, '.ssh')
            if not os.path.isdir(sshdir):
                sshdir = None
    return sshdir


def get_local_user(username):
    """Get the local username"""
    try:
        _ = getpwnam(username)
        luser = username
    except KeyError:
        luser = getuser()
    return luser


def get_host_keys(hostname, sshdir):
    """get host key"""
    hostkey = None

    try:
        host_keys = load_host_keys(os.path.join(sshdir, 'known_hosts'))
    except IOError:
        host_keys = {}

    if hostname in host_keys:
        hostkeytype = host_keys[hostname].keys()[0]
        hostkey = host_keys[hostname][hostkeytype]

    return hostkey


def get_sftp_conn(config):
    """Make a SFTP connection, returns sftp client and connection objects"""
    remote = config.get('remote_location')
    parts = urlparse(remote)

    if ':' in parts.netloc:
        hostname, port = parts.netloc.split(':')
    else:
        hostname = parts.netloc
        port = 22
    port = int(port)

    username = config.get('remote_username') or getuser()
    luser = get_local_user(username)
    sshdir = get_ssh_dir(config, luser)
    hostkey = get_host_keys(hostname, sshdir)

    try:
        sftp = None
        keys = get_ssh_keys(sshdir)
        transport = Transport((hostname, port))
        while not keys.empty():
            try:
                key = PKey.from_private_key_file(keys.get())
                transport.connect(
                    hostkey=hostkey,
                    username=username,
                    password=None,
                    pkey=key)
                sftp = SFTPClient.from_transport(transport)
                break
            except (PasswordRequiredException, SSHException):
                pass
        if sftp is None:
            raise SaChannelUpdateTransportError("SFTP connection failed")
        return sftp, transport
    except BaseException as msg:
        raise SaChannelUpdateTransportError(msg)