forked from RustPython/RustPython
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_patch_spec.py
More file actions
362 lines (308 loc) · 10.7 KB
/
test_patch_spec.py
File metadata and controls
362 lines (308 loc) · 10.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
"""Tests for patch_spec.py - core patch extraction and application."""
import ast
import unittest
from update_lib.patch_spec import (
COMMENT,
PatchSpec,
UtMethod,
_find_import_insert_line,
apply_patches,
extract_patches,
iter_tests,
)
class TestIterTests(unittest.TestCase):
"""Tests for iter_tests function."""
def test_iter_tests_simple(self):
"""Test iterating over test methods in a class."""
code = """
class TestFoo(unittest.TestCase):
def test_one(self):
pass
def test_two(self):
pass
"""
tree = ast.parse(code)
results = list(iter_tests(tree))
self.assertEqual(len(results), 2)
self.assertEqual(results[0][0].name, "TestFoo")
self.assertEqual(results[0][1].name, "test_one")
self.assertEqual(results[1][1].name, "test_two")
def test_iter_tests_multiple_classes(self):
"""Test iterating over multiple test classes."""
code = """
class TestFoo(unittest.TestCase):
def test_foo(self):
pass
class TestBar(unittest.TestCase):
def test_bar(self):
pass
"""
tree = ast.parse(code)
results = list(iter_tests(tree))
self.assertEqual(len(results), 2)
self.assertEqual(results[0][0].name, "TestFoo")
self.assertEqual(results[1][0].name, "TestBar")
def test_iter_tests_async(self):
"""Test iterating over async test methods."""
code = """
class TestAsync(unittest.TestCase):
async def test_async(self):
pass
"""
tree = ast.parse(code)
results = list(iter_tests(tree))
self.assertEqual(len(results), 1)
self.assertEqual(results[0][1].name, "test_async")
class TestExtractPatches(unittest.TestCase):
"""Tests for extract_patches function."""
def test_extract_expected_failure(self):
"""Test extracting @unittest.expectedFailure decorator."""
code = f"""
class TestFoo(unittest.TestCase):
# {COMMENT}
@unittest.expectedFailure
def test_one(self):
pass
"""
patches = extract_patches(code)
self.assertIn("TestFoo", patches)
self.assertIn("test_one", patches["TestFoo"])
specs = patches["TestFoo"]["test_one"]
self.assertEqual(len(specs), 1)
self.assertEqual(specs[0].ut_method, UtMethod.ExpectedFailure)
def test_extract_expected_failure_inline_comment(self):
"""Test extracting expectedFailure with inline comment."""
code = f"""
class TestFoo(unittest.TestCase):
@unittest.expectedFailure # {COMMENT}
def test_one(self):
pass
"""
patches = extract_patches(code)
self.assertIn("TestFoo", patches)
self.assertIn("test_one", patches["TestFoo"])
def test_extract_skip_with_reason(self):
"""Test extracting @unittest.skip with reason."""
code = f'''
class TestFoo(unittest.TestCase):
@unittest.skip("{COMMENT}; not implemented")
def test_one(self):
pass
'''
patches = extract_patches(code)
self.assertIn("TestFoo", patches)
specs = patches["TestFoo"]["test_one"]
self.assertEqual(specs[0].ut_method, UtMethod.Skip)
self.assertIn("not implemented", specs[0].reason)
def test_extract_skip_if(self):
"""Test extracting @unittest.skipIf decorator."""
code = f'''
class TestFoo(unittest.TestCase):
@unittest.skipIf(sys.platform == "win32", "{COMMENT}; windows issue")
def test_one(self):
pass
'''
patches = extract_patches(code)
specs = patches["TestFoo"]["test_one"]
self.assertEqual(specs[0].ut_method, UtMethod.SkipIf)
# ast.unparse normalizes quotes to single quotes
self.assertIn("sys.platform", specs[0].cond)
self.assertIn("win32", specs[0].cond)
def test_no_patches_without_comment(self):
"""Test that decorators without COMMENT are not extracted."""
code = """
class TestFoo(unittest.TestCase):
@unittest.expectedFailure
def test_one(self):
pass
"""
patches = extract_patches(code)
self.assertEqual(patches, {})
def test_multiple_patches_same_method(self):
"""Test extracting multiple decorators on same method."""
code = f'''
class TestFoo(unittest.TestCase):
# {COMMENT}
@unittest.expectedFailure
@unittest.skip("{COMMENT}; reason")
def test_one(self):
pass
'''
patches = extract_patches(code)
specs = patches["TestFoo"]["test_one"]
self.assertEqual(len(specs), 2)
class TestApplyPatches(unittest.TestCase):
"""Tests for apply_patches function."""
def test_apply_expected_failure(self):
"""Test applying @unittest.expectedFailure."""
code = """import unittest
class TestFoo(unittest.TestCase):
def test_one(self):
pass
"""
patches = {
"TestFoo": {"test_one": [PatchSpec(UtMethod.ExpectedFailure, None, "")]}
}
result = apply_patches(code, patches)
self.assertIn("@unittest.expectedFailure", result)
self.assertIn(COMMENT, result)
def test_apply_skip_with_reason(self):
"""Test applying @unittest.skip with reason."""
code = """import unittest
class TestFoo(unittest.TestCase):
def test_one(self):
pass
"""
patches = {
"TestFoo": {"test_one": [PatchSpec(UtMethod.Skip, None, "not ready")]}
}
result = apply_patches(code, patches)
self.assertIn("@unittest.skip", result)
self.assertIn("not ready", result)
def test_apply_skip_if(self):
"""Test applying @unittest.skipIf."""
code = """import unittest
class TestFoo(unittest.TestCase):
def test_one(self):
pass
"""
patches = {
"TestFoo": {
"test_one": [
PatchSpec(UtMethod.SkipIf, "sys.platform == 'win32'", "windows")
]
}
}
result = apply_patches(code, patches)
self.assertIn("@unittest.skipIf", result)
self.assertIn('sys.platform == "win32"', result)
def test_apply_preserves_existing_decorators(self):
"""Test that existing decorators are preserved."""
code = """import unittest
class TestFoo(unittest.TestCase):
@some_decorator
def test_one(self):
pass
"""
patches = {
"TestFoo": {"test_one": [PatchSpec(UtMethod.ExpectedFailure, None, "")]}
}
result = apply_patches(code, patches)
self.assertIn("@some_decorator", result)
self.assertIn("@unittest.expectedFailure", result)
def test_apply_inherited_method(self):
"""Test applying patch to inherited method (creates override)."""
code = """import unittest
class TestFoo(unittest.TestCase):
pass
"""
patches = {
"TestFoo": {
"test_inherited": [PatchSpec(UtMethod.ExpectedFailure, None, "")]
}
}
result = apply_patches(code, patches)
self.assertIn("def test_inherited(self):", result)
self.assertIn("return super().test_inherited()", result)
def test_apply_adds_unittest_import(self):
"""Test that unittest import is added if missing."""
code = """import sys
class TestFoo:
def test_one(self):
pass
"""
patches = {
"TestFoo": {"test_one": [PatchSpec(UtMethod.ExpectedFailure, None, "")]}
}
result = apply_patches(code, patches)
# Should add unittest import after existing imports
self.assertIn("import unittest", result)
def test_apply_no_duplicate_import(self):
"""Test that unittest import is not duplicated."""
code = """import unittest
class TestFoo(unittest.TestCase):
def test_one(self):
pass
"""
patches = {
"TestFoo": {"test_one": [PatchSpec(UtMethod.ExpectedFailure, None, "")]}
}
result = apply_patches(code, patches)
# Count occurrences of 'import unittest'
count = result.count("import unittest")
self.assertEqual(count, 1)
class TestPatchSpec(unittest.TestCase):
"""Tests for PatchSpec class."""
def test_as_decorator_expected_failure(self):
"""Test generating expectedFailure decorator string."""
spec = PatchSpec(UtMethod.ExpectedFailure, None, "reason")
decorator = spec.as_decorator()
self.assertIn("@unittest.expectedFailure", decorator)
self.assertIn(COMMENT, decorator)
self.assertIn("reason", decorator)
def test_as_decorator_skip(self):
"""Test generating skip decorator string."""
spec = PatchSpec(UtMethod.Skip, None, "not ready")
decorator = spec.as_decorator()
self.assertIn("@unittest.skip", decorator)
self.assertIn("not ready", decorator)
def test_as_decorator_skip_if(self):
"""Test generating skipIf decorator string."""
spec = PatchSpec(UtMethod.SkipIf, "condition", "reason")
decorator = spec.as_decorator()
self.assertIn("@unittest.skipIf", decorator)
self.assertIn("condition", decorator)
class TestRoundTrip(unittest.TestCase):
"""Tests for extract -> apply round trip."""
def test_round_trip_expected_failure(self):
"""Test that extracted patches can be re-applied."""
original = f"""import unittest
class TestFoo(unittest.TestCase):
# {COMMENT}
@unittest.expectedFailure
def test_one(self):
pass
"""
# Extract patches
patches = extract_patches(original)
# Apply to clean code
clean = """import unittest
class TestFoo(unittest.TestCase):
def test_one(self):
pass
"""
result = apply_patches(clean, patches)
# Should have the decorator
self.assertIn("@unittest.expectedFailure", result)
self.assertIn(COMMENT, result)
class TestFindImportInsertLine(unittest.TestCase):
"""Tests for _find_import_insert_line function."""
def test_with_imports(self):
"""Test finding line after imports."""
code = """import os
import sys
class Foo:
pass
"""
tree = ast.parse(code)
line = _find_import_insert_line(tree)
self.assertEqual(line, 2)
def test_no_imports_with_docstring(self):
"""Test fallback to after docstring when no imports."""
code = '''"""Module docstring."""
class Foo:
pass
'''
tree = ast.parse(code)
line = _find_import_insert_line(tree)
self.assertEqual(line, 1)
def test_no_imports_no_docstring(self):
"""Test fallback to line 0 when no imports and no docstring."""
code = """class Foo:
pass
"""
tree = ast.parse(code)
line = _find_import_insert_line(tree)
self.assertEqual(line, 0)
if __name__ == "__main__":
unittest.main()