Skip to content

Commit

Permalink
allow inference for specific dimensions only
Browse files Browse the repository at this point in the history
  • Loading branch information
inkrement authored Jul 22, 2023
1 parent 262eabb commit 66e3847
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions nade/nade.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def predict_emojis(
'''
predict emotions based on emoji (stage 2)
'''
def predict(self, txts: List[str]) -> List[str]:
def predict(self, txts: List[str], dimensions: None | List[str]) -> List[str]:
dims_ = dimensions if dimensions not None else self.labels
ft_op = self.predict_emojis(txts, sort_by_key=True, k=151)
X, _ = zip(*ft_op)

Expand All @@ -138,7 +139,7 @@ def predict(self, txts: List[str]) -> List[str]:
),
ndigits=3
)
for lbl in self.labels
for lbl in dims_
}

return raw_reg
Expand Down

0 comments on commit 66e3847

Please sign in to comment.