Skip to content

Commit

Permalink
Merge pull request #19 from seermedical/brendan-add-labels-batched
Browse files Browse the repository at this point in the history
Add batched version of add_labels
  • Loading branch information
bjdoyle authored Jul 26, 2019
2 parents dba2ccf + bd6a364 commit b1254e4
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
40 changes: 40 additions & 0 deletions seerpy/seerpy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2017 Seer Medical Pty Ltd, Inc. or its affiliates. All Rights Reserved.

import math
import time

from gql import gql, Client as GQLClient
Expand Down Expand Up @@ -176,6 +177,45 @@ def del_label_group(self, group_id):
query_string = graphql.get_remove_label_group_mutation_string(group_id)
return self.execute_query(query_string)

def add_labels_batched(self, label_group_id, labels, batch_size=500):
"""Add labels to label group in batches
Parameters
----------
label_group_id : string
Seer label group ID
labels: list of:
note: string
label note
startTime : float
label start time in epoch time
duration : float
duration of event in milliseconds
timezone : float
local UTC timezone (eg. Melbourne = 11.0)
tagIds: [String!]
list of tag ids
confidence: float
Confidence given to label between 0 and 1
batch_size: int
number of labels to add in a batch. Optional, defaults to 500.
Returns
-------
None
Notes
-----
"""
number_of_batches = math.ceil(len(labels) / batch_size)
for i in range(number_of_batches):
start = i * batch_size
end = start + batch_size
self.add_labels(label_group_id, labels[start:end])

def add_labels(self, group_id, labels):
"""Add labels to label group
Expand Down
6 changes: 4 additions & 2 deletions seerpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def get_channel_data(all_data, segment_urls, # pylint:disable=too-many-argument
"""
if threads is None:
if os.name != 'nt':
if os.name == 'nt':
threads = 1
else:
threads = 5
Expand Down Expand Up @@ -191,7 +191,9 @@ def get_channel_data(all_data, segment_urls, # pylint:disable=too-many-argument
data_list = [download_function(data_q_item) for data_q_item in data_q]

if data_list:
data = pd.concat(data_list)
# sort=False to silence deprecation warning. This comes into play when we are processing
# segments across multiple channel groups which have different channels.
data = pd.concat(data_list, sort=False)
data = data.loc[(data['time'] >= from_time) & (data['time'] < to_time)]
data = data.sort_values(['id', 'channelGroups.id', 'time'], axis=0,
ascending=True, na_position='last')
Expand Down

0 comments on commit b1254e4

Please sign in to comment.