Skip to content

Commit

Permalink
Merge pull request #23 from bazingagin/exp
Browse files Browse the repository at this point in the history
fix custom dataset
  • Loading branch information
bazingagin authored Jul 24, 2023
2 parents 5360c8e + 666287d commit 3c21994
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ to calculate accuracy. Otherwise, the accuracy will be calculated automatically

### Use Custom Dataset

You can use your own custom dataset by passing whatever string to `--dataset`, and also remember to pass the data directory that contains `train.txt` and `test.txt` to `--data_dir`.
You can use your own custom dataset by passing `custom` to `--dataset`; pass the data directory that contains `train.txt` and `test.txt` to `--data_dir`; pass the class number to the `--class_num`.

Both `train.txt` and `test.txt` are expected to have the format `{label}\t{text}` per line.

You can change the delimiter according to you dataset by changing `delimiter` in `load_custom_dataset()` in `data.py`.

3 changes: 3 additions & 0 deletions main_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def non_neurl_knn_exp_given_dis(dis_matrix, k, test_label, train_label):
parser.add_argument("--distance_fn", default=None)
parser.add_argument("--score", action="store_true", default=False)
parser.add_argument("--k", default=2, type=int)
parser.add_argument("--class_num", default=5, type=int)
args = parser.parse_args()
# create output dir
if not os.path.exists(args.output_dir):
Expand Down Expand Up @@ -145,6 +146,7 @@ def non_neurl_knn_exp_given_dis(dis_matrix, k, test_label, train_label):
"swahili": 6,
"filipino": 5,
"kirnews": 14,
"custom": args.class_num
}
# load dataset
data_dir = os.path.join(args.data_dir, args.dataset)
Expand All @@ -158,6 +160,7 @@ def non_neurl_knn_exp_given_dis(dis_matrix, k, test_label, train_label):
"swahili",
"filipino",
"kirnews",
"custom",
]:
dataset_pair = eval(args.dataset)(root=args.data_dir)
else:
Expand Down

0 comments on commit 3c21994

Please sign in to comment.