vj4/handler/training.py
import asyncio
from json import decoder
from bson import objectid
from vj4 import app
from vj4 import error
from vj4 import constant
from vj4.model import builtin
from vj4.model import document
from vj4.model import domain
from vj4.model import user
from vj4.model.adaptor import problem
from vj4.model.adaptor import training
from vj4.handler import base
from vj4.util import json
from vj4.util import pagination
from vj4.util import misc
def _parse_dag_json(dag):
try:
dag = json.decode(dag)
except decoder.JSONDecodeError:
raise error.ValidationError('dag') from None
if not isinstance(dag, list):
raise error.ValidationError('dag')
new_dag = []
try:
for node in dag:
if any(k not in node for k in ['_id', 'require_nids', 'pids']):
raise error.ValidationError('dag')
new_node = {'_id': int(node['_id']),
'title': str(node.get('title', '')),
'require_nids': misc.dedupe(map(int, node['require_nids'])),
'pids': misc.dedupe(map(document.convert_doc_id, node['pids']))}
new_dag.append(new_node)
except ValueError:
raise error.ValidationError('dag') from None
return new_dag
class TrainingMixin(object):
def get_pids(self, tdoc):
pids = set()
for node in tdoc['dag']:
pids.update(node['pids'])
return list(pids)
def is_done(self, node, done_nids, done_pids):
return set(done_nids) >= set(node['require_nids']) \
and set(done_pids) >= set(node['pids'])
def is_progress(self, node, done_nids, done_pids, prog_pids):
return set(done_nids) >= set(node['require_nids']) \
and not set(done_pids) >= set(node['pids']) \
and ((set(done_pids) | set(prog_pids)) & set(node['pids']))
def is_open(self, node, done_nids, done_pids, prog_pids):
return set(done_nids) >= set(node['require_nids']) \
and not set(done_pids) >= set(node['pids']) \
and not ((set(done_pids) | set(prog_pids)) & set(node['pids']))
def is_invalid(self, node, done_nids):
return not set(done_nids) >= set(node['require_nids'])
@app.route('/training', 'training_main')
class TrainingMainHandler(base.Handler, TrainingMixin):
TRAININGS_PER_PAGE = 20
@base.require_perm(builtin.PERM_VIEW_TRAINING)
@base.get_argument
@base.sanitize
async def get(self, *, sort: str='', page: int=1):
if sort:
qs = 'sort={0}'.format(sort)
else:
qs = ''
tdocs, tpcount, _ = await pagination.paginate(training.get_multi(self.domain_id) \
.sort('doc_id', 1),
page, self.TRAININGS_PER_PAGE)
tids = set(tdoc['doc_id'] for tdoc in tdocs)
tsdict = dict()
tdict = dict()
if self.has_priv(builtin.PRIV_USER_PROFILE):
enrolled_tids = set()
async for tsdoc in training.get_multi_status(domain_id=self.domain_id,
uid=self.user['_id'],
**{'$or': [{'doc_id': {'$in': list(tids)}},
{'enroll': 1}]}):
tsdict[tsdoc['doc_id']] = tsdoc
enrolled_tids.add(tsdoc['doc_id'])
enrolled_tids -= tids
if enrolled_tids:
tdict = await training.get_dict(self.domain_id, enrolled_tids)
for tdoc in tdocs:
tdict[tdoc['doc_id']] = tdoc
self.render('training_main.html', tdocs=tdocs, page=page, tpcount=tpcount, qs=qs,
tsdict=tsdict, tdict=tdict)
@app.route('/training/{tid:\w{24}}', 'training_detail')
class TrainingDetailHandler(base.OperationHandler, TrainingMixin):
@base.require_perm(builtin.PERM_VIEW_TRAINING)
@base.route_argument
@base.sanitize
async def get(self, tid: objectid.ObjectId):
tdoc = await training.get(self.domain_id, tid)
pids = self.get_pids(tdoc)
# TODO(twd2): check status, eg. test, hidden problem, ...
if not self.has_perm(builtin.PERM_VIEW_PROBLEM_HIDDEN):
f = {'hidden': False}
else:
f = {}
owner_udoc, owner_dudoc, pdict = await asyncio.gather(
user.get_by_uid(tdoc['owner_uid']),
domain.get_user(domain_id=self.domain_id, uid=tdoc['owner_uid']),
problem.get_dict(self.domain_id, pids, **f))
psdict = await problem.get_dict_status(self.domain_id,
self.user['_id'], pdict.keys())
done_pids = set()
prog_pids = set()
for pid, psdoc in psdict.items():
if 'status' in psdoc:
if psdoc['status'] == constant.record.STATUS_ACCEPTED:
done_pids.add(pid)
else:
prog_pids.add(pid)
nsdict = {}
ndict = {}
done_nids = set()
for node in tdoc['dag']:
ndict[node['_id']] = node
total_count = len(node['pids'])
done_count = len(set(node['pids']) & set(done_pids))
nsdoc = {'progress': int(100 * done_count / total_count) if total_count else 100,
'is_done': self.is_done(node, done_nids, done_pids),
'is_progress': self.is_progress(node, done_nids, done_pids, prog_pids),
'is_open': self.is_open(node, done_nids, done_pids, prog_pids),
'is_invalid': self.is_invalid(node, done_nids)}
if nsdoc['is_done']:
done_nids.add(node['_id'])
nsdict[node['_id']] = nsdoc
tsdoc = await training.set_status(self.domain_id, tdoc['doc_id'], self.user['_id'],
done_nids=list(done_nids), done_pids=list(done_pids),
done=len(done_nids) == len(tdoc['dag']))
path_components = self.build_path(
(self.translate('training_main'), self.reverse_url('training_main')),
(tdoc['title'], None))
self.render('training_detail.html', tdoc=tdoc, tsdoc=tsdoc, pids=pids,
pdict=pdict, psdict=psdict, ndict=ndict, nsdict=nsdict,
owner_udoc=owner_udoc, owner_dudoc=owner_dudoc,
page_title=tdoc['title'], path_components=path_components)
@base.require_priv(builtin.PRIV_USER_PROFILE)
@base.require_perm(builtin.PERM_VIEW_TRAINING)
@base.route_argument
@base.require_csrf_token
@base.sanitize
async def post_enroll(self, *, tid: objectid.ObjectId):
tdoc = await training.get(self.domain_id, tid)
await training.enroll(self.domain_id, tdoc['doc_id'], self.user['_id'])
self.json_or_redirect(self.url)
@app.route('/training/create', 'training_create')
class TrainingCreateHandler(base.Handler, TrainingMixin):
@base.require_priv(builtin.PRIV_USER_PROFILE)
@base.require_perm(builtin.PERM_CREATE_TRAINING)
async def get(self):
self.render('training_edit.html')
@base.require_priv(builtin.PRIV_USER_PROFILE)
@base.require_perm(builtin.PERM_CREATE_TRAINING)
@base.post_argument
@base.require_csrf_token
@base.sanitize
async def post(self, *, title: str, content: str, dag: str, desc: str):
dag = _parse_dag_json(dag)
pids = self.get_pids({'dag': dag})
if not pids:
# empty plan
raise error.ValidationError('dag')
pdocs = await problem.get_multi(domain_id=self.domain_id, doc_id={'$in': pids},
fields={'doc_id': 1, 'hidden': 1}) \
.sort('doc_id', 1) \
.to_list()
exist_pids = [pdoc['doc_id'] for pdoc in pdocs]
if len(pids) != len(exist_pids):
for pid in pids:
if pid not in exist_pids:
raise error.ProblemNotFoundError(self.domain_id, pid)
for pdoc in pdocs:
if pdoc.get('hidden', False):
self.check_perm(builtin.PERM_VIEW_PROBLEM_HIDDEN)
tid = await training.add(self.domain_id, title, content, self.user['_id'],
dag=dag, desc=desc)
self.json_or_redirect(self.reverse_url('training_detail', tid=tid))
@app.route('/training/{tid}/edit', 'training_edit')
class TrainingEditHandler(base.Handler, TrainingMixin):
@base.require_priv(builtin.PRIV_USER_PROFILE)
@base.route_argument
@base.sanitize
async def get(self, *, tid: objectid.ObjectId):
tdoc = await training.get(self.domain_id, tid)
if not self.own(tdoc, builtin.PERM_EDIT_TRAINING_SELF):
self.check_perm(builtin.PERM_EDIT_TRAINING)
dag = json.encode_pretty(tdoc['dag'])
path_components = self.build_path(
(self.translate('training_main'), self.reverse_url('training_main')),
(tdoc['title'], self.reverse_url('training_detail', tid=tdoc['doc_id'])),
(self.translate('training_edit'), None))
self.render('training_edit.html', tdoc=tdoc, dag=dag,
page_title=tdoc['title'], path_components=path_components)
@base.require_priv(builtin.PRIV_USER_PROFILE)
@base.route_argument
@base.post_argument
@base.require_csrf_token
@base.sanitize
async def post(self, *, tid: objectid.ObjectId, title: str, content: str, dag: str, desc: str):
tdoc = await training.get(self.domain_id, tid)
if not self.own(tdoc, builtin.PERM_EDIT_TRAINING_SELF):
self.check_perm(builtin.PERM_EDIT_TRAINING)
dag = _parse_dag_json(dag)
pids = self.get_pids({'dag': dag})
if not pids:
# empty plan
raise error.ValidationError('dag')
pdocs = await problem.get_multi(domain_id=self.domain_id, doc_id={'$in': pids},
fields={'doc_id': 1, 'hidden': 1}) \
.sort('doc_id', 1) \
.to_list()
exist_pids = [pdoc['doc_id'] for pdoc in pdocs]
if len(pids) != len(exist_pids):
for pid in pids:
if pid not in exist_pids:
raise error.ProblemNotFoundError(self.domain_id, pid)
for pdoc in pdocs:
if pdoc.get('hidden', False):
self.check_perm(builtin.PERM_VIEW_PROBLEM_HIDDEN)
await training.edit(self.domain_id, tdoc['doc_id'], title=title, content=content,
dag=dag, desc=desc)
self.json_or_redirect(self.reverse_url('training_detail', tid=tid))