Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return self in InferInput setters for chain-initialization #378

Merged
merged 1 commit into from
Aug 7, 2023

Conversation

GuanLuo
Copy link
Contributor

@GuanLuo GuanLuo commented Aug 7, 2023

UX improvement.

Copy link
Contributor

@rmccorm4 rmccorm4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition.

I've been thinking changing our constructor from something like this:

class InferInput:
    def __init__(self, name, shape, datatype):

to something along the lines of this:

class InferInput:
    def __init__(self, name, shape=None, datatype=None, tensor=None):
      if shape:
        self.shape = shape
      if datatype:
        self.datatype = datatype
        
      # Improve existing numpy support to just take an array to initialize input:
      if isinstance(tensor, np.ndarray):
        if not shape:
          self.shape = tensor.shape
        if not datatype:
          self.datatype = utils.np_to_triton_dtype(tensor.dtype)
        self.set_data_from_numpy(tensor)
        
      # Future torch tensor support - might be better to have an explicit InferInput.FromTorch()
      # function here that does `import torch` dynamically so we don't require torch
      if isinstance(tensor, torch.tensor):
        ...

or some wrapper since the above "Either/Or" requirement of shape+datatype or tensor isn't very clear:

class InferInputWrapper:
    def __init__(self, name, tensor):
      # Improve existing numpy support to just take an array to initialize input:
      if isinstance(tensor, np.ndarray):
        if not shape:
          self.shape = tensor.shape
        if not datatype:
          self.datatype = triton_to_np(tensor.dtype)
        self.set_data_from_numpy(tensor)
        
      ...

Then usage for non-shared memory case can simply become something like this:

data = np.zeros((1, 16), dtype=np.float32)
inputs = [InferInput("INPUT0", data)]

@GuanLuo GuanLuo merged commit ae92fd3 into main Aug 7, 2023
3 checks passed
@GuanLuo GuanLuo deleted the gluo-input branch August 7, 2023 19:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

3 participants