diff --git a/poetry.lock b/poetry.lock index 43b834b3..b69674ff 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,16 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.3.3 and should not be changed by hand. + +[[package]] +name = "annotated-types" +version = "0.7.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, +] [[package]] name = "black" @@ -136,19 +148,6 @@ files = [ {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, ] -[[package]] -name = "legacy-cgi" -version = "2.6.4" -description = "Fork of the standard library cgi and cgitb modules removed in Python 3.13" -optional = false -python-versions = ">=3.8" -groups = ["main"] -markers = "python_version >= \"3.13\"" -files = [ - {file = "legacy_cgi-2.6.4-py3-none-any.whl", hash = "sha256:7e235ce58bf1e25d1fc9b2d299015e4e2cd37305eccafec1e6bac3fc04b878cd"}, - {file = "legacy_cgi-2.6.4.tar.gz", hash = "sha256:abb9dfc7835772f7c9317977c63253fd22a7484b5c9bbcdca60a29dcce97c577"}, -] - [[package]] name = "mako" version = "1.3.10" @@ -373,16 +372,161 @@ dev = ["pre-commit", "tox"] testing = ["coverage", "pytest", "pytest-benchmark"] [[package]] -name = "pydal" -version = "20200714.1" -description = "a pure Python Database Abstraction Layer (for python version 2.7 and 3.x)" +name = "pydantic" +version = "2.12.5" +description = "Data validation using Python type hints" optional = false -python-versions = "*" +python-versions = ">=3.9" groups = ["main"] files = [ - {file = "pydal-20200714.1.tar.gz", hash = "sha256:dd35b8ecb009099cce7efa72a40707d2e9bdcdf85924f30683a52d5172d1242f"}, + {file = "pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d"}, + {file = "pydantic-2.12.5.tar.gz", hash = "sha256:4d351024c75c0f085a9febbb665ce8c0c6ec5d30e903bdb6394b7ede26aebb49"}, ] +[package.dependencies] +annotated-types = ">=0.6.0" +pydantic-core = "2.41.5" +typing-extensions = ">=4.14.1" +typing-inspection = ">=0.4.2" + +[package.extras] +email = ["email-validator (>=2.0.0)"] +timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""] + +[[package]] +name = "pydantic-core" +version = "2.41.5" +description = "Core functionality for Pydantic validation and serialization" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "pydantic_core-2.41.5-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:77b63866ca88d804225eaa4af3e664c5faf3568cea95360d21f4725ab6e07146"}, + {file = "pydantic_core-2.41.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dfa8a0c812ac681395907e71e1274819dec685fec28273a28905df579ef137e2"}, + {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5921a4d3ca3aee735d9fd163808f5e8dd6c6972101e4adbda9a4667908849b97"}, + {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e25c479382d26a2a41b7ebea1043564a937db462816ea07afa8a44c0866d52f9"}, + {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f547144f2966e1e16ae626d8ce72b4cfa0caedc7fa28052001c94fb2fcaa1c52"}, + {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f52298fbd394f9ed112d56f3d11aabd0d5bd27beb3084cc3d8ad069483b8941"}, + {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:100baa204bb412b74fe285fb0f3a385256dad1d1879f0a5cb1499ed2e83d132a"}, + {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:05a2c8852530ad2812cb7914dc61a1125dc4e06252ee98e5638a12da6cc6fb6c"}, + {file = "pydantic_core-2.41.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:29452c56df2ed968d18d7e21f4ab0ac55e71dc59524872f6fc57dcf4a3249ed2"}, + {file = "pydantic_core-2.41.5-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:d5160812ea7a8a2ffbe233d8da666880cad0cbaf5d4de74ae15c313213d62556"}, + {file = "pydantic_core-2.41.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:df3959765b553b9440adfd3c795617c352154e497a4eaf3752555cfb5da8fc49"}, + {file = "pydantic_core-2.41.5-cp310-cp310-win32.whl", hash = "sha256:1f8d33a7f4d5a7889e60dc39856d76d09333d8a6ed0f5f1190635cbec70ec4ba"}, + {file = "pydantic_core-2.41.5-cp310-cp310-win_amd64.whl", hash = "sha256:62de39db01b8d593e45871af2af9e497295db8d73b085f6bfd0b18c83c70a8f9"}, + {file = "pydantic_core-2.41.5-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:a3a52f6156e73e7ccb0f8cced536adccb7042be67cb45f9562e12b319c119da6"}, + {file = "pydantic_core-2.41.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7f3bf998340c6d4b0c9a2f02d6a400e51f123b59565d74dc60d252ce888c260b"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:378bec5c66998815d224c9ca994f1e14c0c21cb95d2f52b6021cc0b2a58f2a5a"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e7b576130c69225432866fe2f4a469a85a54ade141d96fd396dffcf607b558f8"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6cb58b9c66f7e4179a2d5e0f849c48eff5c1fca560994d6eb6543abf955a149e"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:88942d3a3dff3afc8288c21e565e476fc278902ae4d6d134f1eeda118cc830b1"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f31d95a179f8d64d90f6831d71fa93290893a33148d890ba15de25642c5d075b"}, + {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c1df3d34aced70add6f867a8cf413e299177e0c22660cc767218373d0779487b"}, + {file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4009935984bd36bd2c774e13f9a09563ce8de4abaa7226f5108262fa3e637284"}, + {file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:34a64bc3441dc1213096a20fe27e8e128bd3ff89921706e83c0b1ac971276594"}, + {file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c9e19dd6e28fdcaa5a1de679aec4141f691023916427ef9bae8584f9c2fb3b0e"}, + {file = "pydantic_core-2.41.5-cp311-cp311-win32.whl", hash = "sha256:2c010c6ded393148374c0f6f0bf89d206bf3217f201faa0635dcd56bd1520f6b"}, + {file = "pydantic_core-2.41.5-cp311-cp311-win_amd64.whl", hash = "sha256:76ee27c6e9c7f16f47db7a94157112a2f3a00e958bc626e2f4ee8bec5c328fbe"}, + {file = "pydantic_core-2.41.5-cp311-cp311-win_arm64.whl", hash = "sha256:4bc36bbc0b7584de96561184ad7f012478987882ebf9f9c389b23f432ea3d90f"}, + {file = "pydantic_core-2.41.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f41a7489d32336dbf2199c8c0a215390a751c5b014c2c1c5366e817202e9cdf7"}, + {file = "pydantic_core-2.41.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:070259a8818988b9a84a449a2a7337c7f430a22acc0859c6b110aa7212a6d9c0"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e96cea19e34778f8d59fe40775a7a574d95816eb150850a85a7a4c8f4b94ac69"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed2e99c456e3fadd05c991f8f437ef902e00eedf34320ba2b0842bd1c3ca3a75"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65840751b72fbfd82c3c640cff9284545342a4f1eb1586ad0636955b261b0b05"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e536c98a7626a98feb2d3eaf75944ef6f3dbee447e1f841eae16f2f0a72d8ddc"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eceb81a8d74f9267ef4081e246ffd6d129da5d87e37a77c9bde550cb04870c1c"}, + {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d38548150c39b74aeeb0ce8ee1d8e82696f4a4e16ddc6de7b1d8823f7de4b9b5"}, + {file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c23e27686783f60290e36827f9c626e63154b82b116d7fe9adba1fda36da706c"}, + {file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:482c982f814460eabe1d3bb0adfdc583387bd4691ef00b90575ca0d2b6fe2294"}, + {file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bfea2a5f0b4d8d43adf9d7b8bf019fb46fdd10a2e5cde477fbcb9d1fa08c68e1"}, + {file = "pydantic_core-2.41.5-cp312-cp312-win32.whl", hash = "sha256:b74557b16e390ec12dca509bce9264c3bbd128f8a2c376eaa68003d7f327276d"}, + {file = "pydantic_core-2.41.5-cp312-cp312-win_amd64.whl", hash = "sha256:1962293292865bca8e54702b08a4f26da73adc83dd1fcf26fbc875b35d81c815"}, + {file = "pydantic_core-2.41.5-cp312-cp312-win_arm64.whl", hash = "sha256:1746d4a3d9a794cacae06a5eaaccb4b8643a131d45fbc9af23e353dc0a5ba5c3"}, + {file = "pydantic_core-2.41.5-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:941103c9be18ac8daf7b7adca8228f8ed6bb7a1849020f643b3a14d15b1924d9"}, + {file = "pydantic_core-2.41.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:112e305c3314f40c93998e567879e887a3160bb8689ef3d2c04b6cc62c33ac34"}, + {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cbaad15cb0c90aa221d43c00e77bb33c93e8d36e0bf74760cd00e732d10a6a0"}, + {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:03ca43e12fab6023fc79d28ca6b39b05f794ad08ec2feccc59a339b02f2b3d33"}, + {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc799088c08fa04e43144b164feb0c13f9a0bc40503f8df3e9fde58a3c0c101e"}, + {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97aeba56665b4c3235a0e52b2c2f5ae9cd071b8a8310ad27bddb3f7fb30e9aa2"}, + {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:406bf18d345822d6c21366031003612b9c77b3e29ffdb0f612367352aab7d586"}, + {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b93590ae81f7010dbe380cdeab6f515902ebcbefe0b9327cc4804d74e93ae69d"}, + {file = "pydantic_core-2.41.5-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:01a3d0ab748ee531f4ea6c3e48ad9dac84ddba4b0d82291f87248f2f9de8d740"}, + {file = "pydantic_core-2.41.5-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:6561e94ba9dacc9c61bce40e2d6bdc3bfaa0259d3ff36ace3b1e6901936d2e3e"}, + {file = "pydantic_core-2.41.5-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:915c3d10f81bec3a74fbd4faebe8391013ba61e5a1a8d48c4455b923bdda7858"}, + {file = "pydantic_core-2.41.5-cp313-cp313-win32.whl", hash = "sha256:650ae77860b45cfa6e2cdafc42618ceafab3a2d9a3811fcfbd3bbf8ac3c40d36"}, + {file = "pydantic_core-2.41.5-cp313-cp313-win_amd64.whl", hash = "sha256:79ec52ec461e99e13791ec6508c722742ad745571f234ea6255bed38c6480f11"}, + {file = "pydantic_core-2.41.5-cp313-cp313-win_arm64.whl", hash = "sha256:3f84d5c1b4ab906093bdc1ff10484838aca54ef08de4afa9de0f5f14d69639cd"}, + {file = "pydantic_core-2.41.5-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:3f37a19d7ebcdd20b96485056ba9e8b304e27d9904d233d7b1015db320e51f0a"}, + {file = "pydantic_core-2.41.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1d1d9764366c73f996edd17abb6d9d7649a7eb690006ab6adbda117717099b14"}, + {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25e1c2af0fce638d5f1988b686f3b3ea8cd7de5f244ca147c777769e798a9cd1"}, + {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:506d766a8727beef16b7adaeb8ee6217c64fc813646b424d0804d67c16eddb66"}, + {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4819fa52133c9aa3c387b3328f25c1facc356491e6135b459f1de698ff64d869"}, + {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2b761d210c9ea91feda40d25b4efe82a1707da2ef62901466a42492c028553a2"}, + {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22f0fb8c1c583a3b6f24df2470833b40207e907b90c928cc8d3594b76f874375"}, + {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2782c870e99878c634505236d81e5443092fba820f0373997ff75f90f68cd553"}, + {file = "pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:0177272f88ab8312479336e1d777f6b124537d47f2123f89cb37e0accea97f90"}, + {file = "pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_armv7l.whl", hash = "sha256:63510af5e38f8955b8ee5687740d6ebf7c2a0886d15a6d65c32814613681bc07"}, + {file = "pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:e56ba91f47764cc14f1daacd723e3e82d1a89d783f0f5afe9c364b8bb491ccdb"}, + {file = "pydantic_core-2.41.5-cp314-cp314-win32.whl", hash = "sha256:aec5cf2fd867b4ff45b9959f8b20ea3993fc93e63c7363fe6851424c8a7e7c23"}, + {file = "pydantic_core-2.41.5-cp314-cp314-win_amd64.whl", hash = "sha256:8e7c86f27c585ef37c35e56a96363ab8de4e549a95512445b85c96d3e2f7c1bf"}, + {file = "pydantic_core-2.41.5-cp314-cp314-win_arm64.whl", hash = "sha256:e672ba74fbc2dc8eea59fb6d4aed6845e6905fc2a8afe93175d94a83ba2a01a0"}, + {file = "pydantic_core-2.41.5-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:8566def80554c3faa0e65ac30ab0932b9e3a5cd7f8323764303d468e5c37595a"}, + {file = "pydantic_core-2.41.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:b80aa5095cd3109962a298ce14110ae16b8c1aece8b72f9dafe81cf597ad80b3"}, + {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3006c3dd9ba34b0c094c544c6006cc79e87d8612999f1a5d43b769b89181f23c"}, + {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:72f6c8b11857a856bcfa48c86f5368439f74453563f951e473514579d44aa612"}, + {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5cb1b2f9742240e4bb26b652a5aeb840aa4b417c7748b6f8387927bc6e45e40d"}, + {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd3d54f38609ff308209bd43acea66061494157703364ae40c951f83ba99a1a9"}, + {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ff4321e56e879ee8d2a879501c8e469414d948f4aba74a2d4593184eb326660"}, + {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d0d2568a8c11bf8225044aa94409e21da0cb09dcdafe9ecd10250b2baad531a9"}, + {file = "pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:a39455728aabd58ceabb03c90e12f71fd30fa69615760a075b9fec596456ccc3"}, + {file = "pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_armv7l.whl", hash = "sha256:239edca560d05757817c13dc17c50766136d21f7cd0fac50295499ae24f90fdf"}, + {file = "pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:2a5e06546e19f24c6a96a129142a75cee553cc018ffee48a460059b1185f4470"}, + {file = "pydantic_core-2.41.5-cp314-cp314t-win32.whl", hash = "sha256:b4ececa40ac28afa90871c2cc2b9ffd2ff0bf749380fbdf57d165fd23da353aa"}, + {file = "pydantic_core-2.41.5-cp314-cp314t-win_amd64.whl", hash = "sha256:80aa89cad80b32a912a65332f64a4450ed00966111b6615ca6816153d3585a8c"}, + {file = "pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008"}, + {file = "pydantic_core-2.41.5-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:8bfeaf8735be79f225f3fefab7f941c712aaca36f1128c9d7e2352ee1aa87bdf"}, + {file = "pydantic_core-2.41.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:346285d28e4c8017da95144c7f3acd42740d637ff41946af5ce6e5e420502dd5"}, + {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a75dafbf87d6276ddc5b2bf6fae5254e3d0876b626eb24969a574fff9149ee5d"}, + {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7b93a4d08587e2b7e7882de461e82b6ed76d9026ce91ca7915e740ecc7855f60"}, + {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e8465ab91a4bd96d36dde3263f06caa6a8a6019e4113f24dc753d79a8b3a3f82"}, + {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:299e0a22e7ae2b85c1a57f104538b2656e8ab1873511fd718a1c1c6f149b77b5"}, + {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:707625ef0983fcfb461acfaf14de2067c5942c6bb0f3b4c99158bed6fedd3cf3"}, + {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f41eb9797986d6ebac5e8edff36d5cef9de40def462311b3eb3eeded1431e425"}, + {file = "pydantic_core-2.41.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0384e2e1021894b1ff5a786dbf94771e2986ebe2869533874d7e43bc79c6f504"}, + {file = "pydantic_core-2.41.5-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:f0cd744688278965817fd0839c4a4116add48d23890d468bc436f78beb28abf5"}, + {file = "pydantic_core-2.41.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:753e230374206729bf0a807954bcc6c150d3743928a73faffee51ac6557a03c3"}, + {file = "pydantic_core-2.41.5-cp39-cp39-win32.whl", hash = "sha256:873e0d5b4fb9b89ef7c2d2a963ea7d02879d9da0da8d9d4933dee8ee86a8b460"}, + {file = "pydantic_core-2.41.5-cp39-cp39-win_amd64.whl", hash = "sha256:e4f4a984405e91527a0d62649ee21138f8e3d0ef103be488c1dc11a80d7f184b"}, + {file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-macosx_10_12_x86_64.whl", hash = "sha256:b96d5f26b05d03cc60f11a7761a5ded1741da411e7fe0909e27a5e6a0cb7b034"}, + {file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-macosx_11_0_arm64.whl", hash = "sha256:634e8609e89ceecea15e2d61bc9ac3718caaaa71963717bf3c8f38bfde64242c"}, + {file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:93e8740d7503eb008aa2df04d3b9735f845d43ae845e6dcd2be0b55a2da43cd2"}, + {file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f15489ba13d61f670dcc96772e733aad1a6f9c429cc27574c6cdaed82d0146ad"}, + {file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:7da7087d756b19037bc2c06edc6c170eeef3c3bafcb8f532ff17d64dc427adfd"}, + {file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:aabf5777b5c8ca26f7824cb4a120a740c9588ed58df9b2d196ce92fba42ff8dc"}, + {file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c007fe8a43d43b3969e8469004e9845944f1a80e6acd47c150856bb87f230c56"}, + {file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76d0819de158cd855d1cbb8fcafdf6f5cf1eb8e470abe056d5d161106e38062b"}, + {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b5819cd790dbf0c5eb9f82c73c16b39a65dd6dd4d1439dcdea7816ec9adddab8"}, + {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:5a4e67afbc95fa5c34cf27d9089bca7fcab4e51e57278d710320a70b956d1b9a"}, + {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ece5c59f0ce7d001e017643d8d24da587ea1f74f6993467d85ae8a5ef9d4f42b"}, + {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:16f80f7abe3351f8ea6858914ddc8c77e02578544a0ebc15b4c2e1a0e813b0b2"}, + {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:33cb885e759a705b426baada1fe68cbb0a2e68e34c5d0d0289a364cf01709093"}, + {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:c8d8b4eb992936023be7dee581270af5c6e0697a8559895f527f5b7105ecd36a"}, + {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:242a206cd0318f95cd21bdacff3fcc3aab23e79bba5cac3db5a841c9ef9c6963"}, + {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:d3a978c4f57a597908b7e697229d996d77a6d3c94901e9edee593adada95ce1a"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2379fa7ed44ddecb5bfe4e48577d752db9fc10be00a6b7446e9663ba143de26"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:266fb4cbf5e3cbd0b53669a6d1b039c45e3ce651fd5442eff4d07c2cc8d66808"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58133647260ea01e4d0500089a8c4f07bd7aa6ce109682b1426394988d8aaacc"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:287dad91cfb551c363dc62899a80e9e14da1f0e2b6ebde82c806612ca2a13ef1"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:03b77d184b9eb40240ae9fd676ca364ce1085f203e1b1256f8ab9984dca80a84"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:a668ce24de96165bb239160b3d854943128f4334822900534f2fe947930e5770"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f14f8f046c14563f8eb3f45f499cc658ab8d10072961e07225e507adb700e93f"}, + {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51"}, + {file = "pydantic_core-2.41.5.tar.gz", hash = "sha256:08daa51ea16ad373ffd5e7606252cc32f07bc72b28284b6bc9c6df804816476e"}, +] + +[package.dependencies] +typing-extensions = ">=4.14.1" + [[package]] name = "pygments" version = "2.19.2" @@ -496,12 +640,27 @@ version = "4.15.0" description = "Backported and Experimental Type Hints for Python 3.9+" optional = false python-versions = ">=3.9" -groups = ["dev"] -markers = "python_version < \"3.11\"" +groups = ["main", "dev"] files = [ {file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"}, {file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"}, ] +markers = {dev = "python_version < \"3.11\""} + +[[package]] +name = "typing-inspection" +version = "0.4.2" +description = "Runtime typing introspection tools" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7"}, + {file = "typing_inspection-0.4.2.tar.gz", hash = "sha256:ba561c48a67c5958007083d386c3295464928b01faa735ab8547c5692e87f464"}, +] + +[package.dependencies] +typing-extensions = ">=4.12.0" [[package]] name = "zipp" @@ -526,5 +685,5 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" -python-versions = "^3.9 || ^3.10 || ^3.11" -content-hash = "d13ccd9a0de456c987bd6c6f20034c2f2a71279f65ddd4b2a0d597ef5ca5fd86" +python-versions = "^3.9 || ^3.10 || ^3.11 || ^3.12 || ^3.13 || ^3.14" +content-hash = "f29cad7d87838e9587a1f1747bd5cae2aba51f0c0abac77a653f0ed1b93b0b8b" diff --git a/pyproject.toml b/pyproject.toml index 483f31c9..1fd3a275 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,9 +6,8 @@ authors = ["pytm Team"] license = "MIT License" [tool.poetry.dependencies] -python = "^3.9 || ^3.10 || ^3.11" -pydal = "~20200714.1" -legacy-cgi = { version = "^2.0", markers = "python_version >= '3.13'" } +python = "^3.9 || ^3.10 || ^3.11 || ^3.12 || ^3.13 || ^3.14" +pydantic = "^2.10.0" [tool.poetry.group.dev.dependencies] pytest = "^8.3.5" diff --git a/pytm/__init__.py b/pytm/__init__.py index eb15b7fa..c8279547 100644 --- a/pytm/__init__.py +++ b/pytm/__init__.py @@ -22,35 +22,46 @@ "SetOfProcesses", "Threat", "TM", + "Controls", + "var", ] import sys from .json import load, loads -from .pytm import ( - TM, - Action, - Actor, - Assumption, - Boundary, - Classification, - Data, - Dataflow, - Datastore, - DatastoreType, - Element, - ExternalEntity, - Finding, - Lambda, - LLM, - Lifetime, - Process, - Server, - SetOfProcesses, - Threat, - TLSVersion, - var, -) +from .pytm import var + +# Import from new Pydantic models +from .enums import Action, Classification, DatastoreType, Lifetime, TLSVersion +from .base import Assumption, Controls +from .element import Element +from .data import Data +from .threat import Threat +from .finding import Finding +from .asset import Asset, Lambda, LLM, Server, ExternalEntity +from .datastore import Datastore +from .actor import Actor +from .process import Process, SetOfProcesses +from .dataflow import Dataflow +from .boundary import Boundary +from .tm import TM + +# Rebuild models to resolve forward references +Element.model_rebuild() +Data.model_rebuild() +Finding.model_rebuild() +Asset.model_rebuild() +Lambda.model_rebuild() +LLM.model_rebuild() +Server.model_rebuild() +ExternalEntity.model_rebuild() +Datastore.model_rebuild() +Actor.model_rebuild() +Process.model_rebuild() +SetOfProcesses.model_rebuild() +Dataflow.model_rebuild() +Boundary.model_rebuild() +TM.model_rebuild() def pdoc_overrides(): @@ -62,9 +73,11 @@ def pdoc_overrides(): for i in dir(klass): if i in ("check", "dfd", "seq"): result[f"{name}.{i}"] = False - attr = getattr(klass, i, {}) - if isinstance(attr, var) and attr.doc != "": - result[f"{name}.{i}"] = attr.doc + model_fields = getattr(klass, "model_fields", {}) + if i in model_fields: + description = model_fields[i].description + if description: + result[f"{name}.{i}"] = description return result diff --git a/pytm/actor.py b/pytm/actor.py new file mode 100644 index 00000000..c46c72e5 --- /dev/null +++ b/pytm/actor.py @@ -0,0 +1,102 @@ +"""Actor model - represents entities that initiate actions.""" + +from typing import TYPE_CHECKING, List +from pydantic import Field, field_validator + +from .element import Element +from .base import DataSet + +if TYPE_CHECKING: + from .dataflow import Dataflow + + +class Actor(Element): + """An entity usually initiating actions. + + Actors represent users or external systems that initiate + interactions with the system being modeled. + + Attributes: + port (int): Default TCP port for outgoing data flows + protocol (str): Default network protocol for outgoing data flows + data (DataSet): pytm.Data object(s) in outgoing data flows + inputs (List[Dataflow]): Incoming Dataflows + outputs (List[Dataflow]): Outgoing Dataflows + isAdmin (bool): Indicates whether the actor has administrative privileges + """ + + port: int = Field( + default=-1, description="Default TCP port for outgoing data flows" + ) + protocol: str = Field( + default="", description="Default network protocol for outgoing data flows" + ) + data: DataSet = Field( + default_factory=DataSet, + description="pytm.Data object(s) in outgoing data flows", + ) + inputs: List["Dataflow"] = Field( + default_factory=list, description="Incoming Dataflows" + ) + outputs: List["Dataflow"] = Field( + default_factory=list, description="Outgoing Dataflows" + ) + isAdmin: bool = Field( + default=False, + description="Indicates whether the actor has administrative privileges", + ) + + @field_validator("data", mode="before") + @classmethod + def _coerce_dataset(cls, v): + """Ensure actor data is stored as a DataSet.""" + from .data import Data # Local import to avoid circular dependency + + if isinstance(v, DataSet): + return v + + dataset = DataSet() + + if v is None: + return dataset + + if isinstance(v, Data): + dataset.add(v) + return dataset + + if hasattr(v, "__iter__") and not isinstance(v, (str, bytes)): + for item in v: + if item is None: + continue + dataset.add(item) + return dataset + + dataset.add(v) + return dataset + + def __init__(self, name: str = None, **data): + """ + Initialize an Actor. + + Args: + name (str): Name of the actor. + **data: Optional actor properties: + - port (int): Default TCP port for outgoing data flows + - protocol (str): Default network protocol for outgoing data flows + - data (DataSet): pytm.Data object(s) in outgoing data flows + - inputs (List[Dataflow]): Incoming Dataflows + - outputs (List[Dataflow]): Outgoing Dataflows + - isAdmin (bool): Indicates whether the actor has administrative privileges + """ + super().__init__(name, **data) + # Register with TM actors + self._register_with_tm_actors() + + def _register_with_tm_actors(self): + """Register this actor with the TM class.""" + try: + from .tm import TM + + TM._actors.append(self) + except ImportError: + pass diff --git a/pytm/asset.py b/pytm/asset.py new file mode 100644 index 00000000..06089d16 --- /dev/null +++ b/pytm/asset.py @@ -0,0 +1,382 @@ +"""Asset models - base Asset class and specific asset implementations.""" + +from typing import List, TYPE_CHECKING + +from pydantic import Field, field_validator + +from .element import Element, sev_to_color +from .base import DataSet + +if TYPE_CHECKING: + from .dataflow import Dataflow + + +class Asset(Element): + """An asset with outgoing or incoming dataflows. + + Attributes: + port (int): Default TCP port for incoming data flows + protocol (str): Default network protocol for incoming data flows + data (DataSet): pytm.Data object(s) in incoming data flows + inputs (List[Dataflow]): Incoming Dataflows + outputs (List[Dataflow]): Outgoing Dataflows + onAWS (bool): Is this asset on AWS? + handlesResources (bool): Does this asset handle resources? + usesEnvironmentVariables (bool): Does this asset use environment variables? + OS (str): Operating system + """ + + port: int = Field( + default=-1, description="Default TCP port for incoming data flows" + ) + protocol: str = Field( + default="", description="Default network protocol for incoming data flows" + ) + data: DataSet = Field( + default_factory=DataSet, + description="pytm.Data object(s) in incoming data flows", + ) + inputs: List["Dataflow"] = Field( + default_factory=list, description="incoming Dataflows" + ) + outputs: List["Dataflow"] = Field( + default_factory=list, description="outgoing Dataflows" + ) + onAWS: bool = Field(default=False, description="Is this asset on AWS?") + handlesResources: bool = Field( + default=False, description="Does this asset handle resources?" + ) + usesEnvironmentVariables: bool = Field( + default=False, description="Does this asset use environment variables?" + ) + OS: str = Field(default="", description="Operating system") + + @field_validator("data", mode="before") + @classmethod + def validate_data(cls, v): + """Coerce incoming values to a DataSet.""" + from .data import Data + + if isinstance(v, DataSet): + return v + + dataset = DataSet() + + if v is None: + return dataset + + if isinstance(v, Data): + dataset.add(v) + return dataset + + if hasattr(v, "__iter__") and not isinstance(v, (str, bytes)): + for item in v: + if item is None: + continue + dataset.add(item) + return dataset + + dataset.add(v) + return dataset + + def __init__(self, name: str = None, **data): + """Initialize an Asset. + + Args: + name (str): Name of the asset. + **data: Optional asset properties: + - port (int): Default TCP port for incoming data flows + - protocol (str): Default network protocol for incoming data flows + - data (DataSet): pytm.Data object(s) in incoming data flows + - inputs (List[Dataflow]): Incoming Dataflows + - outputs (List[Dataflow]): Outgoing Dataflows + - onAWS (bool): Is this asset on AWS? + - handlesResources (bool): Does this asset handle resources? + - usesEnvironmentVariables (bool): Does this asset use environment variables? + - OS (str): Operating system + """ + super().__init__(name, **data) + # Register with TM assets + self._register_with_tm_assets() + + def _register_with_tm_assets(self): + """Register this asset with the TM class.""" + try: + from .tm import TM + + TM._assets.append(self) + except ImportError: + pass + + +class Lambda(Asset): + """A lambda function running in a Function-as-a-Service (FaaS) environment. + + Attributes: + port (int): Default TCP port for incoming data flows + protocol (str): Default network protocol for incoming data flows + data (DataSet): pytm.Data object(s) in incoming data flows + inputs (List[Dataflow]): Incoming Dataflows + outputs (List[Dataflow]): Outgoing Dataflows + onAWS (bool): Is this lambda on AWS? + handlesResources (bool): Does this asset handle resources? + usesEnvironmentVariables (bool): Does this asset use environment variables? + OS (str): Operating system + environment (str): Environment for the lambda + implementsAPI (bool): Does this lambda implement an API? + """ + + onAWS: bool = Field(default=True, description="Is this lambda on AWS?") + environment: str = Field(default="", description="Environment for the lambda") + implementsAPI: bool = Field( + default=False, description="Does this lambda implement an API?" + ) + + def __init__(self, name: str = None, **data): + """Initialize a Lambda. + + Args: + name (str): Name of the lambda. + **data: Optional lambda properties: + - port (int): Default TCP port for incoming data flows + - protocol (str): Default network protocol for incoming data flows + - data (DataSet): pytm.Data object(s) in incoming data flows + - inputs (List[Dataflow]): Incoming Dataflows + - outputs (List[Dataflow]): Outgoing Dataflows + - onAWS (bool): Is this lambda on AWS? + - handlesResources (bool): Does this asset handle resources? + - usesEnvironmentVariables (bool): Does this asset use environment variables? + - OS (str): Operating system + - environment (str): Environment for the lambda + - implementsAPI (bool): Does this lambda implement an API? + """ + super().__init__(name, **data) + + def _dfd_template(self) -> str: + """Template for DFD representation.""" + return """{uniq_name} [ + shape = {shape}; + + color = {color}; + fontcolor = "black"; + label = < + + +
{label}
+ >; +] +""" + + def dfd(self, **kwargs) -> str: + """Generate DFD representation of this element.""" + self.is_drawn = True + + levels = kwargs.get("levels", None) + if levels and not levels & self.levels: + return "" + + color = self._color() + + if kwargs.get("colormap", False): + color = sev_to_color(self.severity) + + return self._dfd_template().format( + uniq_name=self._uniq_name(), + label=self._label(), + color=color, + shape=self._shape(), + ) + + def _shape(self) -> str: + """Get shape for DFD representation.""" + return "rectangle; style=rounded" + + +class Server(Asset): + """An entity processing data. + + Attributes: + port (int): Default TCP port for incoming data flows + protocol (str): Default network protocol for incoming data flows + data (DataSet): pytm.Data object(s) in incoming data flows + inputs (List[Dataflow]): Incoming Dataflows + outputs (List[Dataflow]): Outgoing Dataflows + onAWS (bool): Is this asset on AWS? + handlesResources (bool): Does this asset handle resources? + usesEnvironmentVariables (bool): Does this asset use environment variables? + OS (str): Operating system + usesSessionTokens (bool): Does this server use session tokens? + usesCache (bool): Does this server use cache? + usesVPN (bool): Does this server use VPN? + usesXMLParser (bool): Does this server use XML parser? + """ + + usesSessionTokens: bool = Field( + default=False, description="Does this server use session tokens?" + ) + usesCache: bool = Field(default=False, description="Does this server use cache?") + usesVPN: bool = Field(default=False, description="Does this server use VPN?") + usesXMLParser: bool = Field( + default=False, description="Does this server use XML parser?" + ) + + def __init__(self, name: str = None, **data): + """Initialize a Server. + + Args: + name (str): Name of the server. + **data: Optional server properties: + - port (int): Default TCP port for incoming data flows + - protocol (str): Default network protocol for incoming data flows + - data (DataSet): pytm.Data object(s) in incoming data flows + - inputs (List[Dataflow]): Incoming Dataflows + - outputs (List[Dataflow]): Outgoing Dataflows + - onAWS (bool): Is this asset on AWS? + - handlesResources (bool): Does this asset handle resources? + - usesEnvironmentVariables (bool): Does this asset use environment variables? + - OS (str): Operating system + - usesSessionTokens (bool): Does this server use session tokens? + - usesCache (bool): Does this server use cache? + - usesVPN (bool): Does this server use VPN? + - usesXMLParser (bool): Does this server use XML parser? + """ + super().__init__(name, **data) + + def _shape(self) -> str: + """Get shape for DFD representation.""" + return "circle" + + +class ExternalEntity(Asset): + """An external entity that interacts with the system. + + Attributes: + port (int): Default TCP port for incoming data flows + protocol (str): Default network protocol for incoming data flows + data (DataSet): pytm.Data object(s) in incoming data flows + inputs (List[Dataflow]): Incoming Dataflows + outputs (List[Dataflow]): Outgoing Dataflows + onAWS (bool): Is this asset on AWS? + handlesResources (bool): Does this asset handle resources? + usesEnvironmentVariables (bool): Does this asset use environment variables? + OS (str): Operating system + hasPhysicalAccess (bool): Does this external entity have physical access? + """ + + hasPhysicalAccess: bool = Field( + default=False, description="Does this external entity have physical access?" + ) + + def __init__(self, name: str = None, **data): + """Initialize an ExternalEntity. + + Args: + name (str): Name of the external entity. + **data: Optional external entity properties: + - port (int): Default TCP port for incoming data flows + - protocol (str): Default network protocol for incoming data flows + - data (DataSet): pytm.Data object(s) in incoming data flows + - inputs (List[Dataflow]): Incoming Dataflows + - outputs (List[Dataflow]): Outgoing Dataflows + - onAWS (bool): Is this asset on AWS? + - handlesResources (bool): Does this asset handle resources? + - usesEnvironmentVariables (bool): Does this asset use environment variables? + - OS (str): Operating system + - hasPhysicalAccess (bool): Does this external entity have physical access? + """ + super().__init__(name, **data) + + +class LLM(Asset): + """A Large Language Model element, either third-party or self-hosted. + + Attributes: + port (int): Default TCP port for incoming data flows + protocol (str): Default network protocol for incoming data flows + data (DataSet): pytm.Data object(s) in incoming data flows + inputs (List[Dataflow]): Incoming Dataflows + outputs (List[Dataflow]): Outgoing Dataflows + onAWS (bool): Is this asset on AWS? + handlesResources (bool): Does this asset handle resources? + usesEnvironmentVariables (bool): Does this asset use environment variables? + OS (str): Operating system + isThirdParty (bool): Is this LLM a third-party service? + isSelfHosted (bool): Is this LLM self-hosted? + processesPersonalData (bool): Does this LLM process personal data? + retainsUserData (bool): Does this LLM retain user data? + hasAgentCapabilities (bool): Does this LLM have agent capabilities? + hasAccessToSensitiveSystems (bool): Does this LLM have access to sensitive systems? + executesCode (bool): Does this LLM execute code? + hasContentFiltering (bool): Does this LLM have content filtering? + hasSystemPrompt (bool): Does this LLM have a system prompt? + processesUntrustedInput (bool): Does this LLM process untrusted input? + hasRAG (bool): Does this LLM use retrieval-augmented generation? + hasFineTuning (bool): Has this LLM been fine-tuned? + """ + + isThirdParty: bool = Field( + default=True, description="Is this LLM a third-party service?" + ) + isSelfHosted: bool = Field(default=False, description="Is this LLM self-hosted?") + processesPersonalData: bool = Field( + default=False, description="Does this LLM process personal data?" + ) + retainsUserData: bool = Field( + default=False, description="Does this LLM retain user data?" + ) + hasAgentCapabilities: bool = Field( + default=False, description="Does this LLM have agent capabilities?" + ) + hasAccessToSensitiveSystems: bool = Field( + default=False, description="Does this LLM have access to sensitive systems?" + ) + executesCode: bool = Field(default=False, description="Does this LLM execute code?") + hasContentFiltering: bool = Field( + default=False, description="Does this LLM have content filtering?" + ) + hasSystemPrompt: bool = Field( + default=True, description="Does this LLM have a system prompt?" + ) + processesUntrustedInput: bool = Field( + default=True, description="Does this LLM process untrusted input?" + ) + hasRAG: bool = Field( + default=False, description="Does this LLM use retrieval-augmented generation?" + ) + hasFineTuning: bool = Field( + default=False, description="Has this LLM been fine-tuned?" + ) + + def __init__(self, name: str = None, **data): + """Initialize an LLM. + + Args: + name (str): Name of the LLM. + **data: Optional LLM properties: + - port (int): Default TCP port for incoming data flows + - protocol (str): Default network protocol for incoming data flows + - data (DataSet): pytm.Data object(s) in incoming data flows + - inputs (List[Dataflow]): Incoming Dataflows + - outputs (List[Dataflow]): Outgoing Dataflows + - onAWS (bool): Is this asset on AWS? + - handlesResources (bool): Does this asset handle resources? + - usesEnvironmentVariables (bool): Does this asset use environment variables? + - OS (str): Operating system + - isThirdParty (bool): Is this LLM a third-party service? + - isSelfHosted (bool): Is this LLM self-hosted? + - processesPersonalData (bool): Does this LLM process personal data? + - retainsUserData (bool): Does this LLM retain user data? + - hasAgentCapabilities (bool): Does this LLM have agent capabilities? + - hasAccessToSensitiveSystems (bool): Does this LLM have access to sensitive systems? + - executesCode (bool): Does this LLM execute code? + - hasContentFiltering (bool): Does this LLM have content filtering? + - hasSystemPrompt (bool): Does this LLM have a system prompt? + - processesUntrustedInput (bool): Does this LLM process untrusted input? + - hasRAG (bool): Does this LLM use retrieval-augmented generation? + - hasFineTuning (bool): Has this LLM been fine-tuned? + """ + super().__init__(name, **data) + + def _shape(self) -> str: + """Get shape for DFD representation.""" + return "hexagon" diff --git a/pytm/base.py b/pytm/base.py new file mode 100644 index 00000000..d72214d7 --- /dev/null +++ b/pytm/base.py @@ -0,0 +1,216 @@ +"""Base models and utilities for pytm Pydantic models.""" + +from __future__ import annotations + +from typing import Any, Iterable, List, Set, Union, TYPE_CHECKING + +from pydantic import BaseModel, ConfigDict, Field + +if TYPE_CHECKING: + from .element import Element + from .data import Data + from .threat import Threat + from .finding import Finding + + +class DataSet(set): + """Custom set for Data objects with string lookup capability.""" + + __slots__ = ("_names",) + + def __init__(self, values: Iterable["Data"] | None = None): + super().__init__() + self._names: Set[str] = set() + if values is not None: + self.update(values) + + def __contains__(self, item: object) -> bool: + if isinstance(item, str): + return item in self._names + return super().__contains__(item) + + def __eq__(self, other: object) -> bool: + if isinstance(other, set): + return super().__eq__(other) + if isinstance(other, str): + return other in self + return NotImplemented + + def __ne__(self, other: object) -> bool: + if isinstance(other, set): + return super().__ne__(other) + if isinstance(other, str): + return other not in self + return NotImplemented + + def __str__(self) -> str: + return ", ".join(sorted(self._names)) + + def add(self, element: Any) -> None: # type: ignore[override] + super().add(element) + self._register(element) + + def update(self, *others: Iterable[Any]) -> None: # type: ignore[override] + for iterable in others: + for element in iterable: + super().add(element) + self._register(element) + + def discard(self, element: Any) -> None: # type: ignore[override] + if super().__contains__(element): + super().discard(element) + self._unregister(element) + + def remove(self, element: Any) -> None: # type: ignore[override] + super().remove(element) + self._unregister(element) + + def pop(self) -> Any: # type: ignore[override] + element = super().pop() + self._unregister(element) + return element + + def clear(self) -> None: # type: ignore[override] + super().clear() + self._names.clear() + + def _register(self, element: "Data") -> None: + name = getattr(element, "name", None) + if isinstance(name, str): + self._names.add(name) + + def _unregister(self, element: "Data") -> None: + name = getattr(element, "name", None) + if isinstance(name, str): + self._names.discard(name) + + +class Controls(BaseModel): + """Controls implemented by/on an Element.""" + + model_config = ConfigDict(extra="allow", validate_assignment=True) + + authenticatesDestination: bool = Field( + default=False, + description="Verifies the identity of the destination, for example by verifying the authenticity of a digital certificate.", + ) + authenticatesSource: bool = Field(default=False) + authenticationScheme: str = Field(default="") + authorizesSource: bool = Field(default=False) + checksDestinationRevocation: bool = Field( + default=False, + description="Correctly checks the revocation status of credentials used to authenticate the destination", + ) + checksInputBounds: bool = Field(default=False) + definesConnectionTimeout: bool = Field(default=False) + disablesDTD: bool = Field(default=False) + disablesiFrames: bool = Field(default=False) + encodesHeaders: bool = Field(default=False) + encodesOutput: bool = Field(default=False) + encryptsCookies: bool = Field(default=False) + encryptsSessionData: bool = Field(default=False) + handlesCrashes: bool = Field(default=False) + handlesInterruptions: bool = Field(default=False) + handlesResourceConsumption: bool = Field(default=False) + hasAccessControl: bool = Field(default=False) + implementsAuthenticationScheme: bool = Field(default=False) + implementsCSRFToken: bool = Field(default=False) + implementsNonce: bool = Field( + default=False, + description="Nonce is an arbitrary number that can be used just once in a cryptographic communication.", + ) + implementsPOLP: bool = Field( + default=False, + description="The principle of least privilege (PoLP) requires that every module must be able to access only the information and resources that are necessary for its legitimate purpose.", + ) + implementsServerSideValidation: bool = Field(default=False) + implementsStrictHTTPValidation: bool = Field(default=False) + invokesScriptFilters: bool = Field(default=False) + isEncrypted: bool = Field( + default=False, description="Requires incoming data flow to be encrypted" + ) + isEncryptedAtRest: bool = Field( + default=False, description="Stored data is encrypted at rest" + ) + isHardened: bool = Field(default=False) + isResilient: bool = Field(default=False) + providesConfidentiality: bool = Field(default=False) + providesIntegrity: bool = Field(default=False) + sanitizesInput: bool = Field(default=False) + tracksExecutionFlow: bool = Field(default=False) + usesCodeSigning: bool = Field(default=False) + usesEncryptionAlgorithm: str = Field(default="") + usesMFA: bool = Field( + default=False, + description="Multi-factor authentication is an authentication method in which a computer user is granted access only after successfully presenting two or more pieces of evidence.", + ) + usesParameterizedInput: bool = Field(default=False) + usesSecureFunctions: bool = Field(default=False) + usesStrongSessionIdentifiers: bool = Field(default=False) + usesVPN: bool = Field(default=False) + validatesContentType: bool = Field(default=False) + validatesHeaders: bool = Field(default=False) + validatesInput: bool = Field(default=False) + verifySessionIdentifiers: bool = Field(default=False) + + def _attr_values(self) -> dict: + """Return a dictionary of all attribute values.""" + return self.model_dump() + + def _safeset(self, attr: str, value: Any) -> None: + """Safely set an attribute value.""" + try: + setattr(self, attr, value) + except (ValueError, TypeError): + pass + + +class Assumption(BaseModel): + """Assumption used by an Element. Used to exclude threats on a per-element basis. + + Attributes: + name (str): Name of the assumption + exclude (Set[str]): A set of threat SIDs to exclude for this assumption. For example: INP01 + description (str): An additional description of the assumption + """ + + model_config = ConfigDict(extra="allow") + + name: str = Field(description="Name of the assumption") + exclude: Set[str] = Field( + default_factory=set, + description="A set of threat SIDs to exclude for this assumption. For example: INP01", + ) + description: str = Field( + default="", description="An additional description of the assumption" + ) + + def __init__( + self, name: str = None, exclude: Union[List[str], Set[str]] = None, **kwargs + ): + """Initialize an Assumption. + + Args: + name (str): Name of the assumption. + exclude (Set[str]): A set of threat SIDs to exclude for this assumption. For example: INP01 + **kwargs: Optional properties: + - description (str): An additional description of the assumption + """ + if name is not None: + kwargs["name"] = name + if exclude is not None: + # Convert list to set if needed + kwargs["exclude"] = set(exclude) if isinstance(exclude, list) else exclude + super().__init__(**kwargs) + + def __str__(self): + return self.name + + +# Type aliases for complex field types that reference forward declarations +ElementList = List["Element"] +DataList = List["Data"] +ThreatList = List["Threat"] +FindingList = List["Finding"] +ControlsType = Controls +AssumptionList = List[Assumption] diff --git a/pytm/boundary.py b/pytm/boundary.py new file mode 100644 index 00000000..5a870061 --- /dev/null +++ b/pytm/boundary.py @@ -0,0 +1,85 @@ +"""Boundary model - represents trust boundaries in the threat model.""" + +from typing import List, TYPE_CHECKING +from textwrap import indent + +from .element import Element + +if TYPE_CHECKING: + pass + + +class Boundary(Element): + """Trust boundary groups elements and data with the same trust level.""" + + def __init__(self, name: str = None, **data): + super().__init__(name, **data) + # Register with TM boundaries + self._register_with_tm_boundaries() + + def _register_with_tm_boundaries(self): + """Register this boundary with the TM class.""" + try: + from .tm import TM + + if self.name not in TM._boundaries: + TM._boundaries.append(self) + except ImportError: + pass + + def _dfd_template(self) -> str: + """Template for DFD representation.""" + return """subgraph cluster_{uniq_name} {{ + graph [ + fontsize = 10; + fontcolor = black; + style = dashed; + color = {color}; + label = <{label}>; + ] + +{edges} +}} +""" + + def dfd(self, **kwargs) -> str: + """Generate DFD representation of this boundary.""" + if self.is_drawn: + return "" + + self.is_drawn = True + + edges = [] + try: + from .tm import TM + + for e in TM._elements: + if e.inBoundary != self or e.is_drawn: + continue + # The content to draw can include Boundary objects + edges.append(e.dfd(**kwargs)) + except ImportError: + pass + + return self._dfd_template().format( + uniq_name=self._uniq_name(), + label=self._label(), + color=self._color(**kwargs), + edges=indent("\n".join(edges), " "), + ) + + def _color(self, **kwargs) -> str: + """Get color for DFD representation.""" + if kwargs.get("colormap", False): + return "black" + else: + return "firebrick2" + + def parents(self) -> List["Boundary"]: + """Get parent boundaries.""" + result = [] + parent = self.inBoundary + while parent is not None: + result.append(parent) + parent = parent.inBoundary + return result diff --git a/pytm/data.py b/pytm/data.py new file mode 100644 index 00000000..c11f24ae --- /dev/null +++ b/pytm/data.py @@ -0,0 +1,133 @@ +"""Data model - represents data that traverses the threat model.""" + +from typing import List, TYPE_CHECKING +from pydantic import BaseModel, Field, ConfigDict + +from .enums import Classification, Lifetime + +if TYPE_CHECKING: + from .element import Element + from .dataflow import Dataflow + + +class Data(BaseModel): + """Represents a single piece of data that traverses the system. + + Attributes: + name (str): Name of the data + description (str): Description of the data + format (str): Format of the data + classification (Classification): Level of classification for this piece of data + isPII (bool): Does the data contain personally identifiable information. Should always be encrypted both in transmission and at rest. + isCredentials (bool): Does the data contain authentication information, like passwords or cryptographic keys, with or without expiration date. Should always be encrypted in transmission. If stored, they should be hashed using a cryptographic hash function. + credentialsLife (Lifetime): Credentials lifetime, describing if and how credentials can be revoked + isStored (bool): Is the data going to be stored by the target or only processed. If only derivative data is stored (a hash) it can be set to False. + isDestEncryptedAtRest (bool): Is data encrypted at rest at dest? + isSourceEncryptedAtRest (bool): Is data encrypted at rest at source? + carriedBy (List[Dataflow]): Dataflows that carries this piece of data + processedBy (List[Element]): Elements that store/process this piece of data + """ + + model_config = ConfigDict(extra="allow", validate_assignment=True) + + name: str = Field(description="Name of the data") + description: str = Field(default="", description="Description of the data") + format: str = Field(default="", description="Format of the data") + classification: Classification = Field( + default=Classification.UNKNOWN, + description="Level of classification for this piece of data", + ) + isPII: bool = Field( + default=False, + description="Does the data contain personally identifiable information. Should always be encrypted both in transmission and at rest.", + ) + isCredentials: bool = Field( + default=False, + description="Does the data contain authentication information, like passwords or cryptographic keys, with or without expiration date. Should always be encrypted in transmission. If stored, they should be hashed using a cryptographic hash function.", + ) + credentialsLife: Lifetime = Field( + default=Lifetime.NONE, + description="Credentials lifetime, describing if and how credentials can be revoked", + ) + isStored: bool = Field( + default=False, + description="Is the data going to be stored by the target or only processed. If only derivative data is stored (a hash) it can be set to False.", + ) + isDestEncryptedAtRest: bool = Field( + default=False, description="Is data encrypted at rest at dest?" + ) + isSourceEncryptedAtRest: bool = Field( + default=False, description="Is data encrypted at rest at source?" + ) + carriedBy: List["Dataflow"] = Field( + default_factory=list, description="Dataflows that carries this piece of data" + ) + processedBy: List["Element"] = Field( + default_factory=list, + description="Elements that store/process this piece of data", + ) + + def __init__(self, name: str = None, **data): + """Initialize a Data object. + + Args: + name (str): Name of the data. + **data: Optional data properties: + - description (str): Description of the data + - format (str): Format of the data + - classification (Classification): Level of classification for this piece of data + - isPII (bool): Does the data contain personally identifiable information. Should always be encrypted both in transmission and at rest. + - isCredentials (bool): Does the data contain authentication information, like passwords or cryptographic keys, with or without expiration date. Should always be encrypted in transmission. If stored, they should be hashed using a cryptographic hash function. + - credentialsLife (Lifetime): Credentials lifetime, describing if and how credentials can be revoked + - isStored (bool): Is the data going to be stored by the target or only processed. If only derivative data is stored (a hash) it can be set to False. + - isDestEncryptedAtRest (bool): Is data encrypted at rest at dest? + - isSourceEncryptedAtRest (bool): Is data encrypted at rest at source? + - carriedBy (List[Dataflow]): Dataflows that carries this piece of data + - processedBy (List[Element]): Elements that store/process this piece of data + """ + # Handle positional name argument + if name is not None: + data["name"] = name + super().__init__(**data) + + # Register with TM + self._register_with_tm() + + def _register_with_tm(self): + """Register this data with the TM class.""" + try: + from .tm import TM + + TM._data.append(self) + except ImportError: + # TM might not be available yet during initial setup + pass + + def __repr__(self): + return ( + f"<{self.__module__}.{type(self).__name__}({self.name}) at {hex(id(self))}>" + ) + + def __str__(self): + return f"Data({self.name})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Data): + return NotImplemented + return ( + self.name == other.name + and self.description == other.description + and self.format == other.format + and self.classification == other.classification + ) + + def __hash__(self): + """Make Data objects hashable for use in sets.""" + return hash((self.name, self.description, self.format, self.classification)) + + def _safeset(self, attr: str, value) -> None: + """Safely set an attribute value.""" + try: + setattr(self, attr, value) + except (ValueError, TypeError): + pass diff --git a/pytm/dataflow.py b/pytm/dataflow.py new file mode 100644 index 00000000..3d7315a3 --- /dev/null +++ b/pytm/dataflow.py @@ -0,0 +1,218 @@ +"""Dataflow model - represents data flows between elements.""" + +from typing import Optional +from pydantic import Field, field_validator, model_validator + +from .element import Element, sev_to_color +from .enums import Classification, TLSVersion +from .base import DataSet + + +class Dataflow(Element): + """A data flow from a source to a sink.""" + + source: Element = Field(description="Source element of the data flow") + sink: Element = Field(description="Sink element of the data flow") + isResponse: bool = Field( + default=False, description="Is a response to another data flow" + ) + response: Optional["Dataflow"] = Field( + default=None, description="Another data flow that is a response to this one" + ) + responseTo: Optional["Dataflow"] = Field( + default=None, description="Is a response to this data flow" + ) + srcPort: int = Field(default=-1, description="Source TCP port") + dstPort: int = Field(default=-1, description="Destination TCP port") + tlsVersion: TLSVersion = Field( + default=TLSVersion.NONE, description="TLS version used" + ) + protocol: str = Field(default="", description="Protocol used in this data flow") + data: DataSet = Field( + default_factory=DataSet, + description="pytm.Data object(s) in incoming data flows", + ) + order: int = Field( + default=-1, description="Number of this data flow in the threat model" + ) + implementsCommunicationProtocol: bool = Field( + default=False, description="Does this flow implement a communication protocol" + ) + note: str = Field(default="", description="Note about this data flow") + usesVPN: bool = Field(default=False, description="Does this flow use VPN") + usesSessionTokens: bool = Field( + default=False, description="Does this flow use session tokens" + ) + + @field_validator("data", mode="before") + @classmethod + def validate_data(cls, v): + """Convert single Data object to DataSet, handle compatibility.""" + from .data import Data + + if isinstance(v, str): + # Handle legacy string assignment + return DataSet( + [ + Data( + name="undefined", + description=v, + classification=Classification.UNKNOWN, + ) + ] + ) + + if isinstance(v, Data): + # Single Data object + return DataSet([v]) + + if hasattr(v, "__iter__") and not isinstance(v, (str, bytes)): + # Iterable of Data objects + return DataSet(v) + + if isinstance(v, DataSet): + return v + + return DataSet([v]) + + def __setattr__(self, name, value): + """Handle bidirectional response relationships during assignment.""" + # Set the attribute first + super().__setattr__(name, value) + + # Skip relationship setup if we're already in the process of linking + if getattr(self, "_updating_relationships", False): + return + + # Set up bidirectional relationships for response/responseTo + if name == "responseTo" and value is not None: + self._link_response_to(value) + elif name == "response" and value is not None: + self._link_response_from(value) + + def _link_response_to(self, target: "Dataflow") -> None: + """Link this dataflow as a response to the target dataflow.""" + object.__setattr__(self, "_updating_relationships", True) + try: + # Mark this as a response + if not self.isResponse: + object.__setattr__(self, "isResponse", True) + # Set reverse link on target + if target.response is None: + object.__setattr__(target, "response", self) + finally: + object.__setattr__(self, "_updating_relationships", False) + + def _link_response_from(self, source: "Dataflow") -> None: + """Link the source dataflow as a response to this dataflow.""" + object.__setattr__(self, "_updating_relationships", True) + try: + # Mark source as a response + if not source.isResponse: + object.__setattr__(source, "isResponse", True) + # Set reverse link on source + if source.responseTo is None: + object.__setattr__(source, "responseTo", self) + finally: + object.__setattr__(self, "_updating_relationships", False) + + @model_validator(mode="after") + def setup_response_relationships(self) -> "Dataflow": + """Set up bidirectional response relationships after validation.""" + # Set up bidirectional response relationship if responseTo is set + if self.responseTo is not None: + if not self.isResponse: + object.__setattr__(self, "isResponse", True) + if self.responseTo.response is None: + object.__setattr__(self.responseTo, "response", self) + + # Handle reverse relationship + if self.response is not None: + if not self.response.isResponse: + object.__setattr__(self.response, "isResponse", True) + if self.response.responseTo is None: + object.__setattr__(self.response, "responseTo", self) + + return self + + def __init__(self, source: Element, sink: Element, name: str, **data): + """Create a Dataflow between two elements. + + Args: + source: Source element of the data flow. + sink: Sink element of the data flow. + name: Name of this data flow. + **data: Additional field values (e.g. protocol, tlsVersion, srcPort). + """ + # Handle positional arguments + data["source"] = source + data["sink"] = sink + data["name"] = name + super().__init__(**data) + # Register with TM flows + self._register_with_tm_flows() + + def _register_with_tm_flows(self): + """Register this dataflow with the TM class.""" + try: + from .tm import TM + + TM._flows.append(self) + except ImportError: + pass + + def display_name(self) -> str: + """Get display name for this dataflow.""" + if self.order == -1: + return self.name + return f"({self.order}) {self.name}" + + def _dfd_template(self) -> str: + """Template for DFD representation.""" + return """{source} -> {sink} [ + color = {color}; + fontcolor = {color}; + dir = {direction}; + label = "{label}"; +] +""" + + def dfd(self, mergeResponses: bool = False, **kwargs) -> str: + """Generate DFD representation of this dataflow.""" + self.is_drawn = True + + levels = kwargs.get("levels", None) + if ( + levels + and not levels & self.levels + and not (levels & self.source.levels and levels & self.sink.levels) + ): + return "" + + color = self._color() + + if kwargs.get("colormap", False): + color = sev_to_color(self.severity) + + direction = "forward" + label = self._label() + if mergeResponses and self.response is not None: + direction = "both" + label += "\n" + self.response._label() + + return self._dfd_template().format( + source=self.source._uniq_name(), + sink=self.sink._uniq_name(), + direction=direction, + label=label, + color=color, + ) + + def hasDataLeaks(self) -> bool: + """Check if this dataflow has data leaks.""" + return any( + d.classification > self.source.maxClassification + or d.classification > self.sink.maxClassification + or d.classification > self.maxClassification + for d in self.data + ) diff --git a/pytm/datastore.py b/pytm/datastore.py new file mode 100644 index 00000000..0493406a --- /dev/null +++ b/pytm/datastore.py @@ -0,0 +1,126 @@ +"""Datastore model - represents data storage elements in the threat model.""" + +import os +from typing import TYPE_CHECKING +from pydantic import Field + +from .asset import Asset +from .enums import DatastoreType + +if TYPE_CHECKING: + pass + + +class Datastore(Asset): + """An entity storing data. + + Attributes: + port (int): Default TCP port for incoming data flows + protocol (str): Default network protocol for incoming data flows + data (DataSet): pytm.Data object(s) in incoming data flows + inputs (List[Dataflow]): Incoming Dataflows + outputs (List[Dataflow]): Outgoing Dataflows + onAWS (bool): Is this asset on AWS? + handlesResources (bool): Does this asset handle resources? + usesEnvironmentVariables (bool): Does this asset use environment variables? + OS (str): Operating system + onRDS (bool): Is this datastore on RDS? + storesLogData (bool): Does this datastore store log data? + storesPII (bool): Personally Identifiable Information is any information relating to an identifiable person + storesSensitiveData (bool): Does this datastore store sensitive data? + isSQL (bool): Is this a SQL datastore? + isShared (bool): Is this datastore shared? + hasWriteAccess (bool): Does this datastore have write access? + type (DatastoreType): The type of Datastore + """ + + onRDS: bool = Field(default=False, description="Is this datastore on RDS?") + storesLogData: bool = Field( + default=False, description="Does this datastore store log data?" + ) + storesPII: bool = Field( + default=False, + description="Personally Identifiable Information is any information relating to an identifiable person", + ) + storesSensitiveData: bool = Field( + default=False, description="Does this datastore store sensitive data?" + ) + isSQL: bool = Field(default=True, description="Is this a SQL datastore?") + isShared: bool = Field(default=False, description="Is this datastore shared?") + hasWriteAccess: bool = Field( + default=False, description="Does this datastore have write access?" + ) + type: DatastoreType = Field( + default=DatastoreType.UNKNOWN, description="The type of Datastore" + ) + + def __init__(self, name: str = None, **data): + """Initialize a Datastore. + + Args: + name (str): Name of the datastore. + **data: Optional datastore properties: + - port (int): Default TCP port for incoming data flows + - protocol (str): Default network protocol for incoming data flows + - data (DataSet): pytm.Data object(s) in incoming data flows + - inputs (List[Dataflow]): Incoming Dataflows + - outputs (List[Dataflow]): Outgoing Dataflows + - onAWS (bool): Is this asset on AWS? + - handlesResources (bool): Does this asset handle resources? + - usesEnvironmentVariables (bool): Does this asset use environment variables? + - OS (str): Operating system + - onRDS (bool): Is this datastore on RDS? + - storesLogData (bool): Does this datastore store log data? + - storesPII (bool): Personally Identifiable Information is any information relating to an identifiable person + - storesSensitiveData (bool): Does this datastore store sensitive data? + - isSQL (bool): Is this a SQL datastore? + - isShared (bool): Is this datastore shared? + - hasWriteAccess (bool): Does this datastore have write access? + - type (DatastoreType): The type of Datastore + """ + super().__init__(name, **data) + + def _dfd_template(self) -> str: + """Template for DFD representation.""" + return """{uniq_name} [ + shape = {shape}; + fixedsize = shape; + image = "{image}"; + imagescale = true; + color = {color}; + fontcolor = black; + xlabel = "{label}"; + label = ""; +] +""" + + def _shape(self) -> str: + """Get shape for DFD representation.""" + return "none" + + def dfd(self, **kwargs) -> str: + """Generate DFD representation of this element.""" + from .element import sev_to_color + + self.is_drawn = True + + levels = kwargs.get("levels", None) + if levels and not levels & self.levels: + return "" + + color = self._color() + color_file = "black" + + if kwargs.get("colormap", False): + color = sev_to_color(self.severity) + color_file = color.split(";")[0] + + return self._dfd_template().format( + uniq_name=self._uniq_name(), + label=self._label(), + color=color, + shape=self._shape(), + image=os.path.join( + os.path.dirname(__file__), "images", f"datastore_{color_file}.png" + ), + ) diff --git a/pytm/element.py b/pytm/element.py new file mode 100644 index 00000000..6212fdb9 --- /dev/null +++ b/pytm/element.py @@ -0,0 +1,336 @@ +"""Element model - base class for all threat model elements.""" + +import inspect +import random +import uuid as uuid_module +from hashlib import sha224 +from textwrap import wrap +from typing import Any, List, Optional, Set, TYPE_CHECKING + +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from .base import Assumption, Controls +from .enums import Classification, TLSVersion + +if TYPE_CHECKING: + from .boundary import Boundary + from .dataflow import Dataflow + from .finding import Finding + + +def sev_to_color(sev: int) -> str: + """Return a Graphviz color declaration based on severity.""" + if sev == 5: + return 'firebrick3; fillcolor="#b2222222"; style=filled ' + if 2 <= sev <= 4: + return 'gold; fillcolor="#ffd80022"; style=filled' + if 0 <= sev < 2: + return 'darkgreen; fillcolor="#00630022"; style=filled' + return "black" + + +class Element(BaseModel): + """A generic element in the threat model. + + Attributes: + name (str): Name of the element + description (str): Description of the element + inBoundary (Boundary): Trust boundary this element exists in + inScope (bool): Is the element in scope of the threat model? + maxClassification (Classification): Maximum data classification this element can handle + minTLSVersion (TLSVersion): Minimum TLS version required + findings (List[Finding]): Threats that apply to this element + overrides (List[Finding]): Overrides to findings, allowing to set a custom response, CVSS score or override other attributes + assumptions (List[Assumption]): Assumptions about the element. These optionally allow to exclude threats with the given SIDs + levels (Set[int]): List of levels (0, 1, 2, ...) to be drawn in the model + sourceFiles (List[str]): Location of the source code that describes this element relative to the directory of the model script + controls (Controls): Security controls for this element + severity (int): Severity level of threats affecting this element + """ + + model_config = ConfigDict( + extra="allow", + validate_assignment=True, + arbitrary_types_allowed=True, + use_attribute_docstrings=True, + ) + + name: str = Field(description="Name of the element") + description: str = Field(default="", description="Description of the element") + inBoundary: Optional["Boundary"] = Field( + default=None, + description="Trust boundary this element exists in", + ) + inScope: bool = Field( + default=True, description="Is the element in scope of the threat model" + ) + maxClassification: Classification = Field( + default=Classification.UNKNOWN, + description="Maximum data classification this element can handle", + ) + minTLSVersion: TLSVersion = Field( + default=TLSVersion.NONE, + description="Minimum TLS version required", + ) + findings: List["Finding"] = Field( + default_factory=list, + description="Threats that apply to this element", + ) + overrides: List["Finding"] = Field( + default_factory=list, + description="Overrides to findings, allowing to set a custom response, CVSS score or override other attributes", + ) + assumptions: List[Assumption] = Field( + default_factory=list, + description="Assumptions about the element. These optionally allow to exclude threats with the given SIDs", + ) + levels: Set[int] = Field( + default_factory=lambda: {0}, + description="List of levels (0, 1, 2, ...) to be drawn in the model", + ) + sourceFiles: List[str] = Field( + default_factory=list, + description="Location of the source code that describes this element relative to the directory of the model script", + ) + controls: Controls = Field( + default_factory=Controls, description="Security controls for this element" + ) + severity: int = Field( + default=0, description="Severity level of threats affecting this element" + ) + + # Internal attributes + uuid: uuid_module.UUID = Field( + default_factory=lambda: uuid_module.UUID(int=random.getrandbits(128)) + ) + is_drawn: bool = Field(default=False, exclude=True) + + _WRITE_ONCE_FIELDS = {"name"} + + @field_validator("levels", mode="before") + @classmethod + def _coerce_levels(cls, value): + """Normalize level inputs to a set of integers.""" + if value is None: + return {0} + if isinstance(value, (set, frozenset)): + return set(value) + if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)): + return set(value) + return {value} + + def __setattr__( + self, key: str, value: Any + ) -> None: # noqa: D401 - keep same behaviour + if ( + key in self._WRITE_ONCE_FIELDS + and key in self.__dict__ + and not key.startswith("_") + ): + raise ValueError(f"cannot overwrite {type(self).__name__}.{key} value") + super().__setattr__(key, value) + + def __init__(self, name: Optional[str] = None, **data: Any): + """Initialize an Element. + + Args: + name (str): Name of the element. + **data: Optional element properties: + - description (str): Description of the element + - inBoundary (Boundary): Trust boundary this element exists in + - inScope (bool): Is the element in scope of the threat model? + - maxClassification (Classification): Maximum data classification this element can handle + - minTLSVersion (TLSVersion): Minimum TLS version required + - findings (List[Finding]): Threats that apply to this element + - overrides (List[Finding]): Overrides to findings, allowing to set a custom response, CVSS score or override other attributes + - assumptions (List[Assumption]): Assumptions about the element. These optionally allow to exclude threats with the given SIDs + - levels (Set[int]): List of levels (0, 1, 2, ...) to be drawn in the model + - sourceFiles (List[str]): Location of the source code that describes this element relative to the directory of the model script + - controls (Controls): Security controls for this element + - severity (int): Severity level of threats affecting this element + """ + if name is not None: + data["name"] = name + super().__init__(**data) + self._register_with_tm() + + def _register_with_tm(self) -> None: + """Register this element with the TM class.""" + try: + from .tm import TM + + TM._elements.append(self) + except ImportError: + # TM might not be available yet during initial setup + pass + + def __repr__(self) -> str: + return ( + f"<{self.__module__}.{type(self).__name__}({self.name}) at {hex(id(self))}>" + ) + + def __str__(self) -> str: + return f"{type(self).__name__}({self.name})" + + def __hash__(self) -> int: + """Make Element objects hashable for use in sets and as dict keys.""" + return hash((type(self).__name__, self.name, str(self.uuid))) + + def _uniq_name(self) -> str: + """Transform name and uuid into a unique string.""" + digest = sha224(str(self.uuid).encode("utf-8")).hexdigest() + name = "".join(ch for ch in self.name if ch.isalpha()) + return f"{type(self).__name__.lower()}_{name}_{digest[:10]}" + + def check(self) -> bool: + """Check if the element is valid.""" + return True + + def _dfd_template(self) -> str: + """Template for DFD representation.""" + return """{uniq_name} [ + shape = {shape}; + color = {color}; + fontcolor = black; + label = "{label}"; + margin = 0.02; +] +""" + + def dfd(self, **kwargs: Any) -> str: + """Generate DFD representation of this element.""" + self.is_drawn = True + + levels = kwargs.get("levels") + if levels and not levels & self.levels: + return "" + + color = self._color() + if kwargs.get("colormap", False): + color = sev_to_color(self.severity) + + return self._dfd_template().format( + uniq_name=self._uniq_name(), + label=self._label(), + color=color, + shape=self._shape(), + ) + + def _color(self) -> str: + """Get the color for this element.""" + return "black" + + def oneOf(self, *elements: Any) -> bool: + """Return True if the element matches any provided elements or classes.""" + for element in elements: + if inspect.isclass(element): + if isinstance(self, element): + return True + elif self is element: + return True + return False + + def crosses(self, *boundaries: Any) -> bool: + """Return True if the flow crosses any of the provided boundaries.""" + if hasattr(self, "source") and hasattr(self, "sink"): + if self.source.inBoundary is self.sink.inBoundary: + return False + for boundary in boundaries: + if inspect.isclass(boundary): + if ( + ( + isinstance(self.source.inBoundary, boundary) + and not isinstance(self.sink.inBoundary, boundary) + ) + or ( + not isinstance(self.source.inBoundary, boundary) + and isinstance(self.sink.inBoundary, boundary) + ) + or self.source.inBoundary is not self.sink.inBoundary + ): + return True + elif ( + self.source.inside(boundary) and not self.sink.inside(boundary) + ) or (not self.source.inside(boundary) and self.sink.inside(boundary)): + return True + return False + + def enters(self, *boundaries: Any) -> bool: + """Return True if the flow enters any of the provided boundaries.""" + if hasattr(self, "source") and hasattr(self, "sink"): + return self.source.inBoundary is None and self.sink.inside(*boundaries) + return False + + def exits(self, *boundaries: Any) -> bool: + """Return True if the flow exits any of the provided boundaries.""" + if hasattr(self, "source") and hasattr(self, "sink"): + return self.source.inside(*boundaries) and self.sink.inBoundary is None + return False + + def inside(self, *boundaries: Any) -> bool: + """Return True if the element resides inside any of the provided boundaries.""" + for boundary in boundaries: + if inspect.isclass(boundary): + if isinstance(self.inBoundary, boundary): + return True + elif self.inBoundary is boundary: + return True + return False + + def display_name(self) -> str: + """Get display name for this element.""" + return self.name + + def _label(self) -> str: + """Get label for DFD representation.""" + return "\\n".join(wrap(self.display_name(), 18)) + + def _shape(self) -> str: + """Get shape for DFD representation.""" + return "square" + + def _safeset(self, attr: str, value: Any) -> None: + """Safely set an attribute value.""" + try: + setattr(self, attr, value) + except (ValueError, TypeError): + pass + + def _attr_values(self) -> dict: + """Return a dictionary of all attribute values.""" + return self.model_dump() + + def checkTLSVersion(self, flows: List["Dataflow"]) -> bool: + """Check if any flows have insufficient TLS version.""" + return any(f.tlsVersion < self.minTLSVersion for f in flows) + + def _set_severity(self, sev: Any) -> None: + """Set the severity based on numeric or textual value.""" + if isinstance(sev, int): + self.severity = max(0, sev) + return + + if isinstance(sev, str): + normalized = sev.strip().lower() + mapping = { + "very high": 5, + "critical": 5, + "high": 4, + "medium": 3, + "low": 2, + "very low": 1, + "info": 0, + } + legacy_mapping = { + "critical": 3, + "high": 2, + "medium": 1, + "low": 0, + } + + value = mapping.get(normalized) + if value is None: + value = legacy_mapping.get(normalized) + + if value is not None and value > self.severity: + self.severity = value diff --git a/pytm/enums.py b/pytm/enums.py new file mode 100644 index 00000000..29c08403 --- /dev/null +++ b/pytm/enums.py @@ -0,0 +1,94 @@ +"""Enums used throughout the pytm package.""" + +from enum import Enum + + +class Action(Enum): + """Action taken when validating a threat model.""" + + NO_ACTION = "NO_ACTION" + RESTRICT = "RESTRICT" + IGNORE = "IGNORE" + + +class OrderedEnum(Enum): + """Base enum class that supports ordering operations.""" + + def __ge__(self, other): + if self.__class__ is other.__class__: + return self.value >= other.value + return NotImplemented + + def __gt__(self, other): + if self.__class__ is other.__class__: + return self.value > other.value + return NotImplemented + + def __le__(self, other): + if self.__class__ is other.__class__: + return self.value <= other.value + return NotImplemented + + def __lt__(self, other): + if self.__class__ is other.__class__: + return self.value < other.value + return NotImplemented + + +class Classification(OrderedEnum): + """Data classification levels.""" + + UNKNOWN = 0 + PUBLIC = 1 + RESTRICTED = 2 + SENSITIVE = 3 + SECRET = 4 + TOP_SECRET = 5 + + +class Lifetime(Enum): + """Credential lifetime categories.""" + + # not applicable + NONE = "NONE" + # unknown lifetime + UNKNOWN = "UNKNOWN" + # relatively short expiration date (time to live) + SHORT = "SHORT_LIVED" + # long or no expiration date + LONG = "LONG_LIVED" + # no expiration date but revoked/invalidated automatically in some conditions + AUTO = "AUTO_REVOKABLE" + # no expiration date but can be invalidated manually + MANUAL = "MANUALLY_REVOKABLE" + # cannot be invalidated at all + HARDCODED = "HARDCODED" + + def label(self): + return self.value.lower().replace("_", " ") + + +class DatastoreType(Enum): + """Types of datastores.""" + + UNKNOWN = "UNKNOWN" + FILE_SYSTEM = "FILE_SYSTEM" + SQL = "SQL" + LDAP = "LDAP" + AWS_S3 = "AWS_S3" + + def label(self): + return self.value.lower().replace("_", " ") + + +class TLSVersion(OrderedEnum): + """TLS/SSL version levels.""" + + NONE = 0 + SSLv1 = 1 + SSLv2 = 2 + SSLv3 = 3 + TLSv10 = 4 + TLSv11 = 5 + TLSv12 = 6 + TLSv13 = 7 diff --git a/pytm/finding.py b/pytm/finding.py new file mode 100644 index 00000000..75d201ff --- /dev/null +++ b/pytm/finding.py @@ -0,0 +1,158 @@ +"""Finding model - represents a finding linking an element to a threat.""" + +from typing import Optional, TYPE_CHECKING +from pydantic import BaseModel, Field, ConfigDict + +from .base import Assumption + +if TYPE_CHECKING: + from .element import Element + + +class Finding(BaseModel): + """Represents a Finding - the element in question and a description of the finding. + + Attributes: + element (Element): Element this finding applies to + target (str): Name of the element this finding applies to + description (str): Threat description + details (str): Threat details + severity (str): Threat severity + mitigations (str): Threat mitigations + example (str): Threat example + id (str): Finding ID + threat_id (str): Threat ID + references (str): Threat references + condition (str): Threat condition + assumption (Assumption): The assumption that caused this finding to be excluded + response (str): Describes how this threat matching this particular asset or dataflow is being handled. Can be one of: mitigated, transferred, avoided, accepted + cvss (str): The CVSS score and/or vector + """ + + model_config = ConfigDict( + extra="allow", validate_assignment=True, arbitrary_types_allowed=True + ) + + element: Optional["Element"] = Field( + default=None, description="Element this finding applies to" + ) + target: str = Field( + default="", description="Name of the element this finding applies to" + ) + description: str = Field(description="Threat description") + details: str = Field(description="Threat details") + severity: str = Field(description="Threat severity") + mitigations: str = Field(description="Threat mitigations") + example: str = Field(description="Threat example") + id: str = Field(description="Finding ID") + threat_id: str = Field(description="Threat ID") + references: str = Field(description="Threat references") + condition: str = Field(description="Threat condition") + assumption: Optional[Assumption] = Field( + default=None, + description="The assumption that caused this finding to be excluded", + ) + response: str = Field( + default="", + description="Describes how this threat matching this particular asset or dataflow is being handled. Can be one of: mitigated, transferred, avoided, accepted", + ) + cvss: str = Field(default="", description="The CVSS score and/or vector") + + def __init__(self, *args, **kwargs): + """Initialize a Finding. + + Args: + *args: Optionally pass the element as the first positional argument. + **kwargs: Finding properties: + - element (Element): Element this finding applies to + - target (str): Name of the element this finding applies to + - threat (Threat): Threat object to copy attributes from (description, details, severity, mitigations, example, references, condition) + - description (str): Threat description + - details (str): Threat details + - severity (str): Threat severity + - mitigations (str): Threat mitigations + - example (str): Threat example + - id (str): Finding ID + - threat_id (str): Threat ID + - references (str): Threat references + - condition (str): Threat condition + - assumption (Assumption): The assumption that caused this finding to be excluded + - response (str): Describes how this threat matching this particular asset or dataflow is being handled. Can be one of: mitigated, transferred, avoided, accepted + - cvss (str): The CVSS score and/or vector + """ + # Handle positional element argument + if args: + element = args[0] + kwargs["element"] = element + + # Get element from kwargs + element = kwargs.get("element") + + # Set target from element name if element is provided + if element is not None and "target" not in kwargs: + kwargs["target"] = element.name + + # Handle threat data + threat = kwargs.pop("threat", None) + if threat: + kwargs["threat_id"] = getattr(threat, "id", "") + # Copy threat attributes + threat_attrs = [ + "description", + "details", + "severity", + "mitigations", + "example", + "references", + "condition", + ] + for attr in threat_attrs: + if attr not in kwargs: # Don't override explicit values + kwargs[attr] = getattr(threat, attr, "") + + # Handle overrides from element + threat_id = kwargs.get("threat_id", None) + if hasattr(element, "overrides") and threat_id: + for override in element.overrides: + if getattr(override, "threat_id", None) == threat_id: + # Apply override values + override_dict = ( + override.model_dump() if hasattr(override, "model_dump") else {} + ) + for key, value in override_dict.items(): + if key not in ("element", "target") and value is not None: + kwargs[key] = value + break + + # Ensure all required fields have values + required_fields = [ + "description", + "details", + "severity", + "mitigations", + "example", + "id", + "threat_id", + "references", + "condition", + ] + for field in required_fields: + if field not in kwargs: + kwargs[field] = "" + + super().__init__(**kwargs) + + def _safeset(self, attr: str, value) -> None: + """Safely set an attribute value.""" + try: + setattr(self, attr, value) + except (ValueError, TypeError): + pass + + def __repr__(self): + return ( + f"<{self.__module__}.{type(self).__name__}({self.id}) at {hex(id(self))}>" + ) + + def __str__(self): + return f"'{self.target}': {self.description}\n{self.details}\n{self.severity}" diff --git a/pytm/flows.py b/pytm/flows.py index a1e882e4..b05c0cab 100644 --- a/pytm/flows.py +++ b/pytm/flows.py @@ -3,7 +3,7 @@ def req_reply(src: Element, dest: Element, req_name: str, reply_name=None) -> (DF, DF): - ''' + """ This function creates two datflows where one dataflow is a request and the second dataflow is the corresponding reply to the newly created request. @@ -22,9 +22,9 @@ def req_reply(src: Element, dest: Element, req_name: str, reply_name=None) -> (D Returns: a tuple of two dataflows, where the first is the request and the second is the reply. - ''' + """ if not reply_name: - reply_name = f'Reply to {req_name}' + reply_name = f"Reply to {req_name}" req = DF(src, dest, req_name) reply = DF(dest, src, name=reply_name) reply.responseTo = req @@ -32,7 +32,7 @@ def req_reply(src: Element, dest: Element, req_name: str, reply_name=None) -> (D def reply(req: DF, **kwargs) -> DF: - ''' + """ This function takes a dataflow as an argument and returns a new dataflow, which is a response to the given dataflow. Args: @@ -45,12 +45,12 @@ def reply(req: DF, **kwargs) -> DF: client_reply = reply(client_query) Returns: a Dataflow which is a reply to the given datadlow req - ''' - if 'name' not in kwargs: - name = f'Reply to {req.name}' + """ + if "name" not in kwargs: + name = f"Reply to {req.name}" else: - name = kwargs['name'] - del kwargs['name'] + name = kwargs["name"] + del kwargs["name"] reply = DF(req.sink, req.source, name, **kwargs) reply.responseTo = req return req, reply diff --git a/pytm/json.py b/pytm/json.py index a69cd719..e9a1debf 100644 --- a/pytm/json.py +++ b/pytm/json.py @@ -1,22 +1,25 @@ import json -import sys - -from .pytm import ( - TM, - Boundary, - Element, - Dataflow, - Server, - ExternalEntity, - Datastore, - Actor, - Process, - SetOfProcesses, - Action, - Lambda, - LLM, - Controls, -) + +from .tm import TM +from .boundary import Boundary +from .dataflow import Dataflow +from .asset import Asset, Server, ExternalEntity, Lambda, LLM +from .datastore import Datastore +from .actor import Actor +from .process import Process, SetOfProcesses +from .enums import Action + +_ELEMENT_CLASSES = { + "Asset": Asset, + "Actor": Actor, + "Server": Server, + "ExternalEntity": ExternalEntity, + "Lambda": Lambda, + "LLM": LLM, + "Datastore": Datastore, + "Process": Process, + "SetOfProcesses": SetOfProcesses, +} def loads(s): @@ -74,7 +77,10 @@ def decode_boundaries(flat): def decode_elements(flat, boundaries): elements = {} for i, e in enumerate(flat): - klass = getattr(sys.modules[__name__], e.pop("__class__", "Asset")) + class_name = e.pop("__class__", "Asset") + klass = _ELEMENT_CLASSES.get(class_name) + if klass is None: + raise ValueError(f"Unknown element class: {class_name}") name = e.pop("name", None) if name is None: raise ValueError(f"name property missing in element {i}") diff --git a/pytm/process.py b/pytm/process.py new file mode 100644 index 00000000..e71cb7a3 --- /dev/null +++ b/pytm/process.py @@ -0,0 +1,127 @@ +"""Process model - represents processes that handle data.""" + +from typing import TYPE_CHECKING +from pydantic import Field + +from .asset import Asset + +if TYPE_CHECKING: + pass + + +class Process(Asset): + """An entity processing data. + + Attributes: + port (int): Default TCP port for incoming data flows + protocol (str): Default network protocol for incoming data flows + data (DataSet): pytm.Data object(s) in incoming data flows + inputs (List[Dataflow]): Incoming Dataflows + outputs (List[Dataflow]): Outgoing Dataflows + onAWS (bool): Is this asset on AWS? + handlesResources (bool): Does this asset handle resources? + usesEnvironmentVariables (bool): Does this asset use environment variables? + OS (str): Operating system + codeType (str): Type of code running in this process + implementsCommunicationProtocol (bool): Does this process implement a communication protocol? + tracksExecutionFlow (bool): Does this process track execution flow? + implementsAPI (bool): Does this process implement an API? + environment (str): Environment for this process + allowsClientSideScripting (bool): Does this process allow client-side scripting? + """ + + codeType: str = Field( + default="Unmanaged", description="Type of code running in this process" + ) + implementsCommunicationProtocol: bool = Field( + default=False, + description="Does this process implement a communication protocol?", + ) + tracksExecutionFlow: bool = Field( + default=False, description="Does this process track execution flow?" + ) + implementsAPI: bool = Field( + default=False, description="Does this process implement an API?" + ) + environment: str = Field(default="", description="Environment for this process") + allowsClientSideScripting: bool = Field( + default=False, description="Does this process allow client-side scripting?" + ) + + def __init__(self, name: str = None, **data): + """Initialize a Process. + + Args: + name (str): Name of the process. + **data: Optional process properties: + - port (int): Default TCP port for incoming data flows + - protocol (str): Default network protocol for incoming data flows + - data (DataSet): pytm.Data object(s) in incoming data flows + - inputs (List[Dataflow]): Incoming Dataflows + - outputs (List[Dataflow]): Outgoing Dataflows + - onAWS (bool): Is this asset on AWS? + - handlesResources (bool): Does this asset handle resources? + - usesEnvironmentVariables (bool): Does this asset use environment variables? + - OS (str): Operating system + - codeType (str): Type of code running in this process + - implementsCommunicationProtocol (bool): Does this process implement a communication protocol? + - tracksExecutionFlow (bool): Does this process track execution flow? + - implementsAPI (bool): Does this process implement an API? + - environment (str): Environment for this process + - allowsClientSideScripting (bool): Does this process allow client-side scripting? + """ + super().__init__(name, **data) + + def _shape(self) -> str: + """Get shape for DFD representation.""" + return "circle" + + +class SetOfProcesses(Process): + """A set of processes grouped together. + + Attributes: + port (int): Default TCP port for incoming data flows + protocol (str): Default network protocol for incoming data flows + data (DataSet): pytm.Data object(s) in incoming data flows + inputs (List[Dataflow]): Incoming Dataflows + outputs (List[Dataflow]): Outgoing Dataflows + onAWS (bool): Is this asset on AWS? + handlesResources (bool): Does this asset handle resources? + usesEnvironmentVariables (bool): Does this asset use environment variables? + OS (str): Operating system + codeType (str): Type of code running in this process + implementsCommunicationProtocol (bool): Does this process implement a communication protocol? + tracksExecutionFlow (bool): Does this process track execution flow? + implementsAPI (bool): Does this process implement an API? + environment (str): Environment for this process + allowsClientSideScripting (bool): Does this process allow client-side scripting? + """ + + def __init__(self, name: str = None, **data): + """Initialize a SetOfProcesses. + + Args: + name (str): Name of the set of processes. + **data: Optional properties: + - port (int): Default TCP port for incoming data flows + - protocol (str): Default network protocol for incoming data flows + - data (DataSet): pytm.Data object(s) in incoming data flows + - inputs (List[Dataflow]): Incoming Dataflows + - outputs (List[Dataflow]): Outgoing Dataflows + - onAWS (bool): Is this asset on AWS? + - handlesResources (bool): Does this asset handle resources? + - usesEnvironmentVariables (bool): Does this asset use environment variables? + - OS (str): Operating system + - codeType (str): Type of code running in this process + - implementsCommunicationProtocol (bool): Does this process implement a communication protocol? + - tracksExecutionFlow (bool): Does this process track execution flow? + - implementsAPI (bool): Does this process implement an API? + - environment (str): Environment for this process + - allowsClientSideScripting (bool): Does this process allow client-side scripting? + """ + super().__init__(name, **data) + + def _shape(self) -> str: + """Get shape for DFD representation.""" + return "doublecircle" diff --git a/pytm/pytm.py b/pytm/pytm.py index 82b56aa8..d710c8c6 100644 --- a/pytm/pytm.py +++ b/pytm/pytm.py @@ -1,376 +1,68 @@ import argparse -import errno -import inspect -import json -import logging -import os -import random -import sys -import uuid import html import copy +import logging +import re +import sys -from collections import Counter, defaultdict -from collections.abc import Iterable -from enum import Enum -from functools import lru_cache, singledispatch -from hashlib import sha224 -from itertools import combinations -from shutil import rmtree -from textwrap import indent, wrap -from weakref import WeakKeyDictionary -from datetime import datetime - -from .template_engine import SuperFormatter - -""" Helper functions """ - -""" The base for this (descriptors instead of properties) has been - shamelessly lifted from - https://nbviewer.jupyter.org/urls/gist.github.com/ChrisBeaumont/5758381/raw/descriptor_writeup.ipynb - By Chris Beaumont -""" - - -def sev_to_color(sev): - # calculate the color depending on the severity - if sev == 5: - return 'firebrick3; fillcolor="#b2222222"; style=filled ' - elif sev <= 4 and sev >= 2: - return 'gold; fillcolor="#ffd80022"; style=filled' - elif sev < 2 and sev >= 0: - return 'darkgreen; fillcolor="#00630022"; style=filled' - - return "black" - - -class UIError(Exception): - def __init__(self, e, context): - self.error = e - self.context = context - +from dataclasses import dataclass, field +from typing import ClassVar + +from pydantic import ValidationError +from pydantic.fields import PydanticUndefined + +from collections import defaultdict +from collections.abc import Iterable, Mapping +from functools import singledispatch + +# Import all the new Pydantic models +from .enums import ( + Action, + Classification, + DatastoreType, + Lifetime, + TLSVersion, + OrderedEnum, +) +from .base import Assumption, Controls +from .element import Element +from .data import Data +from .threat import Threat +from .finding import Finding +from .asset import Asset, Lambda, LLM, Server, ExternalEntity +from .datastore import Datastore +from .actor import Actor +from .process import Process, SetOfProcesses +from .dataflow import Dataflow +from .boundary import Boundary +from .tm import TM, UIError logger = logging.getLogger(__name__) - -class var(object): - """A descriptor that allows setting a value only once""" - - def __init__(self, default, required=False, doc="", onSet=None): - self.default = default - self.required = required - self.doc = doc - self.data = WeakKeyDictionary() - self.onSet = onSet - - def __get__(self, instance, owner): - # when x.d is called we get here - # instance = x - # owner = type(x) - if instance is None: - return self - return self.data.get(instance, self.default) - - def __set__(self, instance, value): - # called when x.d = val - # instance = x - # value = val - if instance in self.data: - raise ValueError( - "cannot overwrite {}.{} value with {}, already set to {}".format( - instance, self.__class__.__name__, value, self.data[instance] - ) - ) - self.data[instance] = value - if self.onSet is not None: - self.onSet(instance, value) - - -class varString(var): - def __set__(self, instance, value): - if not isinstance(value, str): - raise ValueError("expecting a String value, got a {}".format(type(value))) - super().__set__(instance, value) - - -class varStrings(var): - def __set__(self, instance, value): - if not isinstance(value, Iterable) or isinstance(value, str): - value = [value] - for i, e in enumerate(value): - if not isinstance(e, str): - raise ValueError( - f"expecting a list of str, item number {i} is a {type(e)}" - ) - super().__set__(instance, set(value)) - - -class varBoundary(var): - def __set__(self, instance, value): - if not isinstance(value, Boundary): - raise ValueError("expecting a Boundary value, got a {}".format(type(value))) - super().__set__(instance, value) - - -class varBool(var): - def __set__(self, instance, value): - if not isinstance(value, bool): - raise ValueError("expecting a boolean value, got a {}".format(type(value))) - super().__set__(instance, value) - - -class varInt(var): - def __set__(self, instance, value): - if not isinstance(value, int): - raise ValueError("expecting an integer value, got a {}".format(type(value))) - super().__set__(instance, value) - - -class varInts(var): - def __set__(self, instance, value): - if not isinstance(value, Iterable): - value = [value] - for i, e in enumerate(value): - if not isinstance(e, int): - raise ValueError( - f"expecting a list of int, item number {i} is a {type(e)}" - ) - super().__set__(instance, set(value)) - - -class varElement(var): - def __set__(self, instance, value): - if not isinstance(value, Element): - raise ValueError( - "expecting an Element (or inherited) " - "value, got a {}".format(type(value)) - ) - super().__set__(instance, value) - - -class varElements(var): - def __set__(self, instance, value): - for i, e in enumerate(value): - if not isinstance(e, Element): - raise ValueError( - "expecting a list of Elements, item number {} is a {}".format( - i, type(e) - ) - ) - super().__set__(instance, list(value)) - - -class varFindings(var): - def __set__(self, instance, value): - for i, e in enumerate(value): - if not isinstance(e, Finding): - raise ValueError( - "expecting a list of Findings, item number {} is a {}".format( - i, type(e) - ) - ) - super().__set__(instance, list(value)) - - -class varAction(var): - def __set__(self, instance, value): - if not isinstance(value, Action): - raise ValueError("expecting an Action, got a {}".format(type(value))) - super().__set__(instance, value) - - -class varClassification(var): - def __set__(self, instance, value): - if not isinstance(value, Classification): - raise ValueError("expecting a Classification, got a {}".format(type(value))) - super().__set__(instance, value) - - -class varLifetime(var): - def __set__(self, instance, value): - if not isinstance(value, Lifetime): - raise ValueError("expecting a Lifetime, got a {}".format(type(value))) - super().__set__(instance, value) - - -class varDatastoreType(var): - def __set__(self, instance, value): - if not isinstance(value, DatastoreType): - raise ValueError("expecting a DatastoreType, got a {}".format(type(value))) - super().__set__(instance, value) - - -class varTLSVersion(var): - def __set__(self, instance, value): - if not isinstance(value, TLSVersion): - raise ValueError("expecting a TLSVersion, got a {}".format(type(value))) - super().__set__(instance, value) - - -class varData(var): - def __set__(self, instance, value): - if isinstance(value, str): - value = [ - Data( - name="undefined", - description=value, - classification=Classification.UNKNOWN, - ) - ] - sys.stderr.write( - "FIXME: a dataflow is using a string as the Data attribute. This has been deprecated and Data objects should be created instead.\n" - ) - - if not isinstance(value, Iterable): - value = [value] - for i, e in enumerate(value): - if not isinstance(e, Data): - raise ValueError( - "expecting a list of pytm.Data, item number {} is a {}".format( - i, type(e) - ) - ) - super().__set__(instance, DataSet(value)) - - -class DataSet(set): - def __contains__(self, item): - if isinstance(item, str): - return item in [d.name for d in self] - if isinstance(item, Data): - return super().__contains__(item) - return NotImplemented - - def __eq__(self, other): - if isinstance(other, set): - return super().__eq__(other) - if isinstance(other, str): - return other in self - return NotImplemented - - def __ne__(self, other): - if isinstance(other, set): - return super().__ne__(other) - if isinstance(other, str): - return other not in self - return NotImplemented - - def __str__(self): - return ", ".join(sorted(set(d.name for d in self))) - - -class varControls(var): - def __set__(self, instance, value): - if not isinstance(value, Controls): - raise ValueError( - f"expecting an Controls value, got a {type(value)}" - ) - super().__set__(instance, value) - - -class varAssumptions(var): - def __set__(self, instance, value): - for i, e in enumerate(value): - if isinstance(e, str): - e = value[i] = Assumption(e) - if not isinstance(e, Assumption): - raise ValueError( - f"expecting a list of Assumptions, item number {i} is a {type(e)}" - ) - super().__set__(instance, list(value)) - - -class varAssumption(var): - def __set__(self, instance, value): - if not isinstance(value, Assumption): - raise ValueError( - f"expecting an Assumption value, got a {type(value)}" - ) - super().__set__(instance, value) - - -class Action(Enum): - """Action taken when validating a threat model.""" - - NO_ACTION = "NO_ACTION" - RESTRICT = "RESTRICT" - IGNORE = "IGNORE" - - -class OrderedEnum(Enum): - def __ge__(self, other): - if self.__class__ is other.__class__: - return self.value >= other.value - return NotImplemented - - def __gt__(self, other): - if self.__class__ is other.__class__: - return self.value > other.value - return NotImplemented - - def __le__(self, other): - if self.__class__ is other.__class__: - return self.value <= other.value - return NotImplemented - - def __lt__(self, other): - if self.__class__ is other.__class__: - return self.value < other.value - return NotImplemented - - -class Classification(OrderedEnum): - UNKNOWN = 0 - PUBLIC = 1 - RESTRICTED = 2 - SENSITIVE = 3 - SECRET = 4 - TOP_SECRET = 5 - - -class Lifetime(Enum): - # not applicable - NONE = "NONE" - # unknown lifetime - UNKNOWN = "UNKNOWN" - # relatively short expiration date (time to live) - SHORT = "SHORT_LIVED" - # long or no expiration date - LONG = "LONG_LIVED" - # no expiration date but revoked/invalidated automatically in some conditions - AUTO = "AUTO_REVOKABLE" - # no expiration date but can be invalidated manually - MANUAL = "MANUALLY_REVOKABLE" - # cannot be invalidated at all - HARDCODED = "HARDCODED" - - def label(self): - return self.value.lower().replace("_", " ") - - -class DatastoreType(Enum): - UNKNOWN = "UNKNOWN" - FILE_SYSTEM = "FILE_SYSTEM" - SQL = "SQL" - LDAP = "LDAP" - AWS_S3 = "AWS_S3" - - def label(self): - return self.value.lower().replace("_", " ") - - -class TLSVersion(OrderedEnum): - NONE = 0 - SSLv1 = 1 - SSLv2 = 2 - SSLv3 = 3 - TLSv10 = 4 - TLSv11 = 5 - TLSv12 = 6 - TLSv13 = 7 - - +# Legacy aliases for backward compatibility +varString = str +varStrings = list +varBoundary = object +varBool = bool +varInt = int +varInts = set +varElement = object +varElements = list +varFindings = list +varAction = Action +varClassification = Classification +varLifetime = Lifetime +varDatastoreType = DatastoreType +varTLSVersion = TLSVersion +varData = list +varControls = Controls +varAssumptions = list +varAssumption = Assumption + + +# Essential helper functions preserved from original def _sort(flows, addOrder=False): + """Sort flows by order.""" ordered = sorted(flows, key=lambda flow: flow.order) if not addOrder: return ordered @@ -382,6 +74,7 @@ def _sort(flows, addOrder=False): def _sort_elem(elements): + """Sort elements.""" if len(elements) == 0: return elements orders = {} @@ -404,8 +97,148 @@ def _sort_elem(elements): ) +def _iter_subclasses(cls): + """Yield all subclasses of *cls*, recursively.""" + seen = set() + stack = [cls] + + while stack: + current = stack.pop() + for subclass in getattr(current, "__subclasses__", lambda: [])(): + if subclass in seen: + continue + seen.add(subclass) + yield subclass + stack.append(subclass) + + +def _list_elements(): + """List all elements usable in a threat model along with their descriptions.""" + + def _print_components(classes): + entries = sorted(classes, key=lambda cls: cls.__name__) + if not entries: + return + + name_width = max(len(entry.__name__) for entry in entries) + for entry in entries: + doc = entry.__doc__ or "" + print(f"{entry.__name__:<{name_width}} -- {doc}") + + print("Elements:") + _print_components(list(_iter_subclasses(Element))) + + print("\nAtributes:") + enumerated = set(_iter_subclasses(OrderedEnum)) + enumerated.update({Data, Action, Lifetime}) + _print_components(list(enumerated)) + + +_CLASS_REGISTRY = { + "Action": Action, + "Actor": Actor, + "Asset": Asset, + "Boundary": Boundary, + "Classification": Classification, + "Data": Data, + "Dataflow": Dataflow, + "Datastore": Datastore, + "DatastoreType": DatastoreType, + "ExternalEntity": ExternalEntity, + "Finding": Finding, + "Lambda": Lambda, + "Lifetime": Lifetime, + "LLM": LLM, + "Process": Process, + "Server": Server, + "SetOfProcesses": SetOfProcesses, + "Threat": Threat, + "TLSVersion": TLSVersion, + "TM": TM, + "UIError": UIError, +} + + +def _describe_classes(class_names): + """Describe available classes and their attributes for CLI users.""" + + registry = dict(_CLASS_REGISTRY) + + for cls in _iter_subclasses(Element): + registry.setdefault(cls.__name__, cls) + + for name in class_names: + klass = registry.get(name) + if klass is None: + logger.error("No such class to describe: %s", name) + sys.exit(1) + + print(f"{name} class attributes:") + + model_fields = getattr(klass, "model_fields", None) + if model_fields: + field_names = sorted(model_fields.keys()) + if not field_names: + print(" (no attributes)") + else: + longest = len(max(field_names, key=len)) + 2 + lpadding = f'\n{" ":<{longest+2}}' + for field_name in field_names: + field_info = model_fields[field_name] + docs: list[str] = [] + description = field_info.description or "" + if description: + docs.extend(description.split("\n")) + if field_info.is_required(): + docs.append("required") + default = field_info.default + if default is not PydanticUndefined: + docs.append(f"default: {default!r}") + elif field_info.default_factory is not None: + factory = field_info.default_factory + factory_name = getattr(factory, "__name__", repr(factory)) + docs.append(f"default factory: {factory_name}") + + if docs: + print(f" {field_name:<{longest}}{lpadding.join(docs)}") + else: + print(f" {field_name}") + elif hasattr(klass, "__members__"): + members = getattr(klass, "__members__", {}) + if not members: + print(" (no members)") + else: + for member in members: + print(f" {member}") + else: + attrs = [ + attr + for attr in dir(klass) + if not attr.startswith("_") and not callable(getattr(klass, attr)) + ] + if not attrs: + print(" (no attributes)") + else: + longest = len(max(attrs, key=len)) + 2 + lpadding = f'\n{" ":<{longest+2}}' + for attr in sorted(attrs): + value = getattr(klass, attr) + docs = [] + doc_attr = getattr(value, "__doc__", None) + if isinstance(doc_attr, str): + stripped = doc_attr.strip() + if stripped: + docs.append(stripped) + if docs: + print(f" {attr:<{longest}}{lpadding.join(docs)}") + else: + print(f" {attr}") + + print() + + def _match_responses(flows): - """Ensure that responses are pointing to requests""" + """Ensure that responses are pointing to requests.""" index = defaultdict(list) for e in flows: key = (e.source, e.sink) @@ -433,1618 +266,257 @@ def _match_responses(flows): return flows -def _apply_defaults(flows, data): - inputs = defaultdict(list) - outputs = defaultdict(list) - carriers = defaultdict(set) - processors = defaultdict(set) - - for d in data: - for e in d.carriedBy: - try: - setattr(e, "data", d) - except ValueError: - e.data.add(d) - - for e in flows: - if e.source.data: - try: - setattr(e, "data", e.source.data.copy()) - except ValueError: - e.data.update(e.source.data) - - for d in e.data: - carriers[d].add(e) - processors[d].add(e.source) - processors[d].add(e.sink) - - e._safeset("levels", e.source.levels & e.sink.levels) +def _add_data(container, value): + """Attach Data objects to a container supporting add/append semantics.""" + if container is None or value is None: + return - try: - e.overrides = e.sink.overrides - e.overrides.extend( - f - for f in e.source.overrides - if f.threat_id not in (f.threat_id for f in e.overrides) - ) - except ValueError: - pass + if isinstance(value, Data): + items = [value] + elif hasattr(value, "__iter__") and not isinstance(value, (str, bytes)): + items = list(value) + else: + items = [value] - if e.isResponse: - e._safeset("protocol", e.source.protocol) - e._safeset("srcPort", e.source.port) - e.controls._safeset("isEncrypted", e.source.controls.isEncrypted) + for item in items: + if item is None: continue + if hasattr(container, "add"): + container.add(item) + elif hasattr(container, "append"): + container.append(item) - e._safeset("protocol", e.sink.protocol) - e._safeset("dstPort", e.sink.port) - if hasattr(e.sink.controls, "isEncrypted"): - e.controls._safeset("isEncrypted", e.sink.controls.isEncrypted) - e.controls._safeset( - "authenticatesDestination", e.source.controls.authenticatesDestination - ) - e.controls._safeset( - "checksDestinationRevocation", e.source.controls.checksDestinationRevocation - ) - - for d in e.data: - if d.isStored: - if hasattr(e.sink.controls, "isEncryptedAtRest"): - for d in e.data: - d._safeset( - "isDestEncryptedAtRest", e.sink.controls.isEncryptedAtRest - ) - if hasattr(e.source, "isEncryptedAtRest"): - for d in e.data: - d._safeset( - "isSourceEncryptedAtRest", - e.source.controls.isEncryptedAtRest, - ) - if d.credentialsLife != Lifetime.NONE and not d.isCredentials: - d._safeset("isCredentials", True) - if d.isCredentials and d.credentialsLife == Lifetime.NONE: - d._safeset("credentialsLife", Lifetime.UNKNOWN) - - outputs[e.source].append(e) - inputs[e.sink].append(e) - - for e, flows in inputs.items(): - try: - e.inputs = flows - except (AttributeError, ValueError): - pass - for e, flows in outputs.items(): - try: - e.outputs = flows - except (AttributeError, ValueError): - pass - - for d, flows in carriers.items(): - flows = sorted(flows, key=lambda f: f.name) - try: - setattr(d, "carriedBy", list(flows)) - except ValueError: - for e in flows: - if e not in d.carriedBy: - d.carriedBy.append(e) - for d, elements in processors.items(): - elements = sorted(elements, key=lambda e: e.name) - try: - setattr(d, "processedBy", elements) - except ValueError: - for e in elements: - if e not in d.processedBy: - d.processedBy.append(e) - - -def _describe_classes(classes): - for name in classes: - klass = getattr(sys.modules[__name__], name, None) - if klass is None: - logger.error("No such class to describe: %s\n", name) - sys.exit(1) - print("{} class attributes:".format(name)) - attrs = [] - for i in dir(klass): - if i.startswith("_") or callable(getattr(klass, i)): - continue - attrs.append(i) - longest = len(max(attrs, key=len)) + 2 - for i in attrs: - attr = getattr(klass, i, {}) - docs = [] - if isinstance(attr, var): - if attr.doc: - docs.extend(attr.doc.split("\n")) - if attr.required: - docs.append("required") - if attr.default or isinstance(attr.default, bool): - docs.append("default: {}".format(attr.default)) - lpadding = f'\n{" ":<{longest+2}}' - print(f" {i:<{longest}}{lpadding.join(docs)}") - print() - - -def _list_elements(): - """List all elements which can be used in a threat model with the corresponding description""" - - def all_subclasses(cls): - """Get all sub classes of a class""" - subclasses = set(cls.__subclasses__()) - return subclasses.union((s for c in subclasses for s in all_subclasses(c))) - def print_components(cls_list): - elements = sorted(cls_list, key=lambda c: c.__name__) - max_len = max((len(e.__name__) for e in elements)) - for sc in elements: - doc = sc.__doc__ if sc.__doc__ is not None else "" - print(f"{sc.__name__:<{max_len}} -- {doc}") +@dataclass +class _FlowDefaultsBuilder: + """Collect and apply default relationships for data flows.""" - # print all elements - print("Elements:") - print_components(all_subclasses(Element)) - - # Print Attributes - print("\nAtributes:") - print_components(all_subclasses(OrderedEnum) | {Data, Action, Lifetime}) - - -def _get_elements_and_boundaries(flows): - """filter out elements and boundaries not used in this TM""" - elements = set() - boundaries = set() - for e in flows: - elements.add(e) - elements.add(e.source) - elements.add(e.sink) - if e.source.inBoundary is not None: - elements.add(e.source.inBoundary) - boundaries.add(e.source.inBoundary) - for b in e.source.inBoundary.parents(): - elements.add(b) - boundaries.add(b) - if e.sink.inBoundary is not None: - elements.add(e.sink.inBoundary) - boundaries.add(e.sink.inBoundary) - for b in e.sink.inBoundary.parents(): - elements.add(b) - boundaries.add(b) - return (list(elements), list(boundaries)) - - -""" End of help functions """ - - -class Threat: - """Represents a possible threat""" - - id = varString("", required=True) - description = varString("") - condition = varString( - "", - doc="""a Python expression that should evaluate -to a boolean True or False""", - ) - details = varString("") - likelihood = varString("") - severity = varString("") - mitigations = varString("") - prerequisites = varString("") - example = varString("") - references = varString("") - target = () - - def __init__(self, **kwargs): - self.id = kwargs["SID"] - self.description = kwargs.get("description", "") - self.likelihood = kwargs.get("Likelihood Of Attack", "") - self.condition = kwargs.get("condition", "True") - target = kwargs.get("target", "Element") - if not isinstance(target, str) and isinstance(target, Iterable): - target = tuple(target) - else: - target = (target,) - self.target = tuple(getattr(sys.modules[__name__], x) for x in target) - self.details = kwargs.get("details", "") - self.severity = kwargs.get("severity", "") - self.mitigations = kwargs.get("mitigations", "") - self.prerequisites = kwargs.get("prerequisites", "") - self.example = kwargs.get("example", "") - self.references = kwargs.get("references", "") - - def _safeset(self, attr, value): - try: - setattr(self, attr, value) - except ValueError: - pass - - def __repr__(self): - return "<{0}.{1}({2}) at {3}>".format( - self.__module__, type(self).__name__, self.id, hex(id(self)) - ) - - def __str__(self): - return "{0}({1})".format(type(self).__name__, self.id) - - def apply(self, target): - if not isinstance(target, self.target): - return None - return eval(self.condition) - - -class Finding: - """Represents a Finding - the element in question - and a description of the finding""" - - element = varElement(None, required=True, doc="Element this finding applies to") - target = varString("", doc="Name of the element this finding applies to") - description = varString("", required=True, doc="Threat description") - details = varString("", required=True, doc="Threat details") - severity = varString("", required=True, doc="Threat severity") - mitigations = varString("", required=True, doc="Threat mitigations") - example = varString("", required=True, doc="Threat example") - id = varString("", required=True, doc="Finding ID") - threat_id = varString("", required=True, doc="Threat ID") - references = varString("", required=True, doc="Threat references") - condition = varString("", required=True, doc="Threat condition") - assumption = varAssumption(None, required=False, doc="The assumption, that caused this finding to be excluded") - response = varString( - "", - required=False, - doc="""Describes how this threat matching this particular asset or dataflow is being handled. -Can be one of: -* mitigated - there were changes made in the modeled system to reduce the probability of this threat occurring or the impact when it does, -* transferred - users of the system are required to mitigate this threat, -* avoided - this asset or dataflow is removed from the system, -* accepted - no action is taken as the probability and/or impact is very low -""", - ) - cvss = varString("", required=False, doc="The CVSS score and/or vector") - likelihood = varString("", required=False, doc="Likelihood of the threat") - - def __init__( - self, - *args, - **kwargs, - ): - if args: - element = args[0] - else: - element = kwargs.pop("element", Element("invalid")) - - self.target = element.name - self.element = element - attrs = [ - "description", - "details", - "severity", - "mitigations", - "example", - "references", - "condition", - "likelihood", - ] - threat = kwargs.pop("threat", None) - if threat: - kwargs["threat_id"] = getattr(threat, "id") - for a in attrs: - # copy threat attrs into kwargs to allow to override them in next step - kwargs[a] = getattr(threat, a) - - threat_id = kwargs.get("threat_id", None) - for f in element.overrides: - if f.threat_id != threat_id: - continue - for i in dir(f.__class__): - attr = getattr(f.__class__, i) - if ( - i in ("element", "target") - or i.startswith("_") - or callable(attr) - or not isinstance(attr, var) - ): - continue - if f in attr.data: - kwargs[i] = attr.data[f] - break - - for k, v in kwargs.items(): - setattr(self, k, v) - - def _safeset(self, attr, value): - try: - setattr(self, attr, value) - except ValueError: - pass - - def __repr__(self): - return "<{0}.{1}({2}) at {3}>".format( - self.__module__, type(self).__name__, self.id, hex(id(self)) - ) - - def __str__(self): - return f"'{self.target}': {self.description}\n{self.details}\n{self.severity}" - - -class TM: - """Describes the threat model administratively, - and holds all details during a run""" - - _flows = [] - _elements = [] - _actors = [] - _assets = [] - _threats = [] - _boundaries = [] - _data = [] - _threatsExcluded = [] - _sf = None - _duplicate_ignored_attrs = ( - "name", - "note", - "order", - "response", - "responseTo", - "controls", - ) - name = varString("", required=True, doc="Model name") - description = varString("", required=True, doc="Model description") - threatsFile = varString( - os.path.dirname(__file__) + "/threatlib/threats.json", - onSet=lambda i, v: i._init_threats(), - doc="JSON file with custom threats", + inputs: defaultdict[Element, list[Dataflow]] = field( + default_factory=lambda: defaultdict(list) ) - isOrdered = varBool(False, doc="Automatically order all Dataflows") - mergeResponses = varBool(False, doc="Merge response edges in DFDs") - ignoreUnused = varBool(False, doc="Ignore elements not used in any Dataflow") - findings = varFindings([], doc="Threats found for elements of this model") - excluded_findings = varFindings( - [], - doc="Threats found for elements of this model, " - "that were excluded on a per-element basis, using the Assumptions class" + outputs: defaultdict[Element, list[Dataflow]] = field( + default_factory=lambda: defaultdict(list) ) - onDuplicates = varAction( - Action.NO_ACTION, - doc="""How to handle duplicate Dataflow -with same properties, except name and notes""", + carriers: defaultdict[Data, set[Dataflow]] = field( + default_factory=lambda: defaultdict(set) ) - assumptions = varAssumptions( - [], - required=False, - doc="A list of assumptions about the design/model.", + processors: defaultdict[Data, set[Element]] = field( + default_factory=lambda: defaultdict(set) ) - _colormap = False - - def __init__(self, name, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - self.name = name - self._sf = SuperFormatter() - self._add_threats() - # make sure generated diagrams do not change, makes sense if they're commited - random.seed(0) - - @classmethod - def reset(cls): - cls._flows = [] - cls._elements = [] - cls._actors = [] - cls._assets = [] - cls._threats = [] - cls._boundaries = [] - cls._data = [] - cls._threatsExcluded = [] - - def _init_threats(self): - TM._threats = [] - self._add_threats() - - def _add_threats(self): - try: - with open(self.threatsFile, "r", encoding="utf8") as threat_file: - threats_json = json.load(threat_file) - except (FileNotFoundError, PermissionError, IsADirectoryError) as e: - raise UIError( - e, f"while trying to open the the threat file ({self.threatsFile})." - ) - active_threats = (threat for threat in threats_json if "DEPRECATED" not in threat) - for threat in active_threats: - TM._threats.append(Threat(**threat)) - - def resolve(self): - finding_count = 0 - excluded_finding_count = 0 - findings = [] - excluded_findings = [] - # We just need the assumptions with SIDs to exclude - global_assumptions = [a for a in self.assumptions if len(a.exclude) > 0] - elements = defaultdict(list) - for e in TM._elements: - if not e.inScope: - e.findings = findings - continue - - override_ids = set(f.threat_id for f in e.overrides) - # if element is a dataflow filter out overrides from source and sink - # because they will be always applied there anyway - try: - override_ids -= set( - f.threat_id for f in e.source.overrides + e.sink.overrides - ) - except AttributeError: - pass - - for t in TM._threats: - if not t.apply(e) and t.id not in override_ids: - continue - - if t.id in TM._threatsExcluded: - continue - - _continue = False - for assumption in e.assumptions + global_assumptions: # type: Assumption - if t.id in assumption.exclude: - excluded_finding_count += 1 - f = Finding(e, id=str(excluded_finding_count), threat=t, assumption=assumption) - excluded_findings.append(f) - _continue = True - break - if _continue: - continue - - finding_count += 1 - f = Finding(e, id=str(finding_count), threat=t) - findings.append(f) - elements[e].append(f) - e._set_severity(f.severity) - self.findings = findings - self.excluded_findings = excluded_findings - for e, findings in elements.items(): - e.findings = findings - - def check(self): - if self.description is None: - raise ValueError( - """Every threat model should have at least -a brief description of the system being modeled.""" - ) - TM._flows = _match_responses(_sort(TM._flows, self.isOrdered)) - - self._check_duplicates(TM._flows) - - _apply_defaults(TM._flows, TM._data) - - for e in TM._elements: - top = Counter(f.threat_id for f in e.overrides).most_common(1) - if not top: - continue - threat_id, count = top[0] - if count != 1: - raise ValueError( - f"Finding {threat_id} have more than one override in {e}" - ) - - if self.ignoreUnused: - TM._elements, TM._boundaries = _get_elements_and_boundaries(TM._flows) - - result = True - for e in TM._elements: - if not e.check(): - result = False + assignment_errors: ClassVar[tuple[type[Exception], ...]] = ( + ValueError, + AttributeError, + TypeError, + ValidationError, + ) - if self.ignoreUnused: - # cannot rely on user defined order if assets are re-used in multiple models - TM._elements = _sort_elem(TM._elements) + def seed_data_relationships(self, data_items: Iterable[Data]) -> None: + """Ensure data instances are referenced by existing carriers.""" + for datum in data_items: + for flow in getattr(datum, "carriedBy", []): + _add_data(getattr(flow, "data", None), datum) - return result + def process_flow(self, flow: Dataflow) -> None: + """Apply defaults and collect relationships for a single flow.""" + self._inherit_source_data(flow) + self._index_flow_relationships(flow) + self._sync_levels(flow) + self._merge_overrides(flow) - def _check_duplicates(self, flows): - if self.onDuplicates == Action.NO_ACTION: + if getattr(flow, "isResponse", False): + self._apply_response_defaults(flow) return - index = defaultdict(list) - for e in flows: - key = (e.source, e.sink) - index[key].append(e) - - for flows in index.values(): - for left, right in combinations(flows, 2): - left_attrs = left._attr_values() - right_attrs = right._attr_values() - for a in self._duplicate_ignored_attrs: - del left_attrs[a], right_attrs[a] - if left_attrs != right_attrs: - continue - if self.onDuplicates == Action.IGNORE: - right._is_drawn = True - continue - - left_controls_attrs = left.controls._attr_values() - right_controls_attrs = right.controls._attr_values() - # for a in self._duplicate_ignored_attrs: - # del left_controls_attrs[a], right_controls_attrs[a] - if left_controls_attrs != right_controls_attrs: - continue - if self.onDuplicates == Action.IGNORE: - right._is_drawn = True - continue - - raise ValueError( - "Duplicate Dataflow found between {} and {}: " - "{} is same as {}".format( - left.source, - left.sink, - left, - right, - ) - ) - - def _dfd_template(self): - return """digraph tm {{ - graph [ - fontname = Arial; - fontsize = 14; - ] - node [ - fontname = Arial; - fontsize = 14; - rankdir = lr; - ] - edge [ - shape = none; - arrowtail = onormal; - fontname = Arial; - fontsize = 12; - ] - labelloc = "t"; - fontsize = 20; - nodesep = 1; - -{edges} -}}""" - - def dfd(self, **kwargs): - if "levels" in kwargs: - levels = kwargs["levels"] - if not isinstance(kwargs["levels"], Iterable): - kwargs["levels"] = [levels] - kwargs["levels"] = set(levels) - - edges = [] - # since boundaries can be nested sort them by level and start from top - parents = set(b.inBoundary for b in TM._boundaries if b.inBoundary) - - # TODO boundaries should not be drawn if they don't contain elements matching requested levels - # or contain only empty boundaries - boundary_levels = defaultdict(set) - max_level = 0 - for b in TM._boundaries: - if b in parents: - continue - boundary_levels[0].add(b) - for i, p in enumerate(b.parents()): - i = i + 1 - boundary_levels[i].add(p) - if i > max_level: - max_level = i - - for i in range(max_level, -1, -1): - for b in sorted(boundary_levels[i], key=lambda b: b.name): - edges.append(b.dfd(**kwargs)) - - if self.mergeResponses: - for e in TM._flows: - if e.response is not None: - e.response._is_drawn = True - kwargs["mergeResponses"] = self.mergeResponses - for e in TM._elements: - if not e._is_drawn and not isinstance(e, Boundary) and e.inBoundary is None: - edges.append(e.dfd(**kwargs)) - - return self._dfd_template().format( - edges=indent("\n".join(filter(len, edges)), " ") - ) - - def _seq_template(self): - return """@startuml -{participants} - -{messages} -@enduml""" - - def seq(self): - participants = [] - for e in TM._elements: - if isinstance(e, Actor): - participants.append( - 'actor {0} as "{1}"'.format(e._uniq_name(), e.display_name()) - ) - elif isinstance(e, Datastore): - participants.append( - 'database {0} as "{1}"'.format(e._uniq_name(), e.display_name()) - ) - elif not isinstance(e, Dataflow) and not isinstance(e, Boundary): - participants.append( - 'entity {0} as "{1}"'.format(e._uniq_name(), e.display_name()) - ) - - messages = [] - for e in TM._flows: - message = "{0} -> {1}: {2}".format( - e.source._uniq_name(), e.sink._uniq_name(), e.display_name() - ) - note = "" - if e.note != "": - note = "\nnote left\n{}\nend note".format(e.note) - messages.append("{}{}".format(message, note)) + self._apply_forward_defaults(flow) + self._enrich_data_attributes(flow) + self.inputs[flow.sink].append(flow) + self.outputs[flow.source].append(flow) - return self._seq_template().format( - participants="\n".join(participants), messages="\n".join(messages) - ) + def finalize_assets(self) -> None: + """Populate inputs/outputs on elements once all flows are processed.""" + for asset, flow_list in self.inputs.items(): + self._set_sequence(asset, "inputs", flow_list) - def report(self, template_path): - try: - with open(template_path) as file: - template = file.read() - except (FileNotFoundError, PermissionError, IsADirectoryError) as e: - raise UIError( - e, f"while trying to open the report template file ({template_path})." - ) + for asset, flow_list in self.outputs.items(): + self._set_sequence(asset, "outputs", flow_list) - threats = encode_threat_data(TM._threats) - findings = encode_threat_data(self.findings) - - elements = encode_element_threat_data(TM._elements) - assets = encode_element_threat_data(TM._assets) - actors = encode_element_threat_data(TM._actors) - boundaries = encode_element_threat_data(TM._boundaries) - flows = encode_element_threat_data(TM._flows) - - data = { - "tm": self, - "dataflows": flows, - "threats": threats, - "findings": findings, - "elements": elements, - "assets": assets, - "actors": actors, - "boundaries": boundaries, - "data": TM._data, - } - - return self._sf.format(template, **data) - - def process(self): - try: - self._process() - except UIError as e: - erromsg = f"""Failed to excecute - {e.context} - {e.error} -""" - sys.stderr.write(erromsg) - sys.exit(127) - - def _process(self): - self.check() - result = get_args() - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - if result.debug: - logger.setLevel(logging.DEBUG) - - if result.exclude is not None: - TM._threatsExcluded = result.exclude.split(",") - - if result.seq is True: - print(self.seq()) - - if result.dfd is True: - if result.colormap is True: - self.resolve() - print(self.dfd(colormap=result.colormap, levels=(result.levels or set()))) - - if ( - result.report is not None - or result.json is not None - or result.sqldump is not None - or result.stale_days is not None - ): - self.resolve() - - if result.sqldump is not None: - self.sqlDump(result.sqldump) - - if result.json: + def finalize_data_relationships(self) -> None: + """Attach carrier and processor metadata to data objects.""" + for datum, flow_list in self.carriers.items(): + ordered = sorted(flow_list, key=lambda f: f.name) try: - with open(result.json, "w", encoding="utf8") as f: - json.dump(self, f, default=to_serializable) - except (FileExistsError, PermissionError, IsADirectoryError) as e: - raise UIError( - e, f"while trying to write to the result file ({result.json})" - ) - - if result.report is not None: - print(self.report(result.report)) - - if result.describe is not None: - _describe_classes(result.describe.split()) - - if result.list_elements: - _list_elements() - - if result.list is True: - [print("{} - {}".format(t.id, t.description)) for t in TM._threats] - - if result.stale_days is not None: - print(self._stale(result.stale_days)) - - def _stale(self, days): - try: - base_path = os.path.dirname(sys.argv[0]) - tm_mtime = datetime.fromtimestamp( - os.stat(base_path + f"/{sys.argv[0]}").st_mtime - ) - except os.error as err: - sys.stderr.write(f"{sys.argv[0]} - {err}\n") - sys.stderr.flush() - return "[ERROR]" - - print(f"Checking for code {days} days older than this model.") - - for e in TM._elements: - for src in e.sourceFiles: - try: - src_mtime = datetime.fromtimestamp( - os.stat(base_path + f"/{src}").st_mtime - ) - except os.error as err: - sys.stderr.write(f"{sys.argv[0]} - {err}\n") - sys.stderr.flush() - continue - - age = (src_mtime - tm_mtime).days - - # source code is older than model by more than the speficied delta - if (age) >= days: - print(f"This model is {age} days older than {base_path}/{src}.") - elif age <= -days: - print( - f"Model script {sys.argv[0]}" - + " is only " - + str(-1 * age) - + " days newer than source code file " - + f"{base_path}/{src}" - ) - - return "" - - def sqlDump(self, filename): - try: - from pydal import DAL, Field - except ImportError as e: - raise UIError( - e, """This feature requires the pyDAL package, - Please install the package via pip or your packagemanger of choice. - """ - ) - - @lru_cache(maxsize=None) - def get_table(db, klass): - name = klass.__name__ - fields = [ - Field("SID" if i == "id" else i) - for i in dir(klass) - if not i.startswith("_") and not callable(getattr(klass, i)) - ] - return db.define_table(name, fields) - + setattr(datum, "carriedBy", list(ordered)) + except self.assignment_errors: + for flow in ordered: + existing = getattr(datum, "carriedBy", []) + if flow not in existing: + existing.append(flow) + + for datum, elements in self.processors.items(): + ordered = sorted(elements, key=lambda el: el.name) + try: + setattr(datum, "processedBy", list(ordered)) + except self.assignment_errors: + for element in ordered: + existing = getattr(datum, "processedBy", []) + if element not in existing: + existing.append(element) + + def _inherit_source_data(self, flow: Dataflow) -> None: + source_data = getattr(flow.source, "data", None) + if source_data: + _add_data(getattr(flow, "data", None), source_data) + + def _index_flow_relationships(self, flow: Dataflow) -> None: + for datum in list(getattr(flow, "data", [])): + self.carriers[datum].add(flow) + self.processors[datum].add(flow.source) + self.processors[datum].add(flow.sink) + + @staticmethod + def _sync_levels(flow: Dataflow) -> None: try: - rmtree("./sqldump") - os.mkdir("./sqldump") - except OSError as e: - if e.errno != errno.ENOENT: - raise - else: - os.mkdir("./sqldump") - - db = DAL("sqlite://" + filename, folder="sqldump") - - for klass in ( - Server, - ExternalEntity, - Dataflow, - Datastore, - Actor, - Process, - SetOfProcesses, - Boundary, - TM, - Threat, - Lambda, - Data, - Finding, - ): - get_table(db, klass) - - for e in TM._threats + TM._data + TM._elements + self.findings + [self]: - table = get_table(db, e.__class__) - row = {} - for k, v in serialize(e).items(): - if k == "id": - k = "SID" - row[k] = ", ".join(str(i) for i in v) if isinstance(v, list) else v - db[table].bulk_insert([row]) - - db.close() - - - -class Controls: - """Controls implemented by/on and Element""" - - authenticatesDestination = varBool( - False, - doc="""Verifies the identity of the destination, -for example by verifying the authenticity of a digital certificate.""", - ) - authenticatesSource = varBool(False) - authenticationScheme = varString("") - authorizesSource = varBool(False) - checksDestinationRevocation = varBool( - False, - doc="""Correctly checks the revocation status -of credentials used to authenticate the destination""", - ) - checksInputBounds = varBool(False) - definesConnectionTimeout = varBool(False) - disablesDTD = varBool(False) - disablesiFrames = varBool(False) - encodesHeaders = varBool(False) - encodesOutput = varBool(False) - encryptsCookies = varBool(False) - encryptsSessionData = varBool(False) - handlesCrashes = varBool(False) - handlesInterruptions = varBool(False) - handlesResourceConsumption = varBool(False) - hasAccessControl = varBool(False) - implementsAuthenticationScheme = varBool(False) - implementsCSRFToken = varBool(False) - implementsNonce = varBool( - False, - doc="""Nonce is an arbitrary number -that can be used just once in a cryptographic communication. -It is often a random or pseudo-random number issued in an authentication protocol -to ensure that old communications cannot be reused in replay attacks. -They can also be useful as initialization vectors and in cryptographic -hash functions.""", - ) - implementsPOLP = varBool( - False, - doc="""The principle of least privilege (PoLP), -also known as the principle of minimal privilege or the principle of least authority, -requires that in a particular abstraction layer of a computing environment, -every module (such as a process, a user, or a program, depending on the subject) -must be able to access only the information and resources -that are necessary for its legitimate purpose.""", - ) - implementsServerSideValidation = varBool(False) - implementsStrictHTTPValidation = varBool(False) - invokesScriptFilters = varBool(False) - isEncrypted = varBool(False, doc="Requires incoming data flow to be encrypted") - isEncryptedAtRest = varBool(False, doc="Stored data is encrypted at rest") - isHardened = varBool(False) - isResilient = varBool(False) - providesConfidentiality = varBool(False) - providesIntegrity = varBool(False) - sanitizesInput = varBool(False) - tracksExecutionFlow = varBool(False) - usesCodeSigning = varBool(False) - usesEncryptionAlgorithm = varString("") - usesMFA = varBool( - False, - doc="""Multi-factor authentication is an authentication method -in which a computer user is granted access only after successfully presenting two -or more pieces of evidence (or factors) to an authentication mechanism: knowledge -(something the user and only the user knows), possession (something the user -and only the user has), and inherence (something the user and only the user is).""", - ) - usesParameterizedInput = varBool(False) - usesSecureFunctions = varBool(False) - usesStrongSessionIdentifiers = varBool(False) - usesVPN = varBool(False) - validatesContentType = varBool(False) - validatesHeaders = varBool(False) - validatesInput = varBool(False) - verifySessionIdentifiers = varBool(False) - - def _attr_values(self): - klass = self.__class__ - result = {} - for i in dir(klass): - if i.startswith("_") or callable(getattr(klass, i)): - continue - attr = getattr(klass, i, {}) - if isinstance(attr, var): - value = attr.data.get(self, attr.default) - else: - value = getattr(self, i) - result[i] = value - return result + level_intersection = flow.source.levels & flow.sink.levels + except TypeError: + level_intersection = set() + if level_intersection: + flow._safeset("levels", level_intersection) - def _safeset(self, attr, value): + def _merge_overrides(self, flow: Dataflow) -> None: try: - setattr(self, attr, value) - except ValueError: + sink_overrides = list(getattr(flow.sink, "overrides", [])) + source_overrides = list(getattr(flow.source, "overrides", [])) + combined = list(sink_overrides) + existing_ids = {getattr(finding, "threat_id", None) for finding in combined} + for finding in source_overrides: + sid = getattr(finding, "threat_id", None) + if sid not in existing_ids: + combined.append(finding) + existing_ids.add(sid) + flow.overrides = combined + except self.assignment_errors: pass - -class Assumption: - """ - Assumption used by an Element. - Used to exclude threats on a per-element basis. - """ - name = varString("", required=True) - exclude = varStrings([], doc="A list of threat SIDs to exclude for this assumption. For example: INP01") - description = varString("", doc="An additional description of the assumption") - - def __init__(self, name, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - self.name = name - - def __str__(self): - return self.name - - -class Element: - """A generic element""" - - name = varString("", required=True) - description = varString("") - inBoundary = varBoundary(None, doc="Trust boundary this element exists in") - inScope = varBool(True, doc="Is the element in scope of the threat model") - maxClassification = varClassification( - Classification.UNKNOWN, - required=False, - doc="Maximum data classification this element can handle.", - ) - minTLSVersion = varTLSVersion( - TLSVersion.NONE, - required=False, - doc="""Minimum TLS version required.""", - ) - findings = varFindings([], doc="Threats that apply to this element") - overrides = varFindings( - [], - doc="""Overrides to findings, allowing to set -a custom response, CVSS score or override other attributes.""", - ) - assumptions = varAssumptions( - [], - doc="Assumptions about the element. These optionally allow to exclude threats with the given SIDs.", - ) - levels = varInts({0}, doc="List of levels (0, 1, 2, ...) to be drawn in the model.") - sourceFiles = varStrings( - [], - required=False, - doc="Location of the source code that describes this element relative to the directory of the model script.", - ) - controls = varControls(None) - severity = 0 - - def __init__(self, name, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - self.name = name - self.controls = Controls() - self.uuid = uuid.UUID(int=random.getrandbits(128)) - self._is_drawn = False - TM._elements.append(self) - - def __repr__(self): - return "<{0}.{1}({2}) at {3}>".format( - self.__module__, type(self).__name__, self.name, hex(id(self)) + def _apply_response_defaults(self, flow: Dataflow) -> None: + flow._safeset( + "protocol", getattr(flow.source, "protocol", getattr(flow, "protocol", "")) ) - - def __str__(self): - return "{0}({1})".format(type(self).__name__, self.name) - - def _uniq_name(self): - """transform name and uuid into a unique string""" - h = sha224(str(self.uuid).encode("utf-8")).hexdigest() - name = "".join(x for x in self.name if x.isalpha()) - return "{0}_{1}_{2}".format(type(self).__name__.lower(), name, h[:10]) - - def check(self): - return True - - def _dfd_template(self): - return """{uniq_name} [ - shape = {shape}; - color = {color}; - fontcolor = black; - label = "{label}"; - margin = 0.02; -] -""" - - def dfd(self, **kwargs): - self._is_drawn = True - - levels = kwargs.get("levels", None) - if levels and not levels & self.levels: - return "" - - color = self._color() - if kwargs.get("colormap", False): - color = sev_to_color(self.severity) - - return self._dfd_template().format( - uniq_name=self._uniq_name(), - label=self._label(), - color=color, - shape=self._shape(), + flow._safeset( + "srcPort", getattr(flow.source, "port", getattr(flow, "srcPort", -1)) ) + if hasattr(flow.source, "controls"): + flow.controls._safeset( + "isEncrypted", getattr(flow.source.controls, "isEncrypted", False) + ) - def _color(self): - if self.inScope is True: - return "black" - else: - return "grey69" - - def display_name(self): - return self.name - - def _label(self): - return "\\n".join(wrap(self.display_name(), 18)) - - def _shape(self): - return "square" - - def _safeset(self, attr, value): - try: - setattr(self, attr, value) - except ValueError: - pass + def _apply_forward_defaults(self, flow: Dataflow) -> None: + flow._safeset( + "protocol", getattr(flow.sink, "protocol", getattr(flow, "protocol", "")) + ) + flow._safeset( + "dstPort", getattr(flow.sink, "port", getattr(flow, "dstPort", -1)) + ) + if hasattr(flow.sink, "controls"): + flow.controls._safeset( + "isEncrypted", getattr(flow.sink.controls, "isEncrypted", False) + ) + if hasattr(flow.source, "controls"): + flow.controls._safeset( + "authenticatesDestination", + getattr(flow.source.controls, "authenticatesDestination", False), + ) + flow.controls._safeset( + "checksDestinationRevocation", + getattr(flow.source.controls, "checksDestinationRevocation", False), + ) - def oneOf(self, *elements): - """Is self one of a list of Elements""" - for element in elements: - if inspect.isclass(element): - if isinstance(self, element): - return True - elif self is element: - return True - return False - - def crosses(self, *boundaries): - """Does self (dataflow) cross any of the list of boundaries""" - if self.source.inBoundary is self.sink.inBoundary: - return False - for boundary in boundaries: - if inspect.isclass(boundary): - if ( - ( - isinstance(self.source.inBoundary, boundary) - and not isinstance(self.sink.inBoundary, boundary) - ) - or ( - not isinstance(self.source.inBoundary, boundary) - and isinstance(self.sink.inBoundary, boundary) + def _enrich_data_attributes(self, flow: Dataflow) -> None: + for datum in list(getattr(flow, "data", [])): + if getattr(datum, "isStored", False): + if hasattr(flow.sink, "controls") and hasattr( + flow.sink.controls, "isEncryptedAtRest" + ): + datum._safeset( + "isDestEncryptedAtRest", flow.sink.controls.isEncryptedAtRest ) - or self.source.inBoundary is not self.sink.inBoundary + if hasattr(flow.source, "controls") and hasattr( + flow.source.controls, "isEncryptedAtRest" ): - return True - elif (self.source.inside(boundary) and not self.sink.inside(boundary)) or ( - not self.source.inside(boundary) and self.sink.inside(boundary) - ): - return True - return False - - def enters(self, *boundaries): - """does self (dataflow) enter into one of the list of boundaries""" - return self.source.inBoundary is None and self.sink.inside(*boundaries) - - def exits(self, *boundaries): - """does self (dataflow) exit one of the list of boundaries""" - return self.source.inside(*boundaries) and self.sink.inBoundary is None - - def inside(self, *boundaries): - """is self inside of one of the list of boundaries""" - for boundary in boundaries: - if inspect.isclass(boundary): - if isinstance(self.inBoundary, boundary): - return True - elif self.inBoundary is boundary: - return True - return False - - def _attr_values(self): - klass = self.__class__ - result = {} - for i in dir(klass): - if i.startswith("_") or callable(getattr(klass, i)): - continue - attr = getattr(klass, i, {}) - if isinstance(attr, var): - value = attr.data.get(self, attr.default) - else: - value = getattr(self, i) - result[i] = value - return result - - def checkTLSVersion(self, flows): - return any(f.tlsVersion < self.minTLSVersion for f in flows) - - def _set_severity(self, sev): - sevs = { - "very high": 5, - "high": 4, - "medium": 3, - "low": 2, - "very low": 1, - "info": 0, - } - - if sev.lower() not in sevs.keys(): - return - - if self.severity < sevs[sev.lower()]: - self.severity = sevs[sev.lower()] - return - - -class Data: - """Represents a single piece of data that traverses the system""" - - name = varString("", required=True) - description = varString("") - format = varString("") - classification = varClassification( - Classification.UNKNOWN, - required=True, - doc="Level of classification for this piece of data", - ) - isPII = varBool( - False, - doc="""Does the data contain personally identifyable information. -Should always be encrypted both in transmission and at rest.""", - ) - isCredentials = varBool( - False, - doc="""Does the data contain authentication information, -like passwords or cryptographic keys, with or without expiration date. -Should always be encrypted in transmission. If stored, they should be hashed -using a cryptographic hash function.""", - ) - credentialsLife = varLifetime( - Lifetime.NONE, - doc="""Credentials lifetime, describing if and how -credentials can be revoked. One of: -* NONE - not applicable -* UNKNOWN - unknown lifetime -* SHORT - relatively short expiration date, with an allowed maximum -* LONG - long or no expiration date -* AUTO - no expiration date but can be revoked/invalidated automatically - in some conditions -* MANUAL - no expiration date but can be revoked/invalidated manually -* HARDCODED - cannot be invalidated at all""", - ) - isStored = varBool( - False, - doc="""Is the data going to be stored by the target or only processed. -If only derivative data is stored (a hash) it can be set to False.""", - ) - isDestEncryptedAtRest = varBool(False, doc="Is data encrypted at rest at dest") - isSourceEncryptedAtRest = varBool(False, doc="Is data encrypted at rest at source") - carriedBy = varElements([], doc="Dataflows that carries this piece of data") - processedBy = varElements([], doc="Elements that store/process this piece of data") - - def __init__(self, name, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - self.name = name - TM._data.append(self) - - def __repr__(self): - return "<{0}.{1}({2}) at {3}>".format( - self.__module__, type(self).__name__, self.name, hex(id(self)) - ) + datum._safeset( + "isSourceEncryptedAtRest", + flow.source.controls.isEncryptedAtRest, + ) - def __str__(self): - return "{0}({1})".format(type(self).__name__, self.name) + if getattr( + datum, "credentialsLife", Lifetime.NONE + ) != Lifetime.NONE and not getattr(datum, "isCredentials", False): + datum._safeset("isCredentials", True) + if ( + getattr(datum, "isCredentials", False) + and getattr(datum, "credentialsLife", Lifetime.NONE) == Lifetime.NONE + ): + datum._safeset("credentialsLife", Lifetime.UNKNOWN) - def _safeset(self, attr, value): + def _set_sequence( + self, obj: Element, attr: str, values: Iterable[Dataflow] + ) -> None: + if not hasattr(obj, attr): + return + ordered = list(values) try: - setattr(self, attr, value) - except ValueError: - pass - - -class Asset(Element): - """An asset with outgoing or incoming dataflows""" - - port = varInt(-1, doc="Default TCP port for incoming data flows") - protocol = varString("", doc="Default network protocol for incoming data flows") - data = varData([], doc="pytm.Data object(s) in incoming data flows") - inputs = varElements([], doc="incoming Dataflows") - outputs = varElements([], doc="outgoing Dataflows") - onAWS = varBool(False) - handlesResources = varBool(False) - usesEnvironmentVariables = varBool(False) - OS = varString("") - - def __init__(self, name, **kwargs): - super().__init__(name, **kwargs) - TM._assets.append(self) - - -class Lambda(Asset): - """A lambda function running in a Function-as-a-Service (FaaS) environment""" - - onAWS = varBool(True) - environment = varString("") - implementsAPI = varBool(False) - - def __init__(self, name, **kwargs): - super().__init__(name, **kwargs) - - def _dfd_template(self): - return """{uniq_name} [ - shape = {shape}; - - color = {color}; - fontcolor = "black"; - label = < - - -
{label}
- >; -] -""" - - def dfd(self, **kwargs): - self._is_drawn = True - - levels = kwargs.get("levels", None) - if levels and not levels & self.levels: - return "" - - color = self._color() - - if kwargs.get("colormap", False): - color = sev_to_color(self.severity) - - return self._dfd_template().format( - uniq_name=self._uniq_name(), - label=self._label(), - color=color, - shape=self._shape(), - ) - - def _shape(self): - return "rectangle; style=rounded" - - -class LLM(Asset): - """A Large Language Model element, either third-party or self-hosted""" - - isThirdParty = varBool(True) - isSelfHosted = varBool(False) - processesPersonalData = varBool(False) - retainsUserData = varBool(False) - hasAgentCapabilities = varBool(False) - hasAccessToSensitiveSystems = varBool(False) - executesCode = varBool(False) - hasContentFiltering = varBool(False) - hasSystemPrompt = varBool(True) - processesUntrustedInput = varBool(True) - hasRAG = varBool(False) - hasFineTuning = varBool(False) - - def __init__(self, name, **kwargs): - super().__init__(name, **kwargs) - - def _shape(self): - return "hexagon" - - -class Server(Asset): - """An entity processing data""" - - usesSessionTokens = varBool(False) - usesCache = varBool(False) - usesVPN = varBool(False) - usesXMLParser = varBool(False) - - def __init__(self, name, **kwargs): - super().__init__(name, **kwargs) - - def _shape(self): - return "circle" - - -class ExternalEntity(Asset): - hasPhysicalAccess = varBool(False) - - def __init__(self, name, **kwargs): - super().__init__(name, **kwargs) - - -class Datastore(Asset): - """An entity storing data""" - - onRDS = varBool(False) - storesLogData = varBool(False) - storesPII = varBool( - False, - doc="""Personally Identifiable Information -is any information relating to an identifiable person.""", - ) - storesSensitiveData = varBool(False) - isSQL = varBool(True) - isShared = varBool(False) - hasWriteAccess = varBool(False) - type = varDatastoreType( - DatastoreType.UNKNOWN, - doc="""The type of Datastore, values may be one of: -* UNKNOWN - unknown applicable -* FILE_SYSTEM - files on a file system -* SQL - A SQL Database -* LDAP - An LDAP Server -* AWS_S3 - An S3 Bucket within AWS""", - ) - - def __init__(self, name, **kwargs): - super().__init__(name, **kwargs) - - def _dfd_template(self): - return """{uniq_name} [ - shape = {shape}; - fixedsize = shape; - image = "{image}"; - imagescale = true; - color = {color}; - fontcolor = black; - xlabel = "{label}"; - label = ""; -] -""" - - def _shape(self): - return "none" - - def dfd(self, **kwargs): - self._is_drawn = True - - levels = kwargs.get("levels", None) - if levels and not levels & self.levels: - return "" - - color = self._color() - color_file = "black" - - if kwargs.get("colormap", False): - color = sev_to_color(self.severity) - color_file = color.split(";")[0] - - return self._dfd_template().format( - uniq_name=self._uniq_name(), - label=self._label(), - color=color, - shape=self._shape(), - image=os.path.join( - os.path.dirname(__file__), "images", f"datastore_{color_file}.png" - ), - ) - - -class Actor(Element): - """An entity usually initiating actions. - - Actors represent users or external systems that initiate - interactions with the system being modeled. - - Attributes: - port (int): Default TCP port for outgoing data flows. - protocol (str): Default network protocol for outgoing data flows. - data (list): pytm.Data objects carried in outgoing data flows. - inputs (list): Incoming Dataflows. - outputs (list): Outgoing Dataflows. - isAdmin (bool): Indicates whether the actor has administrative privileges. - """ - - port = varInt(-1, doc="Default TCP port for outgoing data flows") - protocol = varString("", doc="Default network protocol for outgoing data flows") - data = varData([], doc="pytm.Data object(s) in outgoing data flows") - inputs = varElements([], doc="incoming Dataflows") - outputs = varElements([], doc="outgoing Dataflows") - isAdmin = varBool(False) - - def __init__(self, name, **kwargs): - """ - Initialize an Actor. - - Args: - name (str): Name of the actor. - **kwargs: Optional actor properties. - port (int): Default TCP port for outgoing data flows. - protocol (str): Default network protocol for outgoing data flows. - data (list): pytm.Data objects in outgoing data flows. - inputs (list): Incoming Dataflows. - outputs (list): Outgoing Dataflows. - isAdmin (bool): Indicates administrative privileges. - """ - super().__init__(name, **kwargs) - TM._actors.append(self) - - - -class Process(Asset): - """An entity processing data""" - - codeType = varString("Unmanaged") - implementsCommunicationProtocol = varBool(False) - tracksExecutionFlow = varBool(False) - implementsAPI = varBool(False) - environment = varString("") - allowsClientSideScripting = varBool(False) - - def __init__(self, name, **kwargs): - super().__init__(name, **kwargs) - - def _shape(self): - return "circle" - - -class SetOfProcesses(Process): - def __init__(self, name, **kwargs): - super().__init__(name, **kwargs) - - def _shape(self): - return "doublecircle" - - -class Dataflow(Element): - """A data flow from a source to a sink""" - - source = varElement(None, required=True) - sink = varElement(None, required=True) - isResponse = varBool(False, doc="Is a response to another data flow") - response = varElement(None, doc="Another data flow that is a response to this one") - responseTo = varElement(None, doc="Is a response to this data flow") - srcPort = varInt(-1, doc="Source TCP port") - dstPort = varInt(-1, doc="Destination TCP port") - tlsVersion = varTLSVersion( - TLSVersion.NONE, - required=True, - doc="TLS version used.", - ) - protocol = varString("", doc="Protocol used in this data flow") - data = varData([], doc="pytm.Data object(s) in incoming data flows") - order = varInt(-1, doc="Number of this data flow in the threat model") - implementsCommunicationProtocol = varBool(False) - note = varString("") - usesVPN = varBool(False) - usesSessionTokens = varBool(False) - severity = 0 - - def __init__(self, source, sink, name, **kwargs): - self.source = source - self.sink = sink - super().__init__(name, **kwargs) - TM._flows.append(self) - - def display_name(self): - if self.order == -1: - return self.name - return "({}) {}".format(self.order, self.name) - - def _dfd_template(self): - return """{source} -> {sink} [ - color = {color}; - fontcolor = {color}; - dir = {direction}; - label = "{label}"; -] -""" - - def dfd(self, mergeResponses=False, **kwargs): - self._is_drawn = True - - levels = kwargs.get("levels", None) - if ( - levels - and not levels & self.levels - and not (levels & self.source.levels and levels & self.sink.levels) - ): - return "" - - color = self._color() - - if kwargs.get("colormap", False): - color = sev_to_color(self.severity) - - direction = "forward" - label = self._label() - if mergeResponses and self.response is not None: - direction = "both" - label += "\n" + self.response._label() - - return self._dfd_template().format( - source=self.source._uniq_name(), - sink=self.sink._uniq_name(), - direction=direction, - label=label, - color=color, - ) - - def hasDataLeaks(self): - return any( - d.classification > self.source.maxClassification - or d.classification > self.sink.maxClassification - or d.classification > self.maxClassification - for d in self.data - ) + setattr(obj, attr, ordered) + except self.assignment_errors: + existing = getattr(obj, attr) + try: + existing[:] = ordered + except TypeError: + if hasattr(existing, "clear") and hasattr(existing, "extend"): + existing.clear() + existing.extend(ordered) + else: + for item in ordered: + if item not in existing: + existing.append(item) -class Boundary(Element): - """Trust boundary groups elements and data with the same trust level.""" +def _apply_defaults(flows, data): + """Apply default values to flows and data.""" + builder = _FlowDefaultsBuilder() + builder.seed_data_relationships(data) - def __init__(self, name, **kwargs): - super().__init__(name, **kwargs) - if name not in TM._boundaries: - TM._boundaries.append(self) + for flow in flows: + builder.process_flow(flow) - def _dfd_template(self): - return """subgraph cluster_{uniq_name} {{ - graph [ - fontsize = 10; - fontcolor = black; - style = dashed; - color = {color}; - label = <{label}>; - ] + builder.finalize_assets() + builder.finalize_data_relationships() -{edges} -}} -""" - def dfd(self, **kwargs): - if self._is_drawn: - return "" +def _get_elements_and_boundaries(flows): + """Get elements and boundaries used in flows.""" + elements = set() + boundaries = set() - self._is_drawn = True + for flow in flows: + elements.add(flow) + elements.add(flow.source) + elements.add(flow.sink) - edges = [] - for e in TM._elements: - if e.inBoundary != self or e._is_drawn: - continue - # The content to draw can include Boundary objects - edges.append(e.dfd(**kwargs)) - - return self._dfd_template().format( - uniq_name=self._uniq_name(), - label=self._label(), - color=self._color(**kwargs), - edges=indent("\n".join(edges), " "), - ) + source_boundary = getattr(flow.source, "inBoundary", None) + if source_boundary is not None: + boundaries.add(source_boundary) + for parent in getattr(source_boundary, "parents", lambda: [])(): + elements.add(parent) + boundaries.add(parent) - def _color(self, **kwargs): - if kwargs.get("colormap", False): - return "black" - else: - return "firebrick2" + sink_boundary = getattr(flow.sink, "inBoundary", None) + if sink_boundary is not None: + boundaries.add(sink_boundary) + for parent in getattr(sink_boundary, "parents", lambda: [])(): + elements.add(parent) + boundaries.add(parent) - def parents(self): - result = [] - parent = self.inBoundary - while parent is not None: - result.append(parent) - parent = parent.inBoundary - return result + return list(elements), list(boundaries) @singledispatch @@ -2055,6 +527,7 @@ def to_serializable(val): @to_serializable.register(TM) def ts_tm(obj): + """Serialize TM object.""" return serialize(obj, nested=True) @@ -2064,73 +537,154 @@ def ts_tm(obj): @to_serializable.register(Element) @to_serializable.register(Finding) def ts_element(obj): + """Serialize element objects.""" return serialize(obj, nested=False) def serialize(obj, nested=False): """Used if *obj* is an instance of TM, Element, Threat or Finding.""" - klass = obj.__class__ + result = {} + klass = obj.__class__ + if isinstance(obj, (Actor, Asset)): result["__class__"] = klass.__name__ - for i in dir(obj): - if ( - i.startswith("__") - or callable(getattr(klass, i, {})) - or ( - isinstance(obj, TM) - and i in ("_sf", "_duplicate_ignored_attrs", "_threats") - ) - or (isinstance(obj, Element) and i in ("_is_drawn", "uuid")) - or (isinstance(obj, Finding) and i == "element") - ): + + attribute_names = set() + + if hasattr(obj, "__dict__"): + attribute_names.update( + name for name in obj.__dict__.keys() if not name.startswith("__") + ) + + model_fields = getattr(klass, "model_fields", {}) + attribute_names.update(model_fields.keys()) + + computed_fields = getattr(klass, "model_computed_fields", {}) + attribute_names.update(computed_fields.keys()) + + if isinstance(obj, TM): + attribute_names.update( + { + "_actors", + "_assets", + "_elements", + "_flows", + "_data", + "_boundaries", + "assumptions", + "findings", + "excluded_findings", + "_threatsExcluded", + } + ) + + skip_attrs = { + "_sf", + "_duplicate_ignored_attrs", + "_threats", + "model_fields", + "model_computed_fields", + "model_config", + "model_post_init", + "model_extra", + "model_json_schema", + "schema", + "copy", + "dict", + "json", + "parse_file", + "parse_obj", + "parse_raw", + "construct", + "model_copy", + "validate", + "abc_impl", + "register", + } + + for attr_name in sorted(attribute_names): + if attr_name in skip_attrs: + continue + + if isinstance(obj, Element) and attr_name in {"uuid", "_is_drawn", "is_drawn"}: continue - value = getattr(obj, i) - if isinstance(obj, TM) and i == "_elements": + if isinstance(obj, Finding) and attr_name == "element": + continue + + try: + value = getattr(obj, attr_name) + except AttributeError: + continue + + key = attr_name.lstrip("_") + + if isinstance(obj, TM) and attr_name == "_elements": value = [e for e in value if isinstance(e, (Actor, Asset))] - if value is not None: - if isinstance(value, (Element, Data)): - value = value.name - elif isinstance(obj, Threat) and i == "target": - value = [v.__name__ for v in value] - elif i in ("levels", "sourceFiles", "assumptions"): - value = list(value) - elif ( - not nested - and not isinstance(value, str) - and isinstance(value, Iterable) - ): - value = [v.id if isinstance(v, Finding) else v.name for v in value] - result[i.lstrip("_")] = value + + if value is None: + result[key] = None + continue + + if isinstance(value, (Element, Data)): + value = value.name + elif hasattr(value, "model_dump") and not isinstance(value, TM): + value = value.model_dump() + elif isinstance(obj, Threat) and attr_name == "target": + coerced_targets = [] + for target in value: + if hasattr(target, "__name__"): + coerced_targets.append(target.__name__) + else: + coerced_targets.append(str(target)) + value = coerced_targets + elif attr_name in {"levels", "sourceFiles", "assumptions"}: + value = list(value) + elif ( + not nested + and not isinstance(value, (str, bytes)) + and isinstance(value, Iterable) + and not isinstance(value, Mapping) + ): + coerced = [] + for item in value: + if isinstance(item, Finding): + coerced.append(item.id) + elif isinstance(item, (Element, Data)): + coerced.append(item.name) + else: + coerced.append(item) + value = coerced + + result[key] = value + return result def encode_element_threat_data(obj): - """Used to html encode threat data from a list of Elements""" - encoded_elements = [] - if type(obj) is not list: - raise ValueError("expecting a list value, got a {}".format(type(obj))) - - for o in obj: - c = copy.deepcopy(o) - for a in o._attr_values(): - if a == "findings": - encoded_findings = encode_threat_data(o.findings) - c._safeset("findings", encoded_findings) + """Encode element threat data.""" + result = [] + if hasattr(obj, "__iter__"): + for item in obj: + if hasattr(item, "model_dump"): + result.append(item.model_dump()) else: - v = getattr(o, a) - if type(v) is not list or (type(v) is list and len(v) != 0): - c._safeset(a, v) - - encoded_elements.append(c) - - return encoded_elements + result.append(serialize(item)) + return result def encode_threat_data(obj): - """Used to html encode threat data from a list of threats or findings""" + """HTML-encode threat data while preserving attribute access.""" encoded_threat_data = [] + if obj is None: + return encoded_threat_data + + if isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)): + candidates = list(obj) + else: + candidates = [obj] + attrs = [ "description", "details", @@ -2143,73 +697,109 @@ def encode_threat_data(obj): "condition", "cvss", "response", - "likelihood", ] - if type(obj) is Finding or (len(obj) != 0 and type(obj[0]) is Finding): + from .finding import Finding + + if candidates and isinstance(candidates[0], Finding): attrs.append("target") - for e in obj: - t = copy.deepcopy(e) + def _escape_markdown(text: str) -> str: + return re.sub(r"(? str: + """Return the parent boundary name for *element* or an empty string.""" + from pytm import Boundary # Local import to avoid circular dependency + + if not isinstance(element, Boundary): + return ( + f"ERROR: getParentName method is not valid for {type(element).__name__}" + ) + parent = element.inBoundary + return parent.name if parent is not None else "" @staticmethod - def getNamesOfParents(element): + def getNamesOfParents(element: Any) -> List[str] | str: + """Return a list of parent boundary names for *element*.""" from pytm import Boundary - if (isinstance(element, Boundary)): - parents = [p.name for p in element.parents()] - return parents - else: - return "ERROR: getNamesOfParents method is not valid for " + element.__class__.__name__ - + + if not isinstance(element, Boundary): + return f"ERROR: getNamesOfParents method is not valid for {type(element).__name__}" + + return [parent.name for parent in element.parents()] @staticmethod - def getInScopeFindings(element): - """ - Return only findings that: - 1. Belong to an in-scope element - 2. Target an in-scope element - """ + def getInScopeFindings(element: Any) -> list: + """Return only findings that belong to an in-scope element and target an in-scope element.""" from pytm import Element if not isinstance(element, Element): @@ -39,7 +43,6 @@ def getInScopeFindings(element): return [] in_scope_findings = [] - for finding in element.findings: target = getattr(finding, "target", None) if target is not None and getattr(target, "inScope", False): @@ -47,69 +50,78 @@ def getInScopeFindings(element): return in_scope_findings - @staticmethod - def getFindingCount(element): + def getFindingCount(element: Any) -> str: + """Return the count of findings for *element* as a string.""" from pytm import Element + if not isinstance(element, Element): - return "ERROR: getFindingCount method is not valid for " + element.__class__.__name__ - return str(len(ReportUtils.getInScopeFindings(element))) + return f"ERROR: getFindingCount method is not valid for {type(element).__name__}" + return str(len(list(element.findings))) @staticmethod - def getElementType(element): + def getElementType(element: Any) -> str: + """Return the class name for *element*.""" from pytm import Element - if (isinstance(element, Element)): - return str(element.__class__.__name__) - else: - return "ERROR: getElementType method is not valid for " + element.__class__.__name__ + if not isinstance(element, Element): + return f"ERROR: getElementType method is not valid for {type(element).__name__}" + + return element.__class__.__name__ @staticmethod - def getThreatId(obj): + def getThreatId(obj: Any) -> str: from pytm import Finding + if isinstance(obj, Finding): return obj.threat_id return "" @staticmethod - def getFindingDescription(obj): + def getFindingDescription(obj: Any) -> str: from pytm import Finding + if isinstance(obj, Finding): return obj.description return "" @staticmethod - def getFindingTarget(obj): + def getFindingTarget(obj: Any) -> Any: from pytm import Finding + if isinstance(obj, Finding): return obj.target return "" @staticmethod - def getFindingSeverity(obj): + def getFindingSeverity(obj: Any) -> str: from pytm import Finding + if isinstance(obj, Finding): return obj.severity return "" @staticmethod - def getFindingMitigations(obj): + def getFindingMitigations(obj: Any) -> str: from pytm import Finding + if isinstance(obj, Finding): return obj.mitigations return "" @staticmethod - def getFindingReferences(obj): + def getFindingReferences(obj: Any) -> str: from pytm import Finding + if isinstance(obj, Finding): return obj.references return "" - + @staticmethod - def getFindingExample(obj): + def getFindingExample(obj: Any) -> str: from pytm import Finding + if isinstance(obj, Finding): return obj.example return "" diff --git a/pytm/template_engine.py b/pytm/template_engine.py index 9118a5d4..172d1521 100644 --- a/pytm/template_engine.py +++ b/pytm/template_engine.py @@ -2,84 +2,76 @@ # but modified to include support to call methods which return lists, to call external utility methods, use # if operator with methods and added a not operator. +from __future__ import annotations + import string +from collections.abc import Iterable +from functools import lru_cache +from typing import Any, Callable class SuperFormatter(string.Formatter): - """World's simplest Template engine.""" + """Lightweight formatter with helpers for reports and templates.""" - def format_field(self, value, spec): + def format_field( + self, value: Any, spec: str + ) -> Any: # noqa: D401 - same semantics as base + if not spec: + return super().format_field(value, spec) - spec_parts = spec.split(":") if spec.startswith("repeat"): - # Example usage, format, count of spec_parts, exampple format - # object:repeat:template 2 {item.findings:repeat:{{item.id}}, } - - template = spec.partition(":")[-1] - if type(value) is dict: - value = value.items() - return "".join([self.format(template, item=item) for item in value]) - - elif spec.startswith("call:") and hasattr(value, "__call__"): - # Example usage, format, exampple format - # methood:call {item.display_name:call:} - # methood:call:template {item.parents:call:{{item.name}}, } - result = value() - - if type(result) is list: - template = spec.partition(":")[-1] - return "".join([self.format(template, item=item) for item in result]) + return self._format_repeat(value, spec) - return result + if spec.startswith("call:"): + return self._format_call(value, spec) - elif spec.startswith("call:"): - # Example usage, format, exampple format - # object:call:method_name {item:call:getFindingCount} - # object:call:method_name:template {item:call:getNamesOfParents: - # {{item}} - # } + if spec.startswith("if") or spec.startswith("not"): + return self._format_conditional(value, spec) - method_name = spec_parts[1] + return super().format_field(value, spec) - result = self.call_util_method(method_name, value) + def _format_repeat(self, value: Any, spec: str) -> str: + """Handle the custom repeat operator.""" + template = spec.partition(":")[2] + if isinstance(value, dict): + iterable: Iterable[Any] = value.items() + elif isinstance(value, Iterable) and not isinstance(value, (str, bytes)): + iterable = value + else: + iterable = [] + return "".join(self.format(template, item=item) for item in iterable) - if type(result) is list: - template = spec.partition(":")[-1] - template = template.partition(":")[-1] - return "".join([self.format(template, item=item) for item in result]) - - return result - - elif (spec.startswith("if") or spec.startswith("not")): - # Example usage, format, exampple format - # object.bool:if:template {item.inScope:if:Is in scope} - # object:if:template {item.findings:if:Has Findings} - # object.method:if:template {item.parents:if:Has Parents} - # - # object.bool:not:template {item.inScope:not:Is not in scope} - # object:not:template {item.findings:not:Has No Findings} - # object.method:not:template {item.parents:not:Has No Parents} - - template = spec.partition(":")[-1] - if (hasattr(value, "__call__")): - result = value() - else: - result = value - - if (spec.startswith("if")): - return (result and template or "") - else: - return (not result and template or "") + def _format_call(self, value: Any, spec: str) -> Any: + """Evaluate callable values or report utility helpers.""" + _, _, remainder = spec.partition(":") + if callable(value): + result = value() + template = remainder else: - return super(SuperFormatter, self).format_field(value, spec) - - def call_util_method(self, method_name, object): - module_name = "pytm.report_util" - klass_name = "ReportUtils" - module = __import__(module_name, fromlist=['ReportUtils']) - klass = getattr(module, klass_name) - method = getattr(klass, method_name) + method_name, _, template = remainder.partition(":") + result = self.call_util_method(method_name, value) - result = method(object) + if isinstance(result, list) and template: + return "".join(self.format(template, item=item) for item in result) return result + + def _format_conditional(self, value: Any, spec: str) -> str: + """Render content conditionally based on truthiness of *value*.""" + _, _, template = spec.partition(":") + result = value() if callable(value) else value + if spec.startswith("if"): + return template if result else "" + return template if not result else "" + + def call_util_method(self, method_name: str, obj: Any) -> Any: + """Invoke a helper method from :mod:`pytm.report_util`.""" + method = self._resolve_report_method(method_name) + return method(obj) + + @staticmethod + @lru_cache(maxsize=None) + def _resolve_report_method(method_name: str) -> Callable[[Any], Any]: + from pytm.report_util import ReportUtils + + return getattr(ReportUtils, method_name) diff --git a/pytm/threat.py b/pytm/threat.py new file mode 100644 index 00000000..e972b8ea --- /dev/null +++ b/pytm/threat.py @@ -0,0 +1,331 @@ +"""Threat model - represents possible threats in the system.""" + +from __future__ import annotations + +import ast +import sys +from types import CodeType +from typing import Any, ClassVar, Tuple, List +from collections.abc import Iterable + +import builtins + +from pydantic import ( + BaseModel, + Field, + ConfigDict, + model_validator, + PrivateAttr, +) + + +class _ConditionValidator(ast.NodeVisitor): + """Validate threat conditions to ensure they only use safe constructs.""" + + SAFE_CALL_NAMES: ClassVar[set[str]] = {"any", "all", "len", "min", "max", "sum"} + ALLOWED_TARGET_METHODS: ClassVar[set[str]] = { + "oneOf", + "crosses", + "enters", + "exits", + "inside", + "checkTLSVersion", + "hasDataLeaks", + } + _ALLOWED_NODES: ClassVar[tuple[type[ast.AST], ...]] = ( + ast.Expression, + ast.BoolOp, + ast.BinOp, + ast.UnaryOp, + ast.Compare, + ast.Name, + ast.Load, + ast.Constant, + ast.Attribute, + ast.Call, + ast.Subscript, + ast.List, + ast.Tuple, + ast.Set, + ast.Dict, + ast.ListComp, + ast.GeneratorExp, + ast.comprehension, + ast.IfExp, + ast.And, + ast.Or, + ast.Not, + ast.Eq, + ast.NotEq, + ast.Lt, + ast.LtE, + ast.Gt, + ast.GtE, + ast.Is, + ast.IsNot, + ast.In, + ast.NotIn, + ast.Add, + ast.Sub, + ast.Mult, + ast.Div, + ast.Mod, + ast.Pow, + ast.USub, + ast.UAdd, + ast.BitAnd, + ast.BitOr, + ast.BitXor, + ast.FloorDiv, + ast.Slice, + ) + + def __init__(self, allowed_names: set[str]) -> None: + super().__init__() + self.allowed_names = allowed_names | {"target", "True", "False", "None"} + + def visit(self, node: ast.AST) -> Any: # type: ignore[override] + if not isinstance(node, self._ALLOWED_NODES): + raise ValueError( + f"Unsupported syntax in threat condition: {type(node).__name__}" + ) + return super().visit(node) + + def visit_Attribute(self, node: ast.Attribute) -> Any: # noqa: D401 + if isinstance(node.attr, str) and node.attr.startswith("__"): + raise ValueError( + "Access to dunder attributes is not permitted in threat conditions" + ) + return self.generic_visit(node) + + def visit_Call(self, node: ast.Call) -> Any: # noqa: D401 + if node.keywords: + raise ValueError("Keyword arguments are not permitted in threat conditions") + + func = node.func + if isinstance(func, ast.Name): + if func.id not in self.SAFE_CALL_NAMES: + raise ValueError( + f"Call to '{func.id}' is not permitted in threat conditions" + ) + elif isinstance(func, ast.Attribute): + chain = self._attribute_chain(func) + if chain[-1] not in self.ALLOWED_TARGET_METHODS: + raise ValueError( + f"Call to target method '{chain[-1]}' is not permitted" + ) + else: + raise ValueError("Unsupported call target in threat condition") + + return self.generic_visit(node) + + def visit_Name(self, node: ast.Name) -> Any: # noqa: D401 + if ( + isinstance(node.ctx, ast.Load) + and node.id not in self.allowed_names + and node.id not in self.SAFE_CALL_NAMES + ): + # Allow names introduced by comprehensions; they will fail at runtime if undefined. + return + return None + + @staticmethod + def _attribute_chain(node: ast.Attribute) -> List[str]: + chain: List[str] = [node.attr] + current = node.value + while isinstance(current, ast.Attribute): + if isinstance(current.attr, str) and current.attr.startswith("__"): + raise ValueError( + "Access to dunder attributes is not permitted in threat conditions" + ) + chain.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + chain.append(current.id) + else: + raise ValueError( + "Only attribute access on names is permitted in threat conditions" + ) + chain.reverse() + return chain + + +class Threat(BaseModel): + """Represents a possible threat. + + Attributes: + id (str): Threat identifier (SID) + description (str): Description of the threat + condition (str): A Python expression that should evaluate to a boolean True or False + details (str): Detailed information about the threat + likelihood (str): Likelihood of the threat occurring + severity (str): Severity level of the threat + mitigations (str): Possible mitigations for the threat + prerequisites (str): Prerequisites for the threat + example (str): Example of the threat + references (str): References for the threat + target (Tuple): Target classes for this threat + """ + + model_config = ConfigDict( + extra="allow", validate_assignment=True, arbitrary_types_allowed=True + ) + + id: str = Field(description="Threat identifier (SID)") + description: str = Field(default="", description="Description of the threat") + condition: str = Field( + default="True", + description="A Python expression that should evaluate to a boolean True or False", + ) + details: str = Field( + default="", description="Detailed information about the threat" + ) + likelihood: str = Field( + default="", description="Likelihood of the threat occurring" + ) + severity: str = Field(default="", description="Severity level of the threat") + mitigations: str = Field( + default="", description="Possible mitigations for the threat" + ) + prerequisites: str = Field(default="", description="Prerequisites for the threat") + example: str = Field(default="", description="Example of the threat") + references: str = Field(default="", description="References for the threat") + target: Tuple = Field(default=(), description="Target classes for this threat") + + _compiled_condition: CodeType | None = PrivateAttr(default=None) + _eval_globals: ClassVar[dict[str, Any] | None] = None + _SAFE_BUILTINS: ClassVar[dict[str, Any]] = { + name: getattr(builtins, name) for name in _ConditionValidator.SAFE_CALL_NAMES + } + + @model_validator(mode="before") + @classmethod + def _normalize_input(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + + # Map legacy field names from the threats.json format + if "SID" in data: + data.setdefault("id", data.pop("SID")) + if "Likelihood Of Attack" in data: + data.setdefault("likelihood", data.pop("Likelihood Of Attack")) + + # Normalise target to a tuple + target = data.get("target", "Element") + if isinstance(target, str) or not isinstance(target, Iterable): + target = (target,) + else: + target = tuple(target) + + # Resolve target name strings to actual Python classes + resolved = [] + for name in target: + if isinstance(name, type): + resolved.append(name) + else: + klass = getattr(sys.modules.get("pytm"), name, None) + resolved.append(klass if klass is not None else name) + data["target"] = tuple(resolved) + + return data + + def model_post_init(self, __context: Any) -> None: # noqa: D401 + if not self.condition: + self._compiled_condition = None + return + + try: + tree = ast.parse(self.condition, mode="eval") + validator = _ConditionValidator(self._allowed_global_names()) + validator.visit(tree) + self._compiled_condition = compile( + tree, filename=f"", mode="eval" + ) + except ValueError as exc: # pragma: no cover - defensive, surfaced via tests + raise ValueError(f"Invalid condition for threat {self.id}: {exc}") from exc + except SyntaxError as exc: # noqa: D401 + raise ValueError( + f"Invalid syntax in condition for threat {self.id}: {exc}" + ) from exc + + def _safeset(self, attr: str, value) -> None: + """Safely set an attribute value.""" + try: + setattr(self, attr, value) + except (ValueError, TypeError): + pass + + def __repr__(self): + return ( + f"<{self.__module__}.{type(self).__name__}({self.id}) at {hex(id(self))}>" + ) + + def __str__(self): + return f"{type(self).__name__}({self.id})" + + @classmethod + def _build_eval_globals(cls) -> dict[str, Any]: + if cls._eval_globals is None: + import pytm + + globals_dict: dict[str, Any] = { + "__builtins__": cls._SAFE_BUILTINS, + "Actor": pytm.Actor, + "Asset": pytm.Asset, + "Boundary": pytm.Boundary, + "Dataflow": pytm.Dataflow, + "Datastore": pytm.Datastore, + "DatastoreType": pytm.DatastoreType, + "Element": pytm.Element, + "ExternalEntity": pytm.ExternalEntity, + "Lambda": pytm.Lambda, + "Process": pytm.Process, + "Server": pytm.Server, + "SetOfProcesses": pytm.SetOfProcesses, + "TM": pytm.TM, + "TLSVersion": pytm.TLSVersion, + "Classification": pytm.Classification, + "Action": pytm.Action, + "Lifetime": pytm.Lifetime, + } + + # Expose safe builtins as globals as well for convenience + globals_dict.update(cls._SAFE_BUILTINS) + cls._eval_globals = globals_dict + + return cls._eval_globals + + @classmethod + def _allowed_global_names(cls) -> set[str]: + globals_dict = cls._build_eval_globals() + return {key for key in globals_dict.keys() if key != "__builtins__"} + + def apply(self, target): + """Apply the threat condition to a target.""" + # Check if target matches any of the target types + if self.target: + target_matches = False + for target_type in self.target: + if isinstance(target_type, str): + # String comparison for backward compatibility + if target_type == type(target).__name__: + target_matches = True + break + elif isinstance(target_type, type): + # Class type comparison + if isinstance(target, target_type): + target_matches = True + break + + if not target_matches: + return False + + if self._compiled_condition is None: + return False + + try: + globals_dict = dict(self._build_eval_globals()) + locals_dict = {"target": target} + return bool(eval(self._compiled_condition, globals_dict, locals_dict)) + except Exception: + return False diff --git a/pytm/tm.py b/pytm/tm.py new file mode 100644 index 00000000..fb75e0ca --- /dev/null +++ b/pytm/tm.py @@ -0,0 +1,813 @@ +"""TM (Threat Model) - the main container for all threat model elements.""" + +from __future__ import annotations + +import copy +import json +import logging +import os +import random +import re +import sys +from collections import defaultdict, Counter +from dataclasses import dataclass, field +from datetime import datetime +from itertools import combinations +from textwrap import indent +from typing import ClassVar, Dict, Iterable, List, TYPE_CHECKING +from html import escape as html_escape + +from pydantic import BaseModel, Field, ConfigDict, field_validator + +from .enums import Action +from .base import Assumption +from .template_engine import SuperFormatter + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from .element import Element + from .asset import Asset + from .actor import Actor + from .dataflow import Dataflow + from .boundary import Boundary + from .data import Data + from .threat import Threat + from .finding import Finding + + +class UIError(Exception): + """Exception for UI-related errors.""" + + def __init__(self, e, context): + self.error = e + self.context = context + + +@dataclass +class TMState: + """Mutable registry for TM-owned collections.""" + + flows: List["Dataflow"] = field(default_factory=list) + elements: List["Element"] = field(default_factory=list) + actors: List["Actor"] = field(default_factory=list) + assets: List["Asset"] = field(default_factory=list) + threats: List["Threat"] = field(default_factory=list) + boundaries: List["Boundary"] = field(default_factory=list) + data: List["Data"] = field(default_factory=list) + threats_excluded: List[str] = field(default_factory=list) + + +class _StateAttribute: + """Descriptor that proxies attribute access to the shared TM state.""" + + def __init__(self, field_name: str): + self.field_name = field_name + self.owner: type["TM"] | None = None + + def __set_name__(self, owner, name): + self.owner = owner + register = getattr(owner, "_register_state_attribute", None) + if callable(register): + register(name, self) + + def __get__(self, instance, owner=None): + owner = owner or self.owner + if owner is None: + raise AttributeError("State attribute descriptor is unbound") + return getattr(owner._state, self.field_name) + + def __set__(self, instance, value): + owner = self.owner if instance is None else type(instance) + if owner is None: + raise AttributeError("State attribute descriptor is unbound") + setattr(owner._state, self.field_name, value) + + +class TMModelMetaclass(type(BaseModel)): + """Metaclass that keeps TM state descriptors intact on class assignment.""" + + def __setattr__(cls, name, value): + state_attrs = getattr(cls, "_state_attributes", None) + if state_attrs and name in state_attrs: + descriptor = state_attrs[name] + descriptor.__set__(None, value) + return + super().__setattr__(name, value) + + +class TM(BaseModel, metaclass=TMModelMetaclass): + """Describes the threat model administratively, and holds all details during a run.""" + + model_config = ConfigDict( + extra="allow", validate_assignment=True, arbitrary_types_allowed=True + ) + + _state: ClassVar[TMState] = TMState() + _state_attributes: ClassVar[Dict[str, _StateAttribute]] = {} + + @classmethod + def _register_state_attribute(cls, name: str, descriptor: _StateAttribute) -> None: + cls._state_attributes[name] = descriptor + + _flows: ClassVar[_StateAttribute] = _StateAttribute("flows") + _elements: ClassVar[_StateAttribute] = _StateAttribute("elements") + _actors: ClassVar[_StateAttribute] = _StateAttribute("actors") + _assets: ClassVar[_StateAttribute] = _StateAttribute("assets") + _threats: ClassVar[_StateAttribute] = _StateAttribute("threats") + _boundaries: ClassVar[_StateAttribute] = _StateAttribute("boundaries") + _data: ClassVar[_StateAttribute] = _StateAttribute("data") + _threatsExcluded: ClassVar[_StateAttribute] = _StateAttribute("threats_excluded") + + @classmethod + def _get_state(cls) -> TMState: + """Return the mutable shared state for this TM class.""" + return cls._state + + name: str = Field(description="Model name") + description: str = Field(description="Model description") + threatsFile: str = Field( + default_factory=lambda: os.path.dirname(__file__) + "/threatlib/threats.json", + description="JSON file with custom threats", + ) + isOrdered: bool = Field( + default=False, description="Automatically order all Dataflows" + ) + mergeResponses: bool = Field( + default=False, description="Merge response edges in DFDs" + ) + ignoreUnused: bool = Field( + default=False, description="Ignore elements not used in any Dataflow" + ) + findings: List["Finding"] = Field( + default_factory=list, description="Threats found for elements of this model" + ) + excluded_findings: List["Finding"] = Field( + default_factory=list, + description="Threats found for elements of this model, that were excluded on a per-element basis, using the Assumptions class", + ) + onDuplicates: Action = Field( + default=Action.NO_ACTION, + description="How to handle duplicate Dataflow with same properties, except name and notes", + ) + assumptions: List[Assumption] = Field( + default_factory=list, description="A list of assumptions about the design/model" + ) + colormap: bool = Field(default=False, exclude=True) + + @field_validator("assumptions", mode="before") + @classmethod + def _normalize_assumptions(cls, value): + """Allow string assumptions to be converted into Assumption models.""" + if value is None or value == []: + return [] + + def convert(item): + if isinstance(item, Assumption): + return item + if isinstance(item, dict): + return item + if isinstance(item, str): + return {"name": item} + raise TypeError( + "assumptions must be strings, dicts, or Assumption instances" + ) + + if isinstance(value, (str, dict, Assumption)): + return [convert(value)] + + if isinstance(value, Iterable) and not isinstance(value, (bytes, str)): + return [convert(item) for item in value] + + raise TypeError( + "assumptions must be provided as an iterable of supported types" + ) + + def __init__(self, name: str, description: str = "", **data): + """Initialize the threat model.""" + data.update({"name": name, "description": description}) + + object.__setattr__(self, "_initializing_tm", True) + super().__init__(**data) + object.__setattr__(self, "_initializing_tm", False) + + self._sf = SuperFormatter() + random.seed(0) + + try: + self._init_threats() + except UIError as e: + raise e + finally: + if hasattr(self, "_initializing_tm"): + object.__delattr__(self, "_initializing_tm") + + def __setattr__(self, name, value): + if name == "threatsFile" and not getattr(self, "_initializing_tm", False): + current_value = getattr(self, "threatsFile", None) + if current_value == value: + return super().__setattr__(name, value) + + super().__setattr__(name, value) + try: + self._init_threats() + except UIError as e: + object.__setattr__(self, "_initializing_tm", True) + try: + super().__setattr__(name, current_value) + finally: + object.__setattr__(self, "_initializing_tm", False) + + if current_value is not None: + try: + self._init_threats() + except UIError: + TM._get_state().threats.clear() + raise e + return + + super().__setattr__(name, value) + + @classmethod + def reset(cls): + """Reset all class variables.""" + cls._state = TMState() + + def _init_threats(self): + """Initialize threats from file.""" + TM._get_state().threats.clear() + self._add_threats() + + def _add_threats(self): + """Add threats from the threats file.""" + try: + with open(self.threatsFile, "r", encoding="utf8") as threat_file: + threats_json = json.load(threat_file) + except (FileNotFoundError, PermissionError, IsADirectoryError) as e: + raise UIError( + e, f"while trying to open the threat file ({self.threatsFile})." + ) + + from .threat import Threat + + active_threats = ( + threat for threat in threats_json if "DEPRECATED" not in threat + ) + for threat in active_threats: + TM._threats.append(Threat(**threat)) + + def check(self): + """Check the threat model for consistency and completeness.""" + if self.description is None: + raise ValueError( + """Every threat model should have at least +a brief description of the system being modeled.""" + ) + + from . import pytm as pytm_module + + state = TM._get_state() + state.flows = pytm_module._match_responses( + pytm_module._sort(state.flows, getattr(self, "isOrdered", False)) + ) + + self._check_duplicates(state.flows) + + pytm_module._apply_defaults(state.flows, state.data) + + for element in state.elements: + top = Counter( + getattr(f, "threat_id", None) for f in getattr(element, "overrides", []) + ).most_common(1) + if not top: + continue + threat_id, count = top[0] + if count != 1: + raise ValueError( + f"Finding {threat_id} have more than one override in {element}" + ) + + if getattr(self, "ignoreUnused", False): + elements, boundaries = pytm_module._get_elements_and_boundaries(state.flows) + state.elements = elements + state.boundaries = boundaries + + result = True + for element in state.elements: + if not element.check(): + result = False + + if getattr(self, "ignoreUnused", False): + state.elements = pytm_module._sort_elem(state.elements) + + return result + + def resolve(self): + """Resolve threats and generate findings.""" + from .finding import Finding + from collections import defaultdict + + finding_count = 0 + excluded_finding_count = 0 + findings = [] + excluded_findings = [] + + # Get global assumptions with exclusions + global_assumptions = [a for a in self.assumptions if len(a.exclude) > 0] + elements = defaultdict(list) + + for e in TM._elements: + if not getattr(e, "inScope", True): + e.findings = findings + continue + + override_ids = set(f.threat_id for f in getattr(e, "overrides", [])) + + # Filter out overrides from source and sink for dataflows + try: + source_overrides = set( + f.threat_id for f in getattr(e.source, "overrides", []) + ) + sink_overrides = set( + f.threat_id for f in getattr(e.sink, "overrides", []) + ) + override_ids -= source_overrides | sink_overrides + except AttributeError: + pass + + for t in TM._threats: + if not t.apply(e) and t.id not in override_ids: + continue + + if t.id in TM._threatsExcluded: + continue + + _continue = False + element_assumptions = getattr(e, "assumptions", []) + for assumption in element_assumptions + global_assumptions: + if hasattr(assumption, "exclude") and t.id in assumption.exclude: + excluded_finding_count += 1 + f = Finding( + e, + id=str(excluded_finding_count), + threat=t, + assumption=assumption, + ) + excluded_findings.append(f) + _continue = True + break + if _continue: + continue + + finding_count += 1 + f = Finding(e, id=str(finding_count), threat=t) + findings.append(f) + elements[e].append(f) + + # Set severity on element + if hasattr(e, "_set_severity"): + e._set_severity(getattr(f, "severity", 0)) + + self.findings = findings + self.excluded_findings = excluded_findings + + for e, findings in elements.items(): + e.findings = findings + + def process(self): + """Entry point mirroring the legacy CLI workflow.""" + try: + self._process() + except UIError as e: # pragma: no cover - mirrors historical behaviour + message = "Failed to execute\n" f" {e.context}\n" f" {e.error}\n" + sys.stderr.write(message) + raise SystemExit(127) from e + + def _process(self): + """Execute the CLI workflow (check, resolve, render outputs).""" + from . import pytm as pytm_module + + self.check() + + result = pytm_module.get_args() + + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + if getattr(result, "debug", False): + logger.setLevel(logging.DEBUG) + + exclude_raw = getattr(result, "exclude", None) + if exclude_raw: + if isinstance(exclude_raw, str): + tokens = [exclude_raw] + elif isinstance(exclude_raw, list): + tokens = [str(item) for item in exclude_raw] + else: + tokens = [str(exclude_raw)] + + exclusions: list[str] = [] + for token in tokens: + exclusions.extend( + sid.strip() for sid in token.split(",") if sid and sid.strip() + ) + TM._threatsExcluded = exclusions + + if getattr(result, "seq", False): + print(self.seq()) + + if getattr(result, "dfd", False): + if getattr(result, "colormap", False): + self.resolve() + levels = set(getattr(result, "levels", []) or []) + print(self.dfd(colormap=getattr(result, "colormap", False), levels=levels)) + + needs_resolution = any( + getattr(result, attr, None) for attr in ("report", "json", "stale_days") + ) + + if needs_resolution: + self.resolve() + + if getattr(result, "json", None): + try: + with open(result.json, "w", encoding="utf8") as f: + json.dump(self, f, default=pytm_module.to_serializable) + except (FileExistsError, PermissionError, IsADirectoryError) as exc: + raise UIError( + exc, f"while trying to write to the result file ({result.json})" + ) + + if getattr(result, "report", None): + print(self.report(result.report)) + + describe_targets = getattr(result, "describe", None) + if describe_targets: + names = describe_targets.split() + pytm_module._describe_classes(names) + + if getattr(result, "list_elements", False): + pytm_module._list_elements() + + if getattr(result, "list", False): + for threat in TM._threats: + print( + f"{getattr(threat, 'id', '')} - {getattr(threat, 'description', '')}" + ) + + if getattr(result, "stale_days", None) is not None: + print(self._stale(result.stale_days)) + + def _stale(self, days: int) -> str: + """Report source files whose age diverges from the model script.""" + base_path = os.path.dirname(sys.argv[0]) + try: + tm_path = os.path.join(base_path, sys.argv[0]) + tm_mtime = datetime.fromtimestamp(os.stat(tm_path).st_mtime) + except OSError as err: + sys.stderr.write(f"{sys.argv[0]} - {err}\n") + sys.stderr.flush() + return "[ERROR]" + + print(f"Checking for code {days} days older than this model.") + + for element in TM._elements: + source_files = getattr(element, "sourceFiles", []) + for src in source_files: + try: + src_path = os.path.join(base_path, src) + src_mtime = datetime.fromtimestamp(os.stat(src_path).st_mtime) + except OSError as err: + sys.stderr.write(f"{sys.argv[0]} - {err}\n") + sys.stderr.flush() + continue + + age = (src_mtime - tm_mtime).days + if age >= days: + print(f"This model is {age} days older than {src_path}.") + elif age <= -days: + print( + f"Model script {sys.argv[0]} is only {-age} days newer than source code file {src_path}" + ) + + return "" + + def _dfd_template(self): + """Template for DFD generation.""" + return ( + "digraph tm {{\n" + " graph [\n" + " fontname = Arial;\n" + " fontsize = 14;\n" + " ]\n" + " node [\n" + " fontname = Arial;\n" + " fontsize = 14;\n" + " rankdir = lr;\n" + " ]\n" + " edge [\n" + " shape = none;\n" + " arrowtail = onormal;\n" + " fontname = Arial;\n" + " fontsize = 12;\n" + " ]\n" + ' labelloc = "t";\n' + " fontsize = 20;\n" + " nodesep = 1;\n" + "\n" + "{edges}\n" + "\n" + "}}" + ) + + def dfd(self, **kwargs): + """Generate Data Flow Diagram.""" + from collections import defaultdict + from .boundary import Boundary + + if "levels" in kwargs: + levels = kwargs["levels"] + if not hasattr(levels, "__iter__") or isinstance(levels, str): + kwargs["levels"] = [levels] + kwargs["levels"] = set(kwargs["levels"]) + + edges = [] + # Since boundaries can be nested sort them by level and start from top + parents = set(b.inBoundary for b in TM._boundaries if b.inBoundary) + + # Collect boundary levels + boundary_levels = defaultdict(set) + max_level = 0 + for b in TM._boundaries: + if b in parents: + continue + boundary_levels[0].add(b) + for i, p in enumerate(getattr(b, "parents", lambda: [])(), 1): + boundary_levels[i].add(p) + if i > max_level: + max_level = i + + # Draw boundaries from highest level to lowest + for i in range(max_level, -1, -1): + for b in sorted(boundary_levels[i], key=lambda b: b.name): + edges.append(b.dfd(**kwargs)) + + # Handle response merging + if getattr(self, "mergeResponses", False): + for e in TM._flows: + if getattr(e, "response", None) is not None: + e.response.is_drawn = True + kwargs["mergeResponses"] = getattr(self, "mergeResponses", False) + + # Draw elements that are not boundaries and not inside boundaries + for e in TM._elements: + if ( + not getattr(e, "is_drawn", False) + and not isinstance(e, Boundary) + and getattr(e, "inBoundary", None) is None + ): + edges.append(e.dfd(**kwargs)) + + return self._dfd_template().format( + edges=indent("\n".join(filter(len, edges)), " ").rstrip("\n") + ) + + def _seq_template(self): + """Template for sequence diagram generation.""" + return """@startuml +{participants} + +{messages} +@enduml""" + + def seq(self): + """Generate sequence diagram.""" + from .actor import Actor + from .datastore import Datastore + from .boundary import Boundary + from .dataflow import Dataflow + + participants = [] + for e in TM._elements: + if isinstance(e, Actor): + participants.append( + 'actor {0} as "{1}"'.format( + e._uniq_name(), getattr(e, "display_name", lambda: e.name)() + ) + ) + elif isinstance(e, Datastore): + participants.append( + 'database {0} as "{1}"'.format( + e._uniq_name(), getattr(e, "display_name", lambda: e.name)() + ) + ) + elif not isinstance(e, (Dataflow, Boundary)): + participants.append( + 'entity {0} as "{1}"'.format( + e._uniq_name(), getattr(e, "display_name", lambda: e.name)() + ) + ) + + messages = [] + for e in TM._flows: + message = "{0} -> {1}: {2}".format( + e.source._uniq_name(), + e.sink._uniq_name(), + getattr(e, "display_name", lambda: e.name)(), + ) + note = "" + if getattr(e, "note", "") != "": + note = "\nnote left\n{}\nend note".format(e.note) + messages.append("{}{}".format(message, note)) + + return self._seq_template().format( + participants="\n".join(participants), messages="\n".join(messages) + ) + + def report(self, template_path): + """Generate report from template.""" + + try: + with open(template_path) as file: + template = file.read() + except (FileNotFoundError, PermissionError, IsADirectoryError) as e: + from .pytm import UIError + + raise UIError( + e, f"while trying to open the report template file ({template_path})." + ) + + def _clone(obj): + copy_method = getattr(obj, "model_copy", None) + if callable(copy_method): + return copy_method(deep=True) + return copy.deepcopy(obj) + + def encode_threat_data(obj): + """Encode threat data for HTML output.""" + encoded_threat_data = [] + attrs = [ + "description", + "details", + "severity", + "mitigations", + "example", + "id", + "threat_id", + "references", + "condition", + "cvss", + "response", + ] + + items = obj if isinstance(obj, list) else [obj] + + def _escape_markdown(text: str) -> str: + return re.sub(r"(? sqlite3.Connection: - db_path = tmp_path / "sqldump" / "test.db" - return sqlite3.connect(db_path) - - -def test_sql_dump_creates_serialized_columns(sample_tm, tmp_path, monkeypatch): - monkeypatch.chdir(tmp_path) - - sample_tm.sqlDump("test.db") - - with _open_connection(tmp_path) as conn: - column_names = { - column_info[1].lower() - for column_info in conn.execute("PRAGMA table_info(Boundary)") - } - - assert {"name", "inscope", "inboundary"}.issubset(column_names) - - -def test_sql_dump_persists_element_and_finding_data(sample_tm, tmp_path, monkeypatch): - monkeypatch.chdir(tmp_path) - - sample_tm.sqlDump("test.db") - - with _open_connection(tmp_path) as conn: - boundary_rows = conn.execute( - "SELECT name, inBoundary FROM Boundary ORDER BY id" - ).fetchall() - server_rows = conn.execute( - "SELECT name, inBoundary FROM Server ORDER BY id" - ).fetchall() - finding_rows = conn.execute( - "SELECT threat_id FROM Finding ORDER BY id" - ).fetchall() - - assert ("Internet", None) in boundary_rows - assert ("Server/DB", "Internet") in boundary_rows - assert ("Web Server", "Server/DB") in server_rows - assert [row[0] for row in finding_rows] == ["SRV001"] \ No newline at end of file diff --git a/tm.py b/tm.py index 9f0cb464..cd6fdc4c 100755 --- a/tm.py +++ b/tm.py @@ -134,4 +134,3 @@ if __name__ == "__main__": tm.process() -