Skip to content

Commit

Permalink
Add CoreML tests. (#6203)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #6203

.

Reviewed By: metascroy

Differential Revision: D64359459

fbshipit-source-id: acfa3990b1b90fd300ead0f47e71ebe82d70e7f9
  • Loading branch information
shoumikhin authored and facebook-github-bot committed Oct 15, 2024
1 parent 3a7056e commit 5c8b115
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
03DD00B22C8FE44600FE4619 /* backend_mps.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD00A22C8FE44600FE4619 /* backend_mps.xcframework */; };
03DD00B32C8FE44600FE4619 /* executorch.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD00A32C8FE44600FE4619 /* executorch.xcframework */; settings = {ATTRIBUTES = (Required, ); }; };
03DD00B52C8FE44600FE4619 /* kernels_quantized.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD00A52C8FE44600FE4619 /* kernels_quantized.xcframework */; };
03E7E6792CBDCAE900205E71 /* CoreMLTests.mm in Sources */ = {isa = PBXBuildFile; fileRef = 03E7E6782CBDC1C900205E71 /* CoreMLTests.mm */; };
03ED6D0F2C8AAFE900F2D6EE /* libsqlite3.0.tbd in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D0E2C8AAFE900F2D6EE /* libsqlite3.0.tbd */; };
03ED6D112C8AAFF200F2D6EE /* MetalPerformanceShadersGraph.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D102C8AAFF200F2D6EE /* MetalPerformanceShadersGraph.framework */; };
03ED6D132C8AAFF700F2D6EE /* MetalPerformanceShaders.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D122C8AAFF700F2D6EE /* MetalPerformanceShaders.framework */; };
Expand Down Expand Up @@ -90,6 +91,7 @@
03DD00A22C8FE44600FE4619 /* backend_mps.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = backend_mps.xcframework; path = Frameworks/backend_mps.xcframework; sourceTree = "<group>"; };
03DD00A32C8FE44600FE4619 /* executorch.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = executorch.xcframework; path = Frameworks/executorch.xcframework; sourceTree = "<group>"; };
03DD00A52C8FE44600FE4619 /* kernels_quantized.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = kernels_quantized.xcframework; path = Frameworks/kernels_quantized.xcframework; sourceTree = "<group>"; };
03E7E6782CBDC1C900205E71 /* CoreMLTests.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = CoreMLTests.mm; sourceTree = "<group>"; };
03ED6D0E2C8AAFE900F2D6EE /* libsqlite3.0.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = libsqlite3.0.tbd; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS17.5.sdk/usr/lib/libsqlite3.0.tbd; sourceTree = DEVELOPER_DIR; };
03ED6D102C8AAFF200F2D6EE /* MetalPerformanceShadersGraph.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalPerformanceShadersGraph.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS17.5.sdk/System/Library/Frameworks/MetalPerformanceShadersGraph.framework; sourceTree = DEVELOPER_DIR; };
03ED6D122C8AAFF700F2D6EE /* MetalPerformanceShaders.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalPerformanceShaders.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS17.5.sdk/System/Library/Frameworks/MetalPerformanceShaders.framework; sourceTree = DEVELOPER_DIR; };
Expand Down Expand Up @@ -232,6 +234,7 @@
isa = PBXGroup;
children = (
032A73C92CAFBA8600932D36 /* LLaMA */,
03E7E6782CBDC1C900205E71 /* CoreMLTests.mm */,
03B2D3792C8A515C0046936E /* GenericTests.mm */,
03B019502C8A80D30044D558 /* Tests.xcconfig */,
037C96A02C8A570B00B3DF38 /* Tests.xctestplan */,
Expand Down Expand Up @@ -388,6 +391,7 @@
032A741E2CAFBB7800932D36 /* tiktoken.cpp in Sources */,
032A741F2CAFBB7800932D36 /* sampler.cpp in Sources */,
03B011912CAD114E00054791 /* ResourceTestCase.m in Sources */,
03E7E6792CBDCAE900205E71 /* CoreMLTests.mm in Sources */,
032A74232CAFC1B300932D36 /* runner.cpp in Sources */,
03B2D37A2C8A515C0046936E /* GenericTests.mm in Sources */,
032A73CA2CAFBA8600932D36 /* LLaMATests.mm in Sources */,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
</BuildAction>
<TestAction
buildConfiguration = "Release"
selectedDebuggerIdentifier = ""
selectedLauncherIdentifier = "Xcode.IDEFoundation.Launcher.PosixSpawn"
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
shouldUseLaunchSchemeArgsEnv = "YES">
<TestPlans>
<TestPlanReference
Expand Down
105 changes: 105 additions & 0 deletions extension/benchmark/apple/Benchmark/Tests/CoreMLTests.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#import "ResourceTestCase.h"

#import <CoreML/CoreML.h>

static MLMultiArray *DummyMultiArrayForFeature(MLFeatureDescription *feature, NSError **error) {
MLMultiArray *array = [[MLMultiArray alloc] initWithShape:feature.multiArrayConstraint.shape
dataType:feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? MLMultiArrayDataTypeInt32 : MLMultiArrayDataTypeDouble
error:error];
for (auto index = 0; index < array.count; ++index) {
array[index] = feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? @1 : @1.0;
}
return array;
}

static NSMutableDictionary *DummyInputsForModel(MLModel *model, NSError **error) {
NSMutableDictionary *inputs = [NSMutableDictionary dictionary];
NSDictionary<NSString *, MLFeatureDescription *> *inputDescriptions = model.modelDescription.inputDescriptionsByName;

for (NSString *inputName in inputDescriptions) {
MLFeatureDescription *feature = inputDescriptions[inputName];

switch (feature.type) {
case MLFeatureTypeMultiArray: {
MLMultiArray *array = DummyMultiArrayForFeature(feature, error);
inputs[inputName] = [MLFeatureValue featureValueWithMultiArray:array];
break;
}
case MLFeatureTypeInt64:
inputs[inputName] = [MLFeatureValue featureValueWithInt64:1];
break;
case MLFeatureTypeDouble:
inputs[inputName] = [MLFeatureValue featureValueWithDouble:1.0];
break;
case MLFeatureTypeString:
inputs[inputName] = [MLFeatureValue featureValueWithString:@"1"];
break;
default:
break;
}
}
return inputs;
}

@interface CoreMLTests : ResourceTestCase
@end

@implementation CoreMLTests

+ (NSArray<NSString *> *)directories {
return @[@"Resources"];
}

+ (NSDictionary<NSString *, BOOL (^)(NSString *)> *)predicates {
return @{ @"model" : ^BOOL(NSString *filename) {
return [filename hasSuffix:@".mlpackage"];
}};
}

+ (NSDictionary<NSString *, void (^)(XCTestCase *)> *)dynamicTestsForResources:(NSDictionary<NSString *, NSString *> *)resources {
NSString *modelPath = resources[@"model"];

return @{
@"prediction" : ^(XCTestCase *testCase) {
NSError *error = nil;
NSURL *compiledModelURL = [MLModel compileModelAtURL:[NSURL fileURLWithPath:modelPath] error:&error];
if (error || !compiledModelURL) {
XCTFail(@"Failed to compile model: %@", error.localizedDescription);
return;
}
MLModel *model = [MLModel modelWithContentsOfURL:compiledModelURL error:&error];
if (error || !model) {
XCTFail(@"Failed to load model: %@", error.localizedDescription);
return;
}
NSMutableDictionary *inputs = DummyInputsForModel(model, &error);
if (error || !inputs) {
XCTFail(@"Failed to prepare inputs: %@", error.localizedDescription);
return;
}
MLDictionaryFeatureProvider *featureProvider = [[MLDictionaryFeatureProvider alloc] initWithDictionary:inputs error:&error];
if (error || !featureProvider) {
XCTFail(@"Failed to create input provider: %@", error.localizedDescription);
return;
}
[testCase measureWithMetrics:@[[XCTClockMetric new], [XCTMemoryMetric new]]
block:^{
NSError *error = nil;
id<MLFeatureProvider> prediction = [model predictionFromFeatures:featureProvider error:&error];
if (error || !prediction) {
XCTFail(@"Prediction failed: %@", error.localizedDescription);
}
}];
}
};
}

@end

0 comments on commit 5c8b115

Please sign in to comment.