diff --git a/core.py b/core.py index c9c3d3f..cb47362 100644 --- a/core.py +++ b/core.py @@ -1,6 +1,7 @@ import tensorflow.compat.v1 as tf import collections import time +import itertools from typing import Dict, List, Set, Optional, Any as AnyType, Tuple, Union from .utils.logger import ( logger as logging, @@ -582,7 +583,9 @@ def match_once( if node.name in replaced_node_names: continue - candidates = self.pattern_index.get(node.op, []) + self.wildcard_patterns + candidates = itertools.chain( + self.pattern_index.get(node.op, []), self.wildcard_patterns + ) found_match = False for pattern, rewriter in candidates: