diff --git a/pyproject.toml b/pyproject.toml index 07eb8d0..b778b8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,3 +61,5 @@ artifacts = [ [tool.hatch.build.targets.wheel.sources] "src/**" = "infragraph" +[project.scripts] +infragraph = "infragraph.__main__:app" diff --git a/requirements.txt b/requirements.txt index 0d4e5e5..b59900e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ mkdocs-material mkdocs-include-markdown-plugin mkdocs-print-site-plugin jupytext -nbconvert \ No newline at end of file +nbconvert +typer \ No newline at end of file diff --git a/src/infragraph/__main__.py b/src/infragraph/__main__.py index 7bda82f..144e23c 100644 --- a/src/infragraph/__main__.py +++ b/src/infragraph/__main__.py @@ -1,53 +1,56 @@ -import click +import typer from infragraph.translators.translator_handler import run_translator -from .visualizer.visualize import run_visualizer - - -@click.group() -def cli(): - pass - -@cli.command() -@click.argument("tool", type=click.Choice(["lstopo", "lspci"])) -@click.option("-i", "--input", "input_file", help="Input file path") -@click.option("-o", "--output", "output_file", default="dev.yaml", help="Output file path") -@click.option("--dump", type=click.Choice(["json", "yaml"]), default="yaml") - -def translator(tool, input_file, output_file, dump): - """Run selected translator""" +from infragraph.visualizer.visualize import run_visualizer + +app = typer.Typer() + +@app.command() +def translate( + tool = typer.Argument(..., help="Translator to use"), + input_file = typer.Option(None, "--input", "-i", help="Input file Path"), + output_file = typer.Option("dev.yaml","--output", "-o", help="Output file path"), + dump = typer.Option("yaml", "--dump", help="Dump format (json or yaml)") +): + run_translator(tool, input_file, output_file, dump) - -@cli.command() -@click.option( - "--input", "-i", - "input_path", - required=True, - type=click.Path(exists=True, dir_okay=False, readable=True, path_type=str), - help="Path to the InfraGraph infrastructure file.", -) -@click.option( - "--hosts", - default="", - help="Comma-separated instance names that are hosts (e.g., 'dgx1,dgx2'). Only used in visualizer mode.", -) -@click.option( - "--switches", - default="", - help="Comma-separated switch names (e.g., 'sw1,sw2'). Only used in visualizer mode.", -) -@click.option( - "--output", "-o", - "output_dir", - required=True, - type=click.Path(file_okay=False, writable=True, path_type=str), - help="Output directory path where results will be generated.", -) - -def visualize(input_path: str, hosts: str, switches: str, output_dir: str): - """Visualize the graph""" - run_visualizer(input_file=input_path, hosts=hosts, switches=switches, output=output_dir) - - - + +@app.command() +def visualize( + input_path: str = typer.Option( + ..., + "--input", "-i", + help="Path to the InfraGraph infrastructure yaml/json file.", + exists=True, + file_okay=True, + dir_okay=False, + readable=True, + ), + hosts: str = typer.Option( + "", + "--hosts", + help="Comma-separated instance names that are hosts (used in visualizer).", + ), + switches: str = typer.Option( + "", + "--switches", + help="Comma-separated switch names (used in visualizer).", + ), + output_dir: str = typer.Option( + ..., + "--output", "-o", + help="Output directory path where results will be generated.", + file_okay=False, + writable=True, + ), +): + """Visualize the graph""" + run_visualizer( + input_file=input_path, + hosts=hosts, + switches=switches, + output=output_dir, + ) + + if __name__ == "__main__": - cli() \ No newline at end of file + app() \ No newline at end of file diff --git a/src/infragraph/translators/lstopo_translator.py b/src/infragraph/translators/lstopo_translator.py index c54e0a5..e3b71e6 100644 --- a/src/infragraph/translators/lstopo_translator.py +++ b/src/infragraph/translators/lstopo_translator.py @@ -592,9 +592,11 @@ def run_lstopo_parser( with open(output_file, "w", encoding="utf-8") as f: f.write(serialized_data) + print("translated output file", output_file) # delete temp file if created if tmp_xml and tmp_xml.exists(): tmp_xml.unlink() print("removed /tmp/lstopo_output.xml") - return serialized_data \ No newline at end of file + return serialized_data +