superset/daos/tag.py
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from operator import and_
from typing import Any, Optional
from flask import g
from sqlalchemy.exc import NoResultFound
from superset.commands.tag.exceptions import TagNotFoundError
from superset.commands.tag.utils import to_object_type
from superset.daos.base import BaseDAO
from superset.exceptions import MissingUserContextException
from superset.extensions import db
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import SavedQuery
from superset.tags.models import (
get_tag,
ObjectType,
Tag,
TaggedObject,
TagType,
user_favorite_tag_table,
)
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__)
class TagDAO(BaseDAO[Tag]):
# base_filter = TagAccessFilter
@staticmethod
def create_custom_tagged_objects(
object_type: ObjectType, object_id: int, tag_names: list[str]
) -> None:
tagged_objects = []
# striping and de-dupping
clean_tag_names: set[str] = {tag.strip() for tag in tag_names}
for name in clean_tag_names:
type_ = TagType.custom
tag = TagDAO.get_by_name(name, type_)
tagged_objects.append(
TaggedObject(object_id=object_id, object_type=object_type, tag=tag)
)
# Check if the association already exists
existing_tagged_object = (
db.session.query(TaggedObject)
.filter_by(object_id=object_id, object_type=object_type, tag=tag)
.first()
)
if not existing_tagged_object:
tagged_objects.append(
TaggedObject(object_id=object_id, object_type=object_type, tag=tag)
)
db.session.add_all(tagged_objects)
@staticmethod
def delete_tagged_object(
object_type: ObjectType, object_id: int, tag_name: str
) -> None:
"""
deletes a tagged object by the object_id, object_type, and tag_name
"""
tag = TagDAO.find_by_name(tag_name.strip())
if not tag:
raise NoResultFound(message=f"Tag with name {tag_name} does not exist.")
tagged_object = db.session.query(TaggedObject).filter(
TaggedObject.tag_id == tag.id,
TaggedObject.object_type == object_type,
TaggedObject.object_id == object_id,
)
if not tagged_object:
raise NoResultFound(
message=f'Tagged object with object_id: {object_id} \
object_type: {object_type} \
and tag name: "{tag_name}" could not be found'
)
db.session.delete(tagged_object.one())
@staticmethod
def delete_tags(tag_names: list[str]) -> None:
"""
deletes tags from a list of tag names
"""
tags_to_delete = []
for name in tag_names:
tag_name = name.strip()
if not TagDAO.find_by_name(tag_name):
raise NoResultFound(message=f"Tag with name {tag_name} does not exist.")
tags_to_delete.append(tag_name)
tag_objects = db.session.query(Tag).filter(Tag.name.in_(tags_to_delete))
for tag in tag_objects:
db.session.delete(tag)
@staticmethod
def get_by_name(name: str, type_: TagType = TagType.custom) -> Tag:
"""
returns a tag if one exists by that name, none otherwise.
important!: Creates a tag by that name if the tag is not found.
"""
tag = (
db.session.query(Tag)
.filter(Tag.name == name, Tag.type == type_.name)
.first()
)
if not tag:
tag = get_tag(name, db.session, type_)
return tag
@staticmethod
def find_by_name(name: str) -> Tag:
"""
returns a tag if one exists by that name, none otherwise.
Does NOT create a tag if the tag is not found.
"""
return db.session.query(Tag).filter(Tag.name == name).first()
@staticmethod
def find_tagged_object(
object_type: ObjectType, object_id: int, tag_id: int
) -> TaggedObject:
"""
returns a tagged object if one exists by that name, none otherwise.
"""
return (
db.session.query(TaggedObject)
.filter(
TaggedObject.tag_id == tag_id,
TaggedObject.object_id == object_id,
TaggedObject.object_type == object_type,
)
.first()
)
@staticmethod
def get_tagged_objects_by_tag_id(
tag_ids: Optional[list[int]], obj_types: Optional[list[str]] = None
) -> list[dict[str, Any]]:
tags = db.session.query(Tag).filter(Tag.id.in_(tag_ids)).all()
tag_names = [tag.name for tag in tags]
return TagDAO.get_tagged_objects_for_tags(tag_names, obj_types)
@staticmethod
def get_tagged_objects_for_tags(
tags: Optional[list[str]] = None, obj_types: Optional[list[str]] = None
) -> list[dict[str, Any]]:
"""
returns a list of tagged objects filtered by tag names and object types
if no filters applied returns all tagged objects
"""
results: list[dict[str, Any]] = []
# dashboards
if (not obj_types) or ("dashboard" in obj_types):
dashboards = (
db.session.query(Dashboard)
.join(
TaggedObject,
and_(
TaggedObject.object_id == Dashboard.id,
TaggedObject.object_type == ObjectType.dashboard,
),
)
.join(Tag, TaggedObject.tag_id == Tag.id)
.filter(not tags or Tag.name.in_(tags))
)
results.extend(
{
"id": obj.id,
"type": ObjectType.dashboard.name,
"name": obj.dashboard_title,
"url": obj.url,
"changed_on": obj.changed_on,
"created_by": obj.created_by_fk,
"creator": obj.creator(),
"tags": obj.tags,
"owners": obj.owners,
}
for obj in dashboards
)
# charts
if (not obj_types) or ("chart" in obj_types):
charts = (
db.session.query(Slice)
.join(
TaggedObject,
and_(
TaggedObject.object_id == Slice.id,
TaggedObject.object_type == ObjectType.chart,
),
)
.join(Tag, TaggedObject.tag_id == Tag.id)
.filter(not tags or Tag.name.in_(tags))
)
results.extend(
{
"id": obj.id,
"type": ObjectType.chart.name,
"name": obj.slice_name,
"url": obj.url,
"changed_on": obj.changed_on,
"created_by": obj.created_by_fk,
"creator": obj.creator(),
"tags": obj.tags,
"owners": obj.owners,
}
for obj in charts
)
# saved queries
if (not obj_types) or ("query" in obj_types):
saved_queries = (
db.session.query(SavedQuery)
.join(
TaggedObject,
and_(
TaggedObject.object_id == SavedQuery.id,
TaggedObject.object_type == ObjectType.query,
),
)
.join(Tag, TaggedObject.tag_id == Tag.id)
.filter(not tags or Tag.name.in_(tags))
)
results.extend(
{
"id": obj.id,
"type": ObjectType.query.name,
"name": obj.label,
"url": obj.url(),
"changed_on": obj.changed_on,
"created_by": obj.created_by_fk,
"creator": obj.creator(),
"tags": obj.tags,
"owners": [obj.creator()],
}
for obj in saved_queries
)
return results
@staticmethod
def favorite_tag_by_id_for_current_user( # pylint: disable=invalid-name
tag_id: int,
) -> None:
"""
Marks a specific tag as a favorite for the current user.
:param tag_id: The id of the tag that is to be marked as favorite
"""
tag = TagDAO.find_by_id(tag_id)
user = g.user
if not user:
raise MissingUserContextException(message="User doesn't exist")
if not tag:
raise TagNotFoundError()
tag.users_favorited.append(user)
@staticmethod
def remove_user_favorite_tag(tag_id: int) -> None:
"""
Removes a tag from the current user's favorite tags.
:param tag_id: The id of the tag that is to be removed from the favorite tags
"""
tag = TagDAO.find_by_id(tag_id)
user = g.user
if not user:
raise MissingUserContextException(message="User doesn't exist")
if not tag:
raise TagNotFoundError()
tag.users_favorited.remove(user)
@staticmethod
def favorited_ids(tags: list[Tag]) -> list[int]:
"""
Returns the IDs of tags that the current user has favorited.
This function takes in a list of Tag objects, extracts their IDs, and checks
which of these IDs exist in the user_favorite_tag_table for the current user.
The function returns a list of these favorited tag IDs.
Args:
tags (list[Tag]): A list of Tag objects.
Returns:
list[Any]: A list of IDs corresponding to the tags that are favorited by
the current user.
Example:
favorited_ids([tag1, tag2, tag3])
Output: [tag_id1, tag_id3] # if the current user has favorited tag1 and tag3
"""
ids = [tag.id for tag in tags]
return [
star.tag_id
for star in db.session.query(user_favorite_tag_table.c.tag_id)
.filter(
user_favorite_tag_table.c.tag_id.in_(ids),
user_favorite_tag_table.c.user_id == get_user_id(),
)
.all()
]
@staticmethod
def create_tag_relationship(
objects_to_tag: list[tuple[ObjectType, int]],
tag: Tag,
bulk_create: bool = False,
) -> None:
"""
Creates a tag relationship between the given objects and the specified tag.
This function iterates over a list of objects, each specified by a type
and an id, and creates a TaggedObject for each one, associating it with
the provided tag. All created TaggedObjects are collected in a list.
Args:
objects_to_tag (List[Tuple[ObjectType, int]]): A list of tuples, each
containing an ObjectType and an id, representing the objects to be tagged.
tag (Tag): The tag to be associated with the specified objects.
Returns:
None.
"""
tagged_objects = []
if not tag:
raise TagNotFoundError()
current_tagged_objects = {
(obj.object_type, obj.object_id) for obj in tag.objects
}
updated_tagged_objects = {
(to_object_type(obj[0]), obj[1]) for obj in objects_to_tag
}
tagged_objects_to_delete = (
current_tagged_objects
if not objects_to_tag
else current_tagged_objects - updated_tagged_objects
)
for object_type, object_id in updated_tagged_objects:
# create rows for new objects, and skip tags that already exist
if (object_type, object_id) not in current_tagged_objects:
tagged_objects.append(
TaggedObject(object_id=object_id, object_type=object_type, tag=tag)
)
if not bulk_create:
# delete relationships that aren't retained from single tag create
for object_type, object_id in tagged_objects_to_delete:
# delete objects that were removed
TagDAO.delete_tagged_object(
object_type, # type: ignore
object_id,
tag.name,
)
db.session.add_all(tagged_objects)