common/utils.py
import os
import re
import sys
import time
import json
import socket
import random
import string
import platform
import subprocess
from urllib.parse import scheme_chars
from ipaddress import IPv4Address, ip_address
from distutils.version import LooseVersion
from pathlib import Path
from stat import S_IXUSR
import requests
import tenacity
from dns.resolver import Resolver
from common.database import Database
from common.domain import Domain
from common.records import Record, RecordCollection
from config import settings
from config.log import logger
user_agents = [
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 '
'(KHTML, like Gecko) Chrome/76.0.3809.100 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_6) AppleWebKit/537.36 '
'(KHTML, like Gecko) Chrome/76.0.3809.100 Safari/537.36',
'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 '
'(KHTML, like Gecko) Chrome/76.0.3809.100 Safari/537.36',
'Mozilla/5.0 (Windows NT 6.1; WOW64; rv:54.0) Gecko/20100101 Firefox/68.0',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10.13; rv:61.0) '
'Gecko/20100101 Firefox/68.0',
'Mozilla/5.0 (X11; Linux i586; rv:31.0) Gecko/20100101 Firefox/68.0']
IP_RE = re.compile(r'^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$') # pylint: disable=line-too-long
SCHEME_RE = re.compile(r'^([' + scheme_chars + ']+:)?//')
def gen_random_ip():
"""
Generate random decimal IP string
"""
while True:
ip = IPv4Address(random.randint(0, 2 ** 32 - 1))
if ip.is_global:
return ip.exploded
def gen_fake_header():
"""
Generate fake request headers
"""
headers = settings.request_default_headers.copy()
if not isinstance(headers, dict):
headers = dict()
if settings.enable_random_ua:
ua = random.choice(user_agents)
headers['User-Agent'] = ua
headers['Accept-Encoding'] = 'gzip, deflate'
return headers
def get_random_header():
"""
Get random header
"""
headers = gen_fake_header()
if not isinstance(headers, dict):
headers = None
return headers
def get_random_proxy():
"""
Get random proxy
"""
try:
return random.choice(settings.request_proxy_pool)
except IndexError:
return None
def get_proxy():
"""
Get proxy
"""
if settings.enable_request_proxy:
return get_random_proxy()
return None
def split_list(ls, size):
"""
Split list
:param list ls: list
:param int size: size
:return list: result
>>> split_list([1, 2, 3, 4], 3)
[[1, 2, 3], [4]]
"""
if size == 0:
return ls
return [ls[i:i + size] for i in range(0, len(ls), size)]
def match_main_domain(domain):
if not isinstance(domain, str):
return None
item = domain.lower().strip()
return Domain(item).match()
def read_target_file(target):
domains = list()
with open(target, encoding='utf-8', errors='ignore') as file:
for line in file:
domain = match_main_domain(line)
if not domain:
continue
domains.append(domain)
sorted_domains = sorted(set(domains), key=domains.index)
return sorted_domains
def get_from_target(target):
domains = set()
if isinstance(target, str):
if target.endswith('.txt'):
logger.log('FATAL', 'Use targets parameter for multiple domain names')
exit(1)
domain = match_main_domain(target)
if not domain:
return domains
domains.add(domain)
return domains
def get_from_targets(targets):
domains = set()
if not isinstance(targets, str):
return domains
try:
path = Path(targets)
except Exception as e:
logger.log('ERROR', e.args)
return domains
if path.exists() and path.is_file():
domains = read_target_file(targets)
return domains
return domains
def get_domains(target, targets=None):
logger.log('DEBUG', f'Getting domains')
target_domains = get_from_target(target)
targets_domains = get_from_targets(targets)
domains = list(target_domains.union(targets_domains))
if targets_domains:
domains = sorted(domains, key=targets_domains.index) # 按照targets原本的index排序
if not domains:
logger.log('ERROR', f'Did not get a valid domain name')
logger.log('DEBUG', f'The obtained domains \n{domains}')
return domains
def check_dir(dir_path):
if not dir_path.exists():
logger.log('INFOR', f'{dir_path} does not exist, directory will be created')
dir_path.mkdir(parents=True, exist_ok=True)
def check_path(path, name, fmt):
"""
检查结果输出目录路径
:param path: 保存路径
:param name: 导出名字
:param fmt: 保存格式
:return: 保存路径
"""
filename = f'{name}.{fmt}'
default_path = settings.result_save_dir.joinpath(filename)
if isinstance(path, str):
path = repr(path).replace('\\', '/') # 将路径中的反斜杠替换为正斜杠
path = path.replace('\'', '') # 去除多余的转义
else:
path = default_path
path = Path(path)
if path.is_dir(): # 输入是目录的情况
path = path.joinpath(filename)
parent_dir = path.parent
if not parent_dir.exists():
logger.log('ALERT', f'{parent_dir} does not exist, directory will be created')
parent_dir.mkdir(parents=True, exist_ok=True)
if path.exists():
logger.log('ALERT', f'The {path} exists and will be overwritten')
return path
def check_format(fmt):
"""
检查导出格式
:param fmt: 传入的导出格式
:return: 导出格式
"""
formats = ['csv', 'json', ]
if fmt in formats:
return fmt
else:
logger.log('ALERT', f'Does not support {fmt} format')
logger.log('ALERT', 'So use csv format by default')
return 'csv'
def load_json(path):
with open(path) as fp:
return json.load(fp)
def save_to_db(name, data, module):
"""
Save request results to database
:param str name: table name
:param list data: data to be saved
:param str module: module name
"""
db = Database()
db.drop_table(name)
db.create_table(name)
db.save_db(name, data, module)
db.close()
def save_to_file(path, data):
"""
保存数据到文件
:param path: 保存路径
:param data: 待存数据
:return: 保存成功与否
"""
try:
with open(path, 'w', errors='ignore', newline='') as file:
file.write(data)
return True
except TypeError:
with open(path, 'wb') as file:
file.write(data)
return True
except Exception as e:
logger.log('ERROR', e.args)
return False
def check_response(method, resp):
"""
检查响应 输出非正常响应返回json的信息
:param method: 请求方法
:param resp: 响应体
:return: 是否正常响应
"""
if resp.status_code == 200 and resp.content:
return True
logger.log('ALERT', f'{method} {resp.url} {resp.status_code} - '
f'{resp.reason} {len(resp.content)}')
content_type = resp.headers.get('Content-Type')
if content_type and 'json' in content_type and resp.content:
try:
msg = resp.json()
except Exception as e:
logger.log('DEBUG', e.args)
else:
logger.log('ALERT', msg)
return False
def mark_subdomain(old_data, now_data):
"""
标记新增子域并返回新的数据集
:param list old_data: 之前子域数据
:param list now_data: 现在子域数据
:return: 标记后的的子域数据
:rtype: list
"""
# 第一次收集子域的情况
mark_data = now_data.copy()
if not old_data:
for index, item in enumerate(mark_data):
item['new'] = 1
mark_data[index] = item
return mark_data
# 非第一次收集子域的情况
old_subdomains = {item.get('subdomain') for item in old_data}
for index, item in enumerate(mark_data):
subdomain = item.get('subdomain')
if subdomain in old_subdomains:
item['new'] = 0
else:
item['new'] = 1
mark_data[index] = item
return mark_data
def remove_invalid_string(string):
# Excel文件中单元格值不能直接存储以下非法字符
return re.sub(r'[\000-\010]|[\013-\014]|[\016-\037]', r'', string)
def export_all_results(path, name, fmt, datas):
path = check_path(path, name, fmt)
logger.log('ALERT', f'The subdomain result for all main domains: {path}')
row_list = list()
for row in datas:
if 'header' in row:
row.pop('header')
if 'response' in row:
row.pop('response')
keys = row.keys()
values = row.values()
row_list.append(Record(keys, values))
rows = RecordCollection(iter(row_list))
content = rows.export(fmt)
if fmt == 'csv':
content = '\ufeff' + content
save_to_file(path, content)
def export_all_subdomains(alive, path, name, datas):
path = check_path(path, name, 'txt')
logger.log('ALERT', f'The txt subdomain result for all main domains: {path}')
subdomains = set()
for row in datas:
subdomain = row.get('subdomain')
if alive:
if not row.get('alive'):
continue
subdomains.add(subdomain)
else:
subdomains.add(subdomain)
data = '\n'.join(subdomains)
save_to_file(path, data)
def export_all(alive, fmt, path, datas):
"""
将所有结果数据导出
:param bool alive: 只导出存活子域结果
:param str fmt: 导出文件格式
:param str path: 导出文件路径
:param list datas: 待导出的结果数据
"""
fmt = check_format(fmt)
timestamp = get_timestring()
name = f'all_subdomain_result_{timestamp}'
export_all_results(path, name, fmt, datas)
export_all_subdomains(alive, path, name, datas)
def dns_resolver():
"""
dns解析器
"""
resolver = Resolver()
resolver.nameservers = settings.resolver_nameservers
resolver.timeout = settings.resolver_timeout
resolver.lifetime = settings.resolver_lifetime
return resolver
def dns_query(qname, qtype):
"""
查询域名DNS记录
:param str qname: 待查域名
:param str qtype: 查询类型
:return: 查询结果
"""
logger.log('TRACE', f'Try to query {qtype} record of {qname}')
resolver = dns_resolver()
try:
answer = resolver.query(qname, qtype)
except Exception as e:
logger.log('TRACE', e.args)
logger.log('TRACE', f'Query {qtype} record of {qname} failed')
return None
else:
logger.log('TRACE', f'Query {qtype} record of {qname} succeeded')
return answer
def get_timestamp():
return int(time.time())
def get_timestring():
return time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
def get_classname(classobj):
return classobj.__class__.__name__
def python_version():
return sys.version
def calc_alive(data):
return len(list(filter(lambda item: item.get('alive') == 1, data)))
def count_alive(name):
db = Database()
result = db.count_alive(name)
count = result.scalar()
db.close()
return count
def get_subdomains(data):
return set(map(lambda item: item.get('subdomain'), data))
def set_id_none(data):
new_data = []
for item in data:
item['id'] = None
new_data.append(item)
return new_data
def get_filtered_data(data):
filtered_data = []
for item in data:
resolve = item.get('resolve')
if resolve != 1:
filtered_data.append(item)
return filtered_data
def get_sample_banner(headers):
temp_list = []
server = headers.get('Server')
if server:
temp_list.append(server)
via = headers.get('Via')
if via:
temp_list.append(via)
power = headers.get('X-Powered-By')
if power:
temp_list.append(power)
banner = ','.join(temp_list)
return banner
def check_ip_public(ip_list):
for ip_str in ip_list:
ip = ip_address(ip_str)
if not ip.is_global:
return 0
return 1
def ip_is_public(ip_str):
ip = ip_address(ip_str)
if not ip.is_global:
return 0
return 1
def get_request_count():
return os.cpu_count() * 16
def uniq_dict_list(dict_list):
return list(filter(lambda name: dict_list.count(name) == 1, dict_list))
def delete_file(*paths):
for path in paths:
try:
path.unlink()
except Exception as e:
logger.log('ERROR', e.args)
@tenacity.retry(stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_fixed(2))
def check_net():
times = 0
while True:
times += 1
urls = ['https://www.baidu.com', 'https://www.bing.com',
'https://www.cloudflare.com', 'https://www.akamai.com/',
'https://www.fastly.com/', 'https://www.amazon.com/']
url = random.choice(urls)
logger.log('DEBUG', f'Trying to access {url}')
header = get_random_header()
proxy = get_proxy()
timeout = settings.request_timeout_second
verify = settings.request_ssl_verify
session = requests.Session()
session.trust_env = False
session = requests.Session()
session.trust_env = False
try:
rsp = session.get(url, headers=header, proxies=proxy,
timeout=timeout, verify=verify)
except Exception as e:
logger.log('ERROR', e.args)
logger.log('ALERT', f'Unable to access Internet, retrying for the {times}th time')
else:
if rsp.status_code == 200:
logger.log('DEBUG', 'Access to Internet OK')
return True
if times >= 3:
logger.log('ALERT', 'Access to Internet failed')
return False
def check_dep():
logger.log('INFOR', 'Checking dependent environment')
implementation = platform.python_implementation()
version = platform.python_version()
if implementation != 'CPython':
logger.log('FATAL', f'OneForAll only passed the test under CPython')
exit(1)
if LooseVersion(version) < LooseVersion('3.6'):
logger.log('FATAL', 'OneForAll requires Python 3.6 or higher')
exit(1)
def get_net_env():
logger.log('INFOR', 'Checking network environment')
try:
result = check_net()
except Exception as e:
logger.log('DEBUG', e.args)
logger.log('ALERT', 'Please check your network environment.')
return False
return result
def check_version(local):
logger.log('INFOR', 'Checking for the latest version')
api = 'https://api.github.com/repos/shmilylty/OneForAll/releases/latest'
header = get_random_header()
proxy = get_proxy()
timeout = settings.request_timeout_second
verify = settings.request_ssl_verify
session = requests.Session()
session.trust_env = False
try:
resp = session.get(url=api, headers=header, proxies=proxy,
timeout=timeout, verify=verify)
resp_json = resp.json()
latest = resp_json['tag_name']
except Exception as e:
logger.log('ALERT', 'An error occurred while checking the latest version')
logger.log('DEBUG', e.args)
return
if latest > local:
change = resp_json.get("body")
logger.log('ALERT', f'The current version is {local} '
f'but the latest version is {latest}')
logger.log('ALERT', f'The {latest} version mainly has the following changes')
logger.log('ALERT', change)
else:
logger.log('INFOR', f'The current version {local} is already the latest version')
def get_main_domain(domain):
if not isinstance(domain, str):
return None
return Domain(domain).registered()
def call_massdns(massdns_path, dict_path, ns_path, output_path, log_path,
query_type='A', process_num=1, concurrent_num=10000,
quiet_mode=False):
logger.log('DEBUG', 'Start running massdns')
quiet = ''
if quiet_mode:
quiet = '--quiet'
status_format = settings.brute_status_format
socket_num = settings.brute_socket_num
resolve_num = settings.brute_resolve_num
cmd = f'{massdns_path} {quiet} --status-format {status_format} ' \
f'--processes {process_num} --socket-count {socket_num} ' \
f'--hashmap-size {concurrent_num} --resolvers {ns_path} ' \
f'--resolve-count {resolve_num} --type {query_type} ' \
f'--flush --output J --outfile {output_path} ' \
f'--root --error-log {log_path} {dict_path} --filter OK ' \
f'--sndbuf 0 --rcvbuf 0'
logger.log('DEBUG', f'Run command {cmd}')
subprocess.run(args=cmd, shell=True)
logger.log('DEBUG', f'Finished massdns')
def get_massdns_path(massdns_dir):
path = settings.brute_massdns_path
if path:
return path
system = platform.system().lower()
machine = platform.machine().lower()
name = f'massdns_{system}_{machine}'
if system == 'windows':
name = f'massdns.exe'
if machine == 'amd64':
massdns_dir = massdns_dir.joinpath('windows', 'x64')
else:
massdns_dir = massdns_dir.joinpath('windows', 'x86')
path = massdns_dir.joinpath(name)
path.chmod(S_IXUSR)
if not path.exists():
logger.log('FATAL', 'There is no massdns for this platform or architecture')
logger.log('INFOR', 'Please try to compile massdns yourself '
'and specify the path in the configuration')
exit(0)
return path
def is_subname(name):
chars = string.ascii_lowercase + string.digits + '.-'
for char in name:
if char not in chars:
return False
return True
def ip_to_int(ip):
if isinstance(ip, int):
return ip
try:
ipv4 = IPv4Address(ip)
except Exception as e:
logger.log('ERROR', e.args)
return 0
return int(ipv4)
def match_subdomains(domain, html, distinct=True, fuzzy=True):
"""
Use regexp to match subdomains
:param str domain: main domain
:param str html: response html text
:param bool distinct: deduplicate results or not (default True)
:param bool fuzzy: fuzzy match subdomain or not (default True)
:return set/list: result set or list
"""
logger.log('TRACE', f'Use regexp to match subdomains in the response body')
if fuzzy:
regexp = r'(?:[a-z0-9](?:[a-z0-9\-]{0,61}[a-z0-9])?\.){0,}' \
+ domain.replace('.', r'\.')
result = re.findall(regexp, html, re.I)
if not result:
return set()
deal = map(lambda s: s.lower(), result)
if distinct:
return set(deal)
else:
return list(deal)
else:
regexp = r'(?:\>|\"|\'|\=|\,)(?:http\:\/\/|https\:\/\/)?' \
r'(?:[a-z0-9](?:[a-z0-9\-]{0,61}[a-z0-9])?\.){0,}' \
+ domain.replace('.', r'\.')
result = re.findall(regexp, html, re.I)
if not result:
return set()
regexp = r'(?:http://|https://)'
deal = map(lambda s: re.sub(regexp, '', s[1:].lower()), result)
if distinct:
return set(deal)
else:
return list(deal)
def check_random_subdomain(subdomains):
if not subdomains:
logger.log('ALERT', f'The generated dictionary is empty')
return
for subdomain in subdomains:
if subdomain:
logger.log('ALERT', f'Please check whether {subdomain} is correct or not')
return
def get_url_resp(url):
logger.log('INFOR', f'Attempting to request {url}')
timeout = settings.request_timeout_second
verify = settings.request_ssl_verify
session = requests.Session()
session.trust_env = False
try:
resp = session.get(url, params=None, timeout=timeout, verify=verify)
except Exception as e:
logger.log('ALERT', f'Error request {url}')
logger.log('DEBUG', e.args)
return None
return resp
def decode_resp_text(resp):
content = resp.content
if not content:
return str('')
try:
# 先尝试用utf-8严格解码
content = str(content, encoding='utf-8', errors='strict')
except (LookupError, TypeError, UnicodeError):
try:
# 再尝试用gb18030严格解码
content = str(content, encoding='gb18030', errors='strict')
except (LookupError, TypeError, UnicodeError):
# 最后尝试自动解码
content = str(content, errors='replace')
return content
def sort_by_subdomain(data):
return sorted(data, key=lambda item: item.get('subdomain'))
def looks_like_ip(maybe_ip):
"""Does the given str look like an IP address?"""
if not maybe_ip[0].isdigit():
return False
try:
socket.inet_aton(maybe_ip)
return True
except (AttributeError, UnicodeError):
if IP_RE.match(maybe_ip):
return True
except socket.error:
return False
def deal_data(domain):
db = Database()
db.remove_invalid(domain)
db.deduplicate_subdomain(domain)
db.close()
def get_data(domain):
db = Database()
data = db.get_data(domain).as_dict()
db.close()
return data
def clear_data(domain):
db = Database()
db.drop_table(domain)
db.close()
def get_ns_path(in_china=None, enable_wildcard=None, ns_ip_list=None):
data_dir = settings.data_storage_dir
path = data_dir.joinpath('nameservers.txt')
if in_china:
path = data_dir.joinpath('nameservers_cn.txt')
if not enable_wildcard:
return path
if not ns_ip_list:
return path
path = settings.authoritative_dns_path
ns_data = '\n'.join(ns_ip_list)
save_to_file(path, ns_data)
return path
def init_table(domain):
db = Database()
db.drop_table(domain)
db.create_table(domain)
db.close()