dragonfire/deepconv/corpus/opensubsdata.py
import xml.etree.ElementTree as ET
import datetime
import os
import sys
import json
import re
import pprint
from gzip import GzipFile
from tqdm import tqdm
"""
Load the opensubtitles dialog corpus.
"""
class OpensubsData:
"""
"""
def __init__(self, dirName):
"""
Args:
dirName (string): directory where to load the corpus
"""
# Hack this to filter on subset of Opensubtitles
# dirName = "%s/en/Action" % dirName
print("Loading OpenSubtitles conversations in %s." % dirName)
self.conversations = []
self.tag_re = re.compile(r'(<!--.*?-->|<[^>]*>)')
self.conversations = self.loadConversations(dirName)
def loadConversations(self, dirName):
"""
Args:
dirName (str): folder to load
Return:
array(question, answer): the extracted QA pairs
"""
conversations = []
dirList = self.filesInDir(dirName)
for filepath in tqdm(dirList, "OpenSubtitles data files"):
if filepath.endswith('gz'):
try:
doc = self.getXML(filepath)
conversations.extend(self.genList(doc))
except ValueError:
tqdm.write("Skipping file %s with errors." % filepath)
except:
print("Unexpected error:", sys.exc_info()[0])
raise
return conversations
def getConversations(self):
return self.conversations
def genList(self, tree):
root = tree.getroot()
timeFormat = '%H:%M:%S'
maxDelta = datetime.timedelta(seconds=1)
startTime = datetime.datetime.min
strbuf = ''
sentList = []
for child in root:
for elem in child:
if elem.tag == 'time':
elemID = elem.attrib['id']
elemVal = elem.attrib['value'][:-4]
if elemID[-1] == 'S':
startTime = datetime.datetime.strptime(elemVal, timeFormat)
else:
sentList.append((strbuf.strip(), startTime, datetime.datetime.strptime(elemVal, timeFormat)))
strbuf = ''
else:
try:
strbuf = strbuf + " " + elem.text
except:
pass
conversations = []
for idx in range(0, len(sentList) - 1):
cur = sentList[idx]
nxt = sentList[idx + 1]
if nxt[1] - cur[2] <= maxDelta and cur and nxt:
tmp = {}
tmp["lines"] = []
tmp["lines"].append(self.getLine(cur[0]))
tmp["lines"].append(self.getLine(nxt[0]))
if self.filter(tmp):
conversations.append(tmp)
return conversations
def getLine(self, sentence):
line = {}
line["text"] = self.tag_re.sub('', sentence).replace('\\\'','\'').strip().lower()
return line
def filter(self, lines):
# Use the followint to customize filtering of QA pairs
#
# startwords = ("what", "how", "when", "why", "where", "do", "did", "is", "are", "can", "could", "would", "will")
# question = lines["lines"][0]["text"]
# if not question.endswith('?'):
# return False
# if not question.split(' ')[0] in startwords:
# return False
#
return True
def getXML(self, filepath):
fext = os.path.splitext(filepath)[1]
if fext == '.gz':
tmp = GzipFile(filename=filepath)
return ET.parse(tmp)
else:
return ET.parse(filepath)
def filesInDir(self, dirname):
result = []
for dirpath, dirs, files in os.walk(dirname):
for filename in files:
fname = os.path.join(dirpath, filename)
result.append(fname)
return result