From 5b527eaf072448782dc57417bfda27e2f0e91477 Mon Sep 17 00:00:00 2001 From: whitingyan <1712428442@qq.com> Date: Mon, 10 Mar 2025 02:04:37 +0800 Subject: [PATCH 1/2] init --- python/taichi/lang/_ndrange.py | 4 +- python/taichi/lang/ast/ast_transformer.py | 55 +++++++++++------------ python/taichi/lang/impl.py | 8 ++-- 3 files changed, 31 insertions(+), 36 deletions(-) diff --git a/python/taichi/lang/_ndrange.py b/python/taichi/lang/_ndrange.py index a729b217ee7fb..eb191761bfc3c 100644 --- a/python/taichi/lang/_ndrange.py +++ b/python/taichi/lang/_ndrange.py @@ -51,7 +51,7 @@ def gen(d, prefix): yield from gen(0, ()) def grouped(self): - return GroupedNDRange(self) + return Grouped(self) def ndrange(*args) -> Iterable: @@ -137,7 +137,7 @@ def ndrange(*args) -> Iterable: return _Ndrange(*args) -class GroupedNDRange: +class Grouped: def __init__(self, r): self.r = r diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index ba7890d9b6b79..22755dec9376d 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -13,7 +13,7 @@ from taichi._lib import core as _ti_core from taichi.lang import _ndarray, any_array, expr, impl, kernel_arguments, matrix, mesh from taichi.lang import ops as ti_ops -from taichi.lang._ndrange import _Ndrange, ndrange +from taichi.lang._ndrange import _Ndrange, Grouped, ndrange from taichi.lang.argpack import ArgPackType from taichi.lang.ast.ast_transformer_utils import Builder, LoopStatus, ReturnStatus from taichi.lang.ast.symbol_resolver import ASTResolver @@ -1312,7 +1312,7 @@ def build_range_for(ctx, node): @staticmethod def build_ndrange_for(ctx, node): with ctx.variable_scope_guard(): - ndrange_var = impl.expr_init(build_stmt(ctx, node.iter)) + ndrange_var = impl.expr_init(node.iter.ptr) ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32) ndrange_end = ti_ops.cast( expr.Expr(impl.subscript(ctx.ast_builder, ndrange_var.acc_dimensions, 0)), @@ -1355,7 +1355,7 @@ def build_ndrange_for(ctx, node): @staticmethod def build_grouped_ndrange_for(ctx, node): with ctx.variable_scope_guard(): - ndrange_var = impl.expr_init(build_stmt(ctx, node.iter.args[0])) + ndrange_var = impl.expr_init(node.iter.ptr.r) ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32) ndrange_end = ti_ops.cast( expr.Expr(impl.subscript(ctx.ast_builder, ndrange_var.acc_dimensions, 0)), @@ -1400,7 +1400,7 @@ def build_struct_for(ctx, node, is_grouped): if len(targets) != 1: raise TaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}") target = targets[0] - loop_var = build_stmt(ctx, node.iter) + loop_var = node.iter.ptr.r loop_indices = expr.make_var_list(size=len(loop_var.shape), ast_builder=ctx.ast_builder) expr_group = expr.make_expr_group(loop_indices) impl.begin_frontend_struct_for(ctx.ast_builder, expr_group, loop_var) @@ -1479,6 +1479,7 @@ def build_nested_mesh_for(ctx, node): def build_For(ctx, node): if node.orelse: raise TaichiSyntaxError("'else' clause for 'for' not supported in Taichi kernels") + decorator = ASTTransformer.get_decorator(ctx, node.iter) double_decorator = "" if decorator != "" and len(node.iter.args) == 1: @@ -1489,37 +1490,33 @@ def build_For(ctx, node): raise TaichiSyntaxError("'ti.static' cannot be nested") with ctx.loop_scope_guard(is_static=True): return ASTTransformer.build_static_for(ctx, node, double_decorator == "grouped") + elif ( + isinstance(node.iter, ast.Call) + and isinstance(node.iter.func, ast.Name) + and node.iter.func.id == "range" + ): + with ctx.loop_scope_guard(is_static=True): + return ASTTransformer.build_range_for(ctx, node) + with ctx.loop_scope_guard(): - if decorator == "ndrange": - if double_decorator != "": - raise TaichiSyntaxError("No decorator is allowed inside 'ti.ndrange") + iterable = build_stmt(ctx, node.iter) + if isinstance(iterable, _Ndrange): return ASTTransformer.build_ndrange_for(ctx, node) - if decorator == "grouped": - if double_decorator == "static": - raise TaichiSyntaxError("'ti.static' is not allowed inside 'ti.grouped'") - elif double_decorator == "ndrange": + elif isinstance(iterable, Grouped): + if isinstance(iterable.r, _Ndrange): return ASTTransformer.build_grouped_ndrange_for(ctx, node) - elif double_decorator == "grouped": - raise TaichiSyntaxError("'ti.grouped' cannot be nested") else: return ASTTransformer.build_struct_for(ctx, node, is_grouped=True) - elif ( - isinstance(node.iter, ast.Call) - and isinstance(node.iter.func, ast.Name) - and node.iter.func.id == "range" - ): - return ASTTransformer.build_range_for(ctx, node) + elif isinstance(iterable, mesh.MeshElementField): + if not _ti_core.is_extension_supported(impl.default_cfg().arch, _ti_core.Extension.mesh): + raise Exception( + "Backend " + str(impl.default_cfg().arch) + + " doesn't support MeshTaichi extension" + ) + return ASTTransformer.build_mesh_for(ctx, node) + elif isinstance(iterable, mesh.MeshRelationAccessProxy): + return ASTTransformer.build_nested_mesh_for(ctx, node) else: - build_stmt(ctx, node.iter) - if isinstance(node.iter.ptr, mesh.MeshElementField): - if not _ti_core.is_extension_supported(impl.default_cfg().arch, _ti_core.Extension.mesh): - raise Exception( - "Backend " + str(impl.default_cfg().arch) + " doesn't support MeshTaichi extension" - ) - return ASTTransformer.build_mesh_for(ctx, node) - if isinstance(node.iter.ptr, mesh.MeshRelationAccessProxy): - return ASTTransformer.build_nested_mesh_for(ctx, node) - # Struct for return ASTTransformer.build_struct_for(ctx, node, is_grouped=False) @staticmethod diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 7648cb0749a3a..626ba956c9183 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -6,7 +6,7 @@ from taichi._lib import core as _ti_core from taichi._snode.fields_builder import FieldsBuilder from taichi.lang._ndarray import ScalarNdarray -from taichi.lang._ndrange import GroupedNDRange, _Ndrange +from taichi.lang._ndrange import Grouped, _Ndrange from taichi.lang._texture import RWTextureAccessor from taichi.lang.any_array import AnyArray from taichi.lang.enums import SNodeGradType @@ -1109,7 +1109,7 @@ def static(x, *xs) -> Any: list, tuple, enumerate, - GroupedNDRange, + Grouped, _Ndrange, zip, filter, @@ -1152,9 +1152,7 @@ def grouped(x): >>> print(I) prints [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2] """ - if isinstance(x, _Ndrange): - return x.grouped() - return x + return Grouped(x) def stop_grad(x): From 351c23fd29aec7a30b7a2ad1e83458feb365c8f4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 9 Mar 2025 18:28:16 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/ast/ast_transformer.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 22755dec9376d..28298236e712d 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -1490,11 +1490,7 @@ def build_For(ctx, node): raise TaichiSyntaxError("'ti.static' cannot be nested") with ctx.loop_scope_guard(is_static=True): return ASTTransformer.build_static_for(ctx, node, double_decorator == "grouped") - elif ( - isinstance(node.iter, ast.Call) - and isinstance(node.iter.func, ast.Name) - and node.iter.func.id == "range" - ): + elif isinstance(node.iter, ast.Call) and isinstance(node.iter.func, ast.Name) and node.iter.func.id == "range": with ctx.loop_scope_guard(is_static=True): return ASTTransformer.build_range_for(ctx, node) @@ -1509,10 +1505,7 @@ def build_For(ctx, node): return ASTTransformer.build_struct_for(ctx, node, is_grouped=True) elif isinstance(iterable, mesh.MeshElementField): if not _ti_core.is_extension_supported(impl.default_cfg().arch, _ti_core.Extension.mesh): - raise Exception( - "Backend " + str(impl.default_cfg().arch) + - " doesn't support MeshTaichi extension" - ) + raise Exception("Backend " + str(impl.default_cfg().arch) + " doesn't support MeshTaichi extension") return ASTTransformer.build_mesh_for(ctx, node) elif isinstance(iterable, mesh.MeshRelationAccessProxy): return ASTTransformer.build_nested_mesh_for(ctx, node)