Skip to content

Commit de5a9fd

Browse files
pr feedback; refactor test
1 parent 6e524bb commit de5a9fd

File tree

1 file changed

+66
-49
lines changed

1 file changed

+66
-49
lines changed

tests/utils/test_metaprogramming.py

Lines changed: 66 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,18 @@ class DataClass:
8383
x: int
8484

8585

86+
class ReferencedClass:
87+
def __init__(self, value: int):
88+
self.value = value
89+
90+
def get_value(self) -> int:
91+
return self.value
92+
93+
8694
class MyClass:
95+
def __init__(self, x: int):
96+
self.helper = ReferencedClass(x * 2)
97+
8798
@staticmethod
8899
def foo():
89100
return KLASS_X
@@ -95,6 +106,13 @@ def bar(cls):
95106
def baz(self):
96107
return KLASS_Z
97108

109+
def use_referenced(self, value: int) -> int:
110+
ref = ReferencedClass(value)
111+
return ref.get_value()
112+
113+
def compute_with_reference(self) -> int:
114+
return self.helper.get_value() + 10
115+
98116

99117
def other_func(a: int) -> int:
100118
import sqlglot
@@ -103,7 +121,8 @@ def other_func(a: int) -> int:
103121
pd.DataFrame([{"x": 1}])
104122
to_table("y")
105123
my_lambda() # type: ignore
106-
return X + a + W
124+
obj = MyClass(a)
125+
return X + a + W + obj.compute_with_reference()
107126

108127

109128
@contextmanager
@@ -131,7 +150,7 @@ def function_with_custom_decorator():
131150
def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2) -> int:
132151
"""DOC STRING"""
133152
sqlglot.parse_one("1")
134-
MyClass()
153+
MyClass(47)
135154
DataClass(x=y)
136155
normalize_model_name("test" + SQLGLOT_META)
137156
fetch_data()
@@ -177,6 +196,7 @@ def test_func_globals() -> None:
177196
assert func_globals(other_func) == {
178197
"X": 1,
179198
"W": 0,
199+
"MyClass": MyClass,
180200
"my_lambda": my_lambda,
181201
"pd": pd,
182202
"to_table": to_table,
@@ -202,7 +222,7 @@ def test_normalize_source() -> None:
202222
== """def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2
203223
):
204224
sqlglot.parse_one('1')
205-
MyClass()
225+
MyClass(47)
206226
DataClass(x=y)
207227
normalize_model_name('test' + SQLGLOT_META)
208228
fetch_data()
@@ -223,7 +243,8 @@ def closure(z: int):
223243
pd.DataFrame([{'x': 1}])
224244
to_table('y')
225245
my_lambda()
226-
return X + a + W"""
246+
obj = MyClass(a)
247+
return X + a + W + obj.compute_with_reference()"""
227248
)
228249

229250

@@ -252,7 +273,7 @@ def test_serialize_env() -> None:
252273
payload="""def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2
253274
):
254275
sqlglot.parse_one('1')
255-
MyClass()
276+
MyClass(47)
256277
DataClass(x=y)
257278
normalize_model_name('test' + SQLGLOT_META)
258279
fetch_data()
@@ -295,6 +316,9 @@ class DataClass:
295316
path="test_metaprogramming.py",
296317
payload="""class MyClass:
297318
319+
def __init__(self, x: int):
320+
self.helper = ReferencedClass(x * 2)
321+
298322
@staticmethod
299323
def foo():
300324
return KLASS_X
@@ -304,7 +328,26 @@ def bar(cls):
304328
return KLASS_Y
305329
306330
def baz(self):
307-
return KLASS_Z""",
331+
return KLASS_Z
332+
333+
def use_referenced(self, value: int):
334+
ref = ReferencedClass(value)
335+
return ref.get_value()
336+
337+
def compute_with_reference(self):
338+
return self.helper.get_value() + 10""",
339+
),
340+
"ReferencedClass": Executable(
341+
kind=ExecutableKind.DEFINITION,
342+
name="ReferencedClass",
343+
path="test_metaprogramming.py",
344+
payload="""class ReferencedClass:
345+
346+
def __init__(self, value: int):
347+
self.value = value
348+
349+
def get_value(self):
350+
return self.value""",
308351
),
309352
"dataclass": Executable(
310353
payload="from dataclasses import dataclass", kind=ExecutableKind.IMPORT
@@ -341,7 +384,8 @@ def sample_context_manager():
341384
pd.DataFrame([{'x': 1}])
342385
to_table('y')
343386
my_lambda()
344-
return X + a + W""",
387+
obj = MyClass(a)
388+
return X + a + W + obj.compute_with_reference()""",
345389
),
346390
"sample_context_manager": Executable(
347391
payload="""@contextmanager
@@ -424,6 +468,21 @@ def function_with_custom_decorator():
424468
assert all(is_metadata for (_, is_metadata) in env.values())
425469
assert serialized_env == expected_env
426470

471+
# Check that class references inside init are captured
472+
init_globals = func_globals(MyClass.__init__)
473+
assert "ReferencedClass" in init_globals
474+
475+
env = {}
476+
build_env(other_func, env=env, name="other_func_test", path=path)
477+
serialized_env = serialize_env(env, path=path)
478+
479+
assert "MyClass" in serialized_env
480+
assert "ReferencedClass" in serialized_env
481+
482+
prepared_env = prepare_env(serialized_env)
483+
result = eval("other_func_test(2)", prepared_env)
484+
assert result == 17
485+
427486

428487
def test_serialize_env_with_enum_import_appearing_in_two_functions() -> None:
429488
path = Path("tests/utils")
@@ -460,48 +519,6 @@ def test_serialize_env_with_enum_import_appearing_in_two_functions() -> None:
460519
assert serialized_env == expected_env
461520

462521

463-
class ReferencedClass:
464-
def __init__(self, value: int):
465-
self.value = value
466-
467-
def get_value(self) -> int:
468-
return self.value
469-
470-
471-
class ClassThatReferencesAnother:
472-
def __init__(self, x: int):
473-
self.helper = ReferencedClass(x * 2)
474-
475-
def compute(self) -> int:
476-
return self.helper.get_value() + 10
477-
478-
479-
def function_using_class_with_reference(y: int) -> int:
480-
obj = ClassThatReferencesAnother(y)
481-
return obj.compute()
482-
483-
484-
def test_serialize_env_with_class_referencing_another_class() -> None:
485-
# firstly we can confirm that func_globals picks up the reference
486-
init_globals = func_globals(ClassThatReferencesAnother.__init__)
487-
assert "ReferencedClass" in init_globals
488-
489-
path = Path("tests/utils")
490-
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}
491-
492-
# build ajd serialize environment for the function that uses the class
493-
build_env(function_using_class_with_reference, env=env, name="test_func", path=path)
494-
serialized_env = serialize_env(env, path=path)
495-
496-
# both classes should be in the serialized environment
497-
assert "ClassThatReferencesAnother" in serialized_env
498-
assert "ReferencedClass" in serialized_env
499-
500-
prepared_env = prepare_env(serialized_env)
501-
result = eval("test_func(33)", prepared_env)
502-
assert result == 76
503-
504-
505522
def test_dict_sort_basic_types():
506523
"""Test dict_sort with basic Python types."""
507524
# Test basic types that should use standard repr

0 commit comments

Comments
 (0)