scrapi/processing/cassandra.py
from __future__ import absolute_import
import six
import json
import logging
from uuid import uuid4
from datetime import datetime
from dateutil.parser import parse
from cassandra.cqlengine import connection
from cassandra.cqlengine import management
from cassandra.cluster import NoHostAvailable
from cassandra.cqlengine import columns, models
from scrapi import events
from scrapi import settings
from scrapi.util import try_n_times
from scrapi.util import copy_to_unicode
from scrapi.linter import RawDocument, NormalizedDocument
from scrapi.processing import DocumentTuple
from scrapi.processing.base import BaseHarvesterResponse, BaseProcessor, BaseDatabaseManager
logger = logging.getLogger(__name__)
logging.getLogger('cqlengine.cql').setLevel(logging.WARN)
class DatabaseManager(BaseDatabaseManager):
_models = set()
def __init__(self, uri=None, keyspace=None):
self._setup = False
self.uri = uri or settings.CASSANDRA_URI
self.keyspace = keyspace or settings.CASSANDRA_KEYSPACE
self._models = set(map(self.register_model, self._models))
def setup(self, force=False, sync=True):
if self._setup and not force:
return True
try:
connection.setup(self.uri, self.keyspace)
if sync:
management.create_keyspace(self.keyspace, replication_factor=1, strategy_class='SimpleStrategy')
for model in self._models:
model.__keyspace__ = self.keyspace
management.sync_table(model)
except NoHostAvailable:
logger.error('Could not connect to Cassandra, expect errors.')
return False
# Note: return values are for test skipping
self._setup = True
return True
def tear_down(self):
if not self._setup:
logger.warning('Attempting to tear down a database that was never setup')
if connection.cluster is not None:
connection.cluster.shutdown()
if connection.session is not None:
connection.session.shutdown()
self._setup = False
def clear(self, force=False):
assert force, 'clear_keyspace must be called with force'
assert self.keyspace != settings.CASSANDRA_KEYSPACE, 'Cannot erase the keyspace in settings'
management.delete_keyspace(self.keyspace)
self.tear_down()
return self.setup()
def register_model(self, model):
self._models.add(model)
model.__keyspace__ = self.keyspace
if self._setup:
management.sync_table(model)
return model
def celery_setup(self, *args, **kwargs):
self.tear_down()
self.setup()
@classmethod
def registered_model(cls, model):
cls._models.add(model)
return model
class CassandraProcessor(BaseProcessor):
'''
Cassandra processor for scrapi. Handles versioning and storing documents in Cassandra
'''
NAME = 'cassandra'
_manager = None
@property
def manager(self):
self._manager = self._manager or DatabaseManager()
return self._manager
@property
def HarvesterResponseModel(self):
return HarvesterResponse
@events.logged(events.PROCESSING, 'normalized.cassandra')
def process_normalized(self, raw_doc, normalized):
self.send_to_database(
source=copy_to_unicode(raw_doc['source']),
docID=copy_to_unicode(raw_doc['docID']),
contributors=copy_to_unicode(json.dumps(normalized['contributors'])),
description=copy_to_unicode(normalized.get('description')),
uris=copy_to_unicode(json.dumps(normalized['uris'])),
providerUpdatedDateTime=parse(normalized['providerUpdatedDateTime']).replace(tzinfo=None),
freeToRead=copy_to_unicode(json.dumps(normalized.get('freeToRead', {}))),
languages=normalized.get('language'),
licenses=copy_to_unicode(json.dumps(normalized.get('licenseRef', []))),
publisher=copy_to_unicode(json.dumps(normalized.get('publisher', {}))),
sponsorships=copy_to_unicode(json.dumps(normalized.get('sponsorship', []))),
title=copy_to_unicode(normalized['title']),
version=copy_to_unicode(json.dumps(normalized.get('version'), {})),
otherProperties=copy_to_unicode(json.dumps(normalized.get('otherProperties', {}))),
shareProperties=copy_to_unicode(json.dumps(normalized['shareProperties']))
).save()
@events.logged(events.PROCESSING, 'raw.cassandra')
def process_raw(self, raw_doc):
self.send_to_database(
source=copy_to_unicode(raw_doc['source']),
docID=copy_to_unicode(raw_doc['docID']),
doc=six.text_type(raw_doc['doc']).encode('utf-8'),
filetype=copy_to_unicode(raw_doc['filetype']),
timestamps=raw_doc.get('timestamps', {})
).save()
def send_to_database(self, docID, source, **kwargs):
documents = DocumentModel.objects(docID=docID, source=source)
if documents:
document = documents[0]
if self.different(dict(document), dict(docID=docID, source=source, **kwargs)):
# Create new version, get UUID of new version, update
versions = document.versions + kwargs.pop('versions', [])
version = VersionModel(key=uuid4(), **dict(document))
version.save()
versions.append(version.key)
return document.update(versions=versions, **kwargs)
else:
raise events.Skip("No changes detected for document with ID {0} and source {1}.".format(docID, source))
else:
# create document
return DocumentModel.create(docID=docID, source=source, **kwargs)
def documents(self, *sources):
q = DocumentModel.objects.timeout(500).allow_filtering().all().limit(100)
querysets = (q.filter(source=source) for source in sources) if sources else [q]
for query in querysets:
page = try_n_times(5, list, query)
while len(page) > 0:
for doc in page:
doc.save()
yield DocumentTuple(self.to_raw(doc), self.to_normalized(doc))
page = try_n_times(5, self.next_page, query, page)
def next_page(self, query, page):
return list(query.filter(docID__gt=page[-1].docID))
def to_raw(self, doc):
return RawDocument({
'doc': doc.doc,
'docID': doc.docID,
'source': doc.source,
'filetype': doc.filetype,
'timestamps': doc.timestamps
}, validate=False, clean=False)
def to_normalized(self, doc):
# make the new dict actually contain real items
normed = {}
do_not_include = ['docID', 'doc', 'filetype', 'timestamps', 'source', 'versions', 'key']
for key, value in dict(doc).items():
if value and key not in do_not_include:
try:
normed[key] = json.loads(value)
except (ValueError, TypeError):
normed[key] = value
if normed.get('versions'):
normed['versions'] = list(map(str, normed['versions']))
# No datetime means the document wasn't normalized (probably wasn't on the approved list)
# TODO - fix odd circular import that makes us import this here
from scrapi.base.helpers import datetime_formatter
if normed.get('providerUpdatedDateTime'):
normed['providerUpdatedDateTime'] = datetime_formatter(normed['providerUpdatedDateTime'].isoformat())
else:
return None
return NormalizedDocument(normed, validate=False, clean=False)
def get(self, source, docID):
documents = DocumentModel.objects(source=source, docID=docID)
try:
doc = documents[0]
except IndexError:
return None
raw = self.to_raw(doc)
normalized = self.to_normalized(doc)
return DocumentTuple(raw, normalized)
def delete(self, source, docID):
document = DocumentModel.objects(source=source, docID=docID)
document.timeout(5).delete()
def create(self, attributes):
DocumentModel.create(**attributes).save()
def get_versions(self, source, docID):
try:
doc = DocumentModel.get(source=source, docID=docID)
except DocumentModel.DoesNotExist:
return
for uuid in doc.versions:
try:
version = VersionModel.get(key=uuid)
except VersionModel.DoesNotExist:
continue
yield DocumentTuple(self.to_raw(version), self.to_normalized(version))
yield DocumentTuple(self.to_raw(doc), self.to_normalized(doc))
@DatabaseManager.registered_model
class DocumentModel(models.Model):
'''
Defines the schema for a metadata document in cassandra
The schema contains denormalized raw document, denormalized
normalized (so sorry for the terminology clash) document, and
a list of version IDs that refer to previous versions of this
metadata.
'''
__table_name__ = 'documents_source_partitioned'
# Raw
source = columns.Text(primary_key=True, partition_key=True)
docID = columns.Text(primary_key=True, index=True, clustering_order='ASC')
doc = columns.Bytes()
filetype = columns.Text()
timestamps = columns.Map(columns.Text, columns.Text)
# Normalized
uris = columns.Text()
title = columns.Text()
contributors = columns.Text() # TODO
providerUpdatedDateTime = columns.DateTime()
description = columns.Text()
freeToRead = columns.Text() # TODO
languages = columns.List(columns.Text())
licenses = columns.Text() # TODO
publisher = columns.Text() # TODO
subjects = columns.List(columns.Text())
tags = columns.List(columns.Text())
sponsorships = columns.Text() # TODO
version = columns.Text() # TODO
otherProperties = columns.Text() # TODO
shareProperties = columns.Text() # TODO
# Additional metadata
versions = columns.List(columns.UUID)
@DatabaseManager.registered_model
class DocumentModelOld(models.Model):
'''
Defines the schema for a metadata document in cassandra
The schema contains denormalized raw document, denormalized
normalized (so sorry for the terminology clash) document, and
a list of version IDs that refer to previous versions of this
metadata.
'''
__table_name__ = 'documents'
# Raw
docID = columns.Text(primary_key=True)
source = columns.Text(primary_key=True, clustering_order="DESC")
doc = columns.Bytes()
filetype = columns.Text()
timestamps = columns.Map(columns.Text, columns.Text)
# Normalized
uris = columns.Text()
title = columns.Text()
contributors = columns.Text() # TODO
providerUpdatedDateTime = columns.DateTime()
description = columns.Text()
freeToRead = columns.Text() # TODO
languages = columns.List(columns.Text())
licenses = columns.Text() # TODO
publisher = columns.Text() # TODO
subjects = columns.List(columns.Text())
tags = columns.List(columns.Text())
sponsorships = columns.Text() # TODO
version = columns.Text() # TODO
otherProperties = columns.Text() # TODO
shareProperties = columns.Text() # TODO
# Additional metadata
versions = columns.List(columns.UUID)
@DatabaseManager.registered_model
class VersionModel(models.Model):
'''
Defines the schema for a version of a metadata document in Cassandra
See the DocumentModel class. This schema is very similar, except it is
keyed on a UUID that is generated by us, rather than it's own metadata
'''
__table_name__ = 'versions'
key = columns.UUID(primary_key=True, required=True)
# Raw
doc = columns.Bytes()
docID = columns.Text()
filetype = columns.Text()
source = columns.Text()
timestamps = columns.Map(columns.Text, columns.Text)
# Normalized
uris = columns.Text()
title = columns.Text()
contributors = columns.Text() # TODO
providerUpdatedDateTime = columns.DateTime()
description = columns.Text()
freeToRead = columns.Text() # TODO
languages = columns.List(columns.Text())
licenses = columns.Text() # TODO
publisher = columns.Text() # TODO
subjects = columns.List(columns.Text())
tags = columns.List(columns.Text())
sponsorships = columns.Text() # TODO
version = columns.Text() # TODO
otherProperties = columns.Text() # TODO
shareProperties = columns.Text() # TODO
# Additional metadata
versions = columns.List(columns.UUID)
@DatabaseManager.registered_model
class HarvesterResponse(models.Model, BaseHarvesterResponse):
__table_name__ = 'responses'
method = columns.Text(primary_key=True)
url = columns.Text(primary_key=True, required=True)
# Raw request data
ok = columns.Boolean()
content = columns.Bytes()
encoding = columns.Text()
headers_str = columns.Text()
status_code = columns.Integer()
time_made = columns.DateTime(default=datetime.now)