diff --git a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift index aebf50e008a..a4f70602a6f 100644 --- a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift +++ b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift @@ -59,6 +59,16 @@ class ModuleTest: XCTestCase { let inputs = [Tensor([1], dataType: .float), Tensor([1], dataType: .float)] var outputs: [Value]? XCTAssertNoThrow(outputs = try module.forward(inputs)) - XCTAssertEqual(outputs?[0].tensor, Tensor([2], dataType: .float, shapeDynamism: .static)) + XCTAssertEqual(outputs?.first?.tensor, Tensor([2], dataType: .float, shapeDynamism: .static)) + + let inputs2 = [Tensor([2], dataType: .float), Tensor([3], dataType: .float)] + var outputs2: [Value]? + XCTAssertNoThrow(outputs2 = try module.forward(inputs2)) + XCTAssertEqual(outputs2?.first?.tensor, Tensor([5], dataType: .float, shapeDynamism: .static)) + + let inputs3 = [Tensor([13.25], dataType: .float), Tensor([29.25], dataType: .float)] + var outputs3: [Value]? + XCTAssertNoThrow(outputs3 = try module.forward(inputs3)) + XCTAssertEqual(outputs3?.first?.tensor, Tensor([42.5], dataType: .float, shapeDynamism: .static)) } }