Konano/arknights-mower

View on GitHub
arknights_mower/utils/device/adb_client/session.py

Summary

Maintainability
A
25 mins
Test Coverage
from __future__ import annotations

import socket
import struct
import time

from ... import config
from ...log import logger
from .socket import Socket


class Session(object):
    """ Session between ADB client and ADB server """

    def __init__(self, server: tuple[str, int] = None, timeout: int = None) -> None:
        if server is None:
            server = (config.ADB_SERVER_IP, config.ADB_SERVER_PORT)
        if timeout is None:
            timeout = config.ADB_SERVER_TIMEOUT
        self.server = server
        self.timeout = timeout
        self.device_id = None
        self.sock = Socket(self.server, self.timeout)

    def __enter__(self) -> Session:
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
        pass

    def request(self, cmd: str, reconnect: bool = False) -> Session:
        """ make a service request to ADB server, consult ADB sources for available services """
        cmdbytes = cmd.encode()
        data = b'%04X%b' % (len(cmdbytes), cmdbytes)
        while self.timeout <= 60:
            try:
                self.sock.send(data).check_okay()
                return self
            except socket.timeout:
                logger.warning(f'socket.timeout: {self.timeout}s, +5s')
                self.timeout += 5
                self.sock = Socket(self.server, self.timeout)
                if reconnect:
                    self.device(self.device_id)
        raise socket.timeout(f'server: {self.server}')

    def response(self, recv_all: bool = False) -> bytes:
        """ receive response """
        if recv_all:
            return self.sock.recv_all()
        else:
            return self.sock.recv_response()

    def exec(self, cmd: str) -> bytes:
        """ exec: cmd """
        if len(cmd) == 0:
            raise ValueError('no command specified for exec')
        return self.request('exec:' + cmd, True).response(True)

    def shell(self, cmd: str) -> bytes:
        """ shell: cmd """
        if len(cmd) == 0:
            raise ValueError('no command specified for shell')
        return self.request('shell:' + cmd, True).response(True)

    def host(self, cmd: str) -> bytes:
        """ host: cmd """
        if len(cmd) == 0:
            raise ValueError('no command specified for host')
        return self.request('host:' + cmd, True).response()

    def run(self, cmd: str, recv_all: bool = False) -> bytes:
        """ run command """
        if len(cmd) == 0:
            raise ValueError('no command specified')
        return self.request(cmd, True).response(recv_all)

    def device(self, device_id: str = None) -> Session:
        """ switch to a device """
        self.device_id = device_id
        if device_id is None:
            return self.request('host:transport-any')
        else:
            return self.request('host:transport:' + device_id)

    def connect(self, device: str, throw_error: bool = False) -> None:
        """ connect device [ip:port] """
        if len(device.split(':')) != 2 or not device.split(':')[-1].isdigit():
            raise ValueError('wrong format of device address')
        resp = self.request(f'host:connect:{device}').response()
        logger.debug(f'adb connect {device}: {repr(resp)}')
        if throw_error and (b'unable' in resp or b'cannot' in resp):
            raise RuntimeError(repr(resp))

    def disconnect(self, device: str, throw_error: bool = False) -> None:
        """ disconnect device [ip:port] """
        if len(device.split(':')) != 2 or not device.split(':')[-1].isdigit():
            raise ValueError('wrong format of device address')
        resp = self.request(f'host:disconnect:{device}').response()
        logger.debug(f'adb disconnect {device}: {repr(resp)}')
        if throw_error and (b'unable' in resp or b'cannot' in resp):
            raise RuntimeError(repr(resp))

    def devices_list(self) -> list[tuple[str, str]]:
        """ returns list of devices that the adb server knows """
        resp = self.request('host:devices').response().decode(errors='ignore')
        devices = [tuple(line.split('\t')) for line in resp.splitlines()]
        logger.debug(devices)
        return devices

    def push(self, target_path: str, target: bytes, mode=0o100755, mtime: int = None):
        """ push data to device """
        self.request('sync:', True)
        request = b'%s,%d' % (target_path.encode(), mode)
        self.sock.send(b'SEND' + struct.pack('<I', len(request)) + request)
        buf = bytearray(65536+8)
        buf[0:4] = b'DATA'
        idx = 0
        while idx < len(target):
            content = target[idx:idx+65536]
            content_len = len(content)
            idx += content_len
            buf[4:8] = struct.pack('<I', content_len)
            buf[8:8+content_len] = content
            self.sock.sendall(bytes(buf[0:8+content_len]))
        if mtime is None:
            mtime = int(time.time())
        self.sock.send(b'DONE' + struct.pack('<I', mtime))
        response = self.sock.recv_exactly(8)
        if response[:4] != b'OKAY':
            raise RuntimeError('push failed')