From 2b984970a837f0b1a3a192115ada2733accbc540 Mon Sep 17 00:00:00 2001 From: Yiran-ASU Date: Wed, 9 Oct 2024 11:31:02 -0700 Subject: [PATCH] assertation added --- experiments/CamemBERT/model/CamemBERT.py | 2 ++ experiments/VIT/model/vit.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/experiments/CamemBERT/model/CamemBERT.py b/experiments/CamemBERT/model/CamemBERT.py index 1eb4998..23e0412 100644 --- a/experiments/CamemBERT/model/CamemBERT.py +++ b/experiments/CamemBERT/model/CamemBERT.py @@ -46,6 +46,7 @@ def forward(self, input): print("Parsing sentence tokens.") example_input = prepare_sentence_tokens(model_name, sentence) print("example_input shape: ", example_input.shape) +assert example_input.shape == (1, 7, 768), f"Expected shape (1,7,768), but got {example_input.shape}" # The original example_input shape is [1, 7, 768], now we reshape it into [1, 7*768] example_input = example_input.reshape(1, 7*768) @@ -53,6 +54,7 @@ def forward(self, input): print("Instantiating model.") model = OnlyLogitsHuggingFaceModel(model_name) +assert example_input.shape == (1, 7*768), f"Expected the reshaped (1,7*768), but got {example_input.shape}" print(model(example_input).shape) linalg_on_tensors_mlir = torch_mlir.compile( diff --git a/experiments/VIT/model/vit.py b/experiments/VIT/model/vit.py index 381a4c6..ba1eb00 100644 --- a/experiments/VIT/model/vit.py +++ b/experiments/VIT/model/vit.py @@ -38,6 +38,8 @@ def forward(self,x): model = prepare().eval() example_input = model(inputs) print(example_input.shape) +assert example_input.shape == (1, 197, 768), f"Expected shape (1,197,768), but got {example_input.shape}" + # The original example_input shape is [1, 197, 768], now we reshape it into [1, 197*768] example_input = example_input.reshape(1, 197*768) @@ -45,6 +47,7 @@ def forward(self,x): vit_model = vit().eval() +assert example_input.shape == (1, 197*768), f"Expected the reshaped (1,197*768), but got {example_input.shape}" output = vit_model(example_input) print(output.shape)