Skip to content

Commit

Permalink
bugfix deserializer,test_compare_onnx_model
Browse files Browse the repository at this point in the history
  • Loading branch information
Taniya-Das committed Nov 14, 2024
1 parent 6a03098 commit bf9f13a
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 94 deletions.
26 changes: 12 additions & 14 deletions docs/Examples/tf_image_classification_Indoorscenes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,23 @@

############################################################################

openml_tensorflow.config.epoch = 10 # small epoch for test runs
openml_tensorflow.config.epoch = 1 # small epoch for test runs

IMG_SIZE = (128, 128)
IMG_SHAPE = IMG_SIZE + (3,)
base_learning_rate = 0.0001

# dataset = openml.datasets.get_dataset(45936, download_all_files=True, download_data=True)

# Toy example
datagen = ImageDataGenerator(
rotation_range=25,
width_shift_range=0.01,
height_shift_range=0.01,
brightness_range=(0.9, 1.1),
zoom_range=0.1,
horizontal_flip=True,
vertical_flip=True,
)

openml_tensorflow.config.datagen = datagen
openml_tensorflow.config.dir = openml.config.get_cache_directory()+'/datasets/45936/Images/'
openml_tensorflow.config.x_col = "Filename"
Expand All @@ -69,16 +77,6 @@
############################################################################
# Large CNN

datagen = ImageDataGenerator(
rotation_range=25,
width_shift_range=0.01,
height_shift_range=0.01,
brightness_range=(0.9, 1.1),
zoom_range=0.1,
horizontal_flip=True,
vertical_flip=True,
)

IMG_SIZE = 128
NUM_CLASSES = 67

Expand Down Expand Up @@ -144,7 +142,7 @@
# Careful to not call this function when another run_model_on_task is called in between,
# as during publish later, only the last trained model (from last run_model_on_task call) is uploaded.
run = openml_tensorflow.add_onnx_to_run(run)
# breakpoint()

run.publish()

print('URL for run: %s/run/%d?api_key=%s' % (openml.config.server, run.run_id, openml.config.apikey))
Expand Down
Binary file added model.onnx
Binary file not shown.
4 changes: 2 additions & 2 deletions openml_tensorflow/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,8 @@ def _from_parameters(self, parameters: 'OrderedDict[str, Any]') -> Any:
# Recover loss functions and metrics
loss = training_config['loss']
metrics = training_config['metrics']
sample_weight_mode = training_config['sample_weight_mode']
loss_weights = training_config['loss_weights']
sample_weight_mode = training_config.get('sample_weight_mode', None)
loss_weights = training_config.get('loss_weights', None)

# Compile model
model.compile(optimizer=optimizer,
Expand Down
35 changes: 28 additions & 7 deletions tests/test_compare_onnx_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
"""
Below pytest test compares two onnx models for identical structure and parameters.
"""
import onnx
import numpy as np
import os
import pytest

# Paths to ONNX models
MODEL_PATH1 = "model1.onnx"
MODEL_PATH2 = "model2.onnx"

# Helper function to compare graph structures
def compare_graphs(graph1, graph2):
nodes1 = graph1.node
nodes2 = graph2.node
Expand All @@ -17,12 +28,12 @@ def compare_graphs(graph1, graph2):
print("Graph structures are identical.")
return True

def compare_models(model_path1, model_path2):
# Helper function to compare model parameters
def compare_parameters(model_path1, model_path2):
# Load ONNX models
model1 = onnx.load(model_path1)
model2 = onnx.load(model_path2)

# Compare graph structures
# Compare graph structures
if not compare_graphs(model1.graph, model2.graph):
return False
Expand All @@ -40,9 +51,19 @@ def compare_models(model_path1, model_path2):
print("Models are identical.")
return True

# Paths to ONNX models
model_path1 = "model1.onnx"
model_path2 = "model2.onnx"

# Compare models
compare_models(model_path1, model_path2)
@pytest.mark.skipif(
not (os.path.exists("model1.onnx") and os.path.exists("model2.onnx")),
reason="ONNX models are not available"
)
def test_compare_onnx_models():
# Load ONNX models
model1 = onnx.load(MODEL_PATH1)
model2 = onnx.load(MODEL_PATH2)

# Perform comparisons
assert compare_graphs(model1.graph, model2.graph), "Graph structures are different."
assert compare_parameters(model1, model2), "Model parameters are different."

print("ONNX models are identical.")

71 changes: 0 additions & 71 deletions tests/test_keras_extension.py

This file was deleted.

0 comments on commit bf9f13a

Please sign in to comment.