Skip to content

Commit 5f439e9

Browse files
committed
implement tests
1 parent 3b49a8d commit 5f439e9

File tree

1 file changed

+135
-11
lines changed

1 file changed

+135
-11
lines changed

Tests/MLXLMTests/UserInputTests.swift

+135-11
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,140 @@
11
import Foundation
2-
import XCTest
3-
import MLXLMCommon
42
import MLX
3+
import MLXLMCommon
4+
import MLXVLM
5+
import XCTest
6+
7+
func assertEqual(
8+
_ v1: Any, _ v2: Any, path: [String] = [], file: StaticString = #filePath, line: UInt = #line
9+
) {
10+
switch (v1, v2) {
11+
case let (v1, v2) as (String, String):
12+
XCTAssertEqual(v1, v2, file: file, line: line)
13+
14+
case let (v1, v2) as ([Any], [Any]):
15+
XCTAssertEqual(
16+
v1.count, v2.count, "Arrays not equal size at \(path)", file: file, line: line)
17+
18+
for (index, (v1v, v2v)) in zip(v1, v2).enumerated() {
19+
assertEqual(v1v, v2v, path: path + [index.description], file: file, line: line)
20+
}
21+
22+
case let (v1, v2) as ([String: Any], [String: Any]):
23+
XCTAssertEqual(
24+
v1.keys.sorted(), v2.keys.sorted(),
25+
"\(String(describing: v1.keys.sorted())) and \(String(describing: v2.keys.sorted())) not equal at \(path)",
26+
file: file, line: line)
27+
28+
for (k, v1v) in v1 {
29+
if let v2v = v2[k] {
30+
assertEqual(v1v, v2v, path: path + [k], file: file, line: line)
31+
} else {
32+
XCTFail("Missing value for \(k) at \(path)", file: file, line: line)
33+
}
34+
}
35+
default:
36+
XCTFail(
37+
"Unable to compare \(String(describing: v1)) and \(String(describing: v2)) at \(path)",
38+
file: file, line: line)
39+
}
40+
}
41+
42+
public class UserInputTests: XCTestCase {
43+
44+
public func testStandardConversion() {
45+
let chat: [Chat.Message] = [
46+
.system("You are a useful agent."),
47+
.user("Tell me a story."),
48+
]
549

6-
public class MLXLMCommonTests: XCTestCase {
7-
8-
public func testExample() {
9-
let x = UserInput(prompt: "foo")
10-
print(x)
11-
12-
let a = MLXArray(10)
13-
print(a + 1)
50+
let messages = DefaultMessageGenerator().generate(messages: chat)
51+
52+
let expected = [
53+
[
54+
"role": "system",
55+
"content": "You are a useful agent.",
56+
],
57+
[
58+
"role": "user",
59+
"content": "Tell me a story.",
60+
],
61+
]
62+
63+
XCTAssertEqual(expected, messages as? [[String: String]])
64+
}
65+
66+
public func testQwen2ConversionText() {
67+
let chat: [Chat.Message] = [
68+
.system("You are a useful agent."),
69+
.user("Tell me a story."),
70+
]
71+
72+
let messages = Qwen2VLMessageGenerator().generate(messages: chat)
73+
74+
let expected = [
75+
[
76+
"role": "system",
77+
"content": [
78+
[
79+
"type": "text",
80+
"text": "You are a useful agent.",
81+
]
82+
],
83+
],
84+
[
85+
"role": "user",
86+
"content": [
87+
[
88+
"type": "text",
89+
"text": "Tell me a story.",
90+
]
91+
],
92+
],
93+
]
94+
95+
assertEqual(expected, messages)
1496
}
15-
97+
98+
public func testQwen2ConversionImage() {
99+
let chat: [Chat.Message] = [
100+
.system("You are a useful agent."),
101+
.user(
102+
"What is this?",
103+
images: [
104+
.url(
105+
URL(
106+
string: "https://opensource.apple.com/images/projects/mlx.f5c59d8b.png")!
107+
)
108+
]),
109+
]
110+
111+
let messages = Qwen2VLMessageGenerator().generate(messages: chat)
112+
113+
let expected = [
114+
[
115+
"role": "system",
116+
"content": [
117+
[
118+
"type": "text",
119+
"text": "You are a useful agent.",
120+
]
121+
],
122+
],
123+
[
124+
"role": "user",
125+
"content": [
126+
[
127+
"type": "text",
128+
"text": "What is this?",
129+
],
130+
[
131+
"type": "image"
132+
],
133+
],
134+
],
135+
]
136+
137+
assertEqual(expected, messages)
138+
}
139+
16140
}

0 commit comments

Comments
 (0)