diff --git a/nade/nade.py b/nade/nade.py index 20cad27..5c7fa5f 100644 --- a/nade/nade.py +++ b/nade/nade.py @@ -125,7 +125,8 @@ def predict_emojis( ''' predict emotions based on emoji (stage 2) ''' - def predict(self, txts: List[str]) -> List[str]: + def predict(self, txts: List[str], dimensions: None | List[str]) -> List[str]: + dims_ = dimensions if dimensions not None else self.labels ft_op = self.predict_emojis(txts, sort_by_key=True, k=151) X, _ = zip(*ft_op) @@ -138,7 +139,7 @@ def predict(self, txts: List[str]) -> List[str]: ), ndigits=3 ) - for lbl in self.labels + for lbl in dims_ } return raw_reg