from .pandas_backend import pd
import numpy as np
from typing import List, Dict, Callable
RecNN is designed to work with your data flow.
Set kwargs in the beginning of prepare_dataset function.
Kwargs you set are immutable.
args_mut are mutable arguments, you can access the following:
base: data.EnvBase, df: DataFrame, users: List[int],
user_dict: Dict[int, Dict[str, np.ndarray]
Access args_mut and modify them in functions defined by you.
Best to use function chaining with build_data_pipeline. is a function that is used by default in Env.__init__
But sometimes you want some extra. I have also predefined truncate_dataset.
This function truncates the number of items to specified one.
In reinforce example I modify it to look like::
def prepare_dataset(args_mut, kwargs):
kwargs.set('reduce_items_to', num_items) # set kwargs for your functions here!
pipeline = [,], kwargs, args_mut)
# embeddgings:
env ='..',
'...', frame_size, batch_size,
embed_batch=embed_batch, prepare_dataset=prepare_dataset,
def try_progress_apply(dataframe, function):
return dataframe.progress_apply(function)
except AttributeError:
return dataframe.apply(function)
# Plain args. Shouldn't be mutated
class DataFuncKwargs:
def __init__(self, **kwargs):
self.kwargs = kwargs
def keys(self):
return self.kwargs.keys()
def get(self, name: str):
if name not in self.kwargs:
example = """
# example on how to use kwargs:
def prepare_dataset(args, args_mut):
args.set_kwarg('{}', your_value) # set kwargs for your functions here!
pipeline = [,], args, args_mut)
raise AttributeError(
"No kwarg with name {} found!\n{}".format(name, example.format(example))
return self.kwargs[name]
def set(self, name: str, value):
self.kwargs[name] = value
# Used for returning, arguments are mutable
class DataFuncArgsMut:
def __init__(
self, df, base, users: List[int], user_dict: Dict[int, Dict[str, np.ndarray]]
self.base = base
self.users = users
self.user_dict = user_dict
self.df = df
def prepare_dataset(args_mut: DataFuncArgsMut, kwargs: DataFuncKwargs):
Basic prepare dataset function. Automatically makes index linear, in ml20 movie indices look like:
[1, 34, 123, 2000], recnn makes it look like [0,1,2,3] for you.
# get args
frame_size = kwargs.get("frame_size")
key_to_id = args_mut.base.key_to_id
df = args_mut.df
# rating range mapped from [0, 5] to [-5, 5]
df["rating"] = try_progress_apply(df["rating"], lambda i: 2 * (i - 2.5))
# id's tend to be inconsistent and sparse so they are remapped here
df["movieId"] = try_progress_apply(df["movieId"], key_to_id.get)
users = df[["userId", "movieId"]].groupby(["userId"]).size()
users = users[users > frame_size].sort_values(ascending=False).index
if pd.get_type() == "modin":
df = df._to_pandas() # pandas groupby is sync and doesnt affect performance
ratings = (
.drop("timestamp", axis=1)
# Groupby user
user_dict = {}
def app(x):
userid = x.index[0]
user_dict[userid] = {}
user_dict[userid]["items"] = x["movieId"].values
user_dict[userid]["ratings"] = x["rating"].values
try_progress_apply(ratings, app)
args_mut.user_dict = user_dict
args_mut.users = users
return args_mut, kwargs
def truncate_dataset(args_mut: DataFuncArgsMut, kwargs: DataFuncKwargs):
Truncate #items to reduce_items_to provided in kwargs
# here are adjusted n items to keep
num_items = kwargs.get("reduce_items_to")
df = args_mut.df
counts = df["movieId"].value_counts().sort_values()
to_remove = counts[:-num_items].index
to_keep = counts[-num_items:].index
to_keep_id = pd.get().Series(to_keep).apply(args_mut.base.key_to_id.get).values
to_keep_mask = np.zeros(len(counts))
to_keep_mask[to_keep_id] = 1
args_mut.df = df.drop(df[df["movieId"].isin(to_remove)].index)
key_to_id_new = {}
id_to_key_new = {}
count = 0
for idx, i in enumerate(list(args_mut.base.key_to_id.keys())):
if i in to_keep:
key_to_id_new[i] = count
id_to_key_new[idx] = i
count += 1
args_mut.base.embeddings = args_mut.base.embeddings[to_keep_mask]
args_mut.base.key_to_id = key_to_id_new
args_mut.base.id_to_key = id_to_key_new
"action space is reduced to {} - {} = {}".format(
num_items + len(to_remove), len(to_remove), num_items
return args_mut, kwargs
def build_data_pipeline(
chain: List[Callable], kwargs: DataFuncKwargs, args_mut: DataFuncArgsMut
Higher order function
:param chain: array of callable
:param **kwargs: any kwargs you like
for call in chain:
# note: returned kwargs are not utilized to guarantee immutability
args_mut, _ = call(args_mut, kwargs)
return args_mut, kwargs