From 2a6be5506e74210ec87be11e440b52306c131305 Mon Sep 17 00:00:00 2001 From: Francesco Stablum Date: Thu, 25 Nov 2021 08:58:49 +0100 Subject: [PATCH] feat: activity datapoint vectorizer that makes use of field-specific DSPNs --- common/relspecs.py | 16 ++++++++++++++-- preprocess/dag.py | 14 ++------------ preprocess/vectorize_activity.py | 22 ++++++++++++++++++++++ 3 files changed, 38 insertions(+), 14 deletions(-) diff --git a/common/relspecs.py b/common/relspecs.py index 2c113c0..0d0289e 100644 --- a/common/relspecs.py +++ b/common/relspecs.py @@ -18,8 +18,7 @@ ) ) -from common import persistency -from common import utils +from common import persistency, utils, config from models import text_model @@ -87,6 +86,19 @@ def glue(self, tensor_list): # FIXME: maybe to some other module? ret = tensor_list return ret + def extract_from_activity_data(self,activity_data): + ret = {} + for k, v in activity_data.items(): + m = re.match(f'{self.name}_(.*)', k) + if m is not None: + rel_field = m.group(1) + if rel_field in self.fields_names: + # cap the amount of items to config.download_max_set_size + v = v[:config.download_max_set_size] + # logging.info(f"considering field {rel_field}") + ret[rel_field] = v + return ret + @property def scalers(self): return [curr.scaler for curr in self.fields] diff --git a/preprocess/dag.py b/preprocess/dag.py index 8808409..c7b0fd3 100644 --- a/preprocess/dag.py +++ b/preprocess/dag.py @@ -70,18 +70,8 @@ def parse(page, ti): for activity in data['response']['docs']: activity_id = activity['iati_identifier'] # logging.info(f"processing activity {activity_id}") - for k, v in activity.items(): - # logging.info(f"processing activity item {k}") - for rel in rels: - # logging.info(f"processing rel {rel.name}") - m = re.match(f'{rel.name}_(.*)', k) - if m is not None: - rel_field = m.group(1) - if rel_field in rel.fields_names: - # cap the amount of items to config.download_max_set_size - v = v[:config.download_max_set_size] - # logging.info(f"considering field {rel_field}") - rels_vals[rel.name][activity_id][rel_field] = v + for rel in rels: + rels_vals[rel.name][activity_id] = rel.extract_from_activity_data(activity) for rel, sets in rels_vals.items(): remove = [] diff --git a/preprocess/vectorize_activity.py b/preprocess/vectorize_activity.py index e69de29..f5ee55b 100644 --- a/preprocess/vectorize_activity.py +++ b/preprocess/vectorize_activity.py @@ -0,0 +1,22 @@ +from common import relspecs, utils +from models import models_storage +import numpy as np + +class Activity(utils.Collection): + def __init__(self, activity_data): + for rel in relspecs: + self[rel.name] = rel.extract_from_activity_data(activity_data) + + +class ActivityVectorizer(object): + def __init__(self): + self.model_storage = models_storage.DSPNAEModelsStorage() + + def vectorize_activity(self, activity): + vectorized_fields = [] + for rel in relspecs: + field_data = activity[rel.name] + vectorized_field = self.model_storage[rel.name].encoder(field_data) + vectorized_fields.append(vectorized_field) + ret = np.hstack(vectorized_fields) + return ret \ No newline at end of file