diff --git a/python/taichi/lang/_ndrange.py b/python/taichi/lang/_ndrange.py index a729b217ee7fb..eb191761bfc3c 100644 --- a/python/taichi/lang/_ndrange.py +++ b/python/taichi/lang/_ndrange.py @@ -51,7 +51,7 @@ def gen(d, prefix): yield from gen(0, ()) def grouped(self): - return GroupedNDRange(self) + return Grouped(self) def ndrange(*args) -> Iterable: @@ -137,7 +137,7 @@ def ndrange(*args) -> Iterable: return _Ndrange(*args) -class GroupedNDRange: +class Grouped: def __init__(self, r): self.r = r diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index ba7890d9b6b79..28298236e712d 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -13,7 +13,7 @@ from taichi._lib import core as _ti_core from taichi.lang import _ndarray, any_array, expr, impl, kernel_arguments, matrix, mesh from taichi.lang import ops as ti_ops -from taichi.lang._ndrange import _Ndrange, ndrange +from taichi.lang._ndrange import _Ndrange, Grouped, ndrange from taichi.lang.argpack import ArgPackType from taichi.lang.ast.ast_transformer_utils import Builder, LoopStatus, ReturnStatus from taichi.lang.ast.symbol_resolver import ASTResolver @@ -1312,7 +1312,7 @@ def build_range_for(ctx, node): @staticmethod def build_ndrange_for(ctx, node): with ctx.variable_scope_guard(): - ndrange_var = impl.expr_init(build_stmt(ctx, node.iter)) + ndrange_var = impl.expr_init(node.iter.ptr) ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32) ndrange_end = ti_ops.cast( expr.Expr(impl.subscript(ctx.ast_builder, ndrange_var.acc_dimensions, 0)), @@ -1355,7 +1355,7 @@ def build_ndrange_for(ctx, node): @staticmethod def build_grouped_ndrange_for(ctx, node): with ctx.variable_scope_guard(): - ndrange_var = impl.expr_init(build_stmt(ctx, node.iter.args[0])) + ndrange_var = impl.expr_init(node.iter.ptr.r) ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32) ndrange_end = ti_ops.cast( expr.Expr(impl.subscript(ctx.ast_builder, ndrange_var.acc_dimensions, 0)), @@ -1400,7 +1400,7 @@ def build_struct_for(ctx, node, is_grouped): if len(targets) != 1: raise TaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}") target = targets[0] - loop_var = build_stmt(ctx, node.iter) + loop_var = node.iter.ptr.r loop_indices = expr.make_var_list(size=len(loop_var.shape), ast_builder=ctx.ast_builder) expr_group = expr.make_expr_group(loop_indices) impl.begin_frontend_struct_for(ctx.ast_builder, expr_group, loop_var) @@ -1479,6 +1479,7 @@ def build_nested_mesh_for(ctx, node): def build_For(ctx, node): if node.orelse: raise TaichiSyntaxError("'else' clause for 'for' not supported in Taichi kernels") + decorator = ASTTransformer.get_decorator(ctx, node.iter) double_decorator = "" if decorator != "" and len(node.iter.args) == 1: @@ -1489,37 +1490,26 @@ def build_For(ctx, node): raise TaichiSyntaxError("'ti.static' cannot be nested") with ctx.loop_scope_guard(is_static=True): return ASTTransformer.build_static_for(ctx, node, double_decorator == "grouped") + elif isinstance(node.iter, ast.Call) and isinstance(node.iter.func, ast.Name) and node.iter.func.id == "range": + with ctx.loop_scope_guard(is_static=True): + return ASTTransformer.build_range_for(ctx, node) + with ctx.loop_scope_guard(): - if decorator == "ndrange": - if double_decorator != "": - raise TaichiSyntaxError("No decorator is allowed inside 'ti.ndrange") + iterable = build_stmt(ctx, node.iter) + if isinstance(iterable, _Ndrange): return ASTTransformer.build_ndrange_for(ctx, node) - if decorator == "grouped": - if double_decorator == "static": - raise TaichiSyntaxError("'ti.static' is not allowed inside 'ti.grouped'") - elif double_decorator == "ndrange": + elif isinstance(iterable, Grouped): + if isinstance(iterable.r, _Ndrange): return ASTTransformer.build_grouped_ndrange_for(ctx, node) - elif double_decorator == "grouped": - raise TaichiSyntaxError("'ti.grouped' cannot be nested") else: return ASTTransformer.build_struct_for(ctx, node, is_grouped=True) - elif ( - isinstance(node.iter, ast.Call) - and isinstance(node.iter.func, ast.Name) - and node.iter.func.id == "range" - ): - return ASTTransformer.build_range_for(ctx, node) + elif isinstance(iterable, mesh.MeshElementField): + if not _ti_core.is_extension_supported(impl.default_cfg().arch, _ti_core.Extension.mesh): + raise Exception("Backend " + str(impl.default_cfg().arch) + " doesn't support MeshTaichi extension") + return ASTTransformer.build_mesh_for(ctx, node) + elif isinstance(iterable, mesh.MeshRelationAccessProxy): + return ASTTransformer.build_nested_mesh_for(ctx, node) else: - build_stmt(ctx, node.iter) - if isinstance(node.iter.ptr, mesh.MeshElementField): - if not _ti_core.is_extension_supported(impl.default_cfg().arch, _ti_core.Extension.mesh): - raise Exception( - "Backend " + str(impl.default_cfg().arch) + " doesn't support MeshTaichi extension" - ) - return ASTTransformer.build_mesh_for(ctx, node) - if isinstance(node.iter.ptr, mesh.MeshRelationAccessProxy): - return ASTTransformer.build_nested_mesh_for(ctx, node) - # Struct for return ASTTransformer.build_struct_for(ctx, node, is_grouped=False) @staticmethod diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 7648cb0749a3a..626ba956c9183 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -6,7 +6,7 @@ from taichi._lib import core as _ti_core from taichi._snode.fields_builder import FieldsBuilder from taichi.lang._ndarray import ScalarNdarray -from taichi.lang._ndrange import GroupedNDRange, _Ndrange +from taichi.lang._ndrange import Grouped, _Ndrange from taichi.lang._texture import RWTextureAccessor from taichi.lang.any_array import AnyArray from taichi.lang.enums import SNodeGradType @@ -1109,7 +1109,7 @@ def static(x, *xs) -> Any: list, tuple, enumerate, - GroupedNDRange, + Grouped, _Ndrange, zip, filter, @@ -1152,9 +1152,7 @@ def grouped(x): >>> print(I) prints [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2] """ - if isinstance(x, _Ndrange): - return x.grouped() - return x + return Grouped(x) def stop_grad(x):