diff --git a/.github/images/SqueezeBits_orange_H.png b/.github/images/SqueezeBits_orange_H.png
deleted file mode 100644
index ab5163b..0000000
Binary files a/.github/images/SqueezeBits_orange_H.png and /dev/null differ
diff --git a/.github/images/owlite_logo.png b/.github/images/owlite_logo.png
deleted file mode 100644
index 7527291..0000000
Binary files a/.github/images/owlite_logo.png and /dev/null differ
diff --git a/.github/workflows/workflow-deploy.yml b/.github/workflows/workflow-deploy.yml
new file mode 100644
index 0000000..c30bf79
--- /dev/null
+++ b/.github/workflows/workflow-deploy.yml
@@ -0,0 +1,77 @@
+name: deploy target branch
+
+on:
+ workflow_dispatch:
+ inputs:
+ version:
+ description: "The version of the release"
+ required: true
+ release-title:
+ description: "The title of the release"
+ required: true
+ release-content:
+ description: "The summary of the release, a.k.a release note"
+ required: true
+ target-branch:
+ description: "The target branch to deploy"
+ default: "master"
+
+jobs:
+ deploy:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout to branch master
+ uses: actions/checkout@v4
+ with:
+ ref: ${{ inputs.target-branch }}
+
+ - name: Setup python environment
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.10"
+
+ - name: Install build tool
+ run: pip install build
+
+ - name: Build
+ id: build
+ run: python -m build
+
+ - name: Get build results
+ id: build-result
+ run: |
+ echo "whl=$(ls dist/ | grep whl)" >> $GITHUB_OUTPUT
+ echo "tar=$(ls dist/ | grep tar)" >> $GITHUB_OUTPUT
+
+ - name: Create release
+ id: release
+ uses: softprops/action-gh-release@v1
+ with:
+ name: ${{ inputs.release-title }}
+ tag_name: ${{ format('v{0}', inputs.version) }}
+ body: ${{ inputs.release-content }}
+ prerelease: false
+ draft: false
+ files: |
+ ${{ format('dist/{0}', steps.build-result.outputs.whl) }}
+ ${{ format('dist/{0}', steps.build-result.outputs.tar) }}
+
+ - name: Trigger action in index repository
+ run: |
+ curl -L \
+ -X POST \
+ -H "Accept: application/vnd.github+json" \
+ -H "Authorization: Bearer ${{ secrets.PYPI_ACTION_DISPATCH }}" \
+ -H "X-GitHub-Api-Version: 2022-11-28" \
+ ${{ secrets.PYPI_ACTION_URL }} \
+ -d '{"ref":"master"}'
+
+ - name: Trigger action in owlite-doc repository
+ run: |
+ curl -L \
+ -X POST \
+ -H "Accept: application/vnd.github+json" \
+ -H "Authorization: Bearer ${{ secrets.TOKEN }}" \
+ -H "X-GitHub-Api-Version: 2022-11-28" \
+ ${{ secrets.DOC_ACTION_URL }} \
+ -d '{"ref":"main","inputs":{"source":"${{ inputs.target-branch }}","target":"main","msg":"${{ inputs.release-title }} [${{ format('v{0}', inputs.version) }}]"}}'
diff --git a/.gitignore b/.gitignore
index 0fcde50..cbfc4a1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,10 +8,9 @@
*.csv
*.onnx
*.bin
+*.engine
__pycache__/
-__owlite
-condaenv.*.requirements.txt
.DS_Store
.vscode
owlite.egg-info
-examples/
+build
diff --git a/CREDITS b/CREDITS
index 458cfa5..146d9e4 100644
--- a/CREDITS
+++ b/CREDITS
@@ -1176,4 +1176,240 @@ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
-CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
\ No newline at end of file
+CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+================================================================================
+
+pydantic
+https://github.com/pydantic/pydantic
+--------------------------------------------------------------------------------
+
+The MIT License (MIT)
+
+Copyright (c) 2017 to present Pydantic Services Inc. and individual contributors.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+================================================================================
+
+lazy_imports
+https://github.com/telekom/lazy-imports
+--------------------------------------------------------------------------------
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/README.md b/README.md
index db2f4c9..920bdaa 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-![](https://raw.githubusercontent.com/SqueezeBits/owlite/master/.github/images/owlite_logo.png)
+![OwLite logo](https://github.com/SqueezeBits/owlite/assets/64083281/abaa3ad9-0c86-4a9c-9b8d-f54ed6d9524b)
@@ -55,14 +55,14 @@ To install this package, please use Python 3.10.
Using pip (Recommended)
```bash
-pip install --extra-index-url https://pypi.ngc.nvidia.com git+https://github.com/SqueezeBits/owlite
+pip install owlite --extra-index-url https://pypi.squeezebits.com/
```
From source
```bash
git clone https://github.com/SqueezeBits/owlite.git
cd owlite
-pip install --extra-index-url https://pypi.ngc.nvidia.com .
+pip install .
```
## Getting Started
@@ -77,4 +77,4 @@ Please contact [owlite-admin@squeezebits.com](mailto:owlite-admin@squeezebits.co
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/owlite/__init__.py b/owlite/__init__.py
deleted file mode 100644
index dfe2edf..0000000
--- a/owlite/__init__.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from owlite_core.constants import OWLITE_VERSION as __version__ # noqa: N811
-from owlite_core.logger import log
-
-from . import api, backend, calib, nn
-from .backend import fx, onnx
-from .calibrators import (
- CalibrationContext,
- calibrate,
- prepare_for_calibration,
- update_fake_quantizers,
-)
-from .compress import compress
-from .enums import PTQCalibrationType, QATBackwardType
-from .options import (
- Channel,
- CompressionOptions,
- DynamicAxisOptions,
- DynamicInputOptions,
- FakeQuantizerOptions,
- GraphQuantizationOptions,
- NodeQuantizationOptions,
- ONNXExportOptions,
-)
-from .owlite import init
diff --git a/owlite/backend/onnx/onnx_op.py b/owlite/backend/onnx/onnx_op.py
deleted file mode 100644
index 8c1fff7..0000000
--- a/owlite/backend/onnx/onnx_op.py
+++ /dev/null
@@ -1,156 +0,0 @@
-import re
-from collections.abc import Sequence
-from dataclasses import dataclass
-from typing import Optional, Union
-
-import numpy as np
-from onnx.defs import OpSchema
-
-from .onnx_op_schemas import get_core_operator_schemas
-
-NumericValue = Union[int, float, bool, np.ndarray]
-
-
-@dataclass
-class FormalONNXParameter:
- """Structure wrapping properties defined in the ONNX op schema required for ONNX transformations"""
-
- name: str
- is_optional: bool
- is_variadic: bool
- is_homogeneous: bool
- is_differentiable: bool
- type_constraints: Sequence[np.dtype]
-
-
-class ONNXOp:
- """Class representing each ONNX op allowing convenient access to its schema properties"""
-
- schemas: dict[str, OpSchema] = get_core_operator_schemas()
-
- def __init__(self, name: str) -> None:
- self.name = name
-
- def __repr__(self) -> str:
- return f"{self.name}"
-
- def __str__(self) -> str:
- return self.__repr__()
-
- @property
- def is_valid(self) -> bool:
- """Checks if the op exists in schemas"""
- return self.name in ONNXOp.schemas
-
- @property
- def schema(self) -> OpSchema:
- """The full schema structure of the op
-
- Returns:
- list[tuple[int, list[tuple[str, OpSchema, list[OpSchema]]]]]: the full schema structure
- """
- return ONNXOp.schemas[self.name]
-
- @property
- def type_constraints(self) -> dict[str, list[str]]:
- """The dictionary that maps type parameter string to its allowed type strings
-
- Returns:
- dict[str, list[str]]: _description_
- """
- return {
- type_constraint.type_param_str: type_constraint.allowed_type_strs
- for type_constraint in self.schema.type_constraints
- }
-
- def i(self, idx: int = 0) -> "FormalONNXParameter":
- """The formal ONNX paramter of the input at given index.
-
- Args:
- idx (int, optional): the input index. Defaults to 0.
-
- Returns:
- FormalONNXParameter: the formal ONNX paramter of the input.
- """
- return self._get_formal_parameter(self.schema.inputs, idx)
-
- def o(self, idx: int = 0) -> "FormalONNXParameter":
- """The formal ONNX paramter of the output at given index.
-
- Args:
- idx (int, optional): the output index. Defaults to 0.
-
- Returns:
- FormalONNXParameter: the formal ONNX paramter of the output.
- """
- return self._get_formal_parameter(self.schema.outputs, idx)
-
- def _get_formal_parameter(self, params: list, idx: int = 0) -> FormalONNXParameter:
- is_last_parameter_variadic = params[-1].option == OpSchema.FormalParameterOption.Variadic
- if not (-len(params) <= idx < len(params) or is_last_parameter_variadic):
- raise IndexError(f"{self.name}: index out of range: {idx}")
- if is_last_parameter_variadic:
- param_idx = min(idx, len(params) - 1)
- offset = idx - param_idx
- param = params[param_idx]
- param_name = f"{param.name}_{offset}"
- else:
- param = params[idx]
- param_name = param.name
- return FormalONNXParameter(
- name=param_name,
- is_optional=OpSchema.FormalParameterOption.Optional == param.option,
- is_variadic=OpSchema.FormalParameterOption.Variadic == param.option,
- is_homogeneous=param.is_homogeneous,
- is_differentiable=OpSchema.DifferentiationCategory.Differentiable == param.differentiation_category,
- type_constraints=convert_to_np_dtypes(self.type_constraints.get(param.type_str, param.type_str)),
- )
-
-
-def convert_to_np_dtypes(wrapped_type_strs: list[str]) -> list[np.dtype]:
- """Converts type strings from an op schema to numpy data type
-
- Args:
- wrapped_type_strs (list[str]): the op schema type string
-
- Returns:
- list[np.dtype]: the converted numpy data type.
- """
- return [
- dtype for dtype in map(try_convert_to_np_dtype, map(unwrap_type_str, wrapped_type_strs)) if dtype is not None
- ]
-
-
-def unwrap_type_str(type_str: str) -> str:
- """Unwraps a type string from an op schema if possible
-
- Args:
- type_str (str): an op schema type string
-
- Returns:
- str: the string containing only type name if the unwrapping was successful, the input type_str itself otherwise.
- """
- match = re.search(r"tensor\((.*?)\)", type_str)
- if match:
- # Return the extracted string
- return match.group(1)
- # Return the input itself if no match is found
- return type_str
-
-
-def try_convert_to_np_dtype(type_str: str) -> Optional[np.dtype]:
- """Converts the type name in string into numpy data type if possible.
-
- Args:
- type_str (str): a string containing type name
-
- Returns:
- Optional[np.dtype]: a numpy.dtype instance if the conversion was successful, None otherwise.
- """
- if type_str == "float":
- type_str = "float32"
- try:
- return np.dtype(type_str)
- except TypeError:
- pass
- return None
diff --git a/owlite/backend/onnx/onnx_op_schemas.py b/owlite/backend/onnx/onnx_op_schemas.py
deleted file mode 100644
index 594a7c1..0000000
--- a/owlite/backend/onnx/onnx_op_schemas.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from collections import defaultdict
-
-from onnx import defs
-from onnx.defs import ONNX_ML_DOMAIN, OpSchema
-
-
-def get_full_operator_schemas() -> list[tuple[str, list[tuple[int, list[tuple[str, OpSchema, list[OpSchema]]]]]]]:
- """parse full operator schemas
-
- Returns:
- list[tuple[str, list[tuple[int, list[tuple[str, OpSchema, list[OpSchema]]]]]]]: nested structure containing all
- available op schemas
- """
- # domain -> support level -> name -> [schema]
- index: dict[str, dict[int, dict[str, list[OpSchema]]]] = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
- for schema in defs.get_all_schemas_with_history():
- index[schema.domain][int(schema.support_level)][schema.name].append(schema)
-
- # Preprocess the Operator Schemas
- # [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
- operator_schemas: list[tuple[str, list[tuple[int, list[tuple[str, OpSchema, list[OpSchema]]]]]]] = []
- existing_ops: set[str] = set()
- for domain, _supportmap in sorted(index.items()):
- if domain == ONNX_ML_DOMAIN:
- continue
-
- processed_supportmap = []
- for _support, _namemap in sorted(_supportmap.items()):
- processed_namemap = []
- for n, unsorted_versions in sorted(_namemap.items()):
- versions = sorted(unsorted_versions, key=lambda s: s.since_version)
- schema = versions[-1]
- if schema.name in existing_ops:
- continue
- existing_ops.add(schema.name)
- processed_namemap.append((n, schema, versions))
- processed_supportmap.append((_support, processed_namemap))
- operator_schemas.append((domain, processed_supportmap))
- return operator_schemas
-
-
-def get_core_operator_schemas() -> dict[str, OpSchema]:
- """restructured operator schemas for only core operators
-
- Returns:
- dict[str, list[tuple[int, list[tuple[str, OpSchema, list[OpSchema]]]]]]: the dictionary with key-value pairs
- where each op name is a key in string whose value is the nest structure containing various properties
- of the ONNX op.
- """
- triples = dict(get_full_operator_schemas())[""][0][1]
- return {x[0]: x[1] for x in triples}
diff --git a/owlite/nn/functions/__init__.py b/owlite/nn/functions/__init__.py
deleted file mode 100644
index b2a3b74..0000000
--- a/owlite/nn/functions/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .clq import clq_function
-from .clq_plus import clq_plus_function
-from .fake_quantize import FakeQuantizeSignature, fake_quantize
-from .ste import ste_function
diff --git a/owlite/nn/functions/fake_quantize.py b/owlite/nn/functions/fake_quantize.py
deleted file mode 100644
index e3e9405..0000000
--- a/owlite/nn/functions/fake_quantize.py
+++ /dev/null
@@ -1,63 +0,0 @@
-from typing import Callable, Optional
-
-import torch
-from torch import Tensor
-
-
-def fake_quantize(
- inputs: Tensor,
- step_size: Tensor,
- zero_point: Tensor,
- quant_min: int,
- quant_max: int,
- axis: Optional[int] = None,
-) -> torch.Tensor:
- """Same as `torch.fake_quantize_per_channel_affine` if `per_channel` is `True`, otherwise
- `torch.fake_quantize_per_tensor_affine`
-
- Args:
- inputs (torch.Tensor): A tensor to quantize.
- step_size (torch.Tensor): A float tensor which is quantization scales.
- zero_point (torch.Tensor): A float tensor, quantization zero_point.
- quant_min (int): The lower bound of the quantized domain.
- quant_max (int): The upper bound of the quantized domain.
- axis (int, optional): Channel axis. Only used when `per_channel` is `True`. Defaults to 0.
-
- Returns:
- torch.Tensor: fake-quantized tensor
- """
- if axis is not None:
- return torch.fake_quantize_per_channel_affine(
- inputs,
- step_size,
- zero_point,
- axis,
- quant_min,
- quant_max,
- )
-
- return torch.fake_quantize_per_tensor_affine(
- inputs,
- step_size,
- # `torch.fake_quantize_per_tensor_affine` expects `zero_point` to be either int32 or int64
- # (See https://pytorch.org/docs/stable/generated/torch.fake_quantize_per_tensor_affine.html)
- # while `torch.fake_quantize_per_channel_affine` doesn't
- zero_point,
- quant_min=quant_min,
- quant_max=quant_max,
- )
-
-
-FakeQuantizeSignature = Callable[
- [
- Tensor, # inputs
- Tensor, # step_size
- Tensor, # zp
- float, # grad_scale
- int, # quant_min
- int, # quant_max
- Optional[int], # axis
- bool, # compensate_zp
- ],
- Tensor,
-]
diff --git a/owlite/owlite.py b/owlite/owlite.py
deleted file mode 100644
index e25e69d..0000000
--- a/owlite/owlite.py
+++ /dev/null
@@ -1,410 +0,0 @@
-import json
-import os
-import re
-from dataclasses import dataclass, field
-from typing import Any, Optional, Union
-
-import torch
-from packaging.version import Version
-from torch.fx.graph_module import GraphModule
-from torch.nn.parallel import DataParallel, DistributedDataParallel
-
-from owlite_core.constants import OWLITE_REPORT_URL
-from owlite_core.github_utils import get_latest_version_from_github
-from owlite_core.logger import log
-from owlite_core.owlite_settings import OWLITE_SETTINGS
-
-from . import __version__
-from .api import Baseline, Experiment, Project
-from .backend.fx.trace import symbolic_trace
-from .backend.onnx.signature import DynamicSignature, update_dynamic_signature
-from .compress import compress
-from .options import DynamicAxisOptions, DynamicInputOptions, ONNXExportOptions
-
-
-@dataclass
-class OwLite:
- """Class handling OwLite project, baseline, and experiment configurations.
-
- The OwLite class manages project, baseline, and experiment configurations within the OwLite system.
- It allows users to create or load projects, set baselines, create or duplicate experiments, convert models,
- and benchmark models against the specified configurations.
- """
-
- target: Union[Baseline, Experiment]
- module_args: Optional[tuple[Any, ...]] = field(default=None)
- module_kwargs: Optional[dict[str, Any]] = field(default=None)
-
- def convert(
- self,
- model: torch.nn.Module,
- *args: Any,
- **kwargs: Any,
- ) -> GraphModule:
- """Converts input model to compressed model.
-
- Args:
- model (torch.nn.Module): Model to compress.
-
- Returns:
- GraphModule: Compressed graph module.
-
- Raises:
- HTTPError: When request for compression configuration was not successful.
- """
-
- log.info("Model conversion initiated")
- try:
- model = symbolic_trace(model, *args, **kwargs)
- except Exception as e: # pylint: disable=broad-exception-caught
- log.error(
- "Failed to convert the model. This means that\n"
- "i) your model might have some codes that cannot be handled by `torch.compile`; or\n"
- "ii) the inputs provided for the model are incompatible with your model's 'forward' method.\n"
- "Check the full error message below and make changes accordingly. "
- f"Should the problem persist, please report the issue at {OWLITE_REPORT_URL} for further assistance"
- ) # UX
- raise e
-
- self.module_args = args
- self.module_kwargs = kwargs
-
- if isinstance(self.target, Experiment) and self.target.has_config:
- model = compress(model, self.target.config)
- log.info("Applied compression configuration") # UX
-
- return model
-
- def export(
- self,
- model: GraphModule,
- onnx_export_options: Optional[ONNXExportOptions] = None,
- dynamic_axis_options: Optional[Union[DynamicAxisOptions, dict[str, dict[str, int]]]] = None,
- ) -> None:
- """Exports and uploads given model.
-
- Args:
- model (GraphModule): Model to export.
- onnx_export_options (Optional[ONNXExportOptions], optional): Options for ONNX export. Defaults to None.
- dynamic_axis_options (Optional[DynamicAxisOptions], optional):
-
- By default the exported model will have the shapes of all input tensors set to
- exactly match those given when calling convert. To specify axes of tensors as
- dynamic (i.e. known only at run-time), set `dynamic_axis_options` to a dict with schema:
-
- * KEY (str): an input name.
-
- * VALUE (dict[int, dict[str, int]]): a single item dictionary whose key is dynamic dimension of input
- and value is a dynamic range setting dictionary containing min, opt, max, test dimension size
- settings.
-
- For example::
-
- import owlite
-
- owl = owlite.init( ... )
-
- class SumModule(torch.nn.Module):
- def forward(self, x):
- return torch.sum(x, dim=1)
-
- model = owl.convert( ... )
-
- ...
-
- # set first(0-th) dimension of input x
- owl.export(
- model,
- dynamic_axis_options={
- "x": {"axis": 0},
- },
- )
-
- # or equivalently,
- owl.export(
- model,
- dynamic_axis_options=owlite.DynamicAxisOptions(
- {"x": owlite.DynamicAxisOption(axis=0)}
- ),
- )
-
- Raises:
- TypeError: When the `model` is an instance of `torch.nn.DataParallel` or `torch.nn.DistributedDataParallel`.
- RuntimeError: When `dynamic_axes` is set for baseline export.
- ValueError: When invalid `dynamic_axes` is given.
- """
- if isinstance(model, (DataParallel, DistributedDataParallel)):
- model_type = f"torch.nn.parallel.{type(model).__name__}"
- log.error(
- f"{model_type} is not supported. Please use the attribute module "
- f"by unwrapping the model from {model_type}. Try owl.export(model.module)"
- ) # UX
- raise TypeError(f"{model_type} is not supported by export")
- if not isinstance(model, GraphModule):
- model_type = f"{type(model).__module__}.{type(model).__name__}"
- raise TypeError(f"Expected GraphModule, but got model of type {model_type}")
-
- if isinstance(dynamic_axis_options, dict):
- dynamic_axis_options = DynamicAxisOptions(dynamic_axis_options)
- keys_repr = ", ".join(f"'{key}'" for key in dynamic_axis_options.keys())
- log.info(f"`dynamic_axis_options` provided for the following inputs: {keys_repr}") # UX
-
- if isinstance(self.target, Baseline):
- if dynamic_axis_options is not None:
- log.warning(
- "The `dynamic_axis_options` provided for baseline will be ignored. "
- "To export baseline model with dynamic input, "
- "please create an experiment without compression configuration "
- "and export it with `dynamic_axis_options`"
- ) # UX
- proto = self.target.export(
- model, self.module_args, self.module_kwargs, onnx_export_options=onnx_export_options
- )
- self.target.upload(proto, model)
- else:
- proto = self.target.export(
- model,
- self.module_args,
- self.module_kwargs,
- dynamic_axis_options=dynamic_axis_options,
- onnx_export_options=onnx_export_options,
- )
- self.target.upload(
- proto,
- dynamic_axis_options=dynamic_axis_options,
- )
-
- def benchmark(
- self,
- dynamic_input_options: Optional[Union[DynamicInputOptions, dict[str, dict[str, int]]]] = None,
- ) -> None:
- """Benchmarks given model.
-
- Args:
- dynamic_input_options (Optional[Union[DynamicInputOptions, dict[str, dict[str, int]]]]):
-
- By default the exported model will have the shapes of all input tensors set to
- exactly match those given when calling convert. To specify axes of tensors as
- dynamic (i.e. known only at run-time), set `dynamic_axes` to a dict with schema:
-
- * KEY (str): an input name.
-
- * VALUE (dict[str, int]): a single item who is a dynamic range setting dictionary
- containing min, opt, max, test dimension size settings.
-
- For example::
-
- import owlite
-
- owl = owlite.init( ... )
-
- class SumModule(torch.nn.Module):
- def forward(self, x):
- return torch.sum(x, dim=1)
-
- model = owl.convert( ... )
-
- ...
-
- # set input x to be dynamic within the range of 1 ~ 8
- # optimize for 4 and benchmark for 5
- owl.benchmark(
- model,
- dynamic_input_options={
- "x": {
- "min": 1,
- "opt": 4,
- "max": 8,
- "test": 5,
- },
- },
- )
-
- # or equivalently,
- owl.benchmark(
- model,
- dynamic_input_options=owlite.DynamicInputOptions(
- {"x": owlite.DynamicSizeOptions(min=1, opt=4, max=8, test=5)}
- ),
- )
-
- Raises:
- TypeError: When the `model` is an instance of `torch.nn.DataParallel` or `torch.nn.DistributedDataParallel`.
- RuntimeError: When `dynamic_axes` is set for baseline benchmark.
- ValueError: When invalid `dynamic_axes` is given.
- """
-
- if isinstance(self.target, Experiment) and isinstance(self.target.input_signature, DynamicSignature):
- if dynamic_input_options is None:
- log.error(
- "The `dynamic_input_options` for the experiment has `dynamic_input_options`. "
- "Try `owl.benchmark(dynamic_input_options={...})`"
- ) # UX
- raise RuntimeError("Dynamic options failed")
- if isinstance(dynamic_input_options, dict):
- dynamic_input_options = DynamicInputOptions(dynamic_input_options)
- self.target.input_signature = update_dynamic_signature(self.target.input_signature, dynamic_input_options)
-
- self.target.orchestrate_trt_benchmark()
-
- def log(self, **kwargs: Any) -> None:
- """Logs the model's metrics.
-
- Notes:
- Log metrics with OwLite like below
-
- ...
-
- owl = owlite.init(...)
-
- ...
-
- owl.log(accuracy=0.72, loss=1.2)
-
- Raises:
- TypeError: When data is not JSON serializable or not allowed logging.
- """
- if not all(isinstance(value, (int, str, float)) for value in kwargs.values()):
- log.error("Invalied value given to `owl.log`. The value for logging must be `int`, `str`, `float`") # UX
- raise TypeError("Invalid value")
- try:
- self.target.log(json.dumps(kwargs))
- except TypeError as e:
- log.error("Invalid value given to `owl.log`. The metrics for logging must be JSON-serializable") # UX
- raise e
-
-
-# pylint: disable-next=too-many-branches
-def init(
- project: str,
- baseline: str,
- experiment: Optional[str] = None,
- duplicate_from: Optional[str] = None,
- description: Optional[str] = None,
-) -> OwLite:
- """Sets project, baseline and experiment information in DB to proper state and creates `OwLite` instance.
-
- Args:
- project (str): OwLite project name.
- baseline (str): OwLite baseline name.
- experiment (Optional[str], optional): OwLite experiment name. Defaults to None.
- duplicate_from (Optional[str], optional): OwLite source experiment name. Defaults to None.
- description (Optional[str], optional): OwLite project description. Defaults to None.
-
- Raises:
- RuntimeError: When deprecated or not authenticated.
- ValueError: When invalid experiment name or baseline name is given.
-
- Returns:
- OwLite: Created `OwLite` instance.
- """
- owlite_latest_version = Version(get_latest_version_from_github())
-
- current_version = Version(__version__)
- if current_version.major < owlite_latest_version.major:
- log.error(
- f"Your current version ({current_version}) is not supported. "
- "Please update the package to the latest version with the following command: "
- "pip install git+https://github.com/SqueezeBits/owlite --upgrade "
- "--extra-index-url https://pypi.ngc.nvidia.com"
- ) # UX
- raise RuntimeError("Version is not supported")
- if current_version < owlite_latest_version:
- log.warning(
- "A new version of OwLite is available. "
- "To ensure the best usage, please update the package to the latest version with the following command: "
- "pip install git+https://github.com/SqueezeBits/owlite --upgrade "
- "--extra-index-url https://pypi.ngc.nvidia.com"
- ) # UX
-
- if OWLITE_SETTINGS.tokens is None:
- log.error("Please log in using 'owlite login'. Account not found on this device") # UX
- raise RuntimeError("OwLite token not found")
-
- if OWLITE_SETTINGS.connected_device is None:
- log.warning(
- "Connected device not found. "
- "You will be automatically connected to the default NEST device as you are subscribed to the free plan. "
- "Please connect to a specific device using 'owlite device connect --name (name)' if needed"
- ) # UX
-
- else:
- log.info(f"Connected device: {OWLITE_SETTINGS.connected_device.name}") # UX
-
- validate_names(project=project, baseline=baseline, experiment=experiment, duplicate_from=duplicate_from)
- if description and len(description) > 140:
- log.error(
- "The project description should consist of at most 140 characters. "
- "Note that the description is not required for loading an existing project"
- ) # UX
- raise ValueError("Description length exceeds limit")
-
- if experiment == baseline:
- log.error(
- f"Experiment name '{baseline}' is reserved for the baseline. Please try with a different experiment name"
- ) # UX
- raise ValueError("Invalid experiment name")
-
- proj: Project = Project.load_or_create(project, description=description)
-
- target: Union[Baseline, Experiment]
- if experiment is None:
- if duplicate_from:
- log.warning(
- f"`duplicate_from='{duplicate_from}'` will be ignored as no value for `experiment` was provided"
- ) # UX
- target = Baseline.create(proj, baseline)
- else:
- existing_baseline = Baseline.load(proj, baseline)
- if existing_baseline is None:
- log.error(
- f"No such baseline: {baseline}. "
- f"Please check if the baseline name for the experiment '{experiment}' is correct"
- ) # UX
- raise ValueError("Invalid baseline name")
- if duplicate_from is None:
- target = Experiment.load_or_create(existing_baseline, experiment)
- else:
- existing_experiment = Experiment.load(existing_baseline, duplicate_from)
- if existing_experiment is None:
- log.error(
- f"The experiment '{duplicate_from}' to duplicate from is not found. "
- "Please check if the project name provided for `duplicate_from` argument is correct"
- ) # UX
- raise ValueError("Invalid experiment name")
- target = existing_experiment.clone(experiment)
-
- if os.path.exists(target.home):
- log.warning(
- f"Existing local directory found at {target.home}. Continuing this code will overwrite the data"
- ) # UX
- else:
- os.makedirs(target.home, exist_ok=True)
- log.info(f"Experiment data will be saved in {target.home}") # UX
-
- return OwLite(target)
-
-
-def validate_names(**kwargs: Any) -> None:
- """Validate a list of names.
-
- Args:
- **kwargs: A dictionary where keys are identifiers and values are names to validate.
-
- Raises:
- ValueError: If any name is invalid.
- """
- invalid_keys = []
- regex = r"^[a-zA-Z0-9()\-_@:*&]+$"
- for key, name in kwargs.items():
- if name is None:
- continue
- if not re.fullmatch(regex, name):
- invalid_keys.append(key)
- if len(invalid_keys) > 0:
- invalid_items = ", ".join(f"{key}={kwargs[key]}" for key in invalid_keys)
- log.error(
- f"The following names do not meet the requirement: {invalid_items}. "
- "A valid name must consist of alphanumeric characters or special characters chosen from ()-_@:*&"
- ) # UX
- raise ValueError("Invalid name")
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..6834ce8
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,52 @@
+[build-system]
+requires = ["setuptools>=61.0"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name="owlite"
+dynamic = ["version"]
+description = "OwLite - No-Code AI compression Toolkit"
+dependencies = [
+ "torch>=2.0,<2.3",
+ "onnxruntime",
+ "onnxsim",
+ "onnx_graphsurgeon@https://developer.download.nvidia.com/compute/redist/onnx-graphsurgeon/onnx_graphsurgeon-0.3.27-py2.py3-none-any.whl",
+ "colored",
+ "yacs",
+ "tabulate",
+ "requests",
+ "tqdm",
+ "pydantic",
+ "lazy_imports",
+]
+authors = [
+ {name = "SqueezeBits.inc", email = "owlite@squeezebits.com"}
+]
+maintainers = [
+ {name = "SqueezeBits.inc", email = "owlite@squeezebits.com"}
+]
+requires-python = "~=3.10"
+keywords=["torch", "onnx", "graph", "quantization"]
+classifiers=[
+ "Intended Audience :: Developers",
+ "Intended Audience :: Education",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: GNU Affero General Public License v3",
+ "Programming Language :: Python :: 3.10 :: Only",
+ "Topic :: Scientific/Engineering",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Software Development",
+ "Topic :: Software Development :: Libraries",
+ "Topic :: Software Development :: Libraries :: Python Modules",
+]
+
+[project.urls]
+Repository = "https://github.com/SqueezeBits/owlite"
+Documentation = "https://squeezebits.gitbook.io/owlite/quick/readme"
+
+
+[project.scripts]
+owlite = "owlite.owlite_core.cli.owlite_cli:main"
+
+[tool.setuptools.dynamic]
+version = {attr = "owlite.owlite_core.constants.OWLITE_VERSION"}
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index 4016ed5..0000000
--- a/requirements.txt
+++ /dev/null
@@ -1,10 +0,0 @@
--e .
-black
-pre-commit
-pylint
-ruff
-mypy
-pytest-xdist
-torchvision
-transformers~=4.35.2
-types-requests
\ No newline at end of file
diff --git a/setup.py b/setup.py
deleted file mode 100644
index 2de4655..0000000
--- a/setup.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# pylint: disable=all
-
-from setuptools import find_packages, setup
-
-from owlite_core.constants import OWLITE_GIT_REPO_URL, OWLITE_VERSION
-
-
-def requirements() -> list[str]:
- return [
- "torch>=2.0,<2.2",
- "onnxruntime",
- "onnxsim",
- "onnx_graphsurgeon",
- "colored",
- "yacs",
- "tabulate",
- "requests",
- "tqdm",
- "pydantic",
- ]
-
-
-setup(
- name="owlite",
- version=OWLITE_VERSION,
- description="OwLite - No-Code AI compression Toolkit",
- url=OWLITE_GIT_REPO_URL,
- author="SqueezeBits Inc.",
- author_email="owlite@squeezebits.com",
- install_requires=requirements(),
- packages=find_packages(exclude=("test", "scripts")),
- python_requires="~=3.10.0",
- classifiers=[
- "Intended Audience :: Developers",
- "Intended Audience :: Education",
- "Intended Audience :: Science/Research",
- "License :: OSI Approved :: Apache Software License",
- "Programming Language :: Python :: 3.10 :: Only",
- "Topic :: Scientific/Engineering",
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
- "Topic :: Software Development",
- "Topic :: Software Development :: Libraries",
- "Topic :: Software Development :: Libraries :: Python Modules",
- ],
- keywords=["torch", "onnx", "graph", "quantization"],
- entry_points={
- "console_scripts": ["owlite=owlite_core.cli.owlite_cli:main"],
- },
-)
diff --git a/src/owlite/__init__.py b/src/owlite/__init__.py
new file mode 100644
index 0000000..c0265e0
--- /dev/null
+++ b/src/owlite/__init__.py
@@ -0,0 +1,60 @@
+import sys
+from typing import TYPE_CHECKING
+
+from lazy_imports import LazyImporter
+
+from .owlite_core.constants import OWLITE_VERSION as __version__ # noqa: N811
+
+_import_structure = {
+ "backend": [
+ "fx",
+ "onnx",
+ ],
+ "calibrators": [
+ "CalibrationContext",
+ "calibrate",
+ ],
+ "compression": ["compress"],
+ "enums": [
+ "PTQCalibrationType",
+ "QATBackwardType",
+ ],
+ "options": [
+ "Channel",
+ "CompressionOptions",
+ "DynamicAxisOptions",
+ "DynamicInputOptions",
+ "FakeQuantizerOptions",
+ "GraphQuantizationOptions",
+ "NodeQuantizationOptions",
+ "ONNXExportOptions",
+ ],
+ "owlite": ["init"],
+}
+
+if TYPE_CHECKING:
+ from .backend import fx, onnx
+ from .calibrators import (
+ CalibrationContext,
+ calibrate,
+ )
+ from .compression import compress
+ from .enums import PTQCalibrationType, QATBackwardType
+ from .options import (
+ Channel,
+ CompressionOptions,
+ DynamicAxisOptions,
+ DynamicInputOptions,
+ FakeQuantizerOptions,
+ GraphQuantizationOptions,
+ NodeQuantizationOptions,
+ ONNXExportOptions,
+ )
+ from .owlite import init
+else:
+ sys.modules[__name__] = LazyImporter(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ extra_objects={"__version__": __version__},
+ )
diff --git a/owlite/api/__init__.py b/src/owlite/api/__init__.py
similarity index 100%
rename from owlite/api/__init__.py
rename to src/owlite/api/__init__.py
diff --git a/owlite/api/baseline.py b/src/owlite/api/baseline.py
similarity index 96%
rename from owlite/api/baseline.py
rename to src/owlite/api/baseline.py
index 3c0b4a6..4b30c67 100644
--- a/owlite/api/baseline.py
+++ b/src/owlite/api/baseline.py
@@ -10,11 +10,10 @@
from torch.fx.graph_module import GraphModule
from typing_extensions import Self
-from owlite_core.api_base import DOVE_API_BASE, MAIN_API_BASE
-from owlite_core.logger import log
-
from ..backend.fx import serialize
-from ..backend.onnx.signature import DynamicSignature, Signature
+from ..backend.signature import DynamicSignature, Signature
+from ..owlite_core.api_base import DOVE_API_BASE, MAIN_API_BASE
+from ..owlite_core.logger import log
from .benchmarkable import Benchmarkable
from .project import Project
diff --git a/owlite/api/benchmarkable.py b/src/owlite/api/benchmarkable.py
similarity index 97%
rename from owlite/api/benchmarkable.py
rename to src/owlite/api/benchmarkable.py
index 8eef689..40fa8d3 100644
--- a/owlite/api/benchmarkable.py
+++ b/src/owlite/api/benchmarkable.py
@@ -12,17 +12,16 @@
import requests
from torch.fx.graph_module import GraphModule
-from owlite_core.api_base import MAIN_API_BASE, APIBase
-from owlite_core.api_enums import BenchmarkStatus, PricePlan
-from owlite_core.cli.api.login import whoami
-from owlite_core.cli.device import connect_free_device
-from owlite_core.constants import OWLITE_API_DEFAULT_TIMEOUT, OWLITE_REPORT_URL
-from owlite_core.logger import log
-from owlite_core.owlite_settings import OWLITE_SETTINGS
-
from ..backend.onnx.export import export
-from ..backend.onnx.signature import DynamicSignature, Signature
+from ..backend.signature import DynamicSignature, Signature
from ..options import DynamicAxisOptions, ONNXExportOptions
+from ..owlite_core.api_base import MAIN_API_BASE, APIBase
+from ..owlite_core.api_enums import BenchmarkStatus, PricePlan
+from ..owlite_core.cli.api.login import whoami
+from ..owlite_core.cli.device import connect_free_device
+from ..owlite_core.constants import OWLITE_API_DEFAULT_TIMEOUT, OWLITE_REPORT_URL
+from ..owlite_core.logger import log
+from ..owlite_core.owlite_settings import OWLITE_SETTINGS
from .utils import download_file_from_url, upload_file_to_url
DEVICE_API_BASE: APIBase = APIBase(
diff --git a/owlite/api/experiment.py b/src/owlite/api/experiment.py
similarity index 92%
rename from owlite/api/experiment.py
rename to src/owlite/api/experiment.py
index f164aab..b6c7f18 100644
--- a/owlite/api/experiment.py
+++ b/src/owlite/api/experiment.py
@@ -10,13 +10,11 @@
from torch.fx.graph_module import GraphModule
from typing_extensions import Self
-from owlite_core.api_base import DOVE_API_BASE, MAIN_API_BASE
-from owlite_core.constants import FX_CONFIGURATION_FORMAT_VERSION
-from owlite_core.logger import log
-
-from .. import __version__
-from ..backend.onnx.signature import DynamicSignature, Signature
+from ..backend.signature import DynamicSignature, Signature
from ..options import CompressionOptions, DynamicAxisOptions
+from ..owlite_core.api_base import DOVE_API_BASE, MAIN_API_BASE
+from ..owlite_core.constants import FX_CONFIGURATION_FORMAT_VERSION, OWLITE_VERSION
+from ..owlite_core.logger import log
from .baseline import Baseline
from .benchmarkable import Benchmarkable
from .project import Project
@@ -55,17 +53,16 @@ def config(self) -> CompressionOptions:
try:
resp = DOVE_API_BASE.post("/compile", json=self.payload(format_version=FX_CONFIGURATION_FORMAT_VERSION))
except requests.exceptions.HTTPError as e:
- if e.response.status_code == 403: # type: ignore
+ if e.response and e.response.status_code == 403:
log.error(
"Config settings exceed the limit. The free plan supports max. 1 OpType, 10 Layer settings. "
"Please check your config again"
) # UX
- elif e.response.status_code == 426: # type: ignore
+ elif e.response and e.response.status_code == 426:
log.error(
- f"Your current version ({Version(__version__)}) is not supported. "
+ f"Your current version ({Version(OWLITE_VERSION)}) is not supported. "
"Please update the package to the latest version with the following command: "
- "pip install git+https://github.com/SqueezeBits/owlite --upgrade "
- "--extra-index-url https://pypi.ngc.nvidia.com"
+ "pip install owlite --extra-index-url https://pypi.squeezebits.com/ --upgrade "
) # UX
raise e
assert isinstance(resp, dict)
diff --git a/owlite/api/project.py b/src/owlite/api/project.py
similarity index 94%
rename from owlite/api/project.py
rename to src/owlite/api/project.py
index 8f22664..c7c613d 100644
--- a/owlite/api/project.py
+++ b/src/owlite/api/project.py
@@ -6,9 +6,9 @@
from requests.exceptions import HTTPError
from typing_extensions import Self
-from owlite_core.api_base import MAIN_API_BASE
-from owlite_core.constants import OWLITE_FRONT_BASE_URL, OWLITE_HOME
-from owlite_core.logger import log
+from ..owlite_core.api_base import MAIN_API_BASE
+from ..owlite_core.constants import OWLITE_FRONT_BASE_URL, OWLITE_HOME
+from ..owlite_core.logger import log
if TYPE_CHECKING:
from .baseline import Baseline
diff --git a/owlite/api/utils.py b/src/owlite/api/utils.py
similarity index 98%
rename from owlite/api/utils.py
rename to src/owlite/api/utils.py
index b814027..020b03d 100644
--- a/owlite/api/utils.py
+++ b/src/owlite/api/utils.py
@@ -4,7 +4,7 @@
from tqdm import tqdm
from tqdm.utils import CallbackIOWrapper
-from owlite_core.logger import log
+from ..owlite_core.logger import log
def upload_file_to_url(file_path: str, dst_url: str) -> None:
diff --git a/owlite/backend/__init__.py b/src/owlite/backend/__init__.py
similarity index 100%
rename from owlite/backend/__init__.py
rename to src/owlite/backend/__init__.py
diff --git a/owlite/backend/config.py b/src/owlite/backend/config.py
similarity index 79%
rename from owlite/backend/config.py
rename to src/owlite/backend/config.py
index 437d6ab..c3f3047 100644
--- a/owlite/backend/config.py
+++ b/src/owlite/backend/config.py
@@ -1,6 +1,9 @@
import os
-# This flag ensures module output(return value of forward) consistency between before and after trace.
+# Flag to disable automatic object monkey patching
+DISABLE_AUTO_PATCH = os.environ.get("OWLITE_DISABLE_AUTO_PATCH", "0") == "1"
+
+# Flag to enforce module output(return value of forward) consistency between before and after trace.
FORCE_OUTPUT_COMPATIBILITY = os.environ.get("OWLITE_FORCE_OUTPUT_COMPATIBILITY", "1") == "1"
# Maximum iteration limit for ONNX transformations.
diff --git a/owlite/backend/fx/__init__.py b/src/owlite/backend/fx/__init__.py
similarity index 100%
rename from owlite/backend/fx/__init__.py
rename to src/owlite/backend/fx/__init__.py
diff --git a/owlite/backend/fx/edge.py b/src/owlite/backend/fx/edge.py
similarity index 100%
rename from owlite/backend/fx/edge.py
rename to src/owlite/backend/fx/edge.py
diff --git a/owlite/backend/fx/node.py b/src/owlite/backend/fx/node.py
similarity index 100%
rename from owlite/backend/fx/node.py
rename to src/owlite/backend/fx/node.py
diff --git a/owlite/backend/fx/node_configurator.py b/src/owlite/backend/fx/node_configurator.py
similarity index 99%
rename from owlite/backend/fx/node_configurator.py
rename to src/owlite/backend/fx/node_configurator.py
index 7f1bd9f..0e9cdf6 100644
--- a/owlite/backend/fx/node_configurator.py
+++ b/src/owlite/backend/fx/node_configurator.py
@@ -4,11 +4,10 @@
import torch
from torch.fx.node import Node
-from owlite_core.logger import log
-
from ...nn import FakeQuantizer, QLinear
from ...nn.modules import UnaryNeuralQModuleMixin, promote_to_qmodule
from ...options.compression_option import NodeCompressionOptions
+from ...owlite_core.logger import log
from ..utils import nodestr
from .edge import AllInputNodes, Args, Kwargs
from .node import get_target_module, get_torch_target
diff --git a/owlite/backend/fx/serialize.py b/src/owlite/backend/fx/serialize.py
similarity index 96%
rename from owlite/backend/fx/serialize.py
rename to src/owlite/backend/fx/serialize.py
index c723665..5dfe7fc 100644
--- a/owlite/backend/fx/serialize.py
+++ b/src/owlite/backend/fx/serialize.py
@@ -3,8 +3,7 @@
from tabulate import tabulate
from torch.fx.graph_module import GraphModule
-from owlite_core.logger import log
-
+from ...owlite_core.logger import log
from ..utils import targetstr
from .node import get_target_module
diff --git a/owlite/backend/fx/target.py b/src/owlite/backend/fx/target.py
similarity index 98%
rename from owlite/backend/fx/target.py
rename to src/owlite/backend/fx/target.py
index deec553..6eb7739 100644
--- a/owlite/backend/fx/target.py
+++ b/src/owlite/backend/fx/target.py
@@ -129,4 +129,6 @@ def all_targets(op_name: str) -> list[FXTarget]:
*torch_targets("eye"),
*torch_targets("from_file"),
*torch_targets("from_numpy"),
+ *torch_targets("hamming_window"),
+ *torch_targets("hann_window"),
)
diff --git a/owlite/backend/fx/trace.py b/src/owlite/backend/fx/trace.py
similarity index 97%
rename from owlite/backend/fx/trace.py
rename to src/owlite/backend/fx/trace.py
index ba1beec..42a2194 100644
--- a/owlite/backend/fx/trace.py
+++ b/src/owlite/backend/fx/trace.py
@@ -14,11 +14,10 @@
from torch.fx.graph_module import GraphModule, _WrappedCall
from torch.nn.parallel import DataParallel, DistributedDataParallel
-from owlite_core.logger import log
-
from ...enums import OwLiteStatus, ParamStatus
+from ...owlite_core.logger import log
from ..config import FORCE_OUTPUT_COMPATIBILITY
-from ..onnx.signature import map_signature
+from ..signature import map_signature
from ..utils import (
get_most_common_device,
get_most_common_floating_point_type,
@@ -291,15 +290,16 @@ def symbolic_trace(model: torch.nn.Module, *args: Any, **kwargs: Any) -> GraphMo
"training or inference code to use the converted model."
)
- graph_module = apply_graph_module_transforms(graph_module)
+ original_params = {**inspect.signature(model.forward).parameters}
+ graph_module.meta["original_params"] = original_params
+
+ graph_module = apply_graph_module_transforms(graph_module, args, kwargs)
graph_module.train(training_status)
graph_module.meta["owlite_status"] = OwLiteStatus.NOT_COMPRESSED
- original_params = {**inspect.signature(model.forward).parameters}
graph_module_params = {**inspect.signature(graph_module.forward).parameters}
log.debug(f"original_params: {[*original_params.keys()]}")
log.debug(f"graph_module_params: {[*graph_module_params.keys()]}")
- graph_module.meta["original_params"] = original_params
if any(
param.kind in (inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD)
@@ -321,12 +321,7 @@ def symbolic_trace(model: torch.nn.Module, *args: Any, **kwargs: Any) -> GraphMo
name: (
ParamStatus.ALIVE
if name in graph_module_params
- else (
- # ParamStatus.KEPT if received_values[name] is not inspect._empty
- ParamStatus.KEPT
- if name in received_values
- else ParamStatus.PURGED
- )
+ else (ParamStatus.KEPT if name in received_values else ParamStatus.PURGED)
)
for name in original_params
}
diff --git a/owlite/backend/fx/transforms.py b/src/owlite/backend/fx/transforms.py
similarity index 77%
rename from owlite/backend/fx/transforms.py
rename to src/owlite/backend/fx/transforms.py
index 3af0a8f..84d2633 100644
--- a/owlite/backend/fx/transforms.py
+++ b/src/owlite/backend/fx/transforms.py
@@ -1,24 +1,32 @@
-from typing import Callable, Optional, Union
+import inspect
+from typing import Any, Callable, Optional, Union
import torch
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
+from torch.nn.modules.conv import _ConvNd
from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_linear_bn_eval
-from owlite_core.logger import log
-
from ...nn import FakePerTensorQuantizer, FakeQuantizer
from ...nn import modules as qnn
from ...nn.modules.qmodule_mixins import UnaryNeuralQModuleMixin
from ...options import Channel
-from ..utils import get_most_common_device
+from ...owlite_core.logger import log
+from ..signature import Signature
+from ..utils import get_most_common_device, nodestr, normalize_parameter_name
from .node import find_placeholders, get_target_module
-GraphModuleTransform = Callable[[GraphModule], GraphModule]
+GraphModuleTransform = Union[
+ Callable[[GraphModule], GraphModule], Callable[[GraphModule, tuple[Any, ...], dict[str, Any]], GraphModule]
+]
GRAPH_MODULE_TRANSFORMS: dict[str, GraphModuleTransform] = {}
-def apply_graph_module_transforms(graph_module: GraphModule) -> GraphModule:
+def apply_graph_module_transforms(
+ graph_module: GraphModule,
+ args: tuple[Any, ...],
+ kwargs: dict[str, Any],
+) -> GraphModule:
"""Applies all registered graph module transforms
Args:
@@ -29,7 +37,11 @@ def apply_graph_module_transforms(graph_module: GraphModule) -> GraphModule:
"""
for name, transform in GRAPH_MODULE_TRANSFORMS.items():
log.debug(f"Applying graph module transform: {name}")
- graph_module = transform(graph_module)
+ graph_module = (
+ transform(graph_module) # type: ignore[call-arg]
+ if len(inspect.signature(transform).parameters) == 1
+ else transform(graph_module, args, kwargs) # type: ignore[call-arg]
+ )
graph_module.recompile()
return graph_module
@@ -65,16 +77,55 @@ def fix_input_parameter_names(graph_module: GraphModule) -> GraphModule:
for node in find_placeholders(graph_module.graph):
if not isinstance(node.target, str):
continue
- if node.target.startswith("L_kwargs_") and node.target.endswith("_"):
- node.target = node.target[9:-1]
- elif node.target.startswith("L_") and node.target.endswith("_"):
- node.target = node.target[2:-1]
+ target = normalize_parameter_name(node.target)
+ if node.target != target:
+ log.debug(f"Renaming placeholder {node.target} -> {target}")
+ node.target = target
+ return graph_module
+
+
+@register_graph_module_transform
+def fix_forward_argument_ordering(
+ graph_module: GraphModule,
+ args: tuple[Any, ...],
+ kwargs: dict[str, Any],
+) -> GraphModule:
+ """Reorder graph module input arguments to meet the ordering in original module.
+
+ Args:
+ graph_module (GraphModule): the input graph module
+
+ Returns:
+ GraphModule: graph module with inputs reordered
+ """
+ graph = graph_module.graph
+ names = [name for name, _ in Signature.from_module(graph_module, args, kwargs)]
+ log.debug(f"Names from signature: {names}")
+
+ placeholders = find_placeholders(graph)
+ log.debug(f"Original placeholders: {[nodestr(p) for p in placeholders]}")
+
+ def get_index(node: Node) -> int:
+ if isinstance((target := node.target), str) and target in names:
+ return names.index(target)
+ return len(names)
+
+ placeholders = [*sorted(placeholders, key=get_index, reverse=True)]
+ log.debug(f"Reverse-sorted placeholders: {[nodestr(p) for p in placeholders]}")
+
+ for placeholder in placeholders:
+ with graph.inserting_before():
+ reordered_placeholder = graph.placeholder(f"{placeholder.name}_reordered")
+ reordered_placeholder.target = placeholder.target
+ placeholder.replace_all_uses_with(reordered_placeholder)
+ graph.erase_node(placeholder)
+
return graph_module
@register_graph_module_transform
def fix_hard_coded_device(graph_module: GraphModule) -> GraphModule:
- """Fix hard coded devices to enanble data parallel.
+ """Fix hard coded devices to enable data parallel.
Args:
graph_module (GraphModule): the input graph module
@@ -117,7 +168,8 @@ def canonicalize_silu(graph_module: GraphModule) -> GraphModule:
"""
graph = graph_module.graph
for node in graph.nodes:
- assert isinstance(node, Node)
+ if not isinstance(node, Node):
+ continue
module = get_target_module(node)
if isinstance(module, torch.nn.SiLU) or (
node.op == "call_function" and node.target is torch.nn.functional.silu
@@ -131,6 +183,25 @@ def canonicalize_silu(graph_module: GraphModule) -> GraphModule:
return graph_module
+@register_graph_module_transform
+def canonicalize_hstack(graph_module: GraphModule) -> GraphModule:
+ """Replaces call_function(torch.hstack) by call_function(torch.cat)"""
+ graph = graph_module.graph
+ for node in graph.nodes:
+ if not (
+ isinstance(node, Node)
+ and node.op == "call_function"
+ and node.target is torch.hstack
+ and isinstance((tensors := node.args[0] if node.args else node.kwargs.get("tensors", None)), (list, tuple))
+ ):
+ continue
+ with graph.inserting_before(node):
+ cat_node = graph.call_function(torch.cat, args=(tensors,), kwargs={"dim": 1})
+ node.replace_all_uses_with(cat_node)
+ graph.lint()
+ return graph_module
+
+
def matches_module_pattern(pattern: tuple[type[torch.nn.Module], type[torch.nn.Module]], node: Node) -> bool:
"""Check current node matches with one of patterns.
@@ -188,10 +259,18 @@ def _rescale_step_size_with_batchnorm(
return quantizer
+BatchNorm = Union[torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]
+
+FusionFunction = Union[
+ Callable[[_ConvNd, BatchNorm], torch.nn.Module],
+ Callable[[torch.nn.Linear, BatchNorm], torch.nn.Module],
+]
+
+
def _fuse_by_patterns(
model: GraphModule,
patterns: list[tuple[type[torch.nn.Module], type[torch.nn.Module]]],
- fusion_func: Callable[[torch.nn.Module, torch.nn.Module], torch.nn.Module],
+ fusion_func: FusionFunction,
) -> None:
"""Fuses module/BN layers for inference purposes."""
new_graph = model.graph
@@ -206,14 +285,14 @@ def _fuse_by_patterns(
if len(parent.users) > 1: # Output of module is used by other nodes
continue
if not (
- (module := get_target_module(parent)) is not None
+ isinstance((module := get_target_module(parent)), (_ConvNd, torch.nn.Linear))
and isinstance(
(batchnorm := get_target_module(node)),
(torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d),
)
):
continue
- fused_module = fusion_func(module, batchnorm)
+ fused_module = fusion_func(module, batchnorm) # type: ignore[arg-type]
if isinstance(fused_module, UnaryNeuralQModuleMixin) and fused_module.weight_quantizer:
fused_module.weight_quantizer = _rescale_step_size_with_batchnorm(
fused_module.weight_quantizer, batchnorm
diff --git a/owlite/backend/fx/types.py b/src/owlite/backend/fx/types.py
similarity index 100%
rename from owlite/backend/fx/types.py
rename to src/owlite/backend/fx/types.py
diff --git a/owlite/backend/onnx/__init__.py b/src/owlite/backend/onnx/__init__.py
similarity index 100%
rename from owlite/backend/onnx/__init__.py
rename to src/owlite/backend/onnx/__init__.py
diff --git a/owlite/backend/onnx/dynamize.py b/src/owlite/backend/onnx/dynamize.py
similarity index 98%
rename from owlite/backend/onnx/dynamize.py
rename to src/owlite/backend/onnx/dynamize.py
index c9a7aa8..ee457ae 100644
--- a/owlite/backend/onnx/dynamize.py
+++ b/src/owlite/backend/onnx/dynamize.py
@@ -1,3 +1,5 @@
+# pylint: disable=C0116, R0914, R0912, R0915, R1702
+
from collections.abc import Sequence
from dataclasses import dataclass
from functools import reduce
@@ -9,11 +11,10 @@
from onnx.shape_inference import infer_shapes
from ...options import DynamicAxisOptions
+from ..signature import Signature
from ..utils import nodestr
-from .signature import Signature
-# pylint: disable-next=too-many-locals,too-many-branches, too-many-statements
def dynamize(onnx_proto: ModelProto, options: DynamicAxisOptions) -> ModelProto:
"""Dynamizes given ONNX proto with given dynamic dimension setting.
diff --git a/owlite/backend/onnx/export.py b/src/owlite/backend/onnx/export.py
similarity index 90%
rename from owlite/backend/onnx/export.py
rename to src/owlite/backend/onnx/export.py
index c2f2010..3c158a5 100644
--- a/owlite/backend/onnx/export.py
+++ b/src/owlite/backend/onnx/export.py
@@ -18,13 +18,13 @@
from onnxsim.onnxsim_cpp2py_export import simplify_path
from torch.fx.graph_module import GraphModule
-from owlite_core.logger import log
-
from ...enums import OwLiteStatus
from ...nn import FakeQuantizer
from ...nn.functions import clq_function
from ...options import DynamicAxisOptions
+from ...owlite_core.logger import log
from ..fx.transforms import clip_narrow_range_weights, fold_zp_to_bias, fuse_bn
+from ..signature import Signature
from ..utils import (
get_most_common_device,
get_most_common_floating_point_type,
@@ -33,8 +33,7 @@
)
from .dynamize import dynamize
from .export_with_external_data import export_with_external_data
-from .model_checking import compare
-from .signature import Signature
+from .model_checking import compare # type: ignore
from .transforms import apply_onnx_transforms
# Large models (e.g. SwinTransformer) requires
@@ -297,7 +296,7 @@ def export(
if opset_version is None:
opset_version = 17
- if input_names is None and isinstance(module, GraphModule):
+ if input_names is None:
input_names = get_default_input_names(module, args)
onnx_proto = export_function(
module,
@@ -438,34 +437,34 @@ def _optimize(
skip_fuse_bn: bool = False,
skipped_optimizers: Optional[list[str]] = None,
) -> ModelProto:
- if apply_transforms:
- onnx_proto = apply_onnx_transforms(onnx_proto)
- if not simplify:
- return onnx_proto
-
+ modified_proto = onnx_proto
try:
- log.debug("Running onnxsim.simplify")
- simple_proto, _ = onnxsim.simplify(
- onnx_proto,
- check_n=0,
- skip_fuse_bn=skip_fuse_bn,
- skipped_optimizers=skipped_optimizers,
- )
+ if simplify:
+ log.debug("Running onnxsim.simplify")
+ modified_proto, _ = onnxsim.simplify(
+ modified_proto,
+ check_n=0,
+ skip_fuse_bn=skip_fuse_bn,
+ skipped_optimizers=skipped_optimizers,
+ )
- ok = compare(simple_proto, onnx_proto, n_times=check_n)
+ # onnxsim produces anonymous nodes problematic for creating ONNXToFXMap
+ modified_proto = name_anonymous_nodes(modified_proto)
+ if len([*modified_proto.graph.node]) == 0:
+ log.warning("All nodes are constant-folded by onnxsim.")
+ except Exception as e:
+ log.warning(f"ONNX simplifier failed with error: {e}")
+
+ if apply_transforms:
+ modified_proto = apply_onnx_transforms(modified_proto)
+ if modified_proto is not onnx_proto:
+ log.debug("Checking modified model")
+ ok = compare(modified_proto, onnx_proto, n_times=check_n)
if not ok:
log.warning("The output has been changed after the optimization")
- # onnxsim produces anonymous nodes problematic for creating ONNXToFXMap
- simple_proto = name_anonymous_nodes(simple_proto)
- if len([*simple_proto.graph.node]) == 0:
- log.warning("All nodes are constant-folded by onnxsim.")
- return simple_proto
- except Exception as e:
- log.warning(f"ONNX simplifier failed with error: {e}")
-
- return onnx_proto
+ return modified_proto
def _optimize_path(
@@ -478,47 +477,53 @@ def _optimize_path(
) -> ModelProto:
with tempfile.TemporaryDirectory() as tempdir:
output_path = os.path.join(tempdir, "model.onnx")
- if apply_transforms:
- log.debug(f"Applying ONNX transforms with output path: {output_path}")
- onnx_proto = apply_onnx_transforms(onnx_proto, output_path)
-
- if not simplify:
- return onnx_proto
-
- simplified_output_path = os.path.join(tempdir, "simple_model.onnx")
+ onnx.save(
+ onnx_proto,
+ output_path,
+ save_as_external_data=True,
+ all_tensors_to_one_file=True,
+ location="model.bin",
+ )
+ modified_proto = onnx_proto
try:
- log.debug("Running onnxsim.simplify_path")
- if skip_fuse_bn:
- if skipped_optimizers is None:
- skipped_optimizers = []
- skipped_optimizers.append("fuse_bn_into_conv")
- ok = simplify_path(
- output_path,
- simplified_output_path,
- skipped_optimizers,
- True,
- True,
- 10 * (1 << 30), # 10 GB
- )
- if not ok:
- log.warning("ONNX simplifier failed")
- return onnx_proto
+ if simplify:
+ log.debug("Running onnxsim.simplify_path")
+ simplified_output_path = os.path.join(tempdir, "simple_model.onnx")
+ if skip_fuse_bn:
+ if skipped_optimizers is None:
+ skipped_optimizers = []
+ skipped_optimizers.append("fuse_bn_into_conv")
+ ok = simplify_path(
+ output_path,
+ simplified_output_path,
+ skipped_optimizers,
+ True,
+ True,
+ 10 * (1 << 30), # 10 GB
+ )
+ if not ok:
+ log.warning("ONNX simplifier failed")
+ return onnx_proto
+ log.debug(f"Loading simplified ONNX proto from {simplified_output_path}")
+ modified_proto = onnx.load(simplified_output_path)
+ # onnxsim produces anonymous nodes problematic for creating ONNXToFXMap
+ modified_proto = name_anonymous_nodes(modified_proto)
+ if len([*modified_proto.graph.node]) == 0:
+ log.warning("All nodes are constant-folded by onnxsim.")
+ except Exception as e:
+ log.warning(f"ONNX simplifier failed with error: {e}")
- log.debug("Checking simplified model")
+ if apply_transforms:
+ transformed_output_path = os.path.join(tempdir, "transformed_model.onnx")
+ log.debug(f"Applying ONNX transforms with output path: {transformed_output_path}")
+ modified_proto = apply_onnx_transforms(modified_proto, transformed_output_path)
+
+ if modified_proto is not onnx_proto:
+ log.debug("Checking modified model")
ok = compare(simplified_output_path, output_path, check_n, None, None, None)
if not ok:
log.warning("The output has been changed after the optimization")
-
- # onnxsim produces anonymous nodes problematic for creating ONNXToFXMap
- log.debug(f"Loading simplified ONNX proto from {simplified_output_path}")
- simple_proto = onnx.load(simplified_output_path)
- simple_proto = name_anonymous_nodes(simple_proto)
- if len([*simple_proto.graph.node]) == 0:
- log.warning("All nodes are constant-folded by onnxsim.")
- return simple_proto
- except Exception as e:
- log.warning(f"ONNX simplifier failed with error: {e}")
- return onnx_proto
+ return modified_proto
def name_anonymous_nodes(onnx_proto: ModelProto) -> ModelProto:
diff --git a/owlite/backend/onnx/export_with_external_data.py b/src/owlite/backend/onnx/export_with_external_data.py
similarity index 100%
rename from owlite/backend/onnx/export_with_external_data.py
rename to src/owlite/backend/onnx/export_with_external_data.py
diff --git a/owlite/backend/onnx/fold_constants.py b/src/owlite/backend/onnx/fold_constants.py
similarity index 89%
rename from owlite/backend/onnx/fold_constants.py
rename to src/owlite/backend/onnx/fold_constants.py
index c149596..c667fcc 100644
--- a/owlite/backend/onnx/fold_constants.py
+++ b/src/owlite/backend/onnx/fold_constants.py
@@ -1,5 +1,6 @@
# pylint: skip-file
# ruff: noqa: N806
+# type: ignore
import numpy as np
from onnx_graphsurgeon.ir.graph import Graph
from onnx_graphsurgeon.ir.tensor import Constant
@@ -80,9 +81,7 @@ def fold_constants(
PARTITIONING_MODES = [None, "basic", "recursive"]
if partitioning not in PARTITIONING_MODES:
- G_LOGGER.critical(
- f"Argument for parameter 'partitioning' must be one of: {PARTITIONING_MODES}"
- )
+ G_LOGGER.critical(f"Argument for parameter 'partitioning' must be one of: {PARTITIONING_MODES}")
ORT_PROVIDERS = ["CPUExecutionProvider"]
G_LOGGER.debug(f"Folding constants in {graph.name}")
@@ -102,16 +101,14 @@ def fold_constants(
if len(tensor.inputs) == 1:
node = tensor.inputs[0]
if node.op == "Constant":
- tensor.to_constant(
- node.attrs["value"]._values
- ) # Using ._values avoids copying
+ tensor.to_constant(node.attrs["value"]._values) # Using ._values avoids copying
tensor.inputs.clear()
# Pass 2: Run shape-tensor cast elision
def run_cast_elision(node):
import onnx
- # Search for Cast(s) (from int -> float) -> intermediate operator (with float constants) -> Cast(s) (back to int)
+ # Search for Casts (from int -> float) -> intermediate operator (with float constants) -> Casts (back to int)
# This pattern is problematic for TensorRT since these operations may be performed on Shape Tensors, which
# are not allowed to be floating point type. Attempt to fold the pattern here
VALID_CAST_ELISION_OPS = [
@@ -144,8 +141,7 @@ def run_cast_elision(node):
inp_node
for inp_tensor in node.inputs
for inp_node in inp_tensor.inputs
- if inp_node.op == "Cast"
- and inp_node.attrs["to"] == onnx.TensorProto.DataType.FLOAT
+ if inp_node.op == "Cast" and inp_node.attrs["to"] == onnx.TensorProto.DataType.FLOAT
]
# [SQZB] Ensure that Cast nodes are attached to all of the input nodes.
@@ -170,8 +166,7 @@ def run_cast_elision(node):
for out_tensor in node.outputs
for out_node in out_tensor.outputs
if out_node.op == "Cast"
- and out_node.attrs["to"]
- in [onnx.TensorProto.DataType.INT32, onnx.TensorProto.DataType.INT64]
+ and out_node.attrs["to"] in [onnx.TensorProto.DataType.INT32, onnx.TensorProto.DataType.INT64]
]
# No cast node found on outputs, return early
@@ -191,9 +186,7 @@ def run_cast_elision(node):
# `cast_node.inputs[0].outputs[0] == cast_node`.
for index, inp in enumerate(node.inputs):
if isinstance(inp, Constant):
- inp.values = inp.values.astype(
- onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[final_type]
- )
+ inp.values = inp.values.astype(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[final_type])
for cast in inp_casts:
if cast.outputs[0] == inp:
@@ -216,11 +209,7 @@ def run_cast_elision(node):
except Exception as err:
if not error_ok:
raise err
- G_LOGGER.warning(
- "'{:}' routine failed with: {:}".format(
- "Shape tensor cast elision", err
- )
- )
+ G_LOGGER.warning("'{:}' routine failed with: {:}".format("Shape tensor cast elision", err))
# Note that most of the remaining passes operate on a clone of the original graph.
# Pass 3: Find all descendants of constant tensors
@@ -248,9 +237,7 @@ def all_tensors_const(tensors):
for attr in node.attrs.values():
if isinstance(attr, Graph):
foreign_tensors = attr._foreign_tensors().values()
- all_subgraph_foreign_tensors_const &= all_tensors_const(
- foreign_tensors
- )
+ all_subgraph_foreign_tensors_const &= all_tensors_const(foreign_tensors)
return all_subgraph_foreign_tensors_const and not should_exclude_node(node)
@@ -261,11 +248,7 @@ def all_tensors_const(tensors):
graph_constants.update({out.name: out for out in node.outputs})
return graph_constants
- graph_constants = {
- name: tensor
- for name, tensor in clone_tensors.items()
- if isinstance(tensor, Constant)
- }
+ graph_constants = {name: tensor for name, tensor in clone_tensors.items() if isinstance(tensor, Constant)}
graph_constants = update_foldable_outputs(graph_constants)
# Pass 4: Shape Folding
@@ -399,17 +382,13 @@ def fold_shape_slice(tensor):
shape_of = shape_fold_func(tensor)
if shape_of is not None:
- G_LOGGER.ultra_verbose(
- f"Folding shape tensor: {tensor.name} to: {shape_of}"
- )
+ G_LOGGER.ultra_verbose(f"Folding shape tensor: {tensor.name} to: {shape_of}")
graph_constants[tensor.name] = tensor.to_constant(shape_of)
graph_constants[tensor.name].inputs.clear()
except Exception as err:
if not error_ok:
raise err
- G_LOGGER.warning(
- f"'{shape_fold_func.__name__}' routine failed with:\n{err}"
- )
+ G_LOGGER.warning(f"'{shape_fold_func.__name__}' routine failed with:\n{err}")
else:
graph_constants = update_foldable_outputs(graph_constants)
@@ -448,9 +427,7 @@ def get_out_node_ids():
)
values = sess.run(names, {})
except Exception as err:
- G_LOGGER.warning(
- f"Inference failed for subgraph: {part.name}. Note: Error was:\n{err}"
- )
+ G_LOGGER.warning(f"Inference failed for subgraph: {part.name}. Note: Error was:\n{err}")
if partitioning == "recursive":
G_LOGGER.verbose("Attempting to recursively partition subgraph")
# Partition failed, peel off last node.
@@ -460,9 +437,7 @@ def get_out_node_ids():
out_node.outputs.clear()
out_node.inputs.clear()
else:
- G_LOGGER.info(
- "You may see better results if you set partitioning='recursive'"
- )
+ G_LOGGER.info("You may see better results if you set partitioning='recursive'")
if not error_ok:
raise err
@@ -480,9 +455,7 @@ def should_eval_foldable(tensor):
non_const = not isinstance(tensor, Constant)
is_graph_output = not tensor.outputs
- has_non_foldable_outputs = any(
- out.name not in graph_constants for out in tensor.outputs
- )
+ has_non_foldable_outputs = any(out.name not in graph_constants for out in tensor.outputs)
exceeds_size_threshold = (
tensor.shape is not None
and not misc.is_dynamic_shape(tensor.shape)
@@ -490,24 +463,14 @@ def should_eval_foldable(tensor):
and size_threshold is not None
) and (misc.volume(tensor.shape) * get_itemsize(tensor.dtype) > size_threshold)
- return (
- non_const
- and (is_graph_output or has_non_foldable_outputs)
- and not exceeds_size_threshold
- )
+ return non_const and (is_graph_output or has_non_foldable_outputs) and not exceeds_size_threshold
- graph_clone.outputs = [
- t for t in graph_constants.values() if should_eval_foldable(t)
- ]
+ graph_clone.outputs = [t for t in graph_constants.values() if should_eval_foldable(t)]
G_LOGGER.debug(f"Folding tensors: {graph_clone.outputs}")
graph_clone.cleanup(remove_unused_graph_inputs=True)
# Using ._values avoids a deep copy of the values.
- constant_values = {
- name: tensor._values
- for name, tensor in graph_constants.items()
- if isinstance(tensor, Constant)
- }
+ constant_values = {name: tensor._values for name, tensor in graph_constants.items() if isinstance(tensor, Constant)}
if graph_clone.outputs:
if partitioning:
constant_values.update(partition_and_infer(graph_clone))
@@ -562,15 +525,12 @@ def should_eval_foldable(tensor):
if large_tensors:
large_tensors_mib = {
- tensor_name: f"{value // (1 << 20)} MiB"
- for tensor_name, value in large_tensors.items()
+ tensor_name: f"{value // (1 << 20)} MiB" for tensor_name, value in large_tensors.items()
}
G_LOGGER.warning(
"It looks like this model contains foldable nodes that produce large outputs.\n"
"In order to avoid bloating the model, you may want to set a constant-folding size threshold.\n"
- "Note: Large tensors and their corresponding sizes were: {:}".format(
- large_tensors_mib
- ),
+ "Note: Large tensors and their corresponding sizes were: {:}".format(large_tensors_mib),
mode=LogMode.ONCE,
)
@@ -599,9 +559,7 @@ def fold_subgraphs():
if node.op == "If" and isinstance(node.inputs[0], Constant):
G_LOGGER.debug(f"Flattening conditional: {node}")
cond = get_scalar_value(node.inputs[0])
- subgraph = (
- node.attrs["then_branch"] if cond else node.attrs["else_branch"]
- )
+ subgraph = node.attrs["then_branch"] if cond else node.attrs["else_branch"]
# Need to add a suffix to subgraph tensors so they don't collide with outer graph tensors
for tensor in subgraph._local_tensors().values():
tensor.name += f"_subg_{index}_{subgraph.name}"
diff --git a/owlite/backend/onnx/model_checking.py b/src/owlite/backend/onnx/model_checking.py
similarity index 99%
rename from owlite/backend/onnx/model_checking.py
rename to src/owlite/backend/onnx/model_checking.py
index 05c5e6f..bc5906b 100644
--- a/owlite/backend/onnx/model_checking.py
+++ b/src/owlite/backend/onnx/model_checking.py
@@ -1,3 +1,4 @@
+# type: ignore
import os
from collections import OrderedDict
from typing import Optional, Union
@@ -7,8 +8,7 @@
import onnx.checker
import onnxruntime as rt
-from owlite_core.logger import log
-
+from ...owlite_core.logger import log
from ..utils import compare_nested_outputs
Tensors = dict[str, np.ndarray]
diff --git a/src/owlite/backend/onnx/onnx_op.py b/src/owlite/backend/onnx/onnx_op.py
new file mode 100644
index 0000000..19fd5e4
--- /dev/null
+++ b/src/owlite/backend/onnx/onnx_op.py
@@ -0,0 +1,30 @@
+from .op_schema import OpSchema, get_core_operator_schemas
+
+
+class ONNXOp:
+ """Class representing each ONNX op allowing convenient access to its schema properties"""
+
+ schemas: dict[str, OpSchema] = get_core_operator_schemas()
+
+ def __init__(self, name: str) -> None:
+ self.name = name
+
+ def __repr__(self) -> str:
+ return f"{self.name}"
+
+ def __str__(self) -> str:
+ return self.__repr__()
+
+ @property
+ def is_valid(self) -> bool:
+ """Checks if the op exists in schemas"""
+ return self.name in ONNXOp.schemas
+
+ @property
+ def schema(self) -> OpSchema:
+ """The full schema object of the op
+
+ Returns:
+ OpSchema: the op schema
+ """
+ return ONNXOp.schemas[self.name]
diff --git a/src/owlite/backend/onnx/op_schema.py b/src/owlite/backend/onnx/op_schema.py
new file mode 100644
index 0000000..34f39d8
--- /dev/null
+++ b/src/owlite/backend/onnx/op_schema.py
@@ -0,0 +1,458 @@
+# pylint: disable=invalid-name, line-too-long
+# ruff: noqa: E501
+import re
+from collections import defaultdict
+from dataclasses import dataclass
+from enum import Enum, auto
+from functools import cached_property
+from typing import Any, Optional
+
+import numpy as np
+import onnx
+from onnx import defs
+from typing_extensions import Self
+
+from ...options.options_dict import OptionsDict
+from ...options.options_mixin import OptionsMixin
+
+
+class FormalParameterOption(Enum):
+ """
+ A statically analyzable python class for
+ [OpSchema.FormalParameterOption](https://github.com/onnx/onnx/blob/v1.15.0/onnx/onnx_cpp2py_export/defs.pyi#L93-L96)
+ """
+
+ Single = 0
+ Optional = auto()
+ Variadic = auto()
+
+
+class DifferentiationCategory(Enum):
+ """
+ A statically analyzable python class for
+ [OpSchema.DifferentiationCategory](https://github.com/onnx/onnx/blob/v1.15.0/onnx/onnx_cpp2py_export/defs.pyi#L98-L101)
+ """
+
+ Unknown = 0
+ Differentiable = auto()
+ NonDifferentiable = auto()
+
+
+@dataclass
+class FormalParameter(OptionsMixin):
+ """
+ A statically analyzable python class for
+ [OpSchema.FormalParameter](https://github.com/onnx/onnx/blob/v1.15.0/onnx/onnx_cpp2py_export/defs.pyi#L103-L130)
+ """
+
+ name: str
+ type_str: str
+ description: str
+ option: FormalParameterOption
+ is_homogeneous: bool
+ min_arity: int
+ differentiation_category: DifferentiationCategory
+
+ @property
+ def is_single(self) -> bool:
+ """Equivalent to `self.option == FormalParameterOption.Single`"""
+ return self.option == FormalParameterOption.Single
+
+ @property
+ def is_optional(self) -> bool:
+ """Equivalent to `self.option == FormalParameterOption.Optional`"""
+ return self.option == FormalParameterOption.Optional
+
+ @property
+ def is_variadic(self) -> bool:
+ """Equivalent to `self.option == FormalParameterOption.Variadic`"""
+ return self.option == FormalParameterOption.Variadic
+
+ @classmethod
+ def from_defs(cls, parameter: defs.OpSchema.FormalParameter) -> Self:
+ """Instantiation from the original class in `onnx.defs`"""
+ return cls(
+ name=parameter.name,
+ type_str=parameter.type_str,
+ description=parameter.description,
+ option=FormalParameterOption(parameter.option.value),
+ is_homogeneous=parameter.is_homogeneous,
+ min_arity=parameter.min_arity,
+ differentiation_category=DifferentiationCategory(parameter.differentiation_category.value),
+ )
+
+
+@dataclass
+class CompiledFormalParameter(OptionsMixin):
+ """The formal parameter compiled for an input or output at a specific index"""
+
+ name: str
+ allowed_types: list[np.dtype]
+ description: str
+ option: FormalParameterOption
+ is_homogeneous: bool
+ min_arity: int
+ differentiation_category: DifferentiationCategory
+
+ @property
+ def is_single(self) -> bool:
+ """Equivalent to `self.option == FormalParameterOption.Single`"""
+ return self.option == FormalParameterOption.Single
+
+ @property
+ def is_optional(self) -> bool:
+ """Equivalent to `self.option == FormalParameterOption.Optional`"""
+ return self.option == FormalParameterOption.Optional
+
+ @property
+ def is_variadic(self) -> bool:
+ """Equivalent to `self.option == FormalParameterOption.Variadic`"""
+ return self.option == FormalParameterOption.Variadic
+
+
+@dataclass
+class TypeConstraintParam(OptionsMixin):
+ """
+ A statically analyzable python class for
+ [OpSchema.TypeConstraintParam](https://github.com/onnx/onnx/blob/v1.15.0/onnx/onnx_cpp2py_export/defs.pyi#L72-L91)
+ """
+
+ type_param_str: str
+ allowed_type_strs: list[str]
+ description: str
+
+ @cached_property
+ def allowed_types(self) -> list[np.dtype]:
+ """The allowed types converted into np.dtype instances"""
+ return convert_to_np_dtypes(self.allowed_type_strs)
+
+ @classmethod
+ def from_defs(cls, constraint: defs.OpSchema.TypeConstraintParam) -> Self:
+ """Instantiation from the original class in `onnx.defs`"""
+ return cls(
+ type_param_str=constraint.type_param_str,
+ allowed_type_strs=list(constraint.allowed_type_strs),
+ description=constraint.description,
+ )
+
+
+class TypeConstraintParamMap(OptionsDict[str, TypeConstraintParam]):
+ """
+ * Key (str): a type parameter string
+ * Value (TypeConstraintParam): the type constraint corresponding to the type parameter string
+ """
+
+
+class AttrType(Enum):
+ """
+ A statically analyzable python class for
+ [AttrType](https://github.com/onnx/onnx/blob/v1.15.0/onnx/onnx_cpp2py_export/defs.pyi#L132-L146)
+ """
+
+ NONE = 0
+ FLOAT = auto()
+ INT = auto()
+ STRING = auto()
+ TENSOR = auto()
+ GRAPH = auto()
+ SPARSE_TENSOR = auto()
+ TYPE_PROTO = auto()
+ FLOATS = auto()
+ INTS = auto()
+ STRINGS = auto()
+ TENSORS = auto()
+ GRAPHS = auto()
+ SPARSE_TENSORS = auto()
+ TYPE_PROTOS = auto()
+
+
+@dataclass
+class AttributeProto(OptionsMixin):
+ """A statically analyzable python class for `onnx.AttributeProto`"""
+
+ name: str
+ type: AttrType
+ value: Any
+
+ @classmethod # pylint: disable-next=too-many-statements, too-many-return-statements
+ def from_defs(cls, attr_proto: onnx.AttributeProto) -> Optional[Self]:
+ """Instantiation from the original class in `onnx.defs`"""
+ attr_type = AttrType(attr_proto.type)
+ match attr_type:
+ case AttrType.NONE:
+ return None
+ case AttrType.FLOAT:
+ return cls(attr_proto.name, attr_type, attr_proto.f)
+ case AttrType.INT:
+ return cls(attr_proto.name, attr_type, attr_proto.i)
+ case AttrType.STRING:
+ return cls(attr_proto.name, attr_type, attr_proto.s.decode("UTF-8"))
+ case AttrType.TENSOR:
+ return cls(attr_proto.name, attr_type, attr_proto.t)
+ case AttrType.GRAPH:
+ return cls(attr_proto.name, attr_type, attr_proto.g)
+ case AttrType.SPARSE_TENSOR:
+ return cls(attr_proto.name, attr_type, attr_proto.sparse_tensor)
+ case AttrType.TYPE_PROTO:
+ return cls(attr_proto.name, attr_type, attr_proto.tp)
+ case AttrType.FLOATS:
+ return cls(attr_proto.name, attr_type, attr_proto.floats)
+ case AttrType.INTS:
+ return cls(attr_proto.name, attr_type, attr_proto.ints)
+ case AttrType.STRINGS:
+ return cls(attr_proto.name, attr_type, attr_proto.strings)
+ case AttrType.TENSORS:
+ return cls(attr_proto.name, attr_type, attr_proto.tensors)
+ case AttrType.GRAPHS:
+ return cls(attr_proto.name, attr_type, attr_proto.graphs)
+ case AttrType.SPARSE_TENSORS:
+ return cls(attr_proto.name, attr_type, attr_proto.sparse_tensors)
+ case AttrType.TYPE_PROTOS:
+ return cls(attr_proto.name, attr_type, attr_proto.type_protos)
+
+
+@dataclass
+class Attribute(OptionsMixin):
+ """
+ A statically analyzable python class for
+ [OpSchema.Attribute](https://github.com/onnx/onnx/blob/v1.15.0/onnx/onnx_cpp2py_export/defs.pyi#L148-L174)
+ """
+
+ name: str
+ type: AttrType
+ description: str
+ default_value: Optional[AttributeProto]
+ required: bool
+
+ @classmethod
+ def from_defs(cls, attribute: defs.OpSchema.Attribute) -> Self:
+ """Instantiation from the original class in `onnx.defs`"""
+ return cls(
+ name=attribute.name,
+ type=AttrType(attribute.type),
+ description=attribute.description,
+ default_value=AttributeProto.from_defs(attribute.default_value),
+ required=attribute.required,
+ )
+
+
+class AttributeMap(OptionsDict[str, Attribute]):
+ """
+ * Key (str): the name of an attribute
+ * Value (Attribute): the attribute
+ """
+
+
+@dataclass # pylint: disable-next=too-many-instance-attributes
+class OpSchema(OptionsMixin):
+ """
+ A statically analyzable python class for
+ [OpSchema](https://github.com/onnx/onnx/blob/v1.15.0/onnx/onnx_cpp2py_export/defs.pyi#L10-L70)
+ """
+
+ name: str
+ domain: str
+ since_version: int
+ doc: str
+ type_constraints: TypeConstraintParamMap
+ inputs: list[FormalParameter]
+ outputs: list[FormalParameter]
+ attributes: AttributeMap
+ min_input: int
+ max_input: int
+ min_output: int
+ max_output: int
+
+ @classmethod
+ def from_defs(cls, schema: defs.OpSchema) -> Self:
+ """Instantiation from the original class in `onnx.defs`"""
+ return cls(
+ name=schema.name,
+ domain=schema.domain,
+ since_version=schema.since_version,
+ doc=schema.doc,
+ type_constraints=TypeConstraintParamMap(
+ {
+ type_constraint.type_param_str: TypeConstraintParam.from_defs(type_constraint)
+ for type_constraint in schema.type_constraints
+ }
+ ),
+ inputs=[FormalParameter.from_defs(p) for p in schema.inputs],
+ outputs=[FormalParameter.from_defs(p) for p in schema.outputs],
+ attributes=AttributeMap({name: Attribute.from_defs(attr) for name, attr in schema.attributes.items()}),
+ min_input=schema.min_input,
+ max_input=schema.max_input,
+ min_output=schema.min_output,
+ max_output=schema.max_output,
+ )
+
+ def i(self, idx: int = 0) -> CompiledFormalParameter:
+ """The formal parameter of the input at given index.
+
+ Args:
+ idx (int, optional): the input index. Defaults to 0.
+
+ Returns:
+ CompiledFormalParameter: the formal ONNX parameter of the input.
+ """
+ # Ideally, this would've been `return self.inputs[idx]`, but the reality is not so simple.
+ return _get_formal_parameter(idx, self.inputs, self.type_constraints)
+
+ def o(self, idx: int = 0) -> CompiledFormalParameter:
+ """The formal parameter of the output at given index.
+
+ Args:
+ idx (int, optional): the output index. Defaults to 0.
+
+ Returns:
+ CompiledFormalParameter: the formal ONNX parameter of the output.
+ """
+ # Ideally, this would've been `return self.outputs[idx]`, but the reality is not so simple.
+ return _get_formal_parameter(idx, self.outputs, self.type_constraints)
+
+
+def _get_formal_parameter(
+ idx: int,
+ params: list[FormalParameter],
+ type_constraints: TypeConstraintParamMap,
+) -> CompiledFormalParameter:
+ is_last_parameter_variadic = params[-1].is_variadic
+ if not (-len(params) <= idx < len(params) or is_last_parameter_variadic):
+ raise IndexError(f"input or output index out of range: {idx}")
+ if is_last_parameter_variadic:
+ param_idx = min(idx, len(params) - 1)
+ offset = idx - param_idx
+ param = params[param_idx]
+ name = f"{param.name}_{offset}"
+ else:
+ param = params[idx]
+ name = param.name
+ return CompiledFormalParameter(
+ name=name,
+ allowed_types=_get_type_contraints(param.type_str, type_constraints).allowed_types,
+ description=param.description,
+ option=param.option,
+ is_homogeneous=param.is_homogeneous,
+ min_arity=param.min_arity,
+ differentiation_category=param.differentiation_category,
+ )
+
+
+def _get_type_contraints(type_str: str, type_constraints: TypeConstraintParamMap) -> TypeConstraintParam:
+ if type_str in type_constraints:
+ return type_constraints[type_str]
+ return TypeConstraintParam(type_str, [type_str], description="")
+
+
+def convert_to_np_dtypes(wrapped_type_strs: list[str]) -> list[np.dtype]:
+ """Converts type strings from an op schema to numpy data type
+
+ Args:
+ wrapped_type_strs (list[str]): the op schema type string
+
+ Returns:
+ list[np.dtype]: the converted numpy data type.
+ """
+ return [
+ dtype
+ for type_str in wrapped_type_strs
+ if (dtype := try_convert_to_np_dtype(unwrap_type_str(type_str))) is not None
+ ]
+
+
+def unwrap_type_str(type_str: str) -> str:
+ """Unwraps a type string from an op schema if possible
+
+ Args:
+ type_str (str): an op schema type string
+
+ Returns:
+ str: the string containing only type name if the unwrapping was successful, the input type_str itself otherwise.
+ """
+ match = re.search(r"tensor\((.*?)\)", type_str)
+ if match:
+ # Return the extracted string
+ return match.group(1)
+ # Return the input itself if no match is found
+ return type_str
+
+
+def try_convert_to_np_dtype(type_str: str) -> Optional[np.dtype]:
+ """Converts the type name in string into numpy data type if possible.
+
+ Args:
+ type_str (str): a string containing type name
+
+ Returns:
+ Optional[np.dtype]: a numpy.dtype instance if the conversion was successful, None otherwise.
+ """
+ if type_str == "float":
+ type_str = "float32"
+ try:
+ return np.dtype(type_str)
+ except TypeError:
+ pass
+ return None
+
+
+def get_full_operator_schemas() -> (
+ list[tuple[str, list[tuple[int, list[tuple[str, defs.OpSchema, list[defs.OpSchema]]]]]]]
+):
+ """parse full operator schemas
+
+ Returns:
+ list[tuple[str, list[tuple[int, list[tuple[str, defs.OpSchema, list[defs.OpSchema]]]]]]]: nested structure containing all
+ available op schemas
+ """
+ # domain -> support level -> name -> [schema]
+ index: dict[str, dict[int, dict[str, list[defs.OpSchema]]]] = defaultdict(
+ lambda: defaultdict(lambda: defaultdict(list))
+ )
+ for schema in defs.get_all_schemas_with_history():
+ index[schema.domain][int(schema.support_level)][schema.name].append(schema)
+
+ # Preprocess the Operator Schemas
+ # [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
+ operator_schemas: list[tuple[str, list[tuple[int, list[tuple[str, defs.OpSchema, list[defs.OpSchema]]]]]]] = []
+ existing_ops: set[str] = set()
+ for domain, _supportmap in sorted(index.items()):
+ if domain == defs.ONNX_ML_DOMAIN:
+ continue
+
+ processed_supportmap = []
+ for _support, _namemap in sorted(_supportmap.items()):
+ processed_namemap = []
+ for n, unsorted_versions in sorted(_namemap.items()):
+ versions = sorted(unsorted_versions, key=lambda s: s.since_version)
+ schema = versions[-1]
+ if schema.name in existing_ops:
+ continue
+ existing_ops.add(schema.name)
+ processed_namemap.append((n, schema, versions))
+ processed_supportmap.append((_support, processed_namemap))
+ operator_schemas.append((domain, processed_supportmap))
+ return operator_schemas
+
+
+def get_core_operator_schemas_defs() -> dict[str, defs.OpSchema]:
+ """restructured operator schemas for only core operators
+
+ Returns:
+ dict[str, list[tuple[int, list[tuple[str, defs.OpSchema, list[defs.OpSchema]]]]]]: the dictionary with key-value pairs
+ where each op name is a key in string whose value is the nest structure containing various properties
+ of the ONNX op.
+ """
+ triples = dict(get_full_operator_schemas())[""][0][1]
+ return {x[0]: x[1] for x in triples}
+
+
+def get_core_operator_schemas() -> dict[str, OpSchema]:
+ """restructured operator schemas for only core operators
+
+ Returns:
+ dict[str, list[tuple[int, list[tuple[str, defs.OpSchema, list[defs.OpSchema]]]]]]: the dictionary with key-value pairs
+ where each op name is a key in string whose value is the nest structure containing various properties
+ of the ONNX op.
+ """
+ triples = dict(get_full_operator_schemas())[""][0][1]
+ return {x[0]: OpSchema.from_defs(x[1]) for x in triples}
diff --git a/owlite/backend/onnx/transforms.py b/src/owlite/backend/onnx/transforms.py
similarity index 90%
rename from owlite/backend/onnx/transforms.py
rename to src/owlite/backend/onnx/transforms.py
index 2432181..e4d2ead 100644
--- a/owlite/backend/onnx/transforms.py
+++ b/src/owlite/backend/onnx/transforms.py
@@ -7,14 +7,12 @@
import onnx
import onnx_graphsurgeon as gs
from onnx import ModelProto, TensorProto
-from onnx_graphsurgeon.importers.onnx_importer import get_numpy_type
-
-from owlite_core.logger import log
+from ...owlite_core.logger import log
from ..config import ONNX_TRANSFORM_MAXIMUM_ITERATION
-from ..utils import is_floating_point, nodestr
+from ..utils import get_numpy_type, is_floating_point, nodestr
from .export_with_external_data import export_with_external_data
-from .fold_constants import fold_constants
+from .fold_constants import fold_constants # type: ignore
from .onnx_op import ONNXOp
OnnxTransform = Callable[[gs.Graph], gs.Graph]
@@ -129,7 +127,7 @@ def eliminate_nop_dropouts(graph: gs.Graph) -> gs.Graph:
@register_onnx_transform
-def eliminate_nop_cast(graph: gs.Graph) -> gs.Graph:
+def eliminate_nop_casts(graph: gs.Graph) -> gs.Graph:
"""Eliminates all Cast ops with no effect.
Args:
@@ -202,7 +200,7 @@ class FloatingPointSyncType(Enum):
cast_input_fp_tensors_of(node, dtype)
cast_output_fp_tensors_of(node, dtype)
- return eliminate_nop_cast(graph)
+ return eliminate_nop_casts(graph)
@register_onnx_transform
@@ -459,6 +457,48 @@ def _fold_add_or_sub(conv_node: gs.Node, add_or_sub_node: gs.Node) -> None:
return graph
+@register_onnx_transform
+def eliminate_nop_reformatting_sequences(graph: gs.Graph) -> gs.Graph:
+ """Eliminate meaningless reformatting sequences
+ (e.g. Flatten->Reshape with identical input and output shapes)
+
+ Args:
+ graph (gs.Graph): a ONNX graph.
+
+ Returns:
+ gs.Graph: the transformed ONNX graph
+ """
+ reformatting_ops = ("Reshape", "Flatten", "Squeeze", "Unsqueeze")
+ for node in graph.nodes:
+ if node.op not in reformatting_ops:
+ continue
+ reformatting_sequence = [node]
+ log.debug(f"Starting with {nodestr(node)}")
+ while True:
+ child = reformatting_sequence[-1]
+ if not ((parent := get_defining_node(child.inputs[0])) is not None and parent.op in reformatting_ops):
+ log.debug(f"Break: parent {nodestr(parent)} is not a reformatting node")
+ break
+
+ reformatting_sequence.append(parent)
+ log.debug(f"Appended {nodestr(parent)}")
+ if parent.inputs[0].shape == reformatting_sequence[0].outputs[0].shape:
+ log.debug("Early exit: parent input shape matched")
+ break
+ if not (
+ len(reformatting_sequence) > 1
+ and (
+ (the_output := reformatting_sequence[0].outputs[0]).shape
+ == (the_input := reformatting_sequence[-1].inputs[0]).shape
+ )
+ ):
+ log.debug("No foldable reformatting sequence found")
+ continue
+ log.debug(f"Found foldable reformatting sequence: {[*map(nodestr, reformatting_sequence)]}")
+ replace_all_uses(the_output, the_input)
+ return graph
+
+
# pylint: disable=missing-function-docstring
def remove_if_dropout_op_with_ratio_zero(node: gs.Node, graph: gs.Graph) -> None:
if node.op != "Dropout":
@@ -512,8 +552,7 @@ def cast_input_fp_tensors_of(node: gs.Node, dtype: np.dtype) -> None:
return
for idx, tensor in enumerate(node.inputs):
- type_constraints = onnx_op.i(idx).type_constraints
- if dtype not in type_constraints:
+ if dtype not in onnx_op.schema.i(idx).allowed_types:
continue
if not tensor.inputs and is_floating_point(tensor.dtype):
@@ -535,8 +574,7 @@ def cast_output_fp_tensors_of(node: gs.Node, dtype: np.dtype) -> None:
return
for idx, tensor in enumerate(node.outputs):
- type_constraints = onnx_op.o(idx).type_constraints
- if dtype not in type_constraints:
+ if dtype not in onnx_op.schema.o(idx).allowed_types:
continue
if not is_floating_point(tensor.dtype) or tensor.dtype == dtype:
@@ -569,9 +607,7 @@ def remove_if_has_unique_non_optional_input_and_unique_used_output(node: gs.Node
log.debug_warning(f"Unsupported ONNXOp: {node.op}")
return
- non_optional_inputs = [
- *(node.inputs[idx] for idx in filter(lambda idx: not onnx_op.i(idx).is_optional, range(len(node.inputs))))
- ]
+ non_optional_inputs = [node.inputs[idx] for idx in range(len(node.inputs)) if not onnx_op.schema.i(idx).is_optional]
if len(non_optional_inputs) != 1:
log.debug_warning(
@@ -629,3 +665,15 @@ def get_defining_op_type(tensor: Tensor) -> Optional[str]:
return None
return tensor.inputs[0].op
+
+
+def get_defining_node(tensor: Tensor) -> Optional[gs.Node]:
+ if tensor.inputs:
+ return tensor.inputs[0]
+ return None
+
+
+def replace_all_uses(existing_tensor: Tensor, new_tensor: Tensor) -> None:
+ for node in existing_tensor.outputs:
+ i = node.inputs.index(existing_tensor)
+ node.inputs[i] = new_tensor
diff --git a/owlite/backend/patches.py b/src/owlite/backend/patches.py
similarity index 89%
rename from owlite/backend/patches.py
rename to src/owlite/backend/patches.py
index 464d67f..b38ba06 100644
--- a/owlite/backend/patches.py
+++ b/src/owlite/backend/patches.py
@@ -6,9 +6,9 @@
import operator
from collections import OrderedDict
-from dataclasses import dataclass
+from enum import Enum
from packaging import version
-from typing import TYPE_CHECKING, Callable, Optional
+from typing import TYPE_CHECKING, Callable, Optional, Union
if TYPE_CHECKING:
from torch.types import _dtype as DType
@@ -22,7 +22,8 @@
from torch import Tensor
from torch.fx.node import _side_effectful_functions
-from owlite_core.logger import log
+from ..owlite_core.logger import log
+from .config import DISABLE_AUTO_PATCH
try:
import diffusers
@@ -37,97 +38,166 @@
_side_effectful_functions.add(operator.setitem)
+require_explicit_disallow = "2.2.0" <= torch.__version__ < "2.3"
-@dataclass
-class PatchHistory:
- orig_fn: Callable
- patched_fn: Callable
+def force_dynamo_disallow_in_graph(obj):
+ """Forcefully do what `torch._dynamo.disallow_in_graph` is supposed to do by
+ setting the designated property `torchdynamo_force_dynamic` to `True`.
+ See [torch._dynamo.mutation_guard.is_dynamic_nn_module](https://github.com/pytorch/pytorch/blob/7bcf7da3a268b435777fe87c7794c382f444e86d/torch/_dynamo/mutation_guard.py#L84)
+
+ Args:
+ obj: an object to disallow in torch dynamo graph
+ """
+ torch._dynamo.allow_in_graph(obj)
+ torch._dynamo.disallow_in_graph(obj)
+ # if "2.1.0" <= torch.__version__ < "2.2.0":
+ obj.torchdynamo_force_dynamic = True
+
+
+class PatchStatus(Enum):
+ REGISTERED = 0
+ APPLIED = 1
+
+class Patch():
+ orig_fn: Callable # original function
+ patched_fn: Callable # patched version of the function
+ orig_fn_path: str # the full qualified name of the original function
+ patched_fn_path: str # the full qualified name of the patched function
+ status: PatchStatus # status of the patch
+ disallow_in_graph: bool # whether to allow the patched function in graph
+ hard_patch_target: Optional[str] # patch target to assign the patched function with assign operator(=)
+
+ def __init__(
+ self,
+ orig_fn: Callable,
+ patched_fn: Callable,
+ disallow_in_graph: bool,
+ hard_patch_target: Optional[str],
+ ) -> None:
+ self.orig_fn = orig_fn
+ self.patched_fn = patched_fn
+ self.orig_fn_module = inspect.getmodule(orig_fn)
+ self.orig_fn_path = f"{self.orig_fn_module.__name__}.{orig_fn.__name__}"
+ self.patched_fn_module = inspect.getmodule(patched_fn)
+ self.patched_fn_path = f"{self.patched_fn_module.__name__}.{patched_fn.__name__}"
+ self.status = PatchStatus.REGISTERED
+ self.hard_patch_target = hard_patch_target
+ self.disallow_in_graph = disallow_in_graph
+
+ if self.hard_patch_target:
+ log.debug_warning(f"Hard patch for {self.hard_patch_target} detected, "
+ "note that hard patch can result in unexpected outcome")
+
+ def apply(self) -> None:
+ if self.status == PatchStatus.APPLIED:
+ log.warning("This patch is already applied")
+ return
+
+ self.status = PatchStatus.APPLIED
+
+ if self.hard_patch_target:
+ exec(f"{self.hard_patch_target} = self.patched_fn")
+
+ setattr(self.orig_fn_module, self.orig_fn.__name__, self.patched_fn)
+
+ if self.disallow_in_graph:
+ torch._dynamo.allow_in_graph(self.patched_fn)
+ torch._dynamo.disallow_in_graph(self.patched_fn)
+
+ def rollback(self) -> None:
+ if self.status == PatchStatus.REGISTERED:
+ log.warning("This patch is not applied yet")
+ return
+
+ self.status = PatchStatus.REGISTERED
+
+ if self.hard_patch_target:
+ exec(f"{self.hard_patch_target} = self.orig_fn")
+
+ setattr(self.orig_fn_module, self.orig_fn.__name__, self.orig_fn)
+ # TODO: should rollback reset graph allow settings too?
class PatchManager:
- _history: list[PatchHistory] = []
+ patches: list[Patch] = []
@classmethod
- def patch(
- cls,
- orig_fn: Callable,
- patched_fn: Callable,
- ):
- for history in cls._history:
- if (
- orig_fn is history.orig_fn
- or orig_fn is history.patched_fn
- or patched_fn is history.orig_fn
- or patched_fn is history.patched_fn
- ):
- raise Exception(
- f"Patch conflict detected for {orig_fn.__module__}.{orig_fn.__name__}"
- )
+ def is_registered(cls, fn_or_fn_path: Union[Callable, str]):
+ return fn_or_fn_path in [
+ f for patch in cls.patches for f in [
+ patch.orig_fn, patch.patched_fn, patch.orig_fn_path, patch.patched_fn_path
+ ]
+ ]
+
+ @classmethod
+ def register_patch(
+ cls,
+ orig_fn: Callable,
+ patched_fn: Callable,
+ hard_patch_target: Optional[str] = None,
+ disallow_in_graph: bool = False,
+ ) -> None:
+ if cls.is_registered(orig_fn):
+ raise Exception(
+ f"Patch conflict detected for {orig_fn.__module__}.{orig_fn.__name__}"
+ )
+ if cls.is_registered(patched_fn):
+ raise Exception(
+ f"Patch conflict detected for {patched_fn.__module__}.{patched_fn.__name__}"
+ )
+
if patched_fn is not orig_fn:
patched_fn_module = inspect.getmodule(patched_fn)
patched_fn_name = patched_fn.__name__
setattr(inspect.getmodule(orig_fn), orig_fn.__name__, patched_fn)
- cls._history.append(PatchHistory(orig_fn, patched_fn))
+ cls.patches.append(Patch(orig_fn, patched_fn, disallow_in_graph, hard_patch_target))
log.debug(
- f"Patched {orig_fn.__module__}.{orig_fn.__name__} by {patched_fn_module.__name__}.{patched_fn_name}"
+ f"Registered patch: {orig_fn.__module__}.{orig_fn.__name__} -> "
+ f"{patched_fn_module.__name__}.{patched_fn_name}"
)
else:
log.warning(
- f"Ignoring vacuous patch for {orig_fn.__module__}{orig_fn.__name__}"
+ f"Ignoring vacuous patch for {orig_fn.__module__}.{orig_fn.__name__}"
)
@classmethod
- def rollback(cls, orig_fn_or_patched_fn: Callable):
- for history in cls._history:
- if (
- history.orig_fn is orig_fn_or_patched_fn
- or history.patched_fn is orig_fn_or_patched_fn
- ):
- setattr(
- history.patched_fn_module,
- history.patched_fn_name,
- history.patched_fn,
- )
- setattr(
- inspect.getmodule(history.orig_fn),
- history.orig_fn.__name__,
- history.orig_fn,
- )
- cls._history.remove(history)
- log.info(
- f"Rolled back the patched function {history.patched_fn_module.__name__}.{history.patched_fn_name} to {history.orig_fn.__module__}.{history.orig_fn.__name__}"
- )
- return
- log.warning(
- f"No patch registered for {orig_fn_or_patched_fn.__module__}.{orig_fn_or_patched_fn.__name__}"
- )
-
-
-def force_dynamo_disallow_in_graph(obj):
- """Forcefully do what `torch._dynamo.disallow_in_graph` is supposed to do by
- setting the designated property `torchdynamo_force_dynamic` to `True`.
- See [torch._dynamo.mutation_guard.is_dynamic_nn_module](https://github.com/pytorch/pytorch/blob/7bcf7da3a268b435777fe87c7794c382f444e86d/torch/_dynamo/mutation_guard.py#L84)
+ def deregister_patch(cls, fn_or_fn_path: Union[Callable, str]) -> None:
+ if not cls.is_registered(fn_or_fn_path):
+ fn_path = (
+ fn_or_fn_path if isinstance(fn_or_fn_path, str)
+ else '.'.join([str(inspect.getmodule(fn_or_fn_path)), fn_or_fn_path.__name__])
+ )
+ log.warning(f"No patch registered for {fn_path}")
+ return
+
+ patch_index = [
+ fn for patch in cls.patches for fn in [
+ patch.orig_fn, patch.patched_fn, patch.orig_fn_path, patch.patched_fn_path
+ ]
+ ].index(fn_or_fn_path) // 4
+ cls.patches.pop(patch_index)
- Args:
- obj: an object to disallow in torch dynamo graph
- """
- torch._dynamo.disallow_in_graph(obj)
- obj.torchdynamo_force_dynamic = True
+ @classmethod
+ def apply_patches(cls) -> None:
+ for patch in cls.patches:
+ patch.apply()
+
+ @classmethod
+ def rollback_patches(cls) -> None:
+ for patch in cls.patches:
+ patch.rollback()
-def patch(orig_fn: Callable):
+def register_patch(orig_fn: Callable, hard_patch_target: Optional[str] = None, disallow_in_graph: bool = False):
def wrap(patched_fn: Callable):
- PatchManager.patch(orig_fn, patched_fn)
+ PatchManager.register_patch(orig_fn, patched_fn, hard_patch_target, disallow_in_graph)
return patched_fn
-
+
return wrap
-def rollback(orig_fn_or_patched_fn: Callable):
- PatchManager.rollback(orig_fn_or_patched_fn)
-
-
# [SQZB] patch for DataParallel
+@register_patch(torch.fx.GraphModule._replicate_for_data_parallel, "torch.fx.GraphModule._replicate_for_data_parallel")
def patched_replicate_for_data_parallel(self):
new_gm = self.__copy__()
new_gm._is_replica = True
@@ -140,11 +210,9 @@ def patched_replicate_for_data_parallel(self):
return new_gm
-torch.fx.GraphModule._replicate_for_data_parallel = patched_replicate_for_data_parallel
-
-# See https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch-nn-functional-scaled-dot-product-attention
# [SQZB] torch.nn.functional.scaled_dot_product_attention cannot be exported to ONNX
-@patch(torch.nn.functional.scaled_dot_product_attention)
+# See https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch-nn-functional-scaled-dot-product-attention
+@register_patch(torch.nn.functional.scaled_dot_product_attention, "torch.nn.functional.scaled_dot_product_attention", disallow_in_graph=require_explicit_disallow)
def slow_scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
) -> Tensor:
@@ -171,11 +239,10 @@ def slow_scaled_dot_product_attention(
attn_weight = F.dropout(attn_weight, dropout_p).to(value.dtype)
return attn_weight @ value
-torch.nn.functional.scaled_dot_product_attention = slow_scaled_dot_product_attention
# [SQZB] torch.nn.functional._mha_shape_check causes the error: "torch.* op returned non-Tensor bool"
# Made it a local function with no changes in its contents from torch==2.1.0 (same as torch==2.0.0)
-@patch(torch.nn.functional._mha_shape_check)
+@register_patch(torch.nn.functional._mha_shape_check, disallow_in_graph=require_explicit_disallow)
def patched_mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], num_heads: int):
# Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
@@ -226,7 +293,7 @@ def patched_mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
# [SQZB] torch.nn.functional._none_or_dtype causes the error: "torch.* op returned non-Tensor bool"
# Made it a local function with no changes in its contents from torch==2.1.0 (same as torch==2.0.0)
-@patch(torch.nn.functional._none_or_dtype)
+@register_patch(torch.nn.functional._none_or_dtype, disallow_in_graph=require_explicit_disallow)
def patched_none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
if input is None:
return None
@@ -235,7 +302,7 @@ def patched_none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
-@patch(torch.nn.functional._in_projection_packed)
+@register_patch(torch.nn.functional._in_projection_packed, disallow_in_graph=require_explicit_disallow)
def patched_in_projection_packed(
q: Tensor,
k: Tensor,
@@ -313,7 +380,7 @@ def patched_in_projection_packed(
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
-@patch(torch.nn.functional.multi_head_attention_forward)
+@register_patch(torch.nn.functional.multi_head_attention_forward, disallow_in_graph=require_explicit_disallow)
def patched_multi_head_attention_forward(
query: Tensor,
key: Tensor,
@@ -738,9 +805,8 @@ def patched_multi_head_attention_forward(
return attn_output, None
-torch._dynamo.disallow_in_graph(torch.nn.modules.MultiheadAttention)
-
-
+torch._dynamo.disallow_in_graph(torch.nn.MultiheadAttention)
+@register_patch(torch.nn.MultiheadAttention.forward, "torch.nn.MultiheadAttention.forward", disallow_in_graph=require_explicit_disallow)
def patched_nn_multihead_attention_forward(
self,
query: Tensor,
@@ -922,12 +988,9 @@ def patched_nn_multihead_attention_forward(
return attn_output, attn_output_weights
-torch.nn.modules.MultiheadAttention.forward = patched_nn_multihead_attention_forward
-
-
if "2.0.0" <= torch.__version__ < "2.1.0":
# [SQZB] unflatten cannot be exported to ONNX
- @patch(torch.unflatten)
+ @register_patch(torch.unflatten)
def patched_unflatten(self, dim, sizes):
if dim == -1:
return self.reshape(*self.shape[:-1], *sizes)
@@ -936,15 +999,24 @@ def patched_unflatten(self, dim, sizes):
return self.reshape(*self.shape[:dim], *sizes, *self.shape[dim + 1 :])
- Tensor.unflatten = patched_unflatten
+ # [SQZB] unflatten cannot be exported to ONNX
+ @register_patch(Tensor.unflatten, "Tensor.unflatten")
+ def patched_tensor_unflatten(self, dim, sizes):
+ if dim == -1:
+ return self.reshape(*self.shape[:-1], *sizes)
+ if dim == 0:
+ return self.reshape(*sizes, *self.shape[1:])
+ return self.reshape(*self.shape[:dim], *sizes, *self.shape[dim + 1 :])
+
torch._dynamo.disallow_in_graph(torch.nn.Transformer)
torch._dynamo.disallow_in_graph(torch.nn.modules.TransformerEncoder)
torch._dynamo.disallow_in_graph(torch.nn.modules.TransformerDecoder)
- torch._dynamo.disallow_in_graph(torch.nn.modules.TransformerEncoderLayer)
- torch._dynamo.disallow_in_graph(torch.nn.modules.TransformerDecoderLayer)
- def pathed_nn_transformer_encoder_layer_forward(
+
+ torch._dynamo.disallow_in_graph(torch.nn.modules.TransformerEncoderLayer)
+ @register_patch(torch.nn.modules.TransformerEncoderLayer.forward, "torch.nn.modules.TransformerEncoderLayer.forward")
+ def patched_nn_transformer_encoder_layer_forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
@@ -1105,12 +1177,8 @@ def pathed_nn_transformer_encoder_layer_forward(
return x
-
- torch.nn.modules.TransformerEncoderLayer.forward = (
- pathed_nn_transformer_encoder_layer_forward
- )
-
-
+ torch._dynamo.disallow_in_graph(torch.nn.modules.TransformerDecoderLayer)
+ @register_patch(torch.nn.modules.TransformerDecoderLayer.forward, "torch.nn.modules.TransformerDecoderLayer.forward")
def patched_nn_transformer_decoder_layer_forward(
self,
tgt: Tensor,
@@ -1215,11 +1283,7 @@ def patched_nn_transformer_decoder_layer_forward(
return x
- torch.nn.modules.TransformerDecoderLayer.forward = (
- patched_nn_transformer_decoder_layer_forward
- )
-
-if "2.1.0" <= torch.__version__ < "2.2.0":
+if "2.1.0" <= torch.__version__ < "2.3.0":
force_dynamo_disallow_in_graph(torch.nn.Transformer)
force_dynamo_disallow_in_graph(torch.nn.modules.TransformerEncoder)
force_dynamo_disallow_in_graph(torch.nn.modules.TransformerDecoder)
@@ -1230,11 +1294,11 @@ def patched_nn_transformer_decoder_layer_forward(
# [SQZB] torch.nn.functional._none_or_dtype causes the error: "torch.* op returned non-Tensor bool"
# Made it a local function with no changes in its contents from torch==2.1.0
- @patch(torch.nn.modules.transformer._detect_is_causal_mask)
+ @register_patch(torch.nn.modules.transformer._detect_is_causal_mask, disallow_in_graph=require_explicit_disallow)
def patched_detect_is_causal_mask(
- mask: Optional[Tensor],
- is_causal: Optional[bool] = None,
- size: Optional[int] = None,
+ mask: Optional[Tensor],
+ is_causal: Optional[bool] = None,
+ size: Optional[int] = None,
) -> bool:
"""Return whether the given attention mask is causal.
@@ -1273,7 +1337,7 @@ def patched_detect_is_causal_mask(
return make_causal
- @patch(torch.nn.modules.transformer._get_seq_len)
+ @register_patch(torch.nn.modules.transformer._get_seq_len, disallow_in_graph=require_explicit_disallow)
def patched_get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]:
# [SQZB] Accessing src.is_nested causes an error when calling the graph module.
# The error message: "target builtins.getattr has type str but a Callable is expected"
@@ -1301,7 +1365,7 @@ def patched_get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]:
torchvision is not None
and version.parse("0.15.0") <= version.parse(torchvision.__version__) < version.parse("0.16.0")
):
- @patch(torchvision.models.swin_transformer.shifted_window_attention)
+ @register_patch(torchvision.models.swin_transformer.shifted_window_attention, disallow_in_graph=require_explicit_disallow)
def patched_shifted_window_attention(
input: Tensor,
qkv_weight: Tensor,
@@ -1455,11 +1519,12 @@ def patched_shifted_window_attention(
x = x[:, :H, :W, :].contiguous()
return x
+
if diffusers is not None:
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
from diffusers.models.attention import BasicTransformerBlock
- @patch(AttnProcessor2_0)
+ @register_patch(AttnProcessor2_0, disallow_in_graph=require_explicit_disallow)
class PatchedAttnProcessor2_0(torch.nn.Module):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
@@ -1550,7 +1615,11 @@ def forward(
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
- torch._dynamo.disallow_in_graph(Attention)
- torch._dynamo.disallow_in_graph(BasicTransformerBlock)
+
+ force_dynamo_disallow_in_graph(Attention)
+ force_dynamo_disallow_in_graph(BasicTransformerBlock)
# fmt: on
+
+if not DISABLE_AUTO_PATCH:
+ PatchManager.apply_patches()
diff --git a/owlite/backend/onnx/signature.py b/src/owlite/backend/signature.py
similarity index 62%
rename from owlite/backend/onnx/signature.py
rename to src/owlite/backend/signature.py
index 9564850..cbe3d99 100644
--- a/owlite/backend/onnx/signature.py
+++ b/src/owlite/backend/signature.py
@@ -1,4 +1,5 @@
import inspect
+from copy import deepcopy
from typing import Any, Callable, Optional, Union
import onnx_graphsurgeon as gs
@@ -6,7 +7,9 @@
from onnx import ModelProto
from torch.fx.graph_module import GraphModule
-from ...options import DynamicAxisOptions, DynamicInputOptions
+from ..options import DynamicAxisOptions, DynamicInputOptions
+from ..owlite_core.logger import log
+from .utils import normalize_parameter_name
class DynamicSignature(list[tuple[str, tuple[Union[int, str, tuple[int, ...]], ...]]]):
@@ -81,20 +84,24 @@ def from_module(
Union["Signature", DynamicSignature]: A `Signature` object if `options` is `None`,
`DynamicSignature` object otherwise.
"""
- if isinstance(module, GraphModule) and isinstance(original_params := module.meta.get("original_params"), dict):
- modified_params = inspect.signature(module.forward).parameters
- signature_map = {
- k: v for k, v in map_signature(original_params, *args, **kwargs).items() if k in modified_params
- }
+ if isinstance(module, GraphModule):
+ signature_map = map_graph_module_signature(module, *args, **kwargs)
else:
signature_map = map_signature(module.forward, *args, **kwargs)
- signature = cls(
- (
- name,
- tuple(value.shape) if isinstance(value, torch.Tensor) else (),
- )
- for name, value in signature_map.items()
- )
+
+ signature = cls()
+ for name, value in signature_map.items():
+ if isinstance(value, torch.Tensor):
+ signature.append((name, value.shape))
+ continue
+ if isinstance(value, tuple):
+ signature.extend((f"{name}_{i}", t.shape) for i, t in enumerate(value) if isinstance(t, torch.Tensor))
+ continue
+ if isinstance(value, dict):
+ signature.extend((k, t.shape) for k, t in value.items() if isinstance(t, torch.Tensor))
+ continue
+ signature.append((name, ()))
+
if options is not None:
return dynamize_signature(signature, options)
return signature
@@ -171,13 +178,14 @@ def map_signature(
params = (
dict(inspect.signature(func_or_its_params).parameters.items())
if callable(func_or_its_params)
- else func_or_its_params
+ else deepcopy(func_or_its_params)
)
+ names = list(params)
var_pos: Optional[tuple[int, str]] = None
var_key: Optional[tuple[int, str]] = None
- # mapped: dict[str, Any] = {name: inspect._empty for name in params}
mapped: dict[str, Any] = {}
+
for i, (name, param) in enumerate(params.items()):
if param.kind == inspect._ParameterKind.VAR_POSITIONAL:
var_pos = (i, name)
@@ -186,25 +194,82 @@ def map_signature(
var_key = (i, name)
mapped[name] = {}
- for name, val in kwargs.items():
- if name in params:
- mapped[name] = val
- params.pop(name)
- continue
- if var_key is not None:
- var_key_name = var_key[1]
- mapped[var_key_name][name] = val
- params.pop(var_key_name)
+ for i, val in enumerate(args):
+ if var_pos is not None and i >= var_pos[0]:
+ mapped[var_pos[1]] += (val,)
continue
+ mapped[names[i]] = val
- names = list(params)
- for i, val in enumerate(args):
- if i < len(names):
- name = names[i]
- mapped[name] = val
+ for name, val in kwargs.items():
+ if var_key is not None and name not in names:
+ mapped[var_key[1]][name] = val
continue
- if var_pos is not None and i >= var_pos[0]:
- var_pos_name = var_pos[1]
- mapped[var_pos_name] += (val,)
+ mapped[name] = val
return mapped
+
+
+def map_graph_module_signature(
+ module: GraphModule,
+ *args: Any,
+ **kwargs: Any,
+) -> dict[str, Any]:
+ """Maps the args and kwargs to the parameters of the forward method of a graph module
+ generated by `owlite.fx.symbolic_trace`. If the graph module doesn't have meta data 'original_params',
+ automatically falls back to `map_signature`.
+
+ Args:
+ module (GraphModule): a graph module generated by `owlite.fx.symbolic_trace`
+ args (Any): Positional arguments.
+ kwargs (Any): Keyword arguments.
+
+ Returns:
+ dict[str, Any]: the mapped signatures
+ """
+ if (original_params := module.meta.get("original_params", None)) is None:
+ log.debug_warning("This graph module has no meta data 'original_params'")
+ return map_signature(module.forward, *args, **kwargs)
+ modified_params = {normalize_parameter_name(k): v for k, v in inspect.signature(module.forward).parameters.items()}
+ mapped_signature = map_signature(original_params, *args, **kwargs)
+ signature_map: dict[str, Any] = {}
+ for p, (k, v) in enumerate(mapped_signature.items()):
+ if k in modified_params:
+ signature_map[k] = v
+ continue
+ if isinstance(v, tuple):
+ # variadic positional arguments are flattened by torch.compile
+ # e.g. `def forward(self, *args, x)` -> `def forward(self, args_0, args_1, x)`
+ # or `def forward(self, args_0_, args_1_, x)`
+ # when two arguments are provided by the user, depending on torch version and the host OS
+ success = True
+ for i, x in enumerate(v):
+ for name in (f"{k}_{i}", f"{k}_{i}_"):
+ if name in modified_params:
+ signature_map[name] = x
+ break
+ else:
+ success = False
+ log.debug_warning(
+ f"Failed to map {i}-th variadic positional argument {k} of the graph module's forward method"
+ )
+ if success:
+ continue
+ if isinstance(v, dict):
+ # variadic keyword arguments are flattened by torch.compile
+ # e.g. `def forward(self, x, **kwargs)` -> `def forward(self, x, y, z)`
+ # when the model was called as `output = model(a, y=b, z=c)`
+ for name, x in v.items():
+ if name in modified_params:
+ signature_map[name] = x
+ else:
+ log.debug_warning(
+ f"Failed to map the variadic positional argument {k} with key {name} "
+ "of the graph module's forward method"
+ )
+ continue
+ if any(arg_name in modified_params for arg_name in (f"args_{p}", f"args_{p}_")):
+ # Rarely, arguments can be squashed as a variadic position argument `args`
+ signature_map[k] = v
+ continue
+ log.debug_warning(f"Failed to map signature of {p}-th parameter {k} of the graph module's forward method")
+ return signature_map
diff --git a/owlite/backend/utils.py b/src/owlite/backend/utils.py
similarity index 90%
rename from owlite/backend/utils.py
rename to src/owlite/backend/utils.py
index 5f0f6b6..2ebdad1 100644
--- a/owlite/backend/utils.py
+++ b/src/owlite/backend/utils.py
@@ -10,12 +10,30 @@
import torch
from onnx import ModelProto, TensorProto
from onnx import NodeProto as ONNXNode
-from onnx_graphsurgeon.importers.onnx_importer import get_numpy_type
+
+try:
+ from onnx_graphsurgeon.importers.onnx_importer import get_numpy_type
+except ImportError:
+ import onnx.mapping
+
+ # pylint: disable-next=missing-function-docstring
+ def get_numpy_type(onnx_type: Union[int, "TensorProto.DataType", np.dtype]) -> Optional[np.dtype]:
+ if not isinstance(onnx_type, int):
+ # Already a NumPy type
+ return onnx_type
+
+ # For some reason, TENSOR_TYPE_TO_NP_TYPE maps `bfloat16` to `float32`.
+ # This obviously breaks things, so we need to treat this as a special case.
+ if onnx_type != onnx.TensorProto.BFLOAT16 and onnx_type in onnx.mapping.TENSOR_TYPE_TO_NP_TYPE:
+ return onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[onnx_type]
+ return None
+
+
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node as FXNode
from torch.fx.node import Target as FXTarget
-from owlite_core.logger import log
+from ..owlite_core.logger import log
AnyNode = Union[FXNode, ONNXNode, gs.Node]
@@ -106,11 +124,11 @@ def targetstr(target: FXTarget) -> str:
return f"{target}"
-def typestr(tensor: torch.Tensor) -> str:
- """Generates the MLIR-like string representation of the type of a torch.Tensor instance.
+def typestr(tensor: Union[torch.Tensor, np.ndarray]) -> str:
+ """Generates the MLIR-like string representation of the type of a torch.Tensor or np.ndarray instance.
Args:
- tensor (torch.Tensor): a tensor
+ tensor (Union[torch.Tensor, np.ndarray]): a tensor or ndarray
Returns:
str: the string representation of the type of the tensor
@@ -371,10 +389,10 @@ def convert_to_fp_ndarray(x: Union[Number, np.ndarray, torch.Tensor]) -> np.ndar
elif isinstance(x, Number):
x = np.array(x)
- if not is_floating_point(x.dtype):
- x = x.astype(np.float32)
+ if not is_floating_point(x.dtype): # type: ignore
+ x = x.astype(np.float32) # type: ignore
- return x
+ return x # type: ignore
def is_onnx_proto_data_external(onnx_proto: ModelProto) -> bool:
@@ -388,3 +406,19 @@ def is_onnx_proto_data_external(onnx_proto: ModelProto) -> bool:
"""
external_count = [i.data_location for i in onnx_proto.graph.initializer].count(TensorProto.EXTERNAL)
return 2 * external_count >= len(onnx_proto.graph.initializer)
+
+
+def normalize_parameter_name(name: str) -> str:
+ """Normalizes the forward method's parameter names that can be possibly renamed by `torch.compile`
+
+ Args:
+ name (str): a possibly modified name
+
+ Returns:
+ str: the original name
+ """
+ if name.startswith("L_kwargs_") and name.endswith("_"):
+ return name[9:-1]
+ if name.startswith("L_") and name.endswith("_"):
+ return name[2:-1]
+ return name
diff --git a/owlite/calib/__init__.py b/src/owlite/calib/__init__.py
similarity index 100%
rename from owlite/calib/__init__.py
rename to src/owlite/calib/__init__.py
index 49727a9..7b06061 100644
--- a/owlite/calib/__init__.py
+++ b/src/owlite/calib/__init__.py
@@ -1,5 +1,5 @@
from .absmax_calibrator import AbsmaxCalibrator
+from .entropy_calibrator import EntropyCalibrator
from .minmax_calibrator import MinmaxCalibrator
from .mse_calibrator import MSECalibrator
from .percentile_calibrator import PercentileCalibrator
-from .entropy_calibrator import EntropyCalibrator
diff --git a/owlite/calib/absmax_calibrator.py b/src/owlite/calib/absmax_calibrator.py
similarity index 90%
rename from owlite/calib/absmax_calibrator.py
rename to src/owlite/calib/absmax_calibrator.py
index 711bbed..d9e1de4 100644
--- a/owlite/calib/absmax_calibrator.py
+++ b/src/owlite/calib/absmax_calibrator.py
@@ -3,8 +3,7 @@
import torch
from torch.utils.hooks import RemovableHandle
-from owlite_core.logger import log
-
+from ..owlite_core.logger import log
from .calibrator import Calibrator
if TYPE_CHECKING:
@@ -12,19 +11,20 @@
class AbsmaxCalibrator(Calibrator):
- """Absmax calibrator.
+ r"""Absmax Calibrator Class.
The AbsMaxCalibration calibrator stores the **maximum absolute value** encountered in the passed data,
utilizing this value as the quantization range. When the original data is represented by $$X$$, the `step_size`
is calculated as:
+
$$
- \text{step\\_size} = \frac{\\max_{x \\in X}(
- |x |)}{ \text{quant\\_max} - \text{quant\\_min}}
+ \text{step\_size}=\frac{\max_{x \\in X}(|x|)}{\text{quant\_max}-\text{quant\_min}}
$$
+
This approach eliminates clipping errors, but potentially introduces significant rounding errors.
Attributes:
- absmax (Optional[torch.Tensor]): absolute maximum value of data passing through the quantizer.
+ absmax (`torch.Tensor`, `optional`): absolute maximum value of data passing through the quantizer.
"""
def __init__(self, quantizer: "FakeQuantizer"):
@@ -32,7 +32,8 @@ def __init__(self, quantizer: "FakeQuantizer"):
self.absmax: Optional[torch.Tensor] = None
def prepare(self) -> RemovableHandle:
- # define forward hook function
+ """Defines forward hook function."""
+
def absmax_forward_hook_func(module: "FakeQuantizer", inputs: tuple[Any, ...], output: Any) -> Optional[Any]:
"""forward hook function to get absmax value"""
@@ -72,9 +73,9 @@ def absmax_forward_hook_func(module: "FakeQuantizer", inputs: tuple[Any, ...], o
return self.hook_handler
def update(self) -> None:
+ """Updates step_size using "`absmax`"."""
assert self.absmax is not None
assert self.quantizer.step_size.data.shape == self.absmax.shape
- # update step_size using "abs max"
if not self.check_calib_ready():
raise RuntimeError("Not all conditions for calibration were not met.")
assert isinstance(self.hook_handler, RemovableHandle)
diff --git a/owlite/calib/calibrator.py b/src/owlite/calib/calibrator.py
similarity index 74%
rename from owlite/calib/calibrator.py
rename to src/owlite/calib/calibrator.py
index 1d6370a..a8e0d3d 100644
--- a/owlite/calib/calibrator.py
+++ b/src/owlite/calib/calibrator.py
@@ -3,7 +3,7 @@
from torch.utils.hooks import RemovableHandle
-from owlite_core.logger import log
+from ..owlite_core.logger import log
if TYPE_CHECKING:
from ..nn import FakeQuantizer
@@ -15,9 +15,13 @@ class Calibrator(ABC):
Uses the forward hook to collect the data needed for calibration and update the quantizer's
step_size and zero_point.
+ In **OwLite**, calibrator classes collect the necessary data for calibration based on the data passing through
+ the `FakeQuantizer`. This process enables the determination of the `FakeQuantizer`'s `step_size` and `zero_point`.
+ Currently, **OwLite** only supports symmetric quantization, so `zero_point` is fixed to 0.
+
Attributes:
- hook_handler (Optional[torch.utils.hooks.RemovableHandle]): A hook handler.
- quantizer (FakeQuantizer): The `FakeQuantizer` to which the calibration will be applied.
+ hook_handler (`torch.utils.hooks.RemovableHandle`, `optional`): A hook handler.
+ quantizer (`FakeQuantizer`): The `FakeQuantizer` to which the calibration will be applied.
"""
def __init__(self, quantizer: "FakeQuantizer"):
diff --git a/owlite/calib/entropy_calibrator.py b/src/owlite/calib/entropy_calibrator.py
similarity index 93%
rename from owlite/calib/entropy_calibrator.py
rename to src/owlite/calib/entropy_calibrator.py
index 371e955..ff55dd6 100644
--- a/owlite/calib/entropy_calibrator.py
+++ b/src/owlite/calib/entropy_calibrator.py
@@ -10,17 +10,19 @@ class EntropyCalibrator(HistogramCalibrator):
r"""Entropy Calibrator Class
The EntropyCalibrator compares the distribution of original data and
the distribution of quantized data using KL divergence.
- When the original data $X$ and the quantized data $X_{quant}$ is given,
- the $step\_size$ is calculated as follow:
+ When the original data $$ X $$ and the quantized data $$ X_{quant} $$ is given,
+ the $$step\_size$$ is calculated as follow:
+
$$
step\_size = \underset {step\_size}{\operatorname{argmax}} \,
KL \left( X || X_{quant} \right)
$$
+
This approach minimizes the divergence between two distributions.
"""
def update(self) -> None:
- # update step_size using "entropy"
+ """Updates step_size using "`entropy`"."""
super().update()
for chn, _ in enumerate(self.histc_bins):
diff --git a/owlite/calib/histogram_calibrator.py b/src/owlite/calib/histogram_calibrator.py
similarity index 93%
rename from owlite/calib/histogram_calibrator.py
rename to src/owlite/calib/histogram_calibrator.py
index a4d5bcf..82fa0c0 100644
--- a/owlite/calib/histogram_calibrator.py
+++ b/src/owlite/calib/histogram_calibrator.py
@@ -4,8 +4,7 @@
import torch
from torch.utils.hooks import RemovableHandle
-from owlite_core.logger import log
-
+from ..owlite_core.logger import log
from .calibrator import Calibrator
if TYPE_CHECKING:
@@ -13,12 +12,12 @@
class HistogramCalibrator(Calibrator, ABC):
- """Histogram calibrator.
+ """Histogram Calibrator Class.
Attributes:
- histogram(list[torch.Tensor]): list of histogram counts. Each element defaults to [0, ..., 0], len = 2048.
- bin_edges(list[torch.Tensor]): histogram edges. Each element defaults to [0, ..., 0], len = 2048.
- histc_bins(list[torch.Tensor]): number of histogram bins. Each element defaults to 2048.
+ histogram(`list[torch.Tensor]`): list of histogram counts. Each element defaults to [0, ..., 0], len = 2048.
+ bin_edges(`list[torch.Tensor]`): histogram edges. Each element defaults to [0, ..., 0], len = 2048.
+ histc_bins(`list[torch.Tensor]`): number of histogram bins. Each element defaults to 2048.
"""
def __init__(self, quantizer: "FakeQuantizer"):
@@ -29,7 +28,8 @@ def __init__(self, quantizer: "FakeQuantizer"):
self.histc_bins: list[torch.Tensor] = []
def prepare(self) -> RemovableHandle:
- # define forward hook function
+ """Defines forward hook function."""
+
def histogram_forward_hook_func(module: "FakeQuantizer", inputs: tuple[Any, ...], output: Any) -> Optional[Any]:
"""Forward hook function to get histogram value"""
diff --git a/owlite/calib/minmax_calibrator.py b/src/owlite/calib/minmax_calibrator.py
similarity index 93%
rename from owlite/calib/minmax_calibrator.py
rename to src/owlite/calib/minmax_calibrator.py
index 9fe7432..3007da4 100644
--- a/owlite/calib/minmax_calibrator.py
+++ b/src/owlite/calib/minmax_calibrator.py
@@ -3,8 +3,7 @@
import torch
from torch.utils.hooks import RemovableHandle
-from owlite_core.logger import log
-
+from ..owlite_core.logger import log
from .calibrator import Calibrator
if TYPE_CHECKING:
@@ -12,14 +11,14 @@
class MinmaxCalibrator(Calibrator):
- """Minmax calibrator
+ """Minmax Calibrator Class.
Minmax calibration to set step_size and zero_point using min-max during calibration
- with asymmetric quantization
+ with asymmetric quantization.
Attributes:
- max_value (Optional[torch.Tensor]): maximum value of data passing through the quantizer.
- min_value (Optional[torch.Tensor]): minimum value of data passing through the quantizer.
+ max_value (`torch.Tensor`, `optional`): maximum value of data passing through the quantizer.
+ min_value (`torch.Tensor`, `optional`): minimum value of data passing through the quantizer.
"""
def __init__(self, quantizer: "FakeQuantizer"):
diff --git a/owlite/calib/mse_calibrator.py b/src/owlite/calib/mse_calibrator.py
similarity index 69%
rename from owlite/calib/mse_calibrator.py
rename to src/owlite/calib/mse_calibrator.py
index a8265c5..1429e21 100644
--- a/owlite/calib/mse_calibrator.py
+++ b/src/owlite/calib/mse_calibrator.py
@@ -4,20 +4,23 @@
class MSECalibrator(HistogramCalibrator):
- r"""MSE Calibrator Class
- The MSECalibrator solves the $step\_size$ that minimizeds the mean squared error (MSE)
+ r"""MSE Calibrator Class.
+
+ The MSECalibrator solves the $$ step\_size $$ that minimizeds the mean squared error (MSE)
between the original data and its quantized representation.
- When the original data $X$ and the quantized representation $X_{quant}$ is given,
- the optimal $step\_size$ is calculated as follow:
+ When the original data $$ X $$ and the quantized representation $$ X_{quant} $$ is given,
+ the optimal $$ step\_size $$ is calculated as follow:
+
$$
step\_size = \underset {step\_size}{\operatorname{argmax}} \,
{\left\| X - X_{quant} \right\|}^2_2
$$
+
This approach minimizes the mean squared error between two data.
"""
def update(self) -> None:
- # update step_size using "mse"
+ """Updates step_size using "`mse`"."""
super().update()
for chn, _ in enumerate(self.histc_bins):
@@ -30,18 +33,13 @@ def update(self) -> None:
min_mse = np.inf
last_argmin = stop
-
+
for max_idx in range(512, stop + 1, 24):
in_distribution = np.arange(0.5, max_idx)[valid[:max_idx]]
in_distribution_error = (
- in_distribution
- - (nbins * in_distribution / (max_idx-0.5)).round()
- * (max_idx-0.5)
- / nbins
- ) ** 2
- out_distribution_error = (
- np.arange(1, stop-max_idx+1)[valid[max_idx:]]
+ in_distribution - (nbins * in_distribution / (max_idx - 0.5)).round() * (max_idx - 0.5) / nbins
) ** 2
+ out_distribution_error = (np.arange(1, stop - max_idx + 1)[valid[max_idx:]]) ** 2
mse = (np.append(in_distribution_error, out_distribution_error) * valid_bins).sum()
diff --git a/owlite/calib/percentile_calibrator.py b/src/owlite/calib/percentile_calibrator.py
similarity index 73%
rename from owlite/calib/percentile_calibrator.py
rename to src/owlite/calib/percentile_calibrator.py
index de08af5..6ac0285 100644
--- a/owlite/calib/percentile_calibrator.py
+++ b/src/owlite/calib/percentile_calibrator.py
@@ -10,11 +10,16 @@
class PercentileCalibrator(HistogramCalibrator):
- """Percentile Calibrator Class
+ """Percentile Calibrator Class.
+
+ This calibrator also utilizes the data's histogram. However, instead of minimizing an error metric, it employs a
+ heuristic approach based on a pre-specified percentile. The value corresponding to the chosen percentile is set as
+ the **maximum absolute value**, and the `step_size` is calculated accordingly. By tuning percentile, user can
+ control trade-off between quantization accuracy and outlier removal.
Attributes:
- quantizer (FakeQuantizer): The `FakeQuantizer` module to be calibrated.
- percentile (float): The desired percentile value, ranging from 0 to 100.
+ quantizer (`FakeQuantizer`): The `FakeQuantizer` module to be calibrated.
+ percentile (`float`): The desired percentile value, ranging from 0 to 100.
"""
@@ -33,7 +38,7 @@ def __init__(self, quantizer: "FakeQuantizer", percentile: float):
self.percentile = percentile
def update(self) -> None:
- # update step_size using "percentile"
+ """Updates step_size using "`percentile`"."""
super().update()
assert isinstance(self.hook_handler, RemovableHandle)
diff --git a/owlite/calibrators.py b/src/owlite/calibrators.py
similarity index 55%
rename from owlite/calibrators.py
rename to src/owlite/calibrators.py
index b11e6ee..0849aaf 100644
--- a/owlite/calibrators.py
+++ b/src/owlite/calibrators.py
@@ -3,18 +3,17 @@
from torch.nn.parallel import DataParallel, DistributedDataParallel
-from owlite_core.logger import log
-
from .backend.fx.types import GraphModuleOrDataParallel
from .enums import OwLiteStatus
from .nn import FakeQuantizer
+from .owlite_core.logger import log
-def prepare_for_calibration(model: GraphModuleOrDataParallel) -> None:
+def _prepare_for_calibration(model: GraphModuleOrDataParallel) -> None:
"""Create a calibrator and prepare calibration according to opt.
Args:
- model(GraphModule): graph module to calibrate.
+ model(`GraphModuleOrDataParallel`): graph module to calibrate.
"""
log.info("Preparing for calibration") # UX
for _, module in model.named_modules(remove_duplicate=True):
@@ -25,11 +24,11 @@ def prepare_for_calibration(model: GraphModuleOrDataParallel) -> None:
log.info("Calibrating the model") # UX
-def update_fake_quantizers(model: GraphModuleOrDataParallel) -> None:
+def _update_fake_quantizers(model: GraphModuleOrDataParallel) -> None:
"""Calculate step size and zero point using data of calibrator and enabling quantization
Args:
- model(GraphModuleOrDataParallel): model to calibrate.
+ model(`GraphModuleOrDataParallel`): model to calibrate.
"""
log.info("Updating fake quantizers based on collected data")
for name, module in model.named_modules(remove_duplicate=True):
@@ -61,7 +60,7 @@ def __init__(self, model: GraphModuleOrDataParallel):
self.model = model
def __enter__(self) -> GraphModuleOrDataParallel:
- prepare_for_calibration(self.model)
+ _prepare_for_calibration(self.model)
return self.model
def __exit__(
@@ -70,22 +69,50 @@ def __exit__(
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
- update_fake_quantizers(self.model)
+ _update_fake_quantizers(self.model)
def calibrate(model: GraphModuleOrDataParallel) -> CalibrationContext:
"""Calibration is performed using the supplied data within a 'with' statement.
- Set the step_size and zero_point of the fake quantizers using the calibrator that the fake quantizers.
- with calibrate(model):
- ... # feed data to model and store information from it.
- ... # calculate fake quantizers step_sizes and zero_points
+ `owlite.calibrate` performs Post-Training Quantization (PTQ) calibration on a model converted with the
+ `OwLite.convert`. It is required to preserve the model's accuracy by carefully selecting the quantization
+ hyperparameters (the scale and zero-point). PTQ calibration typically requires only a subset of the training data.
+ Please review the
+ [Calibrator](https://squeezebits.gitbook.io/owlite/python-api/owlite.calibrators/owlite.calib.calibrator)
+ for technical details.
Args:
- model: GraphModule or DataParallel model to calibrate.
+ model(`GraphModuleOrDataParallel`): GraphModule or DataParallel model to calibrate.
Returns:
CalibrationContext
+
+ ### Usage
+
+ `owlite.calibrate` returns an `owlite.CalibratorContext` object from the OwLite library can be used with a `with`
+ statement to perform calibration. The `CalibratorContext` prepares the model for calibration and updates
+ the model's fake quantizers after calibration is complete.
+
+ **Example**
+
+ ```python
+ with owlite.calibrate(model):
+ for i, data in enumerate(train_loader):
+ model(*data) # feed data to model and store information from it.
+ # calculate fake quantizers step_sizes and zero_points
+
+ # You should use the `model` outside of the block after the calibration
+ torch.save(model.state_dict())
+ ```
+
+ In this example, the `owlite.calibrate` creates an `owlite.CalibratorContext`,
+ referenced by the variable `calibrator`. The training data fetched from `train_loader`
+ are then passed to the `calibrator` to perform calibration.
+
+ Note that you should continue writing your code outside of the `with` block since the fake quantizers
+ in the model are updated as the `with` block exits.
+
"""
return CalibrationContext(model)
diff --git a/owlite/compress.py b/src/owlite/compression.py
similarity index 76%
rename from owlite/compress.py
rename to src/owlite/compression.py
index 66d1f7b..ef4c6ed 100644
--- a/owlite/compress.py
+++ b/src/owlite/compression.py
@@ -1,14 +1,30 @@
+r"""Quantization is a powerful technique used to reduce the storage and computational requirements of deep learning
+models. However, this reduction in precision can potentially hurt model accuracy. Calibration is a crucial step in
+quantization that helps mitigate this accuracy loss.
+
+Calibration involves measuring the distributions of the activations in the model and using this information
+to determine the optimal quantization parameters. This process involves:
+
+1. Collecting data: A representative dataset, called the **calibration dataset**, is used to evaluate
+the trained floating-point model.
+
+2. Analyzing data: Statistics about the activation or weight distributions are collected.
+Understanding how the data is spread across different values within each layer.
+
+3. Selecting quantization parameters: These parameters, such as the quantization step\_size and zero\_point,
+are determined using one of several optimization objectives.
+The goal is to find the best balance between minimizing quantization error and preserving model accuracy.
+"""
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
-from owlite_core.logger import log
-
from .backend.fx.node import find_constant_nodes
from .backend.fx.node_configurator import NodeConfigurator
from .backend.fx.transforms import fuse_linear_bn_with_quantized_bias
from .enums import OwLiteStatus
from .nn import FakeQuantizer, enable_quantizers
from .options.compression_option import CompressionOptions, FakeQuantizerConfig
+from .owlite_core.logger import log
def compress(model: GraphModule, option: CompressionOptions) -> GraphModule:
diff --git a/owlite/enums/__init__.py b/src/owlite/enums/__init__.py
similarity index 100%
rename from owlite/enums/__init__.py
rename to src/owlite/enums/__init__.py
diff --git a/owlite/enums/dtype.py b/src/owlite/enums/dtype.py
similarity index 100%
rename from owlite/enums/dtype.py
rename to src/owlite/enums/dtype.py
diff --git a/owlite/enums/option_key_type.py b/src/owlite/enums/option_key_type.py
similarity index 100%
rename from owlite/enums/option_key_type.py
rename to src/owlite/enums/option_key_type.py
diff --git a/owlite/enums/owlite_status.py b/src/owlite/enums/owlite_status.py
similarity index 100%
rename from owlite/enums/owlite_status.py
rename to src/owlite/enums/owlite_status.py
diff --git a/owlite/enums/param_status.py b/src/owlite/enums/param_status.py
similarity index 100%
rename from owlite/enums/param_status.py
rename to src/owlite/enums/param_status.py
diff --git a/owlite/enums/ptq_calibration_type.py b/src/owlite/enums/ptq_calibration_type.py
similarity index 100%
rename from owlite/enums/ptq_calibration_type.py
rename to src/owlite/enums/ptq_calibration_type.py
diff --git a/owlite/enums/qat_backward_type.py b/src/owlite/enums/qat_backward_type.py
similarity index 100%
rename from owlite/enums/qat_backward_type.py
rename to src/owlite/enums/qat_backward_type.py
diff --git a/owlite/nn/__init__.py b/src/owlite/nn/__init__.py
similarity index 100%
rename from owlite/nn/__init__.py
rename to src/owlite/nn/__init__.py
diff --git a/src/owlite/nn/functions/__init__.py b/src/owlite/nn/functions/__init__.py
new file mode 100644
index 0000000..e2186f0
--- /dev/null
+++ b/src/owlite/nn/functions/__init__.py
@@ -0,0 +1,50 @@
+"""Quantization Aware Training (QAT) is a technique that allows the model to learn the quantization error during the
+training process. QAT aims to minimize the loss of accuracy during the quantization process, thus making the model
+smaller and faster while maintaining as much of its accuracy as possible. OwLite makes QAT easier, requiring only
+minimal changes to an existing training code.
+
+Please review the subdocuments for technical details.
+
+## Usage
+
+To use QAT with OwLite, you can follow your standard training procedure, keeping in mind two aspects:
+
+* QAT is a process that needs to be performed after the
+[convert](https://squeezebits.gitbook.io/owlite/python-api/owlite.owlite.owlite/owlite.owlite.convert) stage, where
+you have applied the compression configuration in experiment mode using OwLite.
+* If the optimizer for training was set before calling the convert method, you should set the optimizer again with
+the new parameter of the converted mode
+
+Please note that the model converted by OwLite has a fixed batch size. Therefore, you need to set `drop_last=True`
+when creating your [torch.utils.data.DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
+object.
+
+For example:
+
+```python
+DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
+ batch_sampler=None, num_workers=0, collate_fn=None,
+ pin_memory=False, drop_last=True, timeout=0,
+ worker_init_fn=None, *, prefetch_factor=2,
+ persistent_workers=False)
+```
+
+This ensures that the DataLoader will discard the last remaining batch if the dataset size is not divisible
+by the batch size.
+
+## Tips for Better Results
+
+If you are getting unsatisfactory results from your training, consider adjusting the learning rate or the weight decay.
+Lowering the learning rate can help the model converge more smoothly while reducing the weight decay can help prevent
+the model from over-fitting.
+
+* **Adjust the Learning Rate**: If the training loss fluctuates, consider reducing the learning rate to stabilize
+the training of the compressed model. In this way, the model learns more effectively, leading to better performance.
+
+* **Reduce Weight Decay**: Similarly, if the learning process is fluctuating, consider reducing the weight decay
+to stabilize the training of the compressed model. In this way, the model generalizes better for unseen data.
+"""
+from .clq import clq_function
+from .clq_plus import clq_plus_function
+from .fake_quantize import FakeQuantizeSignature, fake_quantize
+from .ste import ste_function
diff --git a/owlite/nn/functions/clq.py b/src/owlite/nn/functions/clq.py
similarity index 66%
rename from owlite/nn/functions/clq.py
rename to src/owlite/nn/functions/clq.py
index 8472f4a..6a7a588 100644
--- a/owlite/nn/functions/clq.py
+++ b/src/owlite/nn/functions/clq.py
@@ -11,7 +11,21 @@
# mypy: disable-error-code=override
# pylint: disable-next=abstract-method
class CLQFunction(Function):
- """An implementation of QAT function using CLQ (Constrained Learned Quantization)"""
+ r"""An implementation of QAT function using CLQ (Constrained Learned Quantization)
+ In **CLQ(Constrained Learned Quantization)** method, instead of using a fixed set of quantization levels,
+ this method adapts the scales during training to minimize the impact on model performance. Learnable step_size
+ allows the model to be better adapted to the distribution of fed data.
+ ### Gradient of step\_size
+
+ When $$x$$ is input of $$FakeQuantize$$ and $$s$$ is step\_size of $$FakeQuantize$$
+
+ $$
+ \dfrac{\partial \hat{x}}{\partial s}= \begin{cases} \left( -\dfrac{x}{|s|}+\left\lceil{\dfrac{x}{|s|}}
+ \right\rfloor \right) \cdot \text{sign}(s) & \text{if, } \text{quant\_min} < \dfrac{x}{|s|} < \text{qant\_max}
+ \\ \\ \text{quant\_min} \cdot \text{sign}(s) &\text{if, }\dfrac{x}{|s|}\leq \text{quant\_min} \\
+ \\ \text{quant\_max}\cdot \text{sign}(s) &\text{if, } \dfrac{x}{|s|}\geq \text{quant\_max} \end{cases}
+ $$
+ """
@staticmethod # pylint: disable-next=arguments-differ
def forward(
diff --git a/owlite/nn/functions/clq_plus.py b/src/owlite/nn/functions/clq_plus.py
similarity index 100%
rename from owlite/nn/functions/clq_plus.py
rename to src/owlite/nn/functions/clq_plus.py
diff --git a/src/owlite/nn/functions/fake_quantize.py b/src/owlite/nn/functions/fake_quantize.py
new file mode 100644
index 0000000..d889383
--- /dev/null
+++ b/src/owlite/nn/functions/fake_quantize.py
@@ -0,0 +1,84 @@
+from typing import Callable, Optional
+
+import torch
+from torch import Tensor
+
+
+def fake_quantize(
+ inputs: Tensor,
+ step_size: Tensor,
+ zero_point: Tensor,
+ quant_min: int,
+ quant_max: int,
+ axis: Optional[int] = None,
+) -> torch.Tensor:
+ r"""Same as `torch.fake_quantize_per_channel_affine` if `per_channel` is `True`, otherwise
+ `torch.fake_quantize_per_tensor_affine`
+
+ In OwLite, quantization is simulated through the following mathematical expression:
+
+ $$
+ \small
+
+ \text{FakeQuantize}(\text{input})= \text{clip} \left( {\lfloor \frac{\text{input} - \text{zero\_point}}{\text
+ {step\_size}} \rceil }, \text{quant\_min}, \text{quant\_max} \right) \cdot \text{step\_size} + \text{zero\_point}
+
+ $$
+
+ The primary objective of exporting to the Open Neural Network Exchange (ONNX) format is to facilitate deployment
+ on TensorRT rather than the ONNX runtime. Consequently, the export process is confined to transforming the model
+ into a format compatible with TensorRT, specifically one that supports fake quantization.
+ The incorporation of fake quantization involves the decomposition of the model into `QuantizeLinear` and
+ `DequantizeLinear` operations within the ONNX specification. Subsequently, TensorRT is entrusted with the task
+ of ingesting the resultant ONNX graph and executing it in INT8 format, optimizing the process to the fullest extent
+ of its capabilities. For more information, see the [TensorRT Developer Guide's section on Explicit
+ Quantization](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#qat-models-work).
+
+ Args:
+ inputs (`torch.Tensor`): A tensor to quantize.
+ step_size (`torch.Tensor`): The quantization scale, determining the magnitude of each quantization interval.
+ zero_point (`torch.Tensor`): The quantization zero\\_point. It may be expressed as a float in the context of
+ asymmetric quantization, while for symmetric quantization, it is fixed at 0.
+ quant_min (`int`): The lower bound of the quantized domain, specified as an integer.
+ quant_max (`int`): The upper bound of the quantized domain in as an integer.
+ axis (`int`, optional): Channel axis. Only used when `per_channel` is `True`. Defaults to 0.
+
+ Returns:
+ torch.Tensor: fake-quantized tensor
+
+ """
+ if axis is not None:
+ return torch.fake_quantize_per_channel_affine(
+ inputs,
+ step_size,
+ zero_point,
+ axis,
+ quant_min,
+ quant_max,
+ )
+
+ return torch.fake_quantize_per_tensor_affine(
+ inputs,
+ step_size,
+ # `torch.fake_quantize_per_tensor_affine` expects `zero_point` to be either int32 or int64
+ # (See https://pytorch.org/docs/stable/generated/torch.fake_quantize_per_tensor_affine.html)
+ # while `torch.fake_quantize_per_channel_affine` doesn't
+ zero_point,
+ quant_min=quant_min,
+ quant_max=quant_max,
+ )
+
+
+FakeQuantizeSignature = Callable[
+ [
+ Tensor, # inputs
+ Tensor, # step_size
+ Tensor, # zp
+ float, # grad_scale
+ int, # quant_min
+ int, # quant_max
+ Optional[int], # axis
+ bool, # compensate_zp
+ ],
+ Tensor,
+]
diff --git a/owlite/nn/functions/ste.py b/src/owlite/nn/functions/ste.py
similarity index 68%
rename from owlite/nn/functions/ste.py
rename to src/owlite/nn/functions/ste.py
index c50047a..caf51d1 100644
--- a/owlite/nn/functions/ste.py
+++ b/src/owlite/nn/functions/ste.py
@@ -11,10 +11,27 @@
# mypy: disable-error-code=override
# pylint: disable-next=abstract-method
class STEFunction(Function):
- """fake quantizing function for QAT using STE (Straight-Through Estimator)
+ r"""Fake quantizing function for QAT using STE (Straight-Through Estimator).
For quant_min <= input <= quant_max the gradient passes straight through,
otherwise the gradient is zero
+
+ In **STE(Straight Through Estimation)** method, the gradient of the round
+ function used in fake quantization is approximated as 1, and
+ backpropagation is performed based on this approximation. As a result,
+ the gradient of the input entering the fake quantizer is propagated as is
+ when it falls between $$ quant\_min $$ and $$ quant\_max $$, while gradients outside
+ this range become 0. However, since the gradient propagated to $$ step\_size $$
+ is 0, $$ step\_size $$ is fixed.
+
+ When $$x$$ is input of FakeQuantize .
+
+ $$
+ \\hat{x} = \text{FakeQuantize}(x) \\
+ $$
+
+
+ ![STE image](https://github.com/SqueezeBits/owlite/assets/116608095/2d0e071b-394c-4cd1-a68e-33b9a6e18ae6)
"""
@staticmethod # pylint: disable-next=arguments-differ
diff --git a/owlite/nn/modules/__init__.py b/src/owlite/nn/modules/__init__.py
similarity index 100%
rename from owlite/nn/modules/__init__.py
rename to src/owlite/nn/modules/__init__.py
diff --git a/owlite/nn/modules/fake_quantizer.py b/src/owlite/nn/modules/fake_quantizer.py
similarity index 92%
rename from owlite/nn/modules/fake_quantizer.py
rename to src/owlite/nn/modules/fake_quantizer.py
index a0f6e8d..ad30132 100644
--- a/owlite/nn/modules/fake_quantizer.py
+++ b/src/owlite/nn/modules/fake_quantizer.py
@@ -7,13 +7,12 @@
import torch
from typing_extensions import Self
-from owlite_core.logger import log
-
from ...calib import PercentileCalibrator
from ...enums import PTQCalibrationType, QATBackwardType
from ...nn.functions import FakeQuantizeSignature, clq_function
from ...options.channel import Channel
from ...options.fake_quantizer_options import FakeQuantizerOptions
+from ...owlite_core.logger import log
if TYPE_CHECKING:
from ...calib.calibrator import Calibrator
@@ -45,7 +44,30 @@ def setter(instance: Any, value: Any) -> None:
# pylint: disable=too-many-instance-attributes
class FakeQuantizer(torch.nn.Module, ABC):
- """An implementation of fake quantization (a.k.a. quantization simulation)"""
+ """An implementation of fake quantization (a.k.a. quantization simulation)
+
+ ### Attributes
+ - __step_size__ (`torch.Tensor`): The quantization scale, determining the magnitude of each quantization interval.
+ - __zero_point__ (`torch.Tensor`) : The quantization zero_point. It may be expressed as a float in the context of
+ asymmetric quantization, while for symmetric quantization, it is fixed at 0.
+ - __precision__ (`torch.IntTensor`): The number of bits used for quantization.
+ - __symmetric__ (`torch.BoolTensor`): Whether symmetric quantization is applied.
+ - __unsigned__ (`torch.BoolTensor`): Whether unsigned quantization is applied
+ - __per_channel__ (`torch.BoolTensor`): Whether per-channel quantization or per-tensor quantization is applied
+ - __learn_zero_point__ (`torch.BoolTensor`): whether the zero point is learnable.
+ - __grad_scale__ (`torch.FloatTensor`): The gradient scaling factor of quantization parameters.
+ - __ptq_calibration__
+ - __qat_backward_type__
+
+ ### ReadOnly properties
+ - __qat_function__ (`Callable`): The autograd function providing forward and backward methods of this
+ fake quantizer for the quantization-aware training.
+ - __quant_min__ (`int`): lower bound of the quantized domain
+ - __quant_max__ (`int`): upper bound of the quantized domain
+ - __narrow__ (`bool`): whether to quantize with a narrow range.
+
+ e.g., if `True`, using quantization with range [-127, 127] instead of [-128, 127] when `precision=8`.
+ """
def __init__(
self,
diff --git a/owlite/nn/modules/qconv.py b/src/owlite/nn/modules/qconv.py
similarity index 82%
rename from owlite/nn/modules/qconv.py
rename to src/owlite/nn/modules/qconv.py
index 36749f2..461e430 100644
--- a/owlite/nn/modules/qconv.py
+++ b/src/owlite/nn/modules/qconv.py
@@ -12,6 +12,26 @@
# mypy: disable-error-code=misc
class _QConvNd(_ConvNd, UnaryNeuralQModuleMixin):
+ """Base class of quantized convolution layer inherited from [torch.nn.modules.conv._ConvNd]
+ (https://github.com/pytorch/pytorch/blob/4c55dc50355d5e923642c59ad2a23d6ad54711e7/torch/nn/modules/conv.py).
+ It performs convolution operations using the input and fake-quantized weights. Its weights and biases are copied
+ from the original convolution instance.
+
+ ### Attributes
+ - __input_quantizer__ (`FakeQuantizer`): fake quantizer used for the input.
+ - __weight_quantizer__ (`FakeQuantizer`): fake quantizer used for the weights.
+ # `class owlite.nn.Qconv1d`
+ Quantized 1D convolution module
+ # `class owlite.nn.Qconv2d`
+ Quantized 2D convolution module
+ # `class owlite.nn.Qconv3d`
+ Quantized 3D convolution module
+
+ Args:
+ _ConvNd (_type_): _description_
+ UnaryNeuralQModuleMixin (_type_): _description_
+ """
+
def __init__(
self,
conv: _ConvNd,
diff --git a/owlite/nn/modules/qlinear.py b/src/owlite/nn/modules/qlinear.py
similarity index 76%
rename from owlite/nn/modules/qlinear.py
rename to src/owlite/nn/modules/qlinear.py
index 1141233..dbb4c0a 100644
--- a/owlite/nn/modules/qlinear.py
+++ b/src/owlite/nn/modules/qlinear.py
@@ -12,8 +12,15 @@
# mypy: disable-error-code=misc
class QLinear(torch.nn.Linear, UnaryNeuralQModuleMixin):
- """Applies a linear transformation to the incoming data: :math:`y = xA_q^T + b`,
- where :math:`A_q` represents the fake-quantized weight.
+ """Applies a linear transformation to the incoming data: $$ y = xA_q^T + b $$,
+ where $$ A_q $$ represents the fake-quantized weight.
+
+ Additionally, fake-quantization is applicable to both the bias and bias addition:
+ $$y = \text{quant}(xW_q^T) + \text{quant}(b)$$, where represents $$\text{quant}$$ the fake-quantize function.
+ The module copies the weights and biases from the original linear instance.
+
+ Quantized linear layer inherited from
+ [torch.nn.Linear](https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/linear.py).
"""
def __init__(
@@ -24,8 +31,8 @@ def __init__(
"""Initializes instance from an existing `torch.nn.Linear` instance, copying the weights and bias if it exists.
Args:
- linear (torch.nn.Linear): An existing `torch.nn.Linear` instance.
- weight_opts (Optional[FakeQuantizerOptions], optional): Option for the fake weight quantizer.
+ linear (`torch.nn.Linear`): An existing `torch.nn.Linear` instance.
+ weight_opts (`Optional[FakeQuantizerOptions]`, optional): Option for the fake weight quantizer.
Defaults to None.
"""
super().__init__(
diff --git a/owlite/nn/modules/qmodule_mixins.py b/src/owlite/nn/modules/qmodule_mixins.py
similarity index 99%
rename from owlite/nn/modules/qmodule_mixins.py
rename to src/owlite/nn/modules/qmodule_mixins.py
index 80c6acd..de30042 100644
--- a/owlite/nn/modules/qmodule_mixins.py
+++ b/src/owlite/nn/modules/qmodule_mixins.py
@@ -2,8 +2,7 @@
import torch
-from owlite_core.logger import log
-
+from ...owlite_core.logger import log
from .fake_quantizer import FakeQuantizer
diff --git a/owlite/options/__init__.py b/src/owlite/options/__init__.py
similarity index 100%
rename from owlite/options/__init__.py
rename to src/owlite/options/__init__.py
diff --git a/owlite/options/channel.py b/src/owlite/options/channel.py
similarity index 100%
rename from owlite/options/channel.py
rename to src/owlite/options/channel.py
diff --git a/owlite/options/compression_option.py b/src/owlite/options/compression_option.py
similarity index 99%
rename from owlite/options/compression_option.py
rename to src/owlite/options/compression_option.py
index a621656..a8f4dd7 100644
--- a/owlite/options/compression_option.py
+++ b/src/owlite/options/compression_option.py
@@ -4,8 +4,7 @@
from typing_extensions import Self
-from owlite_core.logger import log
-
+from ..owlite_core.logger import log
from .channel import Channel
from .fake_quantizer_options import FakeQuantizerOptions
from .options_dict import OptionsDict
diff --git a/owlite/options/dynamic_input_options.py b/src/owlite/options/dynamic_input_options.py
similarity index 100%
rename from owlite/options/dynamic_input_options.py
rename to src/owlite/options/dynamic_input_options.py
diff --git a/owlite/options/fake_quantizer_options.py b/src/owlite/options/fake_quantizer_options.py
similarity index 98%
rename from owlite/options/fake_quantizer_options.py
rename to src/owlite/options/fake_quantizer_options.py
index 11b7545..aa0bc41 100644
--- a/owlite/options/fake_quantizer_options.py
+++ b/src/owlite/options/fake_quantizer_options.py
@@ -3,10 +3,9 @@
from typing_extensions import Self
-from owlite_core.logger import log, suppress_owlite_warnings
-
from ..enums.ptq_calibration_type import PTQCalibrationType
from ..enums.qat_backward_type import QATBackwardType
+from ..owlite_core.logger import log, suppress_owlite_warnings
from .options_mixin import OptionsMixin
diff --git a/owlite/options/generic_type_checking.py b/src/owlite/options/generic_type_checking.py
similarity index 98%
rename from owlite/options/generic_type_checking.py
rename to src/owlite/options/generic_type_checking.py
index 48ff658..10baee5 100644
--- a/owlite/options/generic_type_checking.py
+++ b/src/owlite/options/generic_type_checking.py
@@ -22,7 +22,7 @@ def generic_isinstance(obj: Any, type_hint: Union[type, tuple[type]]) -> bool:
raise NotImplementedError(f"generic_isinstance for {type_hint} is not implemented.")
-def generic_issubclass(type_hint: type, superclass: Union[type, tuple[type]]):
+def generic_issubclass(type_hint: type, superclass: Union[type, tuple[type]]) -> bool:
"""An extension for the builtin function `issubclass` for type hint checking."""
if isinstance(superclass, tuple):
return any(generic_issubclass(type_hint, s) for s in superclass)
diff --git a/owlite/options/load.py b/src/owlite/options/load.py
similarity index 100%
rename from owlite/options/load.py
rename to src/owlite/options/load.py
diff --git a/owlite/options/onnx_export_options.py b/src/owlite/options/onnx_export_options.py
similarity index 100%
rename from owlite/options/onnx_export_options.py
rename to src/owlite/options/onnx_export_options.py
diff --git a/owlite/options/options_dict.py b/src/owlite/options/options_dict.py
similarity index 99%
rename from owlite/options/options_dict.py
rename to src/owlite/options/options_dict.py
index 13a2fa6..7365f56 100644
--- a/owlite/options/options_dict.py
+++ b/src/owlite/options/options_dict.py
@@ -4,8 +4,7 @@
from yacs.config import CfgNode
-from owlite_core.logger import log
-
+from ..owlite_core.logger import log
from .generic_type_checking import generic_isinstance
from .load import check_version, load_json_or_yaml
from .options_mixin import OptionsMixin
diff --git a/owlite/options/options_mixin.py b/src/owlite/options/options_mixin.py
similarity index 97%
rename from owlite/options/options_mixin.py
rename to src/owlite/options/options_mixin.py
index d5016d3..1848043 100644
--- a/owlite/options/options_mixin.py
+++ b/src/owlite/options/options_mixin.py
@@ -6,8 +6,7 @@
from typing_extensions import Self
from yacs.config import CfgNode
-from owlite_core.logger import log
-
+from ..owlite_core.logger import log
from .generic_type_checking import generic_isinstance, is_optional, unwrap_optional
from .load import check_version, load_json_or_yaml
@@ -95,8 +94,9 @@ def __setattr__(self, name: str, value: Any) -> None:
if name not in annotations:
raise KeyError(f"No such property in {cls_name}: {name}")
field_type = annotations[name]
- value = self._deserialize(value, field_type)
- if not generic_isinstance(value, field_type):
+ if field_type is not Any:
+ value = self._deserialize(value, field_type)
+ if not (field_type is Any or generic_isinstance(value, field_type)):
raise ValueError(
f"Expected a value of type {field_type}, "
f"but received {value} of type {type(value)} for {name} in {cls_name}"
diff --git a/owlite/options/quantization_options.py b/src/owlite/options/quantization_options.py
similarity index 100%
rename from owlite/options/quantization_options.py
rename to src/owlite/options/quantization_options.py
diff --git a/owlite/options/tensor_type.py b/src/owlite/options/tensor_type.py
similarity index 100%
rename from owlite/options/tensor_type.py
rename to src/owlite/options/tensor_type.py
diff --git a/src/owlite/owlite.py b/src/owlite/owlite.py
new file mode 100644
index 0000000..5e3bfa9
--- /dev/null
+++ b/src/owlite/owlite.py
@@ -0,0 +1,1038 @@
+# pylint: disable=too-many-lines
+import json
+import os
+import re
+from dataclasses import dataclass, field
+from typing import Any, Optional, Union
+
+import torch
+from packaging.version import Version
+from torch.fx.graph_module import GraphModule
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+from .api import Baseline, Experiment, Project
+from .backend.fx.trace import symbolic_trace
+from .backend.signature import DynamicSignature, update_dynamic_signature
+from .compression import compress
+from .options import DynamicAxisOptions, DynamicInputOptions, ONNXExportOptions
+from .owlite_core.constants import OWLITE_REPORT_URL, OWLITE_VERSION
+from .owlite_core.github_utils import get_latest_version_from_github
+from .owlite_core.logger import log
+from .owlite_core.owlite_settings import OWLITE_SETTINGS
+
+
+@dataclass
+class OwLite:
+ """Class handling OwLite project, baseline, and experiment configurations.
+
+ The OwLite class manages project, baseline, and experiment configurations within the OwLite system.
+ It allows users to create or load projects, set baselines, create or duplicate experiments, convert models,
+ and benchmark models against the specified configurations.
+ """
+
+ target: Union[Baseline, Experiment]
+ module_args: Optional[tuple[Any, ...]] = field(default=None)
+ module_kwargs: Optional[dict[str, Any]] = field(default=None)
+
+ def convert(
+ self,
+ model: torch.nn.Module,
+ *args: Any,
+ **kwargs: Any,
+ ) -> GraphModule:
+ """Converts your model into a `torch.fx.GraphModule` object using the example input(s) provided.
+
+ {% hint style="warning” %}
+ The example input(s) provided for `owl.convert` will also be used by
+ [`owl.export`](https://squeezebits.gitbook.io/owlite/python-api/owlite.owlite.owlite/owlite.owlite.export) for
+ the ONNX and TensorRT conversion afterward. Therefore, it is crucial to provide appropriate example input(s)
+ to ensure the correct behavior of your model.
+ {% endhint %}
+
+ Args:
+ model (`torch.nn.Module`): The model to be compressed. Note that it must be an instance of
+ `torch.nn.Module`, but not `torch.nn.DataParallel` or `torch.nn.DistributedDataParallel`.
+ See [troubleshooting - Models wrapped with `torch.nn.DataParallel` or
+ `torch.nn.parallel.DistributedDataParallel`](https://squeezebits.gitbook.io/owlite/troubleshooting/
+ troubleshooting#models-wrapped-with-torch.nn.dataparallel-or-torch.nn.parallel.distributeddataparallel)
+ for more details.
+ *args, **kwargs: the example input(s) that would be passed to the model’s forward method.
+ These example inputs are required to convert the model into a [`torch.fx.GraphModule`]
+ (https://pytorch.org/docs/stable/fx.html) instance. Each input must be one of the following:
+ * A `torch.Tensor` object
+ * A tuple of `torch.Tensor` objects
+ * A dictionary whose keys are strings and values are `torch.Tensor` objects.
+
+ Returns:
+ GraphModule: The `torch.fx.GraphModule` object converted from the `model`.
+
+ Raises:
+ HTTPError: When request for compression configuration was not successful.
+
+ ### Behavior in each mode
+
+ `owl.convert` behaves differently depending on the
+ [mode](https://squeezebits.gitbook.io/owlite/python-api/owlite.owlite.init#two-different-modes-triggered-by-owlite.init)
+ triggered by [`owlite.init`](https://squeezebits.gitbook.io/owlite/python-api/owlite.owlite.init).
+
+ 1. **Baseline Mode**: In this mode, `owl.convert` traces the input model with the example input(s).
+
+ 2. **Experiment Mode**: In this mode, the converted `torch.fx.GraphModule` object will be further modified
+ according to the compression configuration from the experiment. This configuration could have been created by
+ the user on the OwLite website, or copied from another experiment (in 'duplicate from’ mode). If there’s no
+ compression configuration, it returns the same model as in baseline mode. For dynamic batch size baseline
+ model without compression, create an experiment.
+
+ ### Workflow
+
+ The `owl.convert` function goes through the following steps:
+
+ 1. **Tracing**: `owl.convert` traces the input model with the example input(s) to a GraphModule. If the model
+ cannot be traced, it throws an error with a message.
+
+ 2. **Compression**: If in experiment mode, `owl.convert` applies a compression configuration to the traced
+ model. `owl.convert` doesn’t compress the model if there’s no compression configuration created on the web.
+
+ 3. **Model Return**: `owl.convert` returns the converted model.
+
+ By following these steps, the `convert` function effectively converts the input model
+ to a compressed GraphModule.
+
+ ### Examples
+
+ **Baseline Mode**
+
+ ```python
+ import torch
+ import owlite
+
+ owl = owlite.init(project="testProject”, baseline="sampleModel”)
+
+ # Create a sample model
+ class SampleModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv1 = torch.nn.Conv2d(3, 64, 3)
+ self.pool1 = torch.nn.MaxPool2d(2, 2)
+ self.conv2 = torch.nn.Conv2d(64, 128, 3)
+ self.pool2 = torch.nn.MaxPool2d(2, 2)
+ self.fc1 = torch.nn.Linear(128 * 7 * 7, 10)
+
+ # Create a model instance
+ model = SampleModel()
+
+ # Convert the model
+ model = owl.convert(model, torch.randn(4,3,64,64))
+
+ # Print the model
+ print(model)
+ ```
+
+ This code will create a sample model, convert it to a GraphModule in baseline mode, and export it to ONNX.
+ The output of the code is as follows:
+
+ ```bash
+ OwLite [INFO] Connected device: NVIDIA RTX A6000
+ OwLite [WARNING] Existing local directory found at /home/sqzb/workspace/owlite/testProject/sampleModel/sample
+ Model. Continuing this code will overwrite the data
+ OwLite [INFO] Created new project 'testProject’
+ OwLite [INFO] Created new baseline 'sampleModel’ at project 'testProject’
+ OwLite [INFO] Converted the model
+ GraphModule(
+ (self_conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
+ (self_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
+ (self_conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
+ (self_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
+ (self_fc1): Linear(in_features=6272, out_features=10, bias=True)
+ )
+
+
+ def forward(self, x : torch.Tensor):
+ sqzb_module_device_canary = self.sqzb_module_device_canary
+ getattr_1 = sqzb_module_device_canary.device; sqzb_module_device_canary = None
+ self_conv1 = self.self_conv1(x); x = None
+ relu = torch.nn.functional.relu(self_conv1); self_conv1 = None
+ self_pool1 = self.self_pool1(relu); relu = None
+ self_conv2 = self.self_conv2(self_pool1); self_pool1 = None
+ relu_1 = torch.nn.functional.relu(self_conv2); self_conv2 = None
+ self_pool2 = self.self_pool2(relu_1); relu_1 = None
+ view = self_pool2.view(-1, 6272); self_pool2 = None
+ self_fc1 = self.self_fc1(view); view = None
+ output_adapter = owlite_backend_fx_trace_output_adapter((self_fc1,)); self_fc1 = None
+ return output_adapter
+ ```
+
+ **Experiment Mode**
+
+ ```python
+ import torch
+ import owlite
+
+ owl = owlite.init(project="testProject”, baseline="sampleModel”, experiment="conv”)
+
+ # Create a sample model
+ class SampleModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv1 = torch.nn.Conv2d(3, 64, 3)
+ self.pool1 = torch.nn.MaxPool2d(2, 2)
+ self.conv2 = torch.nn.Conv2d(64, 128, 3)
+ self.pool2 = torch.nn.MaxPool2d(2, 2)
+ self.fc1 = torch.nn.Linear(128 * 7 * 7, 10)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = torch.nn.functional.relu(x)
+ x = self.pool1(x)
+
+ x = self.conv2(x)
+ x = torch.nn.functional.relu(x)
+ x = self.pool2(x)
+
+ x = x.view(-1, 128 * 7 * 7)
+ x = self.fc1(x)
+
+ return x
+
+ # Create a model instance
+ model = SampleModel()
+
+ # Convert the model
+ model = owl.convert(model, torch.randn(4,3,64,64))
+
+ # Print the model
+ print(model)
+ ```
+
+ This code will create a sample model, convert it to a GraphModule in experiment mode, and apply the compression
+ configuration from the `init` function. The output of the code is as follows:
+
+ ```bash
+ OwLite [INFO] Connected device: NVIDIA RTX A6000
+ OwLite [INFO] Experiment data will be saved in /home/sqzb/workspace/owlite/testProject/sampleModel/conv
+ OwLite [INFO] Loaded existing project 'testProject’
+ OwLite [INFO] Existing compression configuration for 'conv’ found
+ OwLite [INFO] Model conversion initiated
+ OwLite [INFO] Compression configuration found for 'conv’
+ OwLite [INFO] Applying compression configuration
+ OwLite [INFO] Converted the model
+ GraphModule(
+ (self_conv1): QConv2d(
+ 3, 64, kernel_size=(3, 3), stride=(1, 1)
+ (weight_quantizer): FakeQuantizer(ste(precision: 8, per_channel, quant_min: -127, quant_max: 127,
+ is_enabled: True, calib: AbsmaxCalibrator))
+ (input_quantizer): FakeQuantizer(ste(precision: 8, per_tensor, quant_min: -128, quant_max: 127,
+ q zero_point: 0.0, is_zero_point_folded: False, is_enabled: True, calib: AbsmaxCalibrator))
+ )
+ (self_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
+ (self_conv2): QConv2d(
+ 64, 128, kernel_size=(3, 3), stride=(1, 1)
+ (weight_quantizer): FakeQuantizer(ste(precision: 8, per_channel, quant_min: -127, quant_max: 127,
+ is_enabled: True, calib: AbsmaxCalibrator))
+ (input_quantizer): FakeQuantizer(ste(precision: 8, per_tensor, quant_min: -128, quant_max: 127,
+ zero_point: 0.0, is_zero_point_folded: False, is_enabled: True, calib: AbsmaxCalibrator))
+ )
+ (self_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
+ (self_fc1): QLinear(
+ in_features=6272, out_features=10, bias=True
+ (weight_quantizer): FakeQuantizer(ste(precision: 8, per_channel, quant_min: -127, quant_max: 127,
+ is_enabled: True, calib: AbsmaxCalibrator))
+ (input_quantizer): FakeQuantizer(ste(precision: 8, per_tensor, quant_min: -128, quant_max: 127,
+ zero_point: 0.0, is_zero_point_folded: False, is_enabled: True, calib: AbsmaxCalibrator))
+ )
+ (self_conv1_0_quantizer): FakeQuantizer(ste(precision: 8, per_tensor, quant_min: -128, quant_max: 127,
+ zero_point: 0.0, is_zero_point_folded: False, is_enabled: True, calib: AbsmaxCalibrator))
+ (self_pool1_0_quantizer): FakeQuantizer(ste(precision: 8, per_tensor, quant_min: -128, quant_max: 127,
+ zero_point: 0.0, is_zero_point_folded: False, is_enabled: True, calib: AbsmaxCalibrator))
+ (self_conv2_0_quantizer): FakeQuantizer(ste(precision: 8, per_tensor, quant_min: -128, quant_max: 127,
+ zero_point: 0.0, is_zero_point_folded: False, is_enabled: True, calib: AbsmaxCalibrator))
+ (self_pool2_0_quantizer): FakeQuantizer(ste(precision: 8, per_tensor, quant_min: -128, quant_max: 127,
+ zero_point: 0.0, is_zero_point_folded: False, is_enabled: True, calib: AbsmaxCalibrator))
+ (self_fc1_0_quantizer): FakeQuantizer(ste(precision: 8, per_tensor, quant_min: -128, quant_max: 127,
+ zero_point: 0.0, is_zero_point_folded: False, is_enabled: True, calib: AbsmaxCalibrator))
+ )
+
+
+
+ def forward(self, x : torch.Tensor):
+ self_conv1_0_quantizer = self.self_conv1_0_quantizer(x); x = None
+ self_conv1 = self.self_conv1(self_conv1_0_quantizer); self_conv1_0_quantizer = None
+ relu = torch.nn.functional.relu(self_conv1); self_conv1 = None
+ self_pool1_0_quantizer = self.self_pool1_0_quantizer(relu); relu = None
+ self_pool1 = self.self_pool1(self_pool1_0_quantizer); self_pool1_0_quantizer = None
+ self_conv2_0_quantizer = self.self_conv2_0_quantizer(self_pool1); self_pool1 = None
+ self_conv2 = self.self_conv2(self_conv2_0_quantizer); self_conv2_0_quantizer = None
+ relu_1 = torch.nn.functional.relu(self_conv2); self_conv2 = None
+ self_pool2_0_quantizer = self.self_pool2_0_quantizer(relu_1); relu_1 = None
+ self_pool2 = self.self_pool2(self_pool2_0_quantizer); self_pool2_0_quantizer = None
+ view = self_pool2.view(-1, 6272); self_pool2 = None
+ self_fc1_0_quantizer = self.self_fc1_0_quantizer(view); view = None
+ self_fc1 = self.self_fc1(self_fc1_0_quantizer); self_fc1_0_quantizer = None
+ output_adapter = owlite_backend_fx_trace_output_adapter((self_fc1,)); self_fc1 = None
+ return output_adapter
+ ```
+ """
+
+ log.info("Model conversion initiated")
+ try:
+ model = symbolic_trace(model, *args, **kwargs)
+ except Exception as e: # pylint: disable=broad-exception-caught
+ log.error(
+ "Failed to convert the model. This means that\n"
+ "i) your model might have some codes that cannot be handled by `torch.compile`; or\n"
+ "ii) the inputs provided for the model are incompatible with your model's 'forward' method.\n"
+ "Check the full error message below and make changes accordingly. "
+ f"Should the problem persist, please report the issue at {OWLITE_REPORT_URL} for further assistance"
+ ) # UX
+ raise e
+
+ self.module_args = args
+ self.module_kwargs = kwargs
+
+ if isinstance(self.target, Experiment) and self.target.has_config:
+ model = compress(model, self.target.config)
+ log.info("Applied compression configuration") # UX
+
+ return model
+
+ def export(
+ self,
+ model: GraphModule,
+ onnx_export_options: Optional[ONNXExportOptions] = None,
+ dynamic_axis_options: Optional[Union[DynamicAxisOptions, dict[str, dict[str, int]]]] = None,
+ ) -> None:
+ """Exports the model converted by `owl.convert` to ONNX format.
+
+ {% hint style=“warning” %}
+ The ONNX model created by `owl.export` will also be used by
+ [`owl.benchmark`](https://squeezebits.gitbook.io/owlite/python-api/owlite.owlite.owlite/owlite.owlite.benchmark)
+ for the TensorRT conversion afterward. Therefore, it is crucial to provide an appropriate pre-trained or
+ calibrated model to ensure the correct behavior of your model.
+
+ Generally, you can export any model with `owl.export` whether it is trained or not.
+ However, keep in mind that some graph-level optimizations performed while building the TensorRT engine
+ depend on the values of your model’s weight.
+
+ For example, when you benchmark a quantized model without calibration, the `step_size` parameter of
+ the fake quantizers in the model would be all initialized to zeros. These zero `step_size` values can make
+ the behavior of the graph-level optimization different, leading to a different latency from a calibrated
+ model’s latency when you benchmark.
+
+ Therefore, we **strongly recommend**
+
+ 1. to export for benchmarking a pre-trained model in the baseline mode; and
+ 2. to perform either [PTQ calibration](https://squeezebits.gitbook.io/owlite/python-api/owlite.calibrators) or
+ [QAT](https://squeezebits.gitbook.io/owlite/python-api/owlite.nn.function) in experiment mode.
+ {% endhint %}
+
+ Args:
+ model (`torch.fx.GraphModule`): The model converted by `owl.convert`, but not `torch.nn.DataParallel`
+ or `torch.nn.DistributedDataParallel`.
+ See [troubleshooting - Models wrapped with `torch.nn.DataParallel` or
+ `torch.nn.parallel.DistributedDataParallel`](https://squeezebits.gitbook.io/owlite/troubleshooting/troubleshooting#models-wrapped-with-torch.nn.dataparallel-or-torch.nn.parallel.distributeddataparallel)
+ for more details.
+
+ onnx_export_options (`owlite.ONNXExportOptions`, `optional`): Additional options for exporting ONNX.
+
+ * OwLite exports your model into ONNX during the conversion using
+ [torch.onnx.export](https://pytorch.org/docs/stable/onnx_torchscript.html#torch.onnx.export)
+ behind the scenes. You can control some of the behaviors of `torch.onnx.export` by passing an
+ `owlite.ONNXExportOptions` object to the `onnx_export_options` argument of `owlite.export`.
+ Currently, you can only set `opset_version`, which defaults to 17. Other parameters of
+ `torch.onnx.export` might be added in the future.
+
+ dynamic_axis_options (`dict[str, dict[str, int]]`, `optional`): By default, the exported model will have the
+ shapes of all input tensors set to match exactly those given when calling convert. To specify the axis of
+ tensors as dynamic (i.e., known only at run-time), set `dynamic_export_options` to a dictionary with schema:
+
+ KEY (`str`): an input name.
+
+ VALUE (`dict[int, dict[str, int]]`): a single item dictionary whose key is dynamic dimension of input
+ and value is also a single item dictionary whose key is "axis" and value is axis to dynamic.
+
+ **Example: dynamic_axis_options**
+
+ ```python
+ import owlite
+
+ owl = owlite.init( ... )
+
+ class SumModule(torch.nn.Module):
+ def forward(self, x):
+ return torch.sum(x, dim=1)
+
+ model = owl.convert( ... )
+
+ ...
+
+ # set first(0-th) dimension of input x
+ owl.export(
+ model,
+ dynamic_axis_options={
+ "x": {"axis": 0},
+ },
+ )
+ ```
+
+ Raises:
+ TypeError: When the `model` is an instance of `torch.nn.DataParallel` or `torch.nn.DistributedDataParallel`.
+ RuntimeError: When `dynamic_axes` is set for baseline export.
+ ValueError: When invalid `dynamic_axes` is given.
+
+ ### Behavior in each mode
+
+ `owl.export` behaves differently depending on the
+ [mode](https://squeezebits.gitbook.io/owlite/python-api/owlite.owlite.init#two-different-modes-triggered-by-owlite.init)
+ triggered by [`owlite.init`](https://squeezebits.gitbook.io/owlite/python-api/owlite.owlite.init).
+
+ 1. **Baseline Mode**: In this mode, `owl.export` traces the input model with the example input(s) and exports
+ it to ONNX. Then, it sends the ONNX graph and the model to the server. This allows users to view the model
+ graph on the web and apply compression.
+ 2. **Experiment Mode**: In this mode, `owl.export` exports the model after applying the compression
+ configuration from the experiment or dynamic export options.
+
+ ### Workflow
+
+ `owl.export` performs the following steps:
+
+ 1. **Exporting ONNX**: `owl.export` exports the input model to ONNX and saves it in your local workspace.
+ In experiment mode, dynamic axes are applied to the model if provided.
+ 2. **Uploading ONNX**: `owl.export` then uploads the ONNX (without weights) to the server.
+
+ ### Examples
+
+ **Baseline Mode**
+
+ ```python
+ # after model converted
+ model = owl.export(model)
+ ```
+
+ ```bash
+ OwLite [INFO] Model conversion initiated
+ ============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============
+ verbose: False, log level: Level.ERROR
+ ======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
+
+ OwLite [INFO] Saving exported ONNX proto at /home/sqzb/workspace/owlite/testProject/sampleModel/sampleModel/
+ testProject_sampleModel_sampleModel.onnx with external data testProject_sampleModel_sampleModel.bin
+ OwLite [INFO] Baseline ONNX saved at /home/sqzb/workspace/owlite/testProject/sampleModel/sampleModel/
+ testProject_sampleModel_sampleModel.onnx
+ OwLite [INFO] Uploaded the model excluding parameters
+ ```
+
+ **Experiment Mode with dynamic batch**
+
+ ```python
+ # after model converted
+ model = owl.export(
+ model,
+ dynamic_axis_options={
+ "x": {
+ "axis": 0
+ }
+ }
+ )
+ ```
+
+ ```bash
+ OwLite [INFO] Model conversion initiated
+ ============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============
+ verbose: False, log level: Level.ERROR
+ ======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
+
+ OwLite [WARNING] ONNX simplifier failed with error: Your model ir_version is higher than the checker's.
+ OwLite [INFO] Saving exported ONNX proto at /home/sqzb/workspace/owlite/testProject/sampleModel/dynamic/
+ testProject_sampleModel_dynamic.onnx with external data testProject_sampleModel_dynamic.bin
+ OwLite [INFO] Experiment ONNX saved at /home/sqzb/workspace/owlite/testProject/sampleModel/dynamic/
+ testProject_sampleModel_dynamic.onnx
+ OwLite [INFO] Uploading /home/sqzb/workspace/owlite/testProject/sampleModel/dynamic/
+ testProject_sampleModel_dynamic.onnx
+ 100%|████████████████████████████████████████████████████████████████████████████████| 2.11k/2.11k
+ [00:00<00:00, 113kiB/s]
+ OwLite [INFO] Uploading done
+ ```
+
+ OwLite will create ONNX graph file and parameter file with the hierarchical structure below.
+
+ ```bash
+ - owlite
+ - testProject
+ - SampleModel
+ - dynamic
+ - testProject_SampleModel_dynamic.onnx
+ - testProject_SampleModel_dynamic.bin
+ ```
+ """
+ if isinstance(model, (DataParallel, DistributedDataParallel)):
+ model_type = f"torch.nn.parallel.{type(model).__name__}"
+ log.error(
+ f"{model_type} is not supported. Please use the attribute module "
+ f"by unwrapping the model from {model_type}. Try owl.export(model.module)"
+ ) # UX
+ raise TypeError(f"{model_type} is not supported by export")
+ if not isinstance(model, GraphModule):
+ model_type = f"{type(model).__module__}.{type(model).__name__}"
+ raise TypeError(f"Expected GraphModule, but got model of type {model_type}")
+
+ if isinstance(dynamic_axis_options, dict):
+ dynamic_axis_options = DynamicAxisOptions(dynamic_axis_options)
+ keys_repr = ", ".join(f"'{key}'" for key in dynamic_axis_options.keys())
+ log.info(f"`dynamic_axis_options` provided for the following inputs: {keys_repr}") # UX
+
+ if isinstance(self.target, Baseline):
+ if dynamic_axis_options is not None:
+ log.warning(
+ "The `dynamic_axis_options` provided for baseline will be ignored. "
+ "To export baseline model with dynamic input, "
+ "please create an experiment without compression configuration "
+ "and export it with `dynamic_axis_options`"
+ ) # UX
+ proto = self.target.export(
+ model,
+ self.module_args,
+ self.module_kwargs,
+ onnx_export_options=onnx_export_options,
+ )
+ self.target.upload(proto, model)
+ else:
+ proto = self.target.export(
+ model,
+ self.module_args,
+ self.module_kwargs,
+ dynamic_axis_options=dynamic_axis_options,
+ onnx_export_options=onnx_export_options,
+ )
+ self.target.upload(
+ proto,
+ dynamic_axis_options=dynamic_axis_options,
+ )
+
+ def benchmark(
+ self,
+ dynamic_input_options: Optional[Union[DynamicInputOptions, dict[str, dict[str, int]]]] = None,
+ ) -> None:
+ """Executes the benchmark for the converted model on a connected device.
+
+ `owl.benchmark` uses the ONNX created by `owl.export`. The ONNX is sent to the connected device and converted
+ to a TensorRT engine, which is benchmarked behind the scenes. If the benchmark finishes successfully, the
+ benchmark summary will be displayed on the terminal. The converted engine file will also be downloaded into
+ the workspace. You can find more information about the benchmark results from the project page in
+ [OwLite Web UI](https://owlite.ai/project).
+
+ {% hint style="warning” %}
+ In general, any model generated by `owl.export` can be benchmarked with `owl.benchmark`, regardless of
+ whether it is trained or not. Additionally, the model to be benchmarked is already determined
+ when `owl.export` is executed.
+
+ To ensure accurate latency measurements, especially for quantized models, we strongly recommend using
+ a pre-trained or calibrated model before using owl.export.
+
+ For details on model preparation, please refer to the
+ [owl.export](https://squeezebits.gitbook.io/owlite/python-api/owlite.owlite.owlite/owlite.owlite.export).
+ {% endhint %}
+
+ Args:
+ dynamic_input_options (`dict[str, dict[str, int]]`):By default, the exported model will have the shapes
+ of all input tensors set to exactly match those given when calling convert. To specify axes of tensors
+ as dynamic (i.e. known only at run-time), set `dynamic_benchmark_options` to a dictionary with schema:
+
+ KEY (`str`): an input name.
+
+ VALUE (`dict[str, int]`): a single item that is a dynamic range setting dictionary containing
+ `"min”`, `"opt”`, `"max”`, `"test”` dimension size settings.
+
+ **Example: dynamic_input_options**
+
+ ```python
+ import owlite
+
+ owl = owlite.init( ... )
+
+ class SumModule(torch.nn.Module):
+ def forward(self, x):
+ return torch.sum(x, dim=1)
+
+ model = owl.convert( ... )
+
+ ...
+
+ # set input x to be dynamic within the range of 1 ~ 8
+ # optimize for 4 and benchmark for 5
+ owl.benchmark(
+ model,
+ dynamic_input_options={
+ "x": {
+ "min": 1,
+ "opt": 4,
+ "max": 8,
+ "test": 5,
+ },
+ },
+ )
+ ```
+
+ Raises:
+ TypeError: When the `model` is an instance of `torch.nn.DataParallel` or `torch.nn.DistributedDataParallel`.
+ RuntimeError: When `dynamic_axes` is set for baseline benchmark.
+ ValueError: When invalid `dynamic_axes` is given.
+
+ ### Workflow
+
+ `owl.benchmark` goes through the following steps:
+
+ 1. **File Transfer**: `owl.benchmark` transfers the ONNX binary file to the connected device.
+
+ 2. **Engine Export and Benchmark**: On the device, `owl.benchmark` exports the model to a TensorRT engine and
+ benchmarks it. It returns the latency information and displays it on the terminal.
+
+ **Interrupting the Benchmarking Process**
+ If the benchmarking process appears to be time-consuming, an interruption can be initiated with
+ ctrl-c. This action triggers an exit message, indicating the cessation of the current experiment on your
+ end. However, the benchmarking process continues on the connected device.
+
+ A URL link is also provided, guiding to the OwLite website for further project configuration.
+
+ Please note that the benchmark is still accessible on the connected device after the interruption,
+ enabling a return to the process when convenient. However, manual retrieval of the engine will
+ not be possible after the interruption.
+
+
+ 3. **Engine File Download**: The converted engine file is downloaded to the user’s workspace.
+
+ Following these steps, `owl.benchmark` effectively benchmarks the converted model
+ and provides latency information.
+
+ ### Examples
+
+ **Baseline Mode (or Experiment Mode with Static Batch Size)**
+
+ ```python
+ # after owl.export(model)
+ owl.benchmark()
+ ```
+
+ **Experiment Mode with Dynamic Batch Size**
+
+ ```python
+ # after owl.export(model, dynamic_axis_options={"x”: {"axis": 0})
+ owl.benchmark(
+ model,
+ dynamic_benchmark_options={
+ "x": {
+ "min": 1,
+ "opt": 4,
+ "max": 8,
+ "test": 5,
+ },
+ },
+ )
+ ```
+
+ ```bash
+ OwLite [INFO] Benchmark initiated for the experiment 'dynamic' for the baseline 'sampleModel'
+ in the project 'testProject'
+ OwLite [INFO] TensorRT benchmark requested
+ OwLite [INFO] Polling for benchmark result. You are free to CTRL-C away. When it is done, you can find the
+ results at https://owlite.ai/project/detail/65a7194af0e4c784fb1f443c
+ Your position in the queue: 0
+ OwLite [INFO] Uploading ONNX model weight to optimize the TensorRT engine
+ OwLite [INFO] Uploading /home/sqzb/workspace/owlite/testProject/sampleModel/dynamic/
+ testProject_sampleModel_dynamic.bin
+ 100%|█████████████████████████████████████████████████████████████████████████████████| 541k/541k
+ [00:00<00:00, 2.26MiB/s]
+ OwLite [INFO] Uploading done
+ [.........🦉..........]
+ Benchmarking done
+ OwLite [INFO] Experiment: dynamic
+ Latency: 0.0245361 (ms) on A6000ONPREM
+ For more details, visit https://owlite.ai/project/detail/id
+ OwLite [INFO] Downloading file at /home/sqzb/workspace/owlite/testProject/sampleModel/dynamic/
+ testProject_sampleModel_dynamic.engine
+ 100%|█████████████████████████████████████████████████████████████████████████████████| 554k/554k
+ [00:00<00:00, 9.51MiB/s]
+ OwLite [INFO] Downloading done
+ ```
+
+ OwLite will create TensorRT engine file with the hierarchical structure below.
+
+ ```bash
+ - owlite
+ - testProject
+ - SampleModel
+ - dynamic
+ - testProject_SampleModel_dynamic.onnx # created by owlite.export()
+ - testProject_SampleModel_dynamic.bin # created by owlite.export()
+ - testProject_SampleModel_dynamic.engine
+ ```
+
+ **Free plan user**
+
+ However, please note that the Free plan does not allow you to export TensorRT engine files with the model's
+ weights. Instead, a random weight engine will be created and you can only query its latency.
+ You will not be able to get the generated engine.
+
+ ```bash
+ OwLite [INFO] Benchmark initiated for the experiment 'dynamic' for the baseline '"sampleModel”'
+ in the project 'testProject'
+ OwLite [INFO] TensorRT benchmark requested
+ OwLite [INFO] Polling for benchmark result. You are free to CTRL-C away. When it is done,
+ you can find the results at https://owlite.ai/project/detail/id
+ [.........🦉..........]
+ Benchmarking done
+ OwLite [INFO] Experiment: dynamic
+ Latency: 0.0327148 (ms) on NVIDIA RTX A6000
+ For more details, visit https://owlite.ai/project/detail/id
+ OwLite [INFO] The free plan doesn't support TensorRT engine download. Upgrade to a higher plan to download
+ the engine through OwLite with a seamless experience. Even so, OwLite still provides you ONNX
+ so that you can generate TensorRT independently
+ ```
+
+ """
+
+ if isinstance(self.target, Experiment) and isinstance(self.target.input_signature, DynamicSignature):
+ if dynamic_input_options is None:
+ log.error(
+ "The `dynamic_input_options` for the experiment has `dynamic_input_options`. "
+ "Try `owl.benchmark(dynamic_input_options={...})`"
+ ) # UX
+ raise RuntimeError("Dynamic options failed")
+ if dynamic_input_options is not None:
+ dynamic_input_options = DynamicInputOptions(dynamic_input_options)
+ self.target.input_signature = update_dynamic_signature(self.target.input_signature, dynamic_input_options)
+
+ self.target.orchestrate_trt_benchmark()
+
+ def log(self, **kwargs: Any) -> None:
+ """Records and sends specific metrics to the server.
+
+ These metrics can then be reviewed and analyzed on the web, along with other project data.
+ This function can be used anytime after the initialization (`init`) step.
+
+ Raises:
+ TypeError: When data is not JSON serializable or not allowed logging.
+
+ ### Usage
+
+ The `log` function is used for logging metrics such as accuracy, loss, etc. for the model.
+ `owl.log` can take any number or string of keyword arguments,
+ where each argument represents a different metric for the model.
+
+ ### Example
+
+ ```python
+ ...
+
+ owl = owlite.init(...)
+
+ ...
+
+ owl.log(accuracy=0.72, loss=1.2)
+ ```
+
+
+ ### Notes
+
+ * All arguments to the `log` function should be JSON serializable. If a provided argument is not serializable,
+ a `TypeError` will be raised.
+
+ * It's recommended to log your metrics near `owl.benchmark` call, as the state of the model at this point is
+ closest to the deployed model. However, you can call the `log` function at any point after the `init` function
+ is called, where the state of the model is expected to be the closest to the deployment.
+
+ * You can update the logged metrics by calling the `log` function again with the new values.
+
+ """
+ if not all(isinstance(value, (int, str, float)) for value in kwargs.values()):
+ log.error("Invalied value given to `owl.log`. The value for logging must be `int`, `str`, `float`") # UX
+ raise TypeError("Invalid value")
+ try:
+ self.target.log(json.dumps(kwargs))
+ except TypeError as e:
+ log.error("Invalid value given to `owl.log`. The metrics for logging must be JSON-serializable") # UX
+ raise e
+
+
+# pylint: disable-next=too-many-branches
+def init(
+ project: str,
+ baseline: str,
+ experiment: Optional[str] = None,
+ duplicate_from: Optional[str] = None,
+ description: Optional[str] = None,
+) -> OwLite:
+ r"""Initializes your projects, baselines, and/or experiments.
+
+ * A project comprises one or more baselines, the unmodified models you want to compress.
+
+ * For each baseline in a project, you can create one or more experiments
+ to benchmark various compression configurations for the baseline.
+
+ * The project, baseline, or experiment name must only include alphanumeric characters
+ and special characters among ()-_@:*&.
+
+ ![Baseline-Experiment Hierarchy](https://github.com/SqueezeBits/owlite/assets/116608095/5bb3d540-4930-4f75-af84-6b4b609db392)
+
+ Args:
+ project (`str`): The name of a new (or an existing) project.
+ baseline (`str`): The name of a new baseline.
+ experiment (`str`. `optional`): The name of the experiment you want to create or load.
+ If `experiment` is not provided, the process defaults to `baseline mode`; however,
+ if `experiment` is specified, the process operates in `experiment mode`.
+ duplicate_from (`str`. `optional`): The name of the experiment you want to clone.
+ description (`str`. `optional`): A brief description of your project within 140 characters.
+ (Required only for creating a new project.)
+
+ Raises:
+ RuntimeError: When deprecated or not authenticated.
+ ValueError: When invalid experiment name or baseline name is given.
+
+ Returns:
+ OwLite: An `owlite.OwLite` object configured for the designated project, baseline, and/or experiment.
+
+
+ ### Two different modes triggered by `owlite.init`
+
+ 1. **Baseline mode** : Creating or loading a project and its baseline
+
+
+ If you want to create a new project named "my_project” with a new baseline named "my_model”,
+ add the following line in your code (provided that you have added the import statement `import owlite`):
+
+
+ ```python
+ owl = owlite.init(project="my_project”, baseline="my_model”)
+ ```
+
+ This function call can behave in different ways depending on the circumstances.
+
+ * If the project named `"my_project”` already exists, the existing one will be loaded.
+ * In contrast, if the baseline `"my_model”` already exists in the project `"my_project”`,
+ it will still create a new baseline. The name of the newly created baseline will be renamed
+ automatically by appending an appropriate postfix (e.g., `"my_model_1"` or `"my_model_2”`)
+
+
+ 2. **Experiment mode** : Creating or loading an experiment
+
+ After creating a compression configuration at [owlite.ai](http://owlite.ai), you can benchmark the (compressed)
+ model from your experiment as follows:
+
+ ```python
+ owl = owlite.init(project="my_project”, baseline="my_model”, experiment="my_experiment”)
+ ```
+
+ This function call can behave in different ways depending on the circumstances.
+
+ * If the experiment `"my_experiment”` is not found, it will create a new one. In this case, the compression
+ configuration for the newly created experiment will be empty. By calling
+ [`owl.convert`](https://squeezebits.gitbook.io/owlite/python-api/owlite.owlite.owlite/owlite.owlite.benchmark) and
+ [`owl.benchmark`](https://squeezebits.gitbook.io/owlite/python-api/owlite.owlite.owlite/owlite.owlite.benchmark)
+ for this experiment, you can benchmark the baseline.
+
+ * If the experiment `"my_experiment”` already exists, it downloads the compression configuration from the
+ experiment. By calling
+ [`owl.convert`](https://squeezebits.gitbook.io/owlite/python-api/owlite.owlite.owlite/owlite.owlite.convert) and
+ [`owl.benchmark`](https://squeezebits.gitbook.io/owlite/python-api/owlite.owlite.owlite/owlite.owlite.benchmark),
+ you can benchmark the compressed model from the experiment.
+
+ Furthermore, you can clone an existing experiment by providing its name to `duplicate_from`.
+
+ ```python
+ owl = owlite.init(project="my_project”, baseline="my_model”, experiment="new_experiment”,
+ duplicate_from="existing_experiment”)
+ ```
+
+ If `"new_experiment”` already exists, the newly created experiment will be renamed appropriately
+ (e.g., `"new_experiment_1"` or `"new_experiment_2”`.)
+
+ By performing these tasks, the `init` function ensures that the necessary setup is done for
+ the project, baseline, and experiment within OwLite.
+
+ ### Examples:
+
+ **Baseline Mode**
+
+ ```python
+ import owlite
+
+ owl = owlite.init(project="testProject”, baseline="sampleModel”)
+ ```
+
+ This code creates a new project named `"testProject”` and a new baseline named `"sampleModel”` provided
+ that the project with the same name does not already exist. `owlite.init` returns an `owlite.OwLite` object,
+ which you will need for converting or benchmarking your baseline model.
+
+ A typical output of this code is as follows:
+
+ ```bash
+ OwLite [INFO] Connected device: NVIDIA RTX A6000
+ OwLite [WARNING] Existing local directory found at /home/sqzb/workspace/owlite/testProject/sampleModel/sampleModel.
+ Continuing this code will overwrite the data
+ OwLite [INFO] Created new project 'testProject’
+ OwLite [INFO] Created new baseline 'sampleModel’ at project 'testProject’
+ ```
+
+ **Experiment Mode**
+
+ ```python
+ import torch
+ import owlite
+
+ owl = owlite.init(project="testProject”, baseline="sampleModel”, experiment="conv”)
+ ```
+
+ This code loads the experiment named `"conv”` for the baseline `"sampleModel` in the project `"testProject”`.
+ Likewise, `owlite.init` returns an `owlite.OwLite` object, which you will need for benchmarking the experiment.
+
+ A typical output of this code is as follows:
+
+ ```bash
+ OwLite [INFO] Connected device: NVIDIA RTX A6000
+ OwLite [INFO] Experiment data will be saved in /home/sqzb/workspace/owlite/testProject/sampleModel/conv
+ OwLite [INFO] Loaded existing project 'testProject’
+ OwLite [INFO] Existing compression configuration for 'conv’ found
+ ```
+
+ OwLite stores files, such as ONNX or TensorRT engine, generated from your code at
+ `${OWLITE_HOME}///`, where OWLITE_HOME is an environment variable
+ that defaults to the current working directory ` . `.
+
+ ### Warning messages:
+
+ **No device connected**
+
+ When there is no device connected, you might see the following warning messages:
+
+ ```bash
+ OwLite [WARNING] Connected device not found. Please connect the device by 'owlite device connect --name (name)’
+ ```
+
+ If you see the warning message above, you will encounter a failure in benchmark initialization if you have called
+ `owl.benchmark`. (See
+ [USER GUIDE/how to use/benchmark](https://squeezebits.gitbook.io/owlite/python-api/owlite.owlite.owlite/owlite.owlite.benchmark)
+ for more details.) Other features such as `owl.convert` and `owl.export` will not be affected.
+
+ **Experiment directory exists**
+
+
+ When the local directory for your baseline or experiment already exists, OwLite will overwrite existing files.
+
+ ```bash
+ OwLite [WARNING] Existing local directory found at /home/sqzb/workspace/owlite/testProject/sampleModel/conv.
+ Continuing this code will overwrite the data
+ ```
+
+ """
+ owlite_latest_version = Version(get_latest_version_from_github())
+
+ current_version = Version(OWLITE_VERSION)
+ if current_version.major < owlite_latest_version.major:
+ log.error(
+ f"Your current version ({current_version}) is not supported. "
+ "Please update the package to the latest version with the following command: "
+ "pip install owlite --extra-index-url https://pypi.squeezebits.com/ --upgrade "
+ ) # UX
+ raise RuntimeError("Version is not supported")
+ if current_version < owlite_latest_version:
+ log.warning(
+ "A new version of OwLite is available. "
+ "To ensure the best usage, please update the package to the latest version with the following command: "
+ "pip install owlite --extra-index-url https://pypi.squeezebits.com/ --upgrade "
+ ) # UX
+
+ if OWLITE_SETTINGS.tokens is None:
+ log.error("Please log in using 'owlite login'. Account not found on this device") # UX
+ raise RuntimeError("OwLite token not found")
+
+ if OWLITE_SETTINGS.connected_device is None:
+ log.warning(
+ "Connected device not found. "
+ "You will be automatically connected to the default NEST device as you are subscribed to the free plan. "
+ "Please connect to a specific device using 'owlite device connect --name (name)' if needed"
+ ) # UX
+
+ else:
+ log.info(f"Connected device: {OWLITE_SETTINGS.connected_device.name}") # UX
+
+ _validate_names(
+ project=project,
+ baseline=baseline,
+ experiment=experiment,
+ duplicate_from=duplicate_from,
+ )
+ if description and len(description) > 140:
+ log.error(
+ "The project description should consist of at most 140 characters. "
+ "Note that the description is not required for loading an existing project"
+ ) # UX
+ raise ValueError("Description length exceeds limit")
+
+ if experiment == baseline:
+ log.error(
+ f"Experiment name '{baseline}' is reserved for the baseline. Please try with a different experiment name"
+ ) # UX
+ raise ValueError("Invalid experiment name")
+
+ proj: Project = Project.load_or_create(project, description=description)
+
+ target: Union[Baseline, Experiment]
+ if experiment is None:
+ if duplicate_from:
+ log.warning(
+ f"`duplicate_from='{duplicate_from}'` will be ignored as no value for `experiment` was provided"
+ ) # UX
+ target = Baseline.create(proj, baseline)
+ else:
+ existing_baseline = Baseline.load(proj, baseline)
+ if existing_baseline is None:
+ log.error(
+ f"No such baseline: {baseline}. "
+ f"Please check if the baseline name for the experiment '{experiment}' is correct"
+ ) # UX
+ raise ValueError("Invalid baseline name")
+ if duplicate_from is None:
+ target = Experiment.load_or_create(existing_baseline, experiment)
+ else:
+ existing_experiment = Experiment.load(existing_baseline, duplicate_from)
+ if existing_experiment is None:
+ log.error(
+ f"The experiment '{duplicate_from}' to duplicate from is not found. "
+ "Please check if the project name provided for `duplicate_from` argument is correct"
+ ) # UX
+ raise ValueError("Invalid experiment name")
+ target = existing_experiment.clone(experiment)
+
+ if os.path.exists(target.home):
+ log.warning(
+ f"Existing local directory found at {target.home}. Continuing this code will overwrite the data"
+ ) # UX
+ else:
+ os.makedirs(target.home, exist_ok=True)
+ log.info(f"Experiment data will be saved in {target.home}") # UX
+
+ return OwLite(target)
+
+
+def _validate_names(**kwargs: Any) -> None:
+ """Validate a list of names.
+
+ Args:
+ **kwargs: A dictionary where keys are identifiers and values are names to validate.
+
+ Raises:
+ ValueError: If any name is invalid.
+ """
+ invalid_keys = []
+ regex = r"^[a-zA-Z0-9()\-_@:*&]+$"
+ for key, name in kwargs.items():
+ if name is None:
+ continue
+ if not re.fullmatch(regex, name):
+ invalid_keys.append(key)
+ if len(invalid_keys) > 0:
+ invalid_items = ", ".join(f"{key}={kwargs[key]}" for key in invalid_keys)
+ log.error(
+ f"The following names do not meet the requirement: {invalid_items}. "
+ "A valid name must consist of alphanumeric characters or special characters chosen from ()-_@:*&"
+ ) # UX
+ raise ValueError("Invalid name")
diff --git a/owlite_core/__init__.py b/src/owlite/owlite_core/__init__.py
similarity index 100%
rename from owlite_core/__init__.py
rename to src/owlite/owlite_core/__init__.py
diff --git a/owlite_core/api_base.py b/src/owlite/owlite_core/api_base.py
similarity index 100%
rename from owlite_core/api_base.py
rename to src/owlite/owlite_core/api_base.py
diff --git a/owlite_core/api_enums.py b/src/owlite/owlite_core/api_enums.py
similarity index 100%
rename from owlite_core/api_enums.py
rename to src/owlite/owlite_core/api_enums.py
diff --git a/owlite_core/cache/__init__.py b/src/owlite/owlite_core/cache/__init__.py
similarity index 100%
rename from owlite_core/cache/__init__.py
rename to src/owlite/owlite_core/cache/__init__.py
diff --git a/owlite_core/cache/base_urls.py b/src/owlite/owlite_core/cache/base_urls.py
similarity index 100%
rename from owlite_core/cache/base_urls.py
rename to src/owlite/owlite_core/cache/base_urls.py
diff --git a/owlite_core/cache/device_manager.py b/src/owlite/owlite_core/cache/device_manager.py
similarity index 100%
rename from owlite_core/cache/device_manager.py
rename to src/owlite/owlite_core/cache/device_manager.py
diff --git a/owlite_core/cache/text.py b/src/owlite/owlite_core/cache/text.py
similarity index 100%
rename from owlite_core/cache/text.py
rename to src/owlite/owlite_core/cache/text.py
diff --git a/owlite_core/cache/tokens.py b/src/owlite/owlite_core/cache/tokens.py
similarity index 100%
rename from owlite_core/cache/tokens.py
rename to src/owlite/owlite_core/cache/tokens.py
diff --git a/owlite_core/cli/__init__.py b/src/owlite/owlite_core/cli/__init__.py
similarity index 100%
rename from owlite_core/cli/__init__.py
rename to src/owlite/owlite_core/cli/__init__.py
diff --git a/owlite_core/cli/api/__init__.py b/src/owlite/owlite_core/cli/api/__init__.py
similarity index 100%
rename from owlite_core/cli/api/__init__.py
rename to src/owlite/owlite_core/cli/api/__init__.py
diff --git a/owlite_core/cli/api/device.py b/src/owlite/owlite_core/cli/api/device.py
similarity index 100%
rename from owlite_core/cli/api/device.py
rename to src/owlite/owlite_core/cli/api/device.py
diff --git a/owlite_core/cli/api/login.py b/src/owlite/owlite_core/cli/api/login.py
similarity index 100%
rename from owlite_core/cli/api/login.py
rename to src/owlite/owlite_core/cli/api/login.py
diff --git a/owlite_core/cli/commands/__init__.py b/src/owlite/owlite_core/cli/commands/__init__.py
similarity index 100%
rename from owlite_core/cli/commands/__init__.py
rename to src/owlite/owlite_core/cli/commands/__init__.py
diff --git a/owlite_core/cli/commands/device_commands.py b/src/owlite/owlite_core/cli/commands/device_commands.py
similarity index 100%
rename from owlite_core/cli/commands/device_commands.py
rename to src/owlite/owlite_core/cli/commands/device_commands.py
diff --git a/owlite_core/cli/commands/url_commands.py b/src/owlite/owlite_core/cli/commands/url_commands.py
similarity index 100%
rename from owlite_core/cli/commands/url_commands.py
rename to src/owlite/owlite_core/cli/commands/url_commands.py
diff --git a/owlite_core/cli/commands/user_commands.py b/src/owlite/owlite_core/cli/commands/user_commands.py
similarity index 98%
rename from owlite_core/cli/commands/user_commands.py
rename to src/owlite/owlite_core/cli/commands/user_commands.py
index 179be41..8f8c631 100644
--- a/owlite_core/cli/commands/user_commands.py
+++ b/src/owlite/owlite_core/cli/commands/user_commands.py
@@ -2,8 +2,7 @@
# pylint: disable=unnecessary-lambda, too-few-public-methods
from argparse import Namespace, _SubParsersAction
-from owlite_core.logger import log
-
+from ...logger import log
from .. import BaseOwLiteCLICommand
from ..api.login import whoami
from ..login import login, logout
diff --git a/owlite_core/cli/device.py b/src/owlite/owlite_core/cli/device.py
similarity index 96%
rename from owlite_core/cli/device.py
rename to src/owlite/owlite_core/cli/device.py
index 80af08d..abb7845 100644
--- a/owlite_core/cli/device.py
+++ b/src/owlite/owlite_core/cli/device.py
@@ -49,6 +49,9 @@ def add_manager(name: str, url: str) -> None:
get_devices(url)
OWLITE_SETTINGS.add_manager(DeviceManager(name=name, url=url))
+ connected_device = OWLITE_SETTINGS.connected_device
+ if connected_device and name == connected_device.manager.name and url != connected_device.manager.url:
+ disconnect_device()
print_manager_list()
diff --git a/owlite_core/cli/login.py b/src/owlite/owlite_core/cli/login.py
similarity index 100%
rename from owlite_core/cli/login.py
rename to src/owlite/owlite_core/cli/login.py
diff --git a/owlite_core/cli/owlite_cli.py b/src/owlite/owlite_core/cli/owlite_cli.py
similarity index 100%
rename from owlite_core/cli/owlite_cli.py
rename to src/owlite/owlite_core/cli/owlite_cli.py
diff --git a/owlite_core/cli/url.py b/src/owlite/owlite_core/cli/url.py
similarity index 72%
rename from owlite_core/cli/url.py
rename to src/owlite/owlite_core/cli/url.py
index 434ca34..238c202 100644
--- a/owlite_core/cli/url.py
+++ b/src/owlite/owlite_core/cli/url.py
@@ -2,22 +2,22 @@
This module handles the caching of base URLs used in OwLite APIs."""
-from owlite_core.logger import log
-
+from ...owlite_core.logger import log
from ..owlite_settings import OWLITE_SETTINGS
+from .device import disconnect_device
URL_NAME_LIST = ["FRONT", "MAIN", "DOVE", "NEST"]
def save_base_url(name: str, url: str) -> None:
- """Saves url in cache.
+ """Saves the base URL for an API in the cache.
Args:
name (str): A name of a URL.
url (str): URL to save.
Raises:
- HTTPError: When login request was not successful.
+ ValueError: If the API name is invalid.
"""
if name not in URL_NAME_LIST:
log.error(f"Invalid API base name: '{name}'. Valid API base names are {URL_NAME_LIST}") # UX
@@ -26,6 +26,14 @@ def save_base_url(name: str, url: str) -> None:
base_urls.set(name, url)
OWLITE_SETTINGS.base_url = base_urls
+ if (
+ name == "NEST"
+ and OWLITE_SETTINGS.connected_device
+ and OWLITE_SETTINGS.connected_device.manager.name == name
+ and OWLITE_SETTINGS.connected_device.manager.url != url
+ ):
+ disconnect_device()
+
log.info(f"The {name} API base is set to {url}") # UX
@@ -50,3 +58,6 @@ def delete_base_url(name: str) -> None:
base_urls.set(name)
OWLITE_SETTINGS.base_url = base_urls
log.info(f"Deleted the {name} API base") # UX
+
+ if name == "NEST" and OWLITE_SETTINGS.connected_device and OWLITE_SETTINGS.connected_device.manager.name == name:
+ disconnect_device()
diff --git a/owlite_core/constants.py b/src/owlite/owlite_core/constants.py
similarity index 90%
rename from owlite_core/constants.py
rename to src/owlite/owlite_core/constants.py
index 9dbf653..181165e 100644
--- a/owlite_core/constants.py
+++ b/src/owlite/owlite_core/constants.py
@@ -1,4 +1,3 @@
-"""Constants for file, paths."""
import os
OWLITE_HOME = os.path.join(os.getenv("OWLITE_HOME", os.path.join(os.getcwd(), "owlite")))
@@ -17,4 +16,4 @@
FX_CONFIGURATION_FORMAT_VERSION = "1.1"
OWLITE_SETTINGS_FORMAT_VERSION = "1.1"
-OWLITE_VERSION = "1.2.0"
+OWLITE_VERSION = "1.2.1"
diff --git a/owlite_core/exceptions.py b/src/owlite/owlite_core/exceptions.py
similarity index 100%
rename from owlite_core/exceptions.py
rename to src/owlite/owlite_core/exceptions.py
diff --git a/owlite_core/github_utils.py b/src/owlite/owlite_core/github_utils.py
similarity index 100%
rename from owlite_core/github_utils.py
rename to src/owlite/owlite_core/github_utils.py
diff --git a/owlite_core/logger.py b/src/owlite/owlite_core/logger.py
similarity index 100%
rename from owlite_core/logger.py
rename to src/owlite/owlite_core/logger.py
diff --git a/owlite_core/owlite_settings.py b/src/owlite/owlite_core/owlite_settings.py
similarity index 100%
rename from owlite_core/owlite_settings.py
rename to src/owlite/owlite_core/owlite_settings.py