diff --git a/pysilcam/silcam_classify.py b/pysilcam/silcam_classify.py index 86e6d8d1..3f493346 100644 --- a/pysilcam/silcam_classify.py +++ b/pysilcam/silcam_classify.py @@ -4,7 +4,6 @@ import numpy as np import pandas as pd import scipy -import tensorflow.keras as keras ''' SilCam TensorFlow analysis for classification of particle types @@ -67,6 +66,8 @@ def load_model(model_path): Returns: model (tf model object) : loaded tf.keras model from load_model() ''' + import tensorflow.keras as keras + path, filename = os.path.split(model_path) header = pd.read_csv(os.path.join(path, 'header.tfl.txt')) class_labels = header.columns diff --git a/pysilcam/tests/test_z_classify.py b/pysilcam/tests/test_z_classify.py index 07943646..06e0f220 100644 --- a/pysilcam/tests/test_z_classify.py +++ b/pysilcam/tests/test_z_classify.py @@ -4,7 +4,6 @@ import os import numpy as np import unittest -import tensorflow as tf # Get user-defined path to unittest data folder ROOTPATH = os.environ.get('UNITTEST_DATA_PATH', None) @@ -25,6 +24,7 @@ def test_classify(): @todo include more advanced testing of the classification feks. assert values in a confusion matrix. ''' + import tensorflow as tf # location of the training data database_path = os.path.join(ROOTPATH, 'silcam_classification_database')