diff --git a/poetry.lock b/poetry.lock
index 43b834b3..b69674ff 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,16 @@
-# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 2.3.3 and should not be changed by hand.
+
+[[package]]
+name = "annotated-types"
+version = "0.7.0"
+description = "Reusable constraint types to use with typing.Annotated"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"},
+ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
+]
[[package]]
name = "black"
@@ -136,19 +148,6 @@ files = [
{file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"},
]
-[[package]]
-name = "legacy-cgi"
-version = "2.6.4"
-description = "Fork of the standard library cgi and cgitb modules removed in Python 3.13"
-optional = false
-python-versions = ">=3.8"
-groups = ["main"]
-markers = "python_version >= \"3.13\""
-files = [
- {file = "legacy_cgi-2.6.4-py3-none-any.whl", hash = "sha256:7e235ce58bf1e25d1fc9b2d299015e4e2cd37305eccafec1e6bac3fc04b878cd"},
- {file = "legacy_cgi-2.6.4.tar.gz", hash = "sha256:abb9dfc7835772f7c9317977c63253fd22a7484b5c9bbcdca60a29dcce97c577"},
-]
-
[[package]]
name = "mako"
version = "1.3.10"
@@ -373,16 +372,161 @@ dev = ["pre-commit", "tox"]
testing = ["coverage", "pytest", "pytest-benchmark"]
[[package]]
-name = "pydal"
-version = "20200714.1"
-description = "a pure Python Database Abstraction Layer (for python version 2.7 and 3.x)"
+name = "pydantic"
+version = "2.12.5"
+description = "Data validation using Python type hints"
optional = false
-python-versions = "*"
+python-versions = ">=3.9"
groups = ["main"]
files = [
- {file = "pydal-20200714.1.tar.gz", hash = "sha256:dd35b8ecb009099cce7efa72a40707d2e9bdcdf85924f30683a52d5172d1242f"},
+ {file = "pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d"},
+ {file = "pydantic-2.12.5.tar.gz", hash = "sha256:4d351024c75c0f085a9febbb665ce8c0c6ec5d30e903bdb6394b7ede26aebb49"},
]
+[package.dependencies]
+annotated-types = ">=0.6.0"
+pydantic-core = "2.41.5"
+typing-extensions = ">=4.14.1"
+typing-inspection = ">=0.4.2"
+
+[package.extras]
+email = ["email-validator (>=2.0.0)"]
+timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""]
+
+[[package]]
+name = "pydantic-core"
+version = "2.41.5"
+description = "Core functionality for Pydantic validation and serialization"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "pydantic_core-2.41.5-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:77b63866ca88d804225eaa4af3e664c5faf3568cea95360d21f4725ab6e07146"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dfa8a0c812ac681395907e71e1274819dec685fec28273a28905df579ef137e2"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5921a4d3ca3aee735d9fd163808f5e8dd6c6972101e4adbda9a4667908849b97"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e25c479382d26a2a41b7ebea1043564a937db462816ea07afa8a44c0866d52f9"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f547144f2966e1e16ae626d8ce72b4cfa0caedc7fa28052001c94fb2fcaa1c52"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f52298fbd394f9ed112d56f3d11aabd0d5bd27beb3084cc3d8ad069483b8941"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:100baa204bb412b74fe285fb0f3a385256dad1d1879f0a5cb1499ed2e83d132a"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:05a2c8852530ad2812cb7914dc61a1125dc4e06252ee98e5638a12da6cc6fb6c"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:29452c56df2ed968d18d7e21f4ab0ac55e71dc59524872f6fc57dcf4a3249ed2"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:d5160812ea7a8a2ffbe233d8da666880cad0cbaf5d4de74ae15c313213d62556"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:df3959765b553b9440adfd3c795617c352154e497a4eaf3752555cfb5da8fc49"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-win32.whl", hash = "sha256:1f8d33a7f4d5a7889e60dc39856d76d09333d8a6ed0f5f1190635cbec70ec4ba"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-win_amd64.whl", hash = "sha256:62de39db01b8d593e45871af2af9e497295db8d73b085f6bfd0b18c83c70a8f9"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:a3a52f6156e73e7ccb0f8cced536adccb7042be67cb45f9562e12b319c119da6"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7f3bf998340c6d4b0c9a2f02d6a400e51f123b59565d74dc60d252ce888c260b"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:378bec5c66998815d224c9ca994f1e14c0c21cb95d2f52b6021cc0b2a58f2a5a"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e7b576130c69225432866fe2f4a469a85a54ade141d96fd396dffcf607b558f8"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6cb58b9c66f7e4179a2d5e0f849c48eff5c1fca560994d6eb6543abf955a149e"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:88942d3a3dff3afc8288c21e565e476fc278902ae4d6d134f1eeda118cc830b1"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f31d95a179f8d64d90f6831d71fa93290893a33148d890ba15de25642c5d075b"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c1df3d34aced70add6f867a8cf413e299177e0c22660cc767218373d0779487b"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4009935984bd36bd2c774e13f9a09563ce8de4abaa7226f5108262fa3e637284"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:34a64bc3441dc1213096a20fe27e8e128bd3ff89921706e83c0b1ac971276594"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c9e19dd6e28fdcaa5a1de679aec4141f691023916427ef9bae8584f9c2fb3b0e"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-win32.whl", hash = "sha256:2c010c6ded393148374c0f6f0bf89d206bf3217f201faa0635dcd56bd1520f6b"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-win_amd64.whl", hash = "sha256:76ee27c6e9c7f16f47db7a94157112a2f3a00e958bc626e2f4ee8bec5c328fbe"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-win_arm64.whl", hash = "sha256:4bc36bbc0b7584de96561184ad7f012478987882ebf9f9c389b23f432ea3d90f"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f41a7489d32336dbf2199c8c0a215390a751c5b014c2c1c5366e817202e9cdf7"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:070259a8818988b9a84a449a2a7337c7f430a22acc0859c6b110aa7212a6d9c0"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e96cea19e34778f8d59fe40775a7a574d95816eb150850a85a7a4c8f4b94ac69"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed2e99c456e3fadd05c991f8f437ef902e00eedf34320ba2b0842bd1c3ca3a75"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65840751b72fbfd82c3c640cff9284545342a4f1eb1586ad0636955b261b0b05"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e536c98a7626a98feb2d3eaf75944ef6f3dbee447e1f841eae16f2f0a72d8ddc"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eceb81a8d74f9267ef4081e246ffd6d129da5d87e37a77c9bde550cb04870c1c"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d38548150c39b74aeeb0ce8ee1d8e82696f4a4e16ddc6de7b1d8823f7de4b9b5"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c23e27686783f60290e36827f9c626e63154b82b116d7fe9adba1fda36da706c"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:482c982f814460eabe1d3bb0adfdc583387bd4691ef00b90575ca0d2b6fe2294"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bfea2a5f0b4d8d43adf9d7b8bf019fb46fdd10a2e5cde477fbcb9d1fa08c68e1"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-win32.whl", hash = "sha256:b74557b16e390ec12dca509bce9264c3bbd128f8a2c376eaa68003d7f327276d"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-win_amd64.whl", hash = "sha256:1962293292865bca8e54702b08a4f26da73adc83dd1fcf26fbc875b35d81c815"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-win_arm64.whl", hash = "sha256:1746d4a3d9a794cacae06a5eaaccb4b8643a131d45fbc9af23e353dc0a5ba5c3"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:941103c9be18ac8daf7b7adca8228f8ed6bb7a1849020f643b3a14d15b1924d9"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:112e305c3314f40c93998e567879e887a3160bb8689ef3d2c04b6cc62c33ac34"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cbaad15cb0c90aa221d43c00e77bb33c93e8d36e0bf74760cd00e732d10a6a0"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:03ca43e12fab6023fc79d28ca6b39b05f794ad08ec2feccc59a339b02f2b3d33"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc799088c08fa04e43144b164feb0c13f9a0bc40503f8df3e9fde58a3c0c101e"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97aeba56665b4c3235a0e52b2c2f5ae9cd071b8a8310ad27bddb3f7fb30e9aa2"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:406bf18d345822d6c21366031003612b9c77b3e29ffdb0f612367352aab7d586"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b93590ae81f7010dbe380cdeab6f515902ebcbefe0b9327cc4804d74e93ae69d"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:01a3d0ab748ee531f4ea6c3e48ad9dac84ddba4b0d82291f87248f2f9de8d740"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:6561e94ba9dacc9c61bce40e2d6bdc3bfaa0259d3ff36ace3b1e6901936d2e3e"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:915c3d10f81bec3a74fbd4faebe8391013ba61e5a1a8d48c4455b923bdda7858"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-win32.whl", hash = "sha256:650ae77860b45cfa6e2cdafc42618ceafab3a2d9a3811fcfbd3bbf8ac3c40d36"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-win_amd64.whl", hash = "sha256:79ec52ec461e99e13791ec6508c722742ad745571f234ea6255bed38c6480f11"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-win_arm64.whl", hash = "sha256:3f84d5c1b4ab906093bdc1ff10484838aca54ef08de4afa9de0f5f14d69639cd"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:3f37a19d7ebcdd20b96485056ba9e8b304e27d9904d233d7b1015db320e51f0a"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1d1d9764366c73f996edd17abb6d9d7649a7eb690006ab6adbda117717099b14"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25e1c2af0fce638d5f1988b686f3b3ea8cd7de5f244ca147c777769e798a9cd1"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:506d766a8727beef16b7adaeb8ee6217c64fc813646b424d0804d67c16eddb66"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4819fa52133c9aa3c387b3328f25c1facc356491e6135b459f1de698ff64d869"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2b761d210c9ea91feda40d25b4efe82a1707da2ef62901466a42492c028553a2"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22f0fb8c1c583a3b6f24df2470833b40207e907b90c928cc8d3594b76f874375"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2782c870e99878c634505236d81e5443092fba820f0373997ff75f90f68cd553"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:0177272f88ab8312479336e1d777f6b124537d47f2123f89cb37e0accea97f90"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_armv7l.whl", hash = "sha256:63510af5e38f8955b8ee5687740d6ebf7c2a0886d15a6d65c32814613681bc07"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:e56ba91f47764cc14f1daacd723e3e82d1a89d783f0f5afe9c364b8bb491ccdb"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-win32.whl", hash = "sha256:aec5cf2fd867b4ff45b9959f8b20ea3993fc93e63c7363fe6851424c8a7e7c23"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-win_amd64.whl", hash = "sha256:8e7c86f27c585ef37c35e56a96363ab8de4e549a95512445b85c96d3e2f7c1bf"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-win_arm64.whl", hash = "sha256:e672ba74fbc2dc8eea59fb6d4aed6845e6905fc2a8afe93175d94a83ba2a01a0"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:8566def80554c3faa0e65ac30ab0932b9e3a5cd7f8323764303d468e5c37595a"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:b80aa5095cd3109962a298ce14110ae16b8c1aece8b72f9dafe81cf597ad80b3"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3006c3dd9ba34b0c094c544c6006cc79e87d8612999f1a5d43b769b89181f23c"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:72f6c8b11857a856bcfa48c86f5368439f74453563f951e473514579d44aa612"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5cb1b2f9742240e4bb26b652a5aeb840aa4b417c7748b6f8387927bc6e45e40d"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd3d54f38609ff308209bd43acea66061494157703364ae40c951f83ba99a1a9"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ff4321e56e879ee8d2a879501c8e469414d948f4aba74a2d4593184eb326660"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d0d2568a8c11bf8225044aa94409e21da0cb09dcdafe9ecd10250b2baad531a9"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:a39455728aabd58ceabb03c90e12f71fd30fa69615760a075b9fec596456ccc3"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_armv7l.whl", hash = "sha256:239edca560d05757817c13dc17c50766136d21f7cd0fac50295499ae24f90fdf"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:2a5e06546e19f24c6a96a129142a75cee553cc018ffee48a460059b1185f4470"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-win32.whl", hash = "sha256:b4ececa40ac28afa90871c2cc2b9ffd2ff0bf749380fbdf57d165fd23da353aa"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-win_amd64.whl", hash = "sha256:80aa89cad80b32a912a65332f64a4450ed00966111b6615ca6816153d3585a8c"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:8bfeaf8735be79f225f3fefab7f941c712aaca36f1128c9d7e2352ee1aa87bdf"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:346285d28e4c8017da95144c7f3acd42740d637ff41946af5ce6e5e420502dd5"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a75dafbf87d6276ddc5b2bf6fae5254e3d0876b626eb24969a574fff9149ee5d"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7b93a4d08587e2b7e7882de461e82b6ed76d9026ce91ca7915e740ecc7855f60"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e8465ab91a4bd96d36dde3263f06caa6a8a6019e4113f24dc753d79a8b3a3f82"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:299e0a22e7ae2b85c1a57f104538b2656e8ab1873511fd718a1c1c6f149b77b5"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:707625ef0983fcfb461acfaf14de2067c5942c6bb0f3b4c99158bed6fedd3cf3"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f41eb9797986d6ebac5e8edff36d5cef9de40def462311b3eb3eeded1431e425"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0384e2e1021894b1ff5a786dbf94771e2986ebe2869533874d7e43bc79c6f504"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:f0cd744688278965817fd0839c4a4116add48d23890d468bc436f78beb28abf5"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:753e230374206729bf0a807954bcc6c150d3743928a73faffee51ac6557a03c3"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-win32.whl", hash = "sha256:873e0d5b4fb9b89ef7c2d2a963ea7d02879d9da0da8d9d4933dee8ee86a8b460"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-win_amd64.whl", hash = "sha256:e4f4a984405e91527a0d62649ee21138f8e3d0ef103be488c1dc11a80d7f184b"},
+ {file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-macosx_10_12_x86_64.whl", hash = "sha256:b96d5f26b05d03cc60f11a7761a5ded1741da411e7fe0909e27a5e6a0cb7b034"},
+ {file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-macosx_11_0_arm64.whl", hash = "sha256:634e8609e89ceecea15e2d61bc9ac3718caaaa71963717bf3c8f38bfde64242c"},
+ {file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:93e8740d7503eb008aa2df04d3b9735f845d43ae845e6dcd2be0b55a2da43cd2"},
+ {file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f15489ba13d61f670dcc96772e733aad1a6f9c429cc27574c6cdaed82d0146ad"},
+ {file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:7da7087d756b19037bc2c06edc6c170eeef3c3bafcb8f532ff17d64dc427adfd"},
+ {file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:aabf5777b5c8ca26f7824cb4a120a740c9588ed58df9b2d196ce92fba42ff8dc"},
+ {file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c007fe8a43d43b3969e8469004e9845944f1a80e6acd47c150856bb87f230c56"},
+ {file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76d0819de158cd855d1cbb8fcafdf6f5cf1eb8e470abe056d5d161106e38062b"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b5819cd790dbf0c5eb9f82c73c16b39a65dd6dd4d1439dcdea7816ec9adddab8"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:5a4e67afbc95fa5c34cf27d9089bca7fcab4e51e57278d710320a70b956d1b9a"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ece5c59f0ce7d001e017643d8d24da587ea1f74f6993467d85ae8a5ef9d4f42b"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:16f80f7abe3351f8ea6858914ddc8c77e02578544a0ebc15b4c2e1a0e813b0b2"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:33cb885e759a705b426baada1fe68cbb0a2e68e34c5d0d0289a364cf01709093"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:c8d8b4eb992936023be7dee581270af5c6e0697a8559895f527f5b7105ecd36a"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:242a206cd0318f95cd21bdacff3fcc3aab23e79bba5cac3db5a841c9ef9c6963"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:d3a978c4f57a597908b7e697229d996d77a6d3c94901e9edee593adada95ce1a"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2379fa7ed44ddecb5bfe4e48577d752db9fc10be00a6b7446e9663ba143de26"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:266fb4cbf5e3cbd0b53669a6d1b039c45e3ce651fd5442eff4d07c2cc8d66808"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58133647260ea01e4d0500089a8c4f07bd7aa6ce109682b1426394988d8aaacc"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:287dad91cfb551c363dc62899a80e9e14da1f0e2b6ebde82c806612ca2a13ef1"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:03b77d184b9eb40240ae9fd676ca364ce1085f203e1b1256f8ab9984dca80a84"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:a668ce24de96165bb239160b3d854943128f4334822900534f2fe947930e5770"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f14f8f046c14563f8eb3f45f499cc658ab8d10072961e07225e507adb700e93f"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51"},
+ {file = "pydantic_core-2.41.5.tar.gz", hash = "sha256:08daa51ea16ad373ffd5e7606252cc32f07bc72b28284b6bc9c6df804816476e"},
+]
+
+[package.dependencies]
+typing-extensions = ">=4.14.1"
+
[[package]]
name = "pygments"
version = "2.19.2"
@@ -496,12 +640,27 @@ version = "4.15.0"
description = "Backported and Experimental Type Hints for Python 3.9+"
optional = false
python-versions = ">=3.9"
-groups = ["dev"]
-markers = "python_version < \"3.11\""
+groups = ["main", "dev"]
files = [
{file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"},
{file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"},
]
+markers = {dev = "python_version < \"3.11\""}
+
+[[package]]
+name = "typing-inspection"
+version = "0.4.2"
+description = "Runtime typing introspection tools"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7"},
+ {file = "typing_inspection-0.4.2.tar.gz", hash = "sha256:ba561c48a67c5958007083d386c3295464928b01faa735ab8547c5692e87f464"},
+]
+
+[package.dependencies]
+typing-extensions = ">=4.12.0"
[[package]]
name = "zipp"
@@ -526,5 +685,5 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
-python-versions = "^3.9 || ^3.10 || ^3.11"
-content-hash = "d13ccd9a0de456c987bd6c6f20034c2f2a71279f65ddd4b2a0d597ef5ca5fd86"
+python-versions = "^3.9 || ^3.10 || ^3.11 || ^3.12 || ^3.13 || ^3.14"
+content-hash = "f29cad7d87838e9587a1f1747bd5cae2aba51f0c0abac77a653f0ed1b93b0b8b"
diff --git a/pyproject.toml b/pyproject.toml
index 483f31c9..1fd3a275 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -6,9 +6,8 @@ authors = ["pytm Team"]
license = "MIT License"
[tool.poetry.dependencies]
-python = "^3.9 || ^3.10 || ^3.11"
-pydal = "~20200714.1"
-legacy-cgi = { version = "^2.0", markers = "python_version >= '3.13'" }
+python = "^3.9 || ^3.10 || ^3.11 || ^3.12 || ^3.13 || ^3.14"
+pydantic = "^2.10.0"
[tool.poetry.group.dev.dependencies]
pytest = "^8.3.5"
diff --git a/pytm/__init__.py b/pytm/__init__.py
index eb15b7fa..c8279547 100644
--- a/pytm/__init__.py
+++ b/pytm/__init__.py
@@ -22,35 +22,46 @@
"SetOfProcesses",
"Threat",
"TM",
+ "Controls",
+ "var",
]
import sys
from .json import load, loads
-from .pytm import (
- TM,
- Action,
- Actor,
- Assumption,
- Boundary,
- Classification,
- Data,
- Dataflow,
- Datastore,
- DatastoreType,
- Element,
- ExternalEntity,
- Finding,
- Lambda,
- LLM,
- Lifetime,
- Process,
- Server,
- SetOfProcesses,
- Threat,
- TLSVersion,
- var,
-)
+from .pytm import var
+
+# Import from new Pydantic models
+from .enums import Action, Classification, DatastoreType, Lifetime, TLSVersion
+from .base import Assumption, Controls
+from .element import Element
+from .data import Data
+from .threat import Threat
+from .finding import Finding
+from .asset import Asset, Lambda, LLM, Server, ExternalEntity
+from .datastore import Datastore
+from .actor import Actor
+from .process import Process, SetOfProcesses
+from .dataflow import Dataflow
+from .boundary import Boundary
+from .tm import TM
+
+# Rebuild models to resolve forward references
+Element.model_rebuild()
+Data.model_rebuild()
+Finding.model_rebuild()
+Asset.model_rebuild()
+Lambda.model_rebuild()
+LLM.model_rebuild()
+Server.model_rebuild()
+ExternalEntity.model_rebuild()
+Datastore.model_rebuild()
+Actor.model_rebuild()
+Process.model_rebuild()
+SetOfProcesses.model_rebuild()
+Dataflow.model_rebuild()
+Boundary.model_rebuild()
+TM.model_rebuild()
def pdoc_overrides():
@@ -62,9 +73,11 @@ def pdoc_overrides():
for i in dir(klass):
if i in ("check", "dfd", "seq"):
result[f"{name}.{i}"] = False
- attr = getattr(klass, i, {})
- if isinstance(attr, var) and attr.doc != "":
- result[f"{name}.{i}"] = attr.doc
+ model_fields = getattr(klass, "model_fields", {})
+ if i in model_fields:
+ description = model_fields[i].description
+ if description:
+ result[f"{name}.{i}"] = description
return result
diff --git a/pytm/actor.py b/pytm/actor.py
new file mode 100644
index 00000000..c46c72e5
--- /dev/null
+++ b/pytm/actor.py
@@ -0,0 +1,102 @@
+"""Actor model - represents entities that initiate actions."""
+
+from typing import TYPE_CHECKING, List
+from pydantic import Field, field_validator
+
+from .element import Element
+from .base import DataSet
+
+if TYPE_CHECKING:
+ from .dataflow import Dataflow
+
+
+class Actor(Element):
+ """An entity usually initiating actions.
+
+ Actors represent users or external systems that initiate
+ interactions with the system being modeled.
+
+ Attributes:
+ port (int): Default TCP port for outgoing data flows
+ protocol (str): Default network protocol for outgoing data flows
+ data (DataSet): pytm.Data object(s) in outgoing data flows
+ inputs (List[Dataflow]): Incoming Dataflows
+ outputs (List[Dataflow]): Outgoing Dataflows
+ isAdmin (bool): Indicates whether the actor has administrative privileges
+ """
+
+ port: int = Field(
+ default=-1, description="Default TCP port for outgoing data flows"
+ )
+ protocol: str = Field(
+ default="", description="Default network protocol for outgoing data flows"
+ )
+ data: DataSet = Field(
+ default_factory=DataSet,
+ description="pytm.Data object(s) in outgoing data flows",
+ )
+ inputs: List["Dataflow"] = Field(
+ default_factory=list, description="Incoming Dataflows"
+ )
+ outputs: List["Dataflow"] = Field(
+ default_factory=list, description="Outgoing Dataflows"
+ )
+ isAdmin: bool = Field(
+ default=False,
+ description="Indicates whether the actor has administrative privileges",
+ )
+
+ @field_validator("data", mode="before")
+ @classmethod
+ def _coerce_dataset(cls, v):
+ """Ensure actor data is stored as a DataSet."""
+ from .data import Data # Local import to avoid circular dependency
+
+ if isinstance(v, DataSet):
+ return v
+
+ dataset = DataSet()
+
+ if v is None:
+ return dataset
+
+ if isinstance(v, Data):
+ dataset.add(v)
+ return dataset
+
+ if hasattr(v, "__iter__") and not isinstance(v, (str, bytes)):
+ for item in v:
+ if item is None:
+ continue
+ dataset.add(item)
+ return dataset
+
+ dataset.add(v)
+ return dataset
+
+ def __init__(self, name: str = None, **data):
+ """
+ Initialize an Actor.
+
+ Args:
+ name (str): Name of the actor.
+ **data: Optional actor properties:
+ - port (int): Default TCP port for outgoing data flows
+ - protocol (str): Default network protocol for outgoing data flows
+ - data (DataSet): pytm.Data object(s) in outgoing data flows
+ - inputs (List[Dataflow]): Incoming Dataflows
+ - outputs (List[Dataflow]): Outgoing Dataflows
+ - isAdmin (bool): Indicates whether the actor has administrative privileges
+ """
+ super().__init__(name, **data)
+ # Register with TM actors
+ self._register_with_tm_actors()
+
+ def _register_with_tm_actors(self):
+ """Register this actor with the TM class."""
+ try:
+ from .tm import TM
+
+ TM._actors.append(self)
+ except ImportError:
+ pass
diff --git a/pytm/asset.py b/pytm/asset.py
new file mode 100644
index 00000000..06089d16
--- /dev/null
+++ b/pytm/asset.py
@@ -0,0 +1,382 @@
+"""Asset models - base Asset class and specific asset implementations."""
+
+from typing import List, TYPE_CHECKING
+
+from pydantic import Field, field_validator
+
+from .element import Element, sev_to_color
+from .base import DataSet
+
+if TYPE_CHECKING:
+ from .dataflow import Dataflow
+
+
+class Asset(Element):
+ """An asset with outgoing or incoming dataflows.
+
+ Attributes:
+ port (int): Default TCP port for incoming data flows
+ protocol (str): Default network protocol for incoming data flows
+ data (DataSet): pytm.Data object(s) in incoming data flows
+ inputs (List[Dataflow]): Incoming Dataflows
+ outputs (List[Dataflow]): Outgoing Dataflows
+ onAWS (bool): Is this asset on AWS?
+ handlesResources (bool): Does this asset handle resources?
+ usesEnvironmentVariables (bool): Does this asset use environment variables?
+ OS (str): Operating system
+ """
+
+ port: int = Field(
+ default=-1, description="Default TCP port for incoming data flows"
+ )
+ protocol: str = Field(
+ default="", description="Default network protocol for incoming data flows"
+ )
+ data: DataSet = Field(
+ default_factory=DataSet,
+ description="pytm.Data object(s) in incoming data flows",
+ )
+ inputs: List["Dataflow"] = Field(
+ default_factory=list, description="incoming Dataflows"
+ )
+ outputs: List["Dataflow"] = Field(
+ default_factory=list, description="outgoing Dataflows"
+ )
+ onAWS: bool = Field(default=False, description="Is this asset on AWS?")
+ handlesResources: bool = Field(
+ default=False, description="Does this asset handle resources?"
+ )
+ usesEnvironmentVariables: bool = Field(
+ default=False, description="Does this asset use environment variables?"
+ )
+ OS: str = Field(default="", description="Operating system")
+
+ @field_validator("data", mode="before")
+ @classmethod
+ def validate_data(cls, v):
+ """Coerce incoming values to a DataSet."""
+ from .data import Data
+
+ if isinstance(v, DataSet):
+ return v
+
+ dataset = DataSet()
+
+ if v is None:
+ return dataset
+
+ if isinstance(v, Data):
+ dataset.add(v)
+ return dataset
+
+ if hasattr(v, "__iter__") and not isinstance(v, (str, bytes)):
+ for item in v:
+ if item is None:
+ continue
+ dataset.add(item)
+ return dataset
+
+ dataset.add(v)
+ return dataset
+
+ def __init__(self, name: str = None, **data):
+ """Initialize an Asset.
+
+ Args:
+ name (str): Name of the asset.
+ **data: Optional asset properties:
+ - port (int): Default TCP port for incoming data flows
+ - protocol (str): Default network protocol for incoming data flows
+ - data (DataSet): pytm.Data object(s) in incoming data flows
+ - inputs (List[Dataflow]): Incoming Dataflows
+ - outputs (List[Dataflow]): Outgoing Dataflows
+ - onAWS (bool): Is this asset on AWS?
+ - handlesResources (bool): Does this asset handle resources?
+ - usesEnvironmentVariables (bool): Does this asset use environment variables?
+ - OS (str): Operating system
+ """
+ super().__init__(name, **data)
+ # Register with TM assets
+ self._register_with_tm_assets()
+
+ def _register_with_tm_assets(self):
+ """Register this asset with the TM class."""
+ try:
+ from .tm import TM
+
+ TM._assets.append(self)
+ except ImportError:
+ pass
+
+
+class Lambda(Asset):
+ """A lambda function running in a Function-as-a-Service (FaaS) environment.
+
+ Attributes:
+ port (int): Default TCP port for incoming data flows
+ protocol (str): Default network protocol for incoming data flows
+ data (DataSet): pytm.Data object(s) in incoming data flows
+ inputs (List[Dataflow]): Incoming Dataflows
+ outputs (List[Dataflow]): Outgoing Dataflows
+ onAWS (bool): Is this lambda on AWS?
+ handlesResources (bool): Does this asset handle resources?
+ usesEnvironmentVariables (bool): Does this asset use environment variables?
+ OS (str): Operating system
+ environment (str): Environment for the lambda
+ implementsAPI (bool): Does this lambda implement an API?
+ """
+
+ onAWS: bool = Field(default=True, description="Is this lambda on AWS?")
+ environment: str = Field(default="", description="Environment for the lambda")
+ implementsAPI: bool = Field(
+ default=False, description="Does this lambda implement an API?"
+ )
+
+ def __init__(self, name: str = None, **data):
+ """Initialize a Lambda.
+
+ Args:
+ name (str): Name of the lambda.
+ **data: Optional lambda properties:
+ - port (int): Default TCP port for incoming data flows
+ - protocol (str): Default network protocol for incoming data flows
+ - data (DataSet): pytm.Data object(s) in incoming data flows
+ - inputs (List[Dataflow]): Incoming Dataflows
+ - outputs (List[Dataflow]): Outgoing Dataflows
+ - onAWS (bool): Is this lambda on AWS?
+ - handlesResources (bool): Does this asset handle resources?
+ - usesEnvironmentVariables (bool): Does this asset use environment variables?
+ - OS (str): Operating system
+ - environment (str): Environment for the lambda
+ - implementsAPI (bool): Does this lambda implement an API?
+ """
+ super().__init__(name, **data)
+
+ def _dfd_template(self) -> str:
+ """Template for DFD representation."""
+ return """{uniq_name} [
+ shape = {shape};
+
+ color = {color};
+ fontcolor = "black";
+ label = <
+
+ >;
+]
+"""
+
+ def dfd(self, **kwargs) -> str:
+ """Generate DFD representation of this element."""
+ self.is_drawn = True
+
+ levels = kwargs.get("levels", None)
+ if levels and not levels & self.levels:
+ return ""
+
+ color = self._color()
+
+ if kwargs.get("colormap", False):
+ color = sev_to_color(self.severity)
+
+ return self._dfd_template().format(
+ uniq_name=self._uniq_name(),
+ label=self._label(),
+ color=color,
+ shape=self._shape(),
+ )
+
+ def _shape(self) -> str:
+ """Get shape for DFD representation."""
+ return "rectangle; style=rounded"
+
+
+class Server(Asset):
+ """An entity processing data.
+
+ Attributes:
+ port (int): Default TCP port for incoming data flows
+ protocol (str): Default network protocol for incoming data flows
+ data (DataSet): pytm.Data object(s) in incoming data flows
+ inputs (List[Dataflow]): Incoming Dataflows
+ outputs (List[Dataflow]): Outgoing Dataflows
+ onAWS (bool): Is this asset on AWS?
+ handlesResources (bool): Does this asset handle resources?
+ usesEnvironmentVariables (bool): Does this asset use environment variables?
+ OS (str): Operating system
+ usesSessionTokens (bool): Does this server use session tokens?
+ usesCache (bool): Does this server use cache?
+ usesVPN (bool): Does this server use VPN?
+ usesXMLParser (bool): Does this server use XML parser?
+ """
+
+ usesSessionTokens: bool = Field(
+ default=False, description="Does this server use session tokens?"
+ )
+ usesCache: bool = Field(default=False, description="Does this server use cache?")
+ usesVPN: bool = Field(default=False, description="Does this server use VPN?")
+ usesXMLParser: bool = Field(
+ default=False, description="Does this server use XML parser?"
+ )
+
+ def __init__(self, name: str = None, **data):
+ """Initialize a Server.
+
+ Args:
+ name (str): Name of the server.
+ **data: Optional server properties:
+ - port (int): Default TCP port for incoming data flows
+ - protocol (str): Default network protocol for incoming data flows
+ - data (DataSet): pytm.Data object(s) in incoming data flows
+ - inputs (List[Dataflow]): Incoming Dataflows
+ - outputs (List[Dataflow]): Outgoing Dataflows
+ - onAWS (bool): Is this asset on AWS?
+ - handlesResources (bool): Does this asset handle resources?
+ - usesEnvironmentVariables (bool): Does this asset use environment variables?
+ - OS (str): Operating system
+ - usesSessionTokens (bool): Does this server use session tokens?
+ - usesCache (bool): Does this server use cache?
+ - usesVPN (bool): Does this server use VPN?
+ - usesXMLParser (bool): Does this server use XML parser?
+ """
+ super().__init__(name, **data)
+
+ def _shape(self) -> str:
+ """Get shape for DFD representation."""
+ return "circle"
+
+
+class ExternalEntity(Asset):
+ """An external entity that interacts with the system.
+
+ Attributes:
+ port (int): Default TCP port for incoming data flows
+ protocol (str): Default network protocol for incoming data flows
+ data (DataSet): pytm.Data object(s) in incoming data flows
+ inputs (List[Dataflow]): Incoming Dataflows
+ outputs (List[Dataflow]): Outgoing Dataflows
+ onAWS (bool): Is this asset on AWS?
+ handlesResources (bool): Does this asset handle resources?
+ usesEnvironmentVariables (bool): Does this asset use environment variables?
+ OS (str): Operating system
+ hasPhysicalAccess (bool): Does this external entity have physical access?
+ """
+
+ hasPhysicalAccess: bool = Field(
+ default=False, description="Does this external entity have physical access?"
+ )
+
+ def __init__(self, name: str = None, **data):
+ """Initialize an ExternalEntity.
+
+ Args:
+ name (str): Name of the external entity.
+ **data: Optional external entity properties:
+ - port (int): Default TCP port for incoming data flows
+ - protocol (str): Default network protocol for incoming data flows
+ - data (DataSet): pytm.Data object(s) in incoming data flows
+ - inputs (List[Dataflow]): Incoming Dataflows
+ - outputs (List[Dataflow]): Outgoing Dataflows
+ - onAWS (bool): Is this asset on AWS?
+ - handlesResources (bool): Does this asset handle resources?
+ - usesEnvironmentVariables (bool): Does this asset use environment variables?
+ - OS (str): Operating system
+ - hasPhysicalAccess (bool): Does this external entity have physical access?
+ """
+ super().__init__(name, **data)
+
+
+class LLM(Asset):
+ """A Large Language Model element, either third-party or self-hosted.
+
+ Attributes:
+ port (int): Default TCP port for incoming data flows
+ protocol (str): Default network protocol for incoming data flows
+ data (DataSet): pytm.Data object(s) in incoming data flows
+ inputs (List[Dataflow]): Incoming Dataflows
+ outputs (List[Dataflow]): Outgoing Dataflows
+ onAWS (bool): Is this asset on AWS?
+ handlesResources (bool): Does this asset handle resources?
+ usesEnvironmentVariables (bool): Does this asset use environment variables?
+ OS (str): Operating system
+ isThirdParty (bool): Is this LLM a third-party service?
+ isSelfHosted (bool): Is this LLM self-hosted?
+ processesPersonalData (bool): Does this LLM process personal data?
+ retainsUserData (bool): Does this LLM retain user data?
+ hasAgentCapabilities (bool): Does this LLM have agent capabilities?
+ hasAccessToSensitiveSystems (bool): Does this LLM have access to sensitive systems?
+ executesCode (bool): Does this LLM execute code?
+ hasContentFiltering (bool): Does this LLM have content filtering?
+ hasSystemPrompt (bool): Does this LLM have a system prompt?
+ processesUntrustedInput (bool): Does this LLM process untrusted input?
+ hasRAG (bool): Does this LLM use retrieval-augmented generation?
+ hasFineTuning (bool): Has this LLM been fine-tuned?
+ """
+
+ isThirdParty: bool = Field(
+ default=True, description="Is this LLM a third-party service?"
+ )
+ isSelfHosted: bool = Field(default=False, description="Is this LLM self-hosted?")
+ processesPersonalData: bool = Field(
+ default=False, description="Does this LLM process personal data?"
+ )
+ retainsUserData: bool = Field(
+ default=False, description="Does this LLM retain user data?"
+ )
+ hasAgentCapabilities: bool = Field(
+ default=False, description="Does this LLM have agent capabilities?"
+ )
+ hasAccessToSensitiveSystems: bool = Field(
+ default=False, description="Does this LLM have access to sensitive systems?"
+ )
+ executesCode: bool = Field(default=False, description="Does this LLM execute code?")
+ hasContentFiltering: bool = Field(
+ default=False, description="Does this LLM have content filtering?"
+ )
+ hasSystemPrompt: bool = Field(
+ default=True, description="Does this LLM have a system prompt?"
+ )
+ processesUntrustedInput: bool = Field(
+ default=True, description="Does this LLM process untrusted input?"
+ )
+ hasRAG: bool = Field(
+ default=False, description="Does this LLM use retrieval-augmented generation?"
+ )
+ hasFineTuning: bool = Field(
+ default=False, description="Has this LLM been fine-tuned?"
+ )
+
+ def __init__(self, name: str = None, **data):
+ """Initialize an LLM.
+
+ Args:
+ name (str): Name of the LLM.
+ **data: Optional LLM properties:
+ - port (int): Default TCP port for incoming data flows
+ - protocol (str): Default network protocol for incoming data flows
+ - data (DataSet): pytm.Data object(s) in incoming data flows
+ - inputs (List[Dataflow]): Incoming Dataflows
+ - outputs (List[Dataflow]): Outgoing Dataflows
+ - onAWS (bool): Is this asset on AWS?
+ - handlesResources (bool): Does this asset handle resources?
+ - usesEnvironmentVariables (bool): Does this asset use environment variables?
+ - OS (str): Operating system
+ - isThirdParty (bool): Is this LLM a third-party service?
+ - isSelfHosted (bool): Is this LLM self-hosted?
+ - processesPersonalData (bool): Does this LLM process personal data?
+ - retainsUserData (bool): Does this LLM retain user data?
+ - hasAgentCapabilities (bool): Does this LLM have agent capabilities?
+ - hasAccessToSensitiveSystems (bool): Does this LLM have access to sensitive systems?
+ - executesCode (bool): Does this LLM execute code?
+ - hasContentFiltering (bool): Does this LLM have content filtering?
+ - hasSystemPrompt (bool): Does this LLM have a system prompt?
+ - processesUntrustedInput (bool): Does this LLM process untrusted input?
+ - hasRAG (bool): Does this LLM use retrieval-augmented generation?
+ - hasFineTuning (bool): Has this LLM been fine-tuned?
+ """
+ super().__init__(name, **data)
+
+ def _shape(self) -> str:
+ """Get shape for DFD representation."""
+ return "hexagon"
diff --git a/pytm/base.py b/pytm/base.py
new file mode 100644
index 00000000..d72214d7
--- /dev/null
+++ b/pytm/base.py
@@ -0,0 +1,216 @@
+"""Base models and utilities for pytm Pydantic models."""
+
+from __future__ import annotations
+
+from typing import Any, Iterable, List, Set, Union, TYPE_CHECKING
+
+from pydantic import BaseModel, ConfigDict, Field
+
+if TYPE_CHECKING:
+ from .element import Element
+ from .data import Data
+ from .threat import Threat
+ from .finding import Finding
+
+
+class DataSet(set):
+ """Custom set for Data objects with string lookup capability."""
+
+ __slots__ = ("_names",)
+
+ def __init__(self, values: Iterable["Data"] | None = None):
+ super().__init__()
+ self._names: Set[str] = set()
+ if values is not None:
+ self.update(values)
+
+ def __contains__(self, item: object) -> bool:
+ if isinstance(item, str):
+ return item in self._names
+ return super().__contains__(item)
+
+ def __eq__(self, other: object) -> bool:
+ if isinstance(other, set):
+ return super().__eq__(other)
+ if isinstance(other, str):
+ return other in self
+ return NotImplemented
+
+ def __ne__(self, other: object) -> bool:
+ if isinstance(other, set):
+ return super().__ne__(other)
+ if isinstance(other, str):
+ return other not in self
+ return NotImplemented
+
+ def __str__(self) -> str:
+ return ", ".join(sorted(self._names))
+
+ def add(self, element: Any) -> None: # type: ignore[override]
+ super().add(element)
+ self._register(element)
+
+ def update(self, *others: Iterable[Any]) -> None: # type: ignore[override]
+ for iterable in others:
+ for element in iterable:
+ super().add(element)
+ self._register(element)
+
+ def discard(self, element: Any) -> None: # type: ignore[override]
+ if super().__contains__(element):
+ super().discard(element)
+ self._unregister(element)
+
+ def remove(self, element: Any) -> None: # type: ignore[override]
+ super().remove(element)
+ self._unregister(element)
+
+ def pop(self) -> Any: # type: ignore[override]
+ element = super().pop()
+ self._unregister(element)
+ return element
+
+ def clear(self) -> None: # type: ignore[override]
+ super().clear()
+ self._names.clear()
+
+ def _register(self, element: "Data") -> None:
+ name = getattr(element, "name", None)
+ if isinstance(name, str):
+ self._names.add(name)
+
+ def _unregister(self, element: "Data") -> None:
+ name = getattr(element, "name", None)
+ if isinstance(name, str):
+ self._names.discard(name)
+
+
+class Controls(BaseModel):
+ """Controls implemented by/on an Element."""
+
+ model_config = ConfigDict(extra="allow", validate_assignment=True)
+
+ authenticatesDestination: bool = Field(
+ default=False,
+ description="Verifies the identity of the destination, for example by verifying the authenticity of a digital certificate.",
+ )
+ authenticatesSource: bool = Field(default=False)
+ authenticationScheme: str = Field(default="")
+ authorizesSource: bool = Field(default=False)
+ checksDestinationRevocation: bool = Field(
+ default=False,
+ description="Correctly checks the revocation status of credentials used to authenticate the destination",
+ )
+ checksInputBounds: bool = Field(default=False)
+ definesConnectionTimeout: bool = Field(default=False)
+ disablesDTD: bool = Field(default=False)
+ disablesiFrames: bool = Field(default=False)
+ encodesHeaders: bool = Field(default=False)
+ encodesOutput: bool = Field(default=False)
+ encryptsCookies: bool = Field(default=False)
+ encryptsSessionData: bool = Field(default=False)
+ handlesCrashes: bool = Field(default=False)
+ handlesInterruptions: bool = Field(default=False)
+ handlesResourceConsumption: bool = Field(default=False)
+ hasAccessControl: bool = Field(default=False)
+ implementsAuthenticationScheme: bool = Field(default=False)
+ implementsCSRFToken: bool = Field(default=False)
+ implementsNonce: bool = Field(
+ default=False,
+ description="Nonce is an arbitrary number that can be used just once in a cryptographic communication.",
+ )
+ implementsPOLP: bool = Field(
+ default=False,
+ description="The principle of least privilege (PoLP) requires that every module must be able to access only the information and resources that are necessary for its legitimate purpose.",
+ )
+ implementsServerSideValidation: bool = Field(default=False)
+ implementsStrictHTTPValidation: bool = Field(default=False)
+ invokesScriptFilters: bool = Field(default=False)
+ isEncrypted: bool = Field(
+ default=False, description="Requires incoming data flow to be encrypted"
+ )
+ isEncryptedAtRest: bool = Field(
+ default=False, description="Stored data is encrypted at rest"
+ )
+ isHardened: bool = Field(default=False)
+ isResilient: bool = Field(default=False)
+ providesConfidentiality: bool = Field(default=False)
+ providesIntegrity: bool = Field(default=False)
+ sanitizesInput: bool = Field(default=False)
+ tracksExecutionFlow: bool = Field(default=False)
+ usesCodeSigning: bool = Field(default=False)
+ usesEncryptionAlgorithm: str = Field(default="")
+ usesMFA: bool = Field(
+ default=False,
+ description="Multi-factor authentication is an authentication method in which a computer user is granted access only after successfully presenting two or more pieces of evidence.",
+ )
+ usesParameterizedInput: bool = Field(default=False)
+ usesSecureFunctions: bool = Field(default=False)
+ usesStrongSessionIdentifiers: bool = Field(default=False)
+ usesVPN: bool = Field(default=False)
+ validatesContentType: bool = Field(default=False)
+ validatesHeaders: bool = Field(default=False)
+ validatesInput: bool = Field(default=False)
+ verifySessionIdentifiers: bool = Field(default=False)
+
+ def _attr_values(self) -> dict:
+ """Return a dictionary of all attribute values."""
+ return self.model_dump()
+
+ def _safeset(self, attr: str, value: Any) -> None:
+ """Safely set an attribute value."""
+ try:
+ setattr(self, attr, value)
+ except (ValueError, TypeError):
+ pass
+
+
+class Assumption(BaseModel):
+ """Assumption used by an Element. Used to exclude threats on a per-element basis.
+
+ Attributes:
+ name (str): Name of the assumption
+ exclude (Set[str]): A set of threat SIDs to exclude for this assumption. For example: INP01
+ description (str): An additional description of the assumption
+ """
+
+ model_config = ConfigDict(extra="allow")
+
+ name: str = Field(description="Name of the assumption")
+ exclude: Set[str] = Field(
+ default_factory=set,
+ description="A set of threat SIDs to exclude for this assumption. For example: INP01",
+ )
+ description: str = Field(
+ default="", description="An additional description of the assumption"
+ )
+
+ def __init__(
+ self, name: str = None, exclude: Union[List[str], Set[str]] = None, **kwargs
+ ):
+ """Initialize an Assumption.
+
+ Args:
+ name (str): Name of the assumption.
+ exclude (Set[str]): A set of threat SIDs to exclude for this assumption. For example: INP01
+ **kwargs: Optional properties:
+ - description (str): An additional description of the assumption
+ """
+ if name is not None:
+ kwargs["name"] = name
+ if exclude is not None:
+ # Convert list to set if needed
+ kwargs["exclude"] = set(exclude) if isinstance(exclude, list) else exclude
+ super().__init__(**kwargs)
+
+ def __str__(self):
+ return self.name
+
+
+# Type aliases for complex field types that reference forward declarations
+ElementList = List["Element"]
+DataList = List["Data"]
+ThreatList = List["Threat"]
+FindingList = List["Finding"]
+ControlsType = Controls
+AssumptionList = List[Assumption]
diff --git a/pytm/boundary.py b/pytm/boundary.py
new file mode 100644
index 00000000..5a870061
--- /dev/null
+++ b/pytm/boundary.py
@@ -0,0 +1,85 @@
+"""Boundary model - represents trust boundaries in the threat model."""
+
+from typing import List, TYPE_CHECKING
+from textwrap import indent
+
+from .element import Element
+
+if TYPE_CHECKING:
+ pass
+
+
+class Boundary(Element):
+ """Trust boundary groups elements and data with the same trust level."""
+
+ def __init__(self, name: str = None, **data):
+ super().__init__(name, **data)
+ # Register with TM boundaries
+ self._register_with_tm_boundaries()
+
+ def _register_with_tm_boundaries(self):
+ """Register this boundary with the TM class."""
+ try:
+ from .tm import TM
+
+ if self.name not in TM._boundaries:
+ TM._boundaries.append(self)
+ except ImportError:
+ pass
+
+ def _dfd_template(self) -> str:
+ """Template for DFD representation."""
+ return """subgraph cluster_{uniq_name} {{
+ graph [
+ fontsize = 10;
+ fontcolor = black;
+ style = dashed;
+ color = {color};
+ label = <{label}>;
+ ]
+
+{edges}
+}}
+"""
+
+ def dfd(self, **kwargs) -> str:
+ """Generate DFD representation of this boundary."""
+ if self.is_drawn:
+ return ""
+
+ self.is_drawn = True
+
+ edges = []
+ try:
+ from .tm import TM
+
+ for e in TM._elements:
+ if e.inBoundary != self or e.is_drawn:
+ continue
+ # The content to draw can include Boundary objects
+ edges.append(e.dfd(**kwargs))
+ except ImportError:
+ pass
+
+ return self._dfd_template().format(
+ uniq_name=self._uniq_name(),
+ label=self._label(),
+ color=self._color(**kwargs),
+ edges=indent("\n".join(edges), " "),
+ )
+
+ def _color(self, **kwargs) -> str:
+ """Get color for DFD representation."""
+ if kwargs.get("colormap", False):
+ return "black"
+ else:
+ return "firebrick2"
+
+ def parents(self) -> List["Boundary"]:
+ """Get parent boundaries."""
+ result = []
+ parent = self.inBoundary
+ while parent is not None:
+ result.append(parent)
+ parent = parent.inBoundary
+ return result
diff --git a/pytm/data.py b/pytm/data.py
new file mode 100644
index 00000000..c11f24ae
--- /dev/null
+++ b/pytm/data.py
@@ -0,0 +1,133 @@
+"""Data model - represents data that traverses the threat model."""
+
+from typing import List, TYPE_CHECKING
+from pydantic import BaseModel, Field, ConfigDict
+
+from .enums import Classification, Lifetime
+
+if TYPE_CHECKING:
+ from .element import Element
+ from .dataflow import Dataflow
+
+
+class Data(BaseModel):
+ """Represents a single piece of data that traverses the system.
+
+ Attributes:
+ name (str): Name of the data
+ description (str): Description of the data
+ format (str): Format of the data
+ classification (Classification): Level of classification for this piece of data
+ isPII (bool): Does the data contain personally identifiable information. Should always be encrypted both in transmission and at rest.
+ isCredentials (bool): Does the data contain authentication information, like passwords or cryptographic keys, with or without expiration date. Should always be encrypted in transmission. If stored, they should be hashed using a cryptographic hash function.
+ credentialsLife (Lifetime): Credentials lifetime, describing if and how credentials can be revoked
+ isStored (bool): Is the data going to be stored by the target or only processed. If only derivative data is stored (a hash) it can be set to False.
+ isDestEncryptedAtRest (bool): Is data encrypted at rest at dest?
+ isSourceEncryptedAtRest (bool): Is data encrypted at rest at source?
+ carriedBy (List[Dataflow]): Dataflows that carries this piece of data
+ processedBy (List[Element]): Elements that store/process this piece of data
+ """
+
+ model_config = ConfigDict(extra="allow", validate_assignment=True)
+
+ name: str = Field(description="Name of the data")
+ description: str = Field(default="", description="Description of the data")
+ format: str = Field(default="", description="Format of the data")
+ classification: Classification = Field(
+ default=Classification.UNKNOWN,
+ description="Level of classification for this piece of data",
+ )
+ isPII: bool = Field(
+ default=False,
+ description="Does the data contain personally identifiable information. Should always be encrypted both in transmission and at rest.",
+ )
+ isCredentials: bool = Field(
+ default=False,
+ description="Does the data contain authentication information, like passwords or cryptographic keys, with or without expiration date. Should always be encrypted in transmission. If stored, they should be hashed using a cryptographic hash function.",
+ )
+ credentialsLife: Lifetime = Field(
+ default=Lifetime.NONE,
+ description="Credentials lifetime, describing if and how credentials can be revoked",
+ )
+ isStored: bool = Field(
+ default=False,
+ description="Is the data going to be stored by the target or only processed. If only derivative data is stored (a hash) it can be set to False.",
+ )
+ isDestEncryptedAtRest: bool = Field(
+ default=False, description="Is data encrypted at rest at dest?"
+ )
+ isSourceEncryptedAtRest: bool = Field(
+ default=False, description="Is data encrypted at rest at source?"
+ )
+ carriedBy: List["Dataflow"] = Field(
+ default_factory=list, description="Dataflows that carries this piece of data"
+ )
+ processedBy: List["Element"] = Field(
+ default_factory=list,
+ description="Elements that store/process this piece of data",
+ )
+
+ def __init__(self, name: str = None, **data):
+ """Initialize a Data object.
+
+ Args:
+ name (str): Name of the data.
+ **data: Optional data properties:
+ - description (str): Description of the data
+ - format (str): Format of the data
+ - classification (Classification): Level of classification for this piece of data
+ - isPII (bool): Does the data contain personally identifiable information. Should always be encrypted both in transmission and at rest.
+ - isCredentials (bool): Does the data contain authentication information, like passwords or cryptographic keys, with or without expiration date. Should always be encrypted in transmission. If stored, they should be hashed using a cryptographic hash function.
+ - credentialsLife (Lifetime): Credentials lifetime, describing if and how credentials can be revoked
+ - isStored (bool): Is the data going to be stored by the target or only processed. If only derivative data is stored (a hash) it can be set to False.
+ - isDestEncryptedAtRest (bool): Is data encrypted at rest at dest?
+ - isSourceEncryptedAtRest (bool): Is data encrypted at rest at source?
+ - carriedBy (List[Dataflow]): Dataflows that carries this piece of data
+ - processedBy (List[Element]): Elements that store/process this piece of data
+ """
+ # Handle positional name argument
+ if name is not None:
+ data["name"] = name
+ super().__init__(**data)
+
+ # Register with TM
+ self._register_with_tm()
+
+ def _register_with_tm(self):
+ """Register this data with the TM class."""
+ try:
+ from .tm import TM
+
+ TM._data.append(self)
+ except ImportError:
+ # TM might not be available yet during initial setup
+ pass
+
+ def __repr__(self):
+ return (
+ f"<{self.__module__}.{type(self).__name__}({self.name}) at {hex(id(self))}>"
+ )
+
+ def __str__(self):
+ return f"Data({self.name})"
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, Data):
+ return NotImplemented
+ return (
+ self.name == other.name
+ and self.description == other.description
+ and self.format == other.format
+ and self.classification == other.classification
+ )
+
+ def __hash__(self):
+ """Make Data objects hashable for use in sets."""
+ return hash((self.name, self.description, self.format, self.classification))
+
+ def _safeset(self, attr: str, value) -> None:
+ """Safely set an attribute value."""
+ try:
+ setattr(self, attr, value)
+ except (ValueError, TypeError):
+ pass
diff --git a/pytm/dataflow.py b/pytm/dataflow.py
new file mode 100644
index 00000000..3d7315a3
--- /dev/null
+++ b/pytm/dataflow.py
@@ -0,0 +1,218 @@
+"""Dataflow model - represents data flows between elements."""
+
+from typing import Optional
+from pydantic import Field, field_validator, model_validator
+
+from .element import Element, sev_to_color
+from .enums import Classification, TLSVersion
+from .base import DataSet
+
+
+class Dataflow(Element):
+ """A data flow from a source to a sink."""
+
+ source: Element = Field(description="Source element of the data flow")
+ sink: Element = Field(description="Sink element of the data flow")
+ isResponse: bool = Field(
+ default=False, description="Is a response to another data flow"
+ )
+ response: Optional["Dataflow"] = Field(
+ default=None, description="Another data flow that is a response to this one"
+ )
+ responseTo: Optional["Dataflow"] = Field(
+ default=None, description="Is a response to this data flow"
+ )
+ srcPort: int = Field(default=-1, description="Source TCP port")
+ dstPort: int = Field(default=-1, description="Destination TCP port")
+ tlsVersion: TLSVersion = Field(
+ default=TLSVersion.NONE, description="TLS version used"
+ )
+ protocol: str = Field(default="", description="Protocol used in this data flow")
+ data: DataSet = Field(
+ default_factory=DataSet,
+ description="pytm.Data object(s) in incoming data flows",
+ )
+ order: int = Field(
+ default=-1, description="Number of this data flow in the threat model"
+ )
+ implementsCommunicationProtocol: bool = Field(
+ default=False, description="Does this flow implement a communication protocol"
+ )
+ note: str = Field(default="", description="Note about this data flow")
+ usesVPN: bool = Field(default=False, description="Does this flow use VPN")
+ usesSessionTokens: bool = Field(
+ default=False, description="Does this flow use session tokens"
+ )
+
+ @field_validator("data", mode="before")
+ @classmethod
+ def validate_data(cls, v):
+ """Convert single Data object to DataSet, handle compatibility."""
+ from .data import Data
+
+ if isinstance(v, str):
+ # Handle legacy string assignment
+ return DataSet(
+ [
+ Data(
+ name="undefined",
+ description=v,
+ classification=Classification.UNKNOWN,
+ )
+ ]
+ )
+
+ if isinstance(v, Data):
+ # Single Data object
+ return DataSet([v])
+
+ if hasattr(v, "__iter__") and not isinstance(v, (str, bytes)):
+ # Iterable of Data objects
+ return DataSet(v)
+
+ if isinstance(v, DataSet):
+ return v
+
+ return DataSet([v])
+
+ def __setattr__(self, name, value):
+ """Handle bidirectional response relationships during assignment."""
+ # Set the attribute first
+ super().__setattr__(name, value)
+
+ # Skip relationship setup if we're already in the process of linking
+ if getattr(self, "_updating_relationships", False):
+ return
+
+ # Set up bidirectional relationships for response/responseTo
+ if name == "responseTo" and value is not None:
+ self._link_response_to(value)
+ elif name == "response" and value is not None:
+ self._link_response_from(value)
+
+ def _link_response_to(self, target: "Dataflow") -> None:
+ """Link this dataflow as a response to the target dataflow."""
+ object.__setattr__(self, "_updating_relationships", True)
+ try:
+ # Mark this as a response
+ if not self.isResponse:
+ object.__setattr__(self, "isResponse", True)
+ # Set reverse link on target
+ if target.response is None:
+ object.__setattr__(target, "response", self)
+ finally:
+ object.__setattr__(self, "_updating_relationships", False)
+
+ def _link_response_from(self, source: "Dataflow") -> None:
+ """Link the source dataflow as a response to this dataflow."""
+ object.__setattr__(self, "_updating_relationships", True)
+ try:
+ # Mark source as a response
+ if not source.isResponse:
+ object.__setattr__(source, "isResponse", True)
+ # Set reverse link on source
+ if source.responseTo is None:
+ object.__setattr__(source, "responseTo", self)
+ finally:
+ object.__setattr__(self, "_updating_relationships", False)
+
+ @model_validator(mode="after")
+ def setup_response_relationships(self) -> "Dataflow":
+ """Set up bidirectional response relationships after validation."""
+ # Set up bidirectional response relationship if responseTo is set
+ if self.responseTo is not None:
+ if not self.isResponse:
+ object.__setattr__(self, "isResponse", True)
+ if self.responseTo.response is None:
+ object.__setattr__(self.responseTo, "response", self)
+
+ # Handle reverse relationship
+ if self.response is not None:
+ if not self.response.isResponse:
+ object.__setattr__(self.response, "isResponse", True)
+ if self.response.responseTo is None:
+ object.__setattr__(self.response, "responseTo", self)
+
+ return self
+
+ def __init__(self, source: Element, sink: Element, name: str, **data):
+ """Create a Dataflow between two elements.
+
+ Args:
+ source: Source element of the data flow.
+ sink: Sink element of the data flow.
+ name: Name of this data flow.
+ **data: Additional field values (e.g. protocol, tlsVersion, srcPort).
+ """
+ # Handle positional arguments
+ data["source"] = source
+ data["sink"] = sink
+ data["name"] = name
+ super().__init__(**data)
+ # Register with TM flows
+ self._register_with_tm_flows()
+
+ def _register_with_tm_flows(self):
+ """Register this dataflow with the TM class."""
+ try:
+ from .tm import TM
+
+ TM._flows.append(self)
+ except ImportError:
+ pass
+
+ def display_name(self) -> str:
+ """Get display name for this dataflow."""
+ if self.order == -1:
+ return self.name
+ return f"({self.order}) {self.name}"
+
+ def _dfd_template(self) -> str:
+ """Template for DFD representation."""
+ return """{source} -> {sink} [
+ color = {color};
+ fontcolor = {color};
+ dir = {direction};
+ label = "{label}";
+]
+"""
+
+ def dfd(self, mergeResponses: bool = False, **kwargs) -> str:
+ """Generate DFD representation of this dataflow."""
+ self.is_drawn = True
+
+ levels = kwargs.get("levels", None)
+ if (
+ levels
+ and not levels & self.levels
+ and not (levels & self.source.levels and levels & self.sink.levels)
+ ):
+ return ""
+
+ color = self._color()
+
+ if kwargs.get("colormap", False):
+ color = sev_to_color(self.severity)
+
+ direction = "forward"
+ label = self._label()
+ if mergeResponses and self.response is not None:
+ direction = "both"
+ label += "\n" + self.response._label()
+
+ return self._dfd_template().format(
+ source=self.source._uniq_name(),
+ sink=self.sink._uniq_name(),
+ direction=direction,
+ label=label,
+ color=color,
+ )
+
+ def hasDataLeaks(self) -> bool:
+ """Check if this dataflow has data leaks."""
+ return any(
+ d.classification > self.source.maxClassification
+ or d.classification > self.sink.maxClassification
+ or d.classification > self.maxClassification
+ for d in self.data
+ )
diff --git a/pytm/datastore.py b/pytm/datastore.py
new file mode 100644
index 00000000..0493406a
--- /dev/null
+++ b/pytm/datastore.py
@@ -0,0 +1,126 @@
+"""Datastore model - represents data storage elements in the threat model."""
+
+import os
+from typing import TYPE_CHECKING
+from pydantic import Field
+
+from .asset import Asset
+from .enums import DatastoreType
+
+if TYPE_CHECKING:
+ pass
+
+
+class Datastore(Asset):
+ """An entity storing data.
+
+ Attributes:
+ port (int): Default TCP port for incoming data flows
+ protocol (str): Default network protocol for incoming data flows
+ data (DataSet): pytm.Data object(s) in incoming data flows
+ inputs (List[Dataflow]): Incoming Dataflows
+ outputs (List[Dataflow]): Outgoing Dataflows
+ onAWS (bool): Is this asset on AWS?
+ handlesResources (bool): Does this asset handle resources?
+ usesEnvironmentVariables (bool): Does this asset use environment variables?
+ OS (str): Operating system
+ onRDS (bool): Is this datastore on RDS?
+ storesLogData (bool): Does this datastore store log data?
+ storesPII (bool): Personally Identifiable Information is any information relating to an identifiable person
+ storesSensitiveData (bool): Does this datastore store sensitive data?
+ isSQL (bool): Is this a SQL datastore?
+ isShared (bool): Is this datastore shared?
+ hasWriteAccess (bool): Does this datastore have write access?
+ type (DatastoreType): The type of Datastore
+ """
+
+ onRDS: bool = Field(default=False, description="Is this datastore on RDS?")
+ storesLogData: bool = Field(
+ default=False, description="Does this datastore store log data?"
+ )
+ storesPII: bool = Field(
+ default=False,
+ description="Personally Identifiable Information is any information relating to an identifiable person",
+ )
+ storesSensitiveData: bool = Field(
+ default=False, description="Does this datastore store sensitive data?"
+ )
+ isSQL: bool = Field(default=True, description="Is this a SQL datastore?")
+ isShared: bool = Field(default=False, description="Is this datastore shared?")
+ hasWriteAccess: bool = Field(
+ default=False, description="Does this datastore have write access?"
+ )
+ type: DatastoreType = Field(
+ default=DatastoreType.UNKNOWN, description="The type of Datastore"
+ )
+
+ def __init__(self, name: str = None, **data):
+ """Initialize a Datastore.
+
+ Args:
+ name (str): Name of the datastore.
+ **data: Optional datastore properties:
+ - port (int): Default TCP port for incoming data flows
+ - protocol (str): Default network protocol for incoming data flows
+ - data (DataSet): pytm.Data object(s) in incoming data flows
+ - inputs (List[Dataflow]): Incoming Dataflows
+ - outputs (List[Dataflow]): Outgoing Dataflows
+ - onAWS (bool): Is this asset on AWS?
+ - handlesResources (bool): Does this asset handle resources?
+ - usesEnvironmentVariables (bool): Does this asset use environment variables?
+ - OS (str): Operating system
+ - onRDS (bool): Is this datastore on RDS?
+ - storesLogData (bool): Does this datastore store log data?
+ - storesPII (bool): Personally Identifiable Information is any information relating to an identifiable person
+ - storesSensitiveData (bool): Does this datastore store sensitive data?
+ - isSQL (bool): Is this a SQL datastore?
+ - isShared (bool): Is this datastore shared?
+ - hasWriteAccess (bool): Does this datastore have write access?
+ - type (DatastoreType): The type of Datastore
+ """
+ super().__init__(name, **data)
+
+ def _dfd_template(self) -> str:
+ """Template for DFD representation."""
+ return """{uniq_name} [
+ shape = {shape};
+ fixedsize = shape;
+ image = "{image}";
+ imagescale = true;
+ color = {color};
+ fontcolor = black;
+ xlabel = "{label}";
+ label = "";
+]
+"""
+
+ def _shape(self) -> str:
+ """Get shape for DFD representation."""
+ return "none"
+
+ def dfd(self, **kwargs) -> str:
+ """Generate DFD representation of this element."""
+ from .element import sev_to_color
+
+ self.is_drawn = True
+
+ levels = kwargs.get("levels", None)
+ if levels and not levels & self.levels:
+ return ""
+
+ color = self._color()
+ color_file = "black"
+
+ if kwargs.get("colormap", False):
+ color = sev_to_color(self.severity)
+ color_file = color.split(";")[0]
+
+ return self._dfd_template().format(
+ uniq_name=self._uniq_name(),
+ label=self._label(),
+ color=color,
+ shape=self._shape(),
+ image=os.path.join(
+ os.path.dirname(__file__), "images", f"datastore_{color_file}.png"
+ ),
+ )
diff --git a/pytm/element.py b/pytm/element.py
new file mode 100644
index 00000000..6212fdb9
--- /dev/null
+++ b/pytm/element.py
@@ -0,0 +1,336 @@
+"""Element model - base class for all threat model elements."""
+
+import inspect
+import random
+import uuid as uuid_module
+from hashlib import sha224
+from textwrap import wrap
+from typing import Any, List, Optional, Set, TYPE_CHECKING
+
+from pydantic import BaseModel, ConfigDict, Field, field_validator
+
+from .base import Assumption, Controls
+from .enums import Classification, TLSVersion
+
+if TYPE_CHECKING:
+ from .boundary import Boundary
+ from .dataflow import Dataflow
+ from .finding import Finding
+
+
+def sev_to_color(sev: int) -> str:
+ """Return a Graphviz color declaration based on severity."""
+ if sev == 5:
+ return 'firebrick3; fillcolor="#b2222222"; style=filled '
+ if 2 <= sev <= 4:
+ return 'gold; fillcolor="#ffd80022"; style=filled'
+ if 0 <= sev < 2:
+ return 'darkgreen; fillcolor="#00630022"; style=filled'
+ return "black"
+
+
+class Element(BaseModel):
+ """A generic element in the threat model.
+
+ Attributes:
+ name (str): Name of the element
+ description (str): Description of the element
+ inBoundary (Boundary): Trust boundary this element exists in
+ inScope (bool): Is the element in scope of the threat model?
+ maxClassification (Classification): Maximum data classification this element can handle
+ minTLSVersion (TLSVersion): Minimum TLS version required
+ findings (List[Finding]): Threats that apply to this element
+ overrides (List[Finding]): Overrides to findings, allowing to set a custom response, CVSS score or override other attributes
+ assumptions (List[Assumption]): Assumptions about the element. These optionally allow to exclude threats with the given SIDs
+ levels (Set[int]): List of levels (0, 1, 2, ...) to be drawn in the model
+ sourceFiles (List[str]): Location of the source code that describes this element relative to the directory of the model script
+ controls (Controls): Security controls for this element
+ severity (int): Severity level of threats affecting this element
+ """
+
+ model_config = ConfigDict(
+ extra="allow",
+ validate_assignment=True,
+ arbitrary_types_allowed=True,
+ use_attribute_docstrings=True,
+ )
+
+ name: str = Field(description="Name of the element")
+ description: str = Field(default="", description="Description of the element")
+ inBoundary: Optional["Boundary"] = Field(
+ default=None,
+ description="Trust boundary this element exists in",
+ )
+ inScope: bool = Field(
+ default=True, description="Is the element in scope of the threat model"
+ )
+ maxClassification: Classification = Field(
+ default=Classification.UNKNOWN,
+ description="Maximum data classification this element can handle",
+ )
+ minTLSVersion: TLSVersion = Field(
+ default=TLSVersion.NONE,
+ description="Minimum TLS version required",
+ )
+ findings: List["Finding"] = Field(
+ default_factory=list,
+ description="Threats that apply to this element",
+ )
+ overrides: List["Finding"] = Field(
+ default_factory=list,
+ description="Overrides to findings, allowing to set a custom response, CVSS score or override other attributes",
+ )
+ assumptions: List[Assumption] = Field(
+ default_factory=list,
+ description="Assumptions about the element. These optionally allow to exclude threats with the given SIDs",
+ )
+ levels: Set[int] = Field(
+ default_factory=lambda: {0},
+ description="List of levels (0, 1, 2, ...) to be drawn in the model",
+ )
+ sourceFiles: List[str] = Field(
+ default_factory=list,
+ description="Location of the source code that describes this element relative to the directory of the model script",
+ )
+ controls: Controls = Field(
+ default_factory=Controls, description="Security controls for this element"
+ )
+ severity: int = Field(
+ default=0, description="Severity level of threats affecting this element"
+ )
+
+ # Internal attributes
+ uuid: uuid_module.UUID = Field(
+ default_factory=lambda: uuid_module.UUID(int=random.getrandbits(128))
+ )
+ is_drawn: bool = Field(default=False, exclude=True)
+
+ _WRITE_ONCE_FIELDS = {"name"}
+
+ @field_validator("levels", mode="before")
+ @classmethod
+ def _coerce_levels(cls, value):
+ """Normalize level inputs to a set of integers."""
+ if value is None:
+ return {0}
+ if isinstance(value, (set, frozenset)):
+ return set(value)
+ if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
+ return set(value)
+ return {value}
+
+ def __setattr__(
+ self, key: str, value: Any
+ ) -> None: # noqa: D401 - keep same behaviour
+ if (
+ key in self._WRITE_ONCE_FIELDS
+ and key in self.__dict__
+ and not key.startswith("_")
+ ):
+ raise ValueError(f"cannot overwrite {type(self).__name__}.{key} value")
+ super().__setattr__(key, value)
+
+ def __init__(self, name: Optional[str] = None, **data: Any):
+ """Initialize an Element.
+
+ Args:
+ name (str): Name of the element.
+ **data: Optional element properties:
+ - description (str): Description of the element
+ - inBoundary (Boundary): Trust boundary this element exists in
+ - inScope (bool): Is the element in scope of the threat model?
+ - maxClassification (Classification): Maximum data classification this element can handle
+ - minTLSVersion (TLSVersion): Minimum TLS version required
+ - findings (List[Finding]): Threats that apply to this element
+ - overrides (List[Finding]): Overrides to findings, allowing to set a custom response, CVSS score or override other attributes
+ - assumptions (List[Assumption]): Assumptions about the element. These optionally allow to exclude threats with the given SIDs
+ - levels (Set[int]): List of levels (0, 1, 2, ...) to be drawn in the model
+ - sourceFiles (List[str]): Location of the source code that describes this element relative to the directory of the model script
+ - controls (Controls): Security controls for this element
+ - severity (int): Severity level of threats affecting this element
+ """
+ if name is not None:
+ data["name"] = name
+ super().__init__(**data)
+ self._register_with_tm()
+
+ def _register_with_tm(self) -> None:
+ """Register this element with the TM class."""
+ try:
+ from .tm import TM
+
+ TM._elements.append(self)
+ except ImportError:
+ # TM might not be available yet during initial setup
+ pass
+
+ def __repr__(self) -> str:
+ return (
+ f"<{self.__module__}.{type(self).__name__}({self.name}) at {hex(id(self))}>"
+ )
+
+ def __str__(self) -> str:
+ return f"{type(self).__name__}({self.name})"
+
+ def __hash__(self) -> int:
+ """Make Element objects hashable for use in sets and as dict keys."""
+ return hash((type(self).__name__, self.name, str(self.uuid)))
+
+ def _uniq_name(self) -> str:
+ """Transform name and uuid into a unique string."""
+ digest = sha224(str(self.uuid).encode("utf-8")).hexdigest()
+ name = "".join(ch for ch in self.name if ch.isalpha())
+ return f"{type(self).__name__.lower()}_{name}_{digest[:10]}"
+
+ def check(self) -> bool:
+ """Check if the element is valid."""
+ return True
+
+ def _dfd_template(self) -> str:
+ """Template for DFD representation."""
+ return """{uniq_name} [
+ shape = {shape};
+ color = {color};
+ fontcolor = black;
+ label = "{label}";
+ margin = 0.02;
+]
+"""
+
+ def dfd(self, **kwargs: Any) -> str:
+ """Generate DFD representation of this element."""
+ self.is_drawn = True
+
+ levels = kwargs.get("levels")
+ if levels and not levels & self.levels:
+ return ""
+
+ color = self._color()
+ if kwargs.get("colormap", False):
+ color = sev_to_color(self.severity)
+
+ return self._dfd_template().format(
+ uniq_name=self._uniq_name(),
+ label=self._label(),
+ color=color,
+ shape=self._shape(),
+ )
+
+ def _color(self) -> str:
+ """Get the color for this element."""
+ return "black"
+
+ def oneOf(self, *elements: Any) -> bool:
+ """Return True if the element matches any provided elements or classes."""
+ for element in elements:
+ if inspect.isclass(element):
+ if isinstance(self, element):
+ return True
+ elif self is element:
+ return True
+ return False
+
+ def crosses(self, *boundaries: Any) -> bool:
+ """Return True if the flow crosses any of the provided boundaries."""
+ if hasattr(self, "source") and hasattr(self, "sink"):
+ if self.source.inBoundary is self.sink.inBoundary:
+ return False
+ for boundary in boundaries:
+ if inspect.isclass(boundary):
+ if (
+ (
+ isinstance(self.source.inBoundary, boundary)
+ and not isinstance(self.sink.inBoundary, boundary)
+ )
+ or (
+ not isinstance(self.source.inBoundary, boundary)
+ and isinstance(self.sink.inBoundary, boundary)
+ )
+ or self.source.inBoundary is not self.sink.inBoundary
+ ):
+ return True
+ elif (
+ self.source.inside(boundary) and not self.sink.inside(boundary)
+ ) or (not self.source.inside(boundary) and self.sink.inside(boundary)):
+ return True
+ return False
+
+ def enters(self, *boundaries: Any) -> bool:
+ """Return True if the flow enters any of the provided boundaries."""
+ if hasattr(self, "source") and hasattr(self, "sink"):
+ return self.source.inBoundary is None and self.sink.inside(*boundaries)
+ return False
+
+ def exits(self, *boundaries: Any) -> bool:
+ """Return True if the flow exits any of the provided boundaries."""
+ if hasattr(self, "source") and hasattr(self, "sink"):
+ return self.source.inside(*boundaries) and self.sink.inBoundary is None
+ return False
+
+ def inside(self, *boundaries: Any) -> bool:
+ """Return True if the element resides inside any of the provided boundaries."""
+ for boundary in boundaries:
+ if inspect.isclass(boundary):
+ if isinstance(self.inBoundary, boundary):
+ return True
+ elif self.inBoundary is boundary:
+ return True
+ return False
+
+ def display_name(self) -> str:
+ """Get display name for this element."""
+ return self.name
+
+ def _label(self) -> str:
+ """Get label for DFD representation."""
+ return "\\n".join(wrap(self.display_name(), 18))
+
+ def _shape(self) -> str:
+ """Get shape for DFD representation."""
+ return "square"
+
+ def _safeset(self, attr: str, value: Any) -> None:
+ """Safely set an attribute value."""
+ try:
+ setattr(self, attr, value)
+ except (ValueError, TypeError):
+ pass
+
+ def _attr_values(self) -> dict:
+ """Return a dictionary of all attribute values."""
+ return self.model_dump()
+
+ def checkTLSVersion(self, flows: List["Dataflow"]) -> bool:
+ """Check if any flows have insufficient TLS version."""
+ return any(f.tlsVersion < self.minTLSVersion for f in flows)
+
+ def _set_severity(self, sev: Any) -> None:
+ """Set the severity based on numeric or textual value."""
+ if isinstance(sev, int):
+ self.severity = max(0, sev)
+ return
+
+ if isinstance(sev, str):
+ normalized = sev.strip().lower()
+ mapping = {
+ "very high": 5,
+ "critical": 5,
+ "high": 4,
+ "medium": 3,
+ "low": 2,
+ "very low": 1,
+ "info": 0,
+ }
+ legacy_mapping = {
+ "critical": 3,
+ "high": 2,
+ "medium": 1,
+ "low": 0,
+ }
+
+ value = mapping.get(normalized)
+ if value is None:
+ value = legacy_mapping.get(normalized)
+
+ if value is not None and value > self.severity:
+ self.severity = value
diff --git a/pytm/enums.py b/pytm/enums.py
new file mode 100644
index 00000000..29c08403
--- /dev/null
+++ b/pytm/enums.py
@@ -0,0 +1,94 @@
+"""Enums used throughout the pytm package."""
+
+from enum import Enum
+
+
+class Action(Enum):
+ """Action taken when validating a threat model."""
+
+ NO_ACTION = "NO_ACTION"
+ RESTRICT = "RESTRICT"
+ IGNORE = "IGNORE"
+
+
+class OrderedEnum(Enum):
+ """Base enum class that supports ordering operations."""
+
+ def __ge__(self, other):
+ if self.__class__ is other.__class__:
+ return self.value >= other.value
+ return NotImplemented
+
+ def __gt__(self, other):
+ if self.__class__ is other.__class__:
+ return self.value > other.value
+ return NotImplemented
+
+ def __le__(self, other):
+ if self.__class__ is other.__class__:
+ return self.value <= other.value
+ return NotImplemented
+
+ def __lt__(self, other):
+ if self.__class__ is other.__class__:
+ return self.value < other.value
+ return NotImplemented
+
+
+class Classification(OrderedEnum):
+ """Data classification levels."""
+
+ UNKNOWN = 0
+ PUBLIC = 1
+ RESTRICTED = 2
+ SENSITIVE = 3
+ SECRET = 4
+ TOP_SECRET = 5
+
+
+class Lifetime(Enum):
+ """Credential lifetime categories."""
+
+ # not applicable
+ NONE = "NONE"
+ # unknown lifetime
+ UNKNOWN = "UNKNOWN"
+ # relatively short expiration date (time to live)
+ SHORT = "SHORT_LIVED"
+ # long or no expiration date
+ LONG = "LONG_LIVED"
+ # no expiration date but revoked/invalidated automatically in some conditions
+ AUTO = "AUTO_REVOKABLE"
+ # no expiration date but can be invalidated manually
+ MANUAL = "MANUALLY_REVOKABLE"
+ # cannot be invalidated at all
+ HARDCODED = "HARDCODED"
+
+ def label(self):
+ return self.value.lower().replace("_", " ")
+
+
+class DatastoreType(Enum):
+ """Types of datastores."""
+
+ UNKNOWN = "UNKNOWN"
+ FILE_SYSTEM = "FILE_SYSTEM"
+ SQL = "SQL"
+ LDAP = "LDAP"
+ AWS_S3 = "AWS_S3"
+
+ def label(self):
+ return self.value.lower().replace("_", " ")
+
+
+class TLSVersion(OrderedEnum):
+ """TLS/SSL version levels."""
+
+ NONE = 0
+ SSLv1 = 1
+ SSLv2 = 2
+ SSLv3 = 3
+ TLSv10 = 4
+ TLSv11 = 5
+ TLSv12 = 6
+ TLSv13 = 7
diff --git a/pytm/finding.py b/pytm/finding.py
new file mode 100644
index 00000000..75d201ff
--- /dev/null
+++ b/pytm/finding.py
@@ -0,0 +1,158 @@
+"""Finding model - represents a finding linking an element to a threat."""
+
+from typing import Optional, TYPE_CHECKING
+from pydantic import BaseModel, Field, ConfigDict
+
+from .base import Assumption
+
+if TYPE_CHECKING:
+ from .element import Element
+
+
+class Finding(BaseModel):
+ """Represents a Finding - the element in question and a description of the finding.
+
+ Attributes:
+ element (Element): Element this finding applies to
+ target (str): Name of the element this finding applies to
+ description (str): Threat description
+ details (str): Threat details
+ severity (str): Threat severity
+ mitigations (str): Threat mitigations
+ example (str): Threat example
+ id (str): Finding ID
+ threat_id (str): Threat ID
+ references (str): Threat references
+ condition (str): Threat condition
+ assumption (Assumption): The assumption that caused this finding to be excluded
+ response (str): Describes how this threat matching this particular asset or dataflow is being handled. Can be one of: mitigated, transferred, avoided, accepted
+ cvss (str): The CVSS score and/or vector
+ """
+
+ model_config = ConfigDict(
+ extra="allow", validate_assignment=True, arbitrary_types_allowed=True
+ )
+
+ element: Optional["Element"] = Field(
+ default=None, description="Element this finding applies to"
+ )
+ target: str = Field(
+ default="", description="Name of the element this finding applies to"
+ )
+ description: str = Field(description="Threat description")
+ details: str = Field(description="Threat details")
+ severity: str = Field(description="Threat severity")
+ mitigations: str = Field(description="Threat mitigations")
+ example: str = Field(description="Threat example")
+ id: str = Field(description="Finding ID")
+ threat_id: str = Field(description="Threat ID")
+ references: str = Field(description="Threat references")
+ condition: str = Field(description="Threat condition")
+ assumption: Optional[Assumption] = Field(
+ default=None,
+ description="The assumption that caused this finding to be excluded",
+ )
+ response: str = Field(
+ default="",
+ description="Describes how this threat matching this particular asset or dataflow is being handled. Can be one of: mitigated, transferred, avoided, accepted",
+ )
+ cvss: str = Field(default="", description="The CVSS score and/or vector")
+
+ def __init__(self, *args, **kwargs):
+ """Initialize a Finding.
+
+ Args:
+ *args: Optionally pass the element as the first positional argument.
+ **kwargs: Finding properties:
+ - element (Element): Element this finding applies to
+ - target (str): Name of the element this finding applies to
+ - threat (Threat): Threat object to copy attributes from (description, details, severity, mitigations, example, references, condition)
+ - description (str): Threat description
+ - details (str): Threat details
+ - severity (str): Threat severity
+ - mitigations (str): Threat mitigations
+ - example (str): Threat example
+ - id (str): Finding ID
+ - threat_id (str): Threat ID
+ - references (str): Threat references
+ - condition (str): Threat condition
+ - assumption (Assumption): The assumption that caused this finding to be excluded
+ - response (str): Describes how this threat matching this particular asset or dataflow is being handled. Can be one of: mitigated, transferred, avoided, accepted
+ - cvss (str): The CVSS score and/or vector
+ """
+ # Handle positional element argument
+ if args:
+ element = args[0]
+ kwargs["element"] = element
+
+ # Get element from kwargs
+ element = kwargs.get("element")
+
+ # Set target from element name if element is provided
+ if element is not None and "target" not in kwargs:
+ kwargs["target"] = element.name
+
+ # Handle threat data
+ threat = kwargs.pop("threat", None)
+ if threat:
+ kwargs["threat_id"] = getattr(threat, "id", "")
+ # Copy threat attributes
+ threat_attrs = [
+ "description",
+ "details",
+ "severity",
+ "mitigations",
+ "example",
+ "references",
+ "condition",
+ ]
+ for attr in threat_attrs:
+ if attr not in kwargs: # Don't override explicit values
+ kwargs[attr] = getattr(threat, attr, "")
+
+ # Handle overrides from element
+ threat_id = kwargs.get("threat_id", None)
+ if hasattr(element, "overrides") and threat_id:
+ for override in element.overrides:
+ if getattr(override, "threat_id", None) == threat_id:
+ # Apply override values
+ override_dict = (
+ override.model_dump() if hasattr(override, "model_dump") else {}
+ )
+ for key, value in override_dict.items():
+ if key not in ("element", "target") and value is not None:
+ kwargs[key] = value
+ break
+
+ # Ensure all required fields have values
+ required_fields = [
+ "description",
+ "details",
+ "severity",
+ "mitigations",
+ "example",
+ "id",
+ "threat_id",
+ "references",
+ "condition",
+ ]
+ for field in required_fields:
+ if field not in kwargs:
+ kwargs[field] = ""
+
+ super().__init__(**kwargs)
+
+ def _safeset(self, attr: str, value) -> None:
+ """Safely set an attribute value."""
+ try:
+ setattr(self, attr, value)
+ except (ValueError, TypeError):
+ pass
+
+ def __repr__(self):
+ return (
+ f"<{self.__module__}.{type(self).__name__}({self.id}) at {hex(id(self))}>"
+ )
+
+ def __str__(self):
+ return f"'{self.target}': {self.description}\n{self.details}\n{self.severity}"
diff --git a/pytm/flows.py b/pytm/flows.py
index a1e882e4..b05c0cab 100644
--- a/pytm/flows.py
+++ b/pytm/flows.py
@@ -3,7 +3,7 @@
def req_reply(src: Element, dest: Element, req_name: str, reply_name=None) -> (DF, DF):
- '''
+ """
This function creates two datflows where one dataflow is a request
and the second dataflow is the corresponding reply to the newly created request.
@@ -22,9 +22,9 @@ def req_reply(src: Element, dest: Element, req_name: str, reply_name=None) -> (D
Returns:
a tuple of two dataflows, where the first is the request and the second is the reply.
- '''
+ """
if not reply_name:
- reply_name = f'Reply to {req_name}'
+ reply_name = f"Reply to {req_name}"
req = DF(src, dest, req_name)
reply = DF(dest, src, name=reply_name)
reply.responseTo = req
@@ -32,7 +32,7 @@ def req_reply(src: Element, dest: Element, req_name: str, reply_name=None) -> (D
def reply(req: DF, **kwargs) -> DF:
- '''
+ """
This function takes a dataflow as an argument and returns a new dataflow, which is a response to the given dataflow.
Args:
@@ -45,12 +45,12 @@ def reply(req: DF, **kwargs) -> DF:
client_reply = reply(client_query)
Returns:
a Dataflow which is a reply to the given datadlow req
- '''
- if 'name' not in kwargs:
- name = f'Reply to {req.name}'
+ """
+ if "name" not in kwargs:
+ name = f"Reply to {req.name}"
else:
- name = kwargs['name']
- del kwargs['name']
+ name = kwargs["name"]
+ del kwargs["name"]
reply = DF(req.sink, req.source, name, **kwargs)
reply.responseTo = req
return req, reply
diff --git a/pytm/json.py b/pytm/json.py
index a69cd719..e9a1debf 100644
--- a/pytm/json.py
+++ b/pytm/json.py
@@ -1,22 +1,25 @@
import json
-import sys
-
-from .pytm import (
- TM,
- Boundary,
- Element,
- Dataflow,
- Server,
- ExternalEntity,
- Datastore,
- Actor,
- Process,
- SetOfProcesses,
- Action,
- Lambda,
- LLM,
- Controls,
-)
+
+from .tm import TM
+from .boundary import Boundary
+from .dataflow import Dataflow
+from .asset import Asset, Server, ExternalEntity, Lambda, LLM
+from .datastore import Datastore
+from .actor import Actor
+from .process import Process, SetOfProcesses
+from .enums import Action
+
+_ELEMENT_CLASSES = {
+ "Asset": Asset,
+ "Actor": Actor,
+ "Server": Server,
+ "ExternalEntity": ExternalEntity,
+ "Lambda": Lambda,
+ "LLM": LLM,
+ "Datastore": Datastore,
+ "Process": Process,
+ "SetOfProcesses": SetOfProcesses,
+}
def loads(s):
@@ -74,7 +77,10 @@ def decode_boundaries(flat):
def decode_elements(flat, boundaries):
elements = {}
for i, e in enumerate(flat):
- klass = getattr(sys.modules[__name__], e.pop("__class__", "Asset"))
+ class_name = e.pop("__class__", "Asset")
+ klass = _ELEMENT_CLASSES.get(class_name)
+ if klass is None:
+ raise ValueError(f"Unknown element class: {class_name}")
name = e.pop("name", None)
if name is None:
raise ValueError(f"name property missing in element {i}")
diff --git a/pytm/process.py b/pytm/process.py
new file mode 100644
index 00000000..e71cb7a3
--- /dev/null
+++ b/pytm/process.py
@@ -0,0 +1,127 @@
+"""Process model - represents processes that handle data."""
+
+from typing import TYPE_CHECKING
+from pydantic import Field
+
+from .asset import Asset
+
+if TYPE_CHECKING:
+ pass
+
+
+class Process(Asset):
+ """An entity processing data.
+
+ Attributes:
+ port (int): Default TCP port for incoming data flows
+ protocol (str): Default network protocol for incoming data flows
+ data (DataSet): pytm.Data object(s) in incoming data flows
+ inputs (List[Dataflow]): Incoming Dataflows
+ outputs (List[Dataflow]): Outgoing Dataflows
+ onAWS (bool): Is this asset on AWS?
+ handlesResources (bool): Does this asset handle resources?
+ usesEnvironmentVariables (bool): Does this asset use environment variables?
+ OS (str): Operating system
+ codeType (str): Type of code running in this process
+ implementsCommunicationProtocol (bool): Does this process implement a communication protocol?
+ tracksExecutionFlow (bool): Does this process track execution flow?
+ implementsAPI (bool): Does this process implement an API?
+ environment (str): Environment for this process
+ allowsClientSideScripting (bool): Does this process allow client-side scripting?
+ """
+
+ codeType: str = Field(
+ default="Unmanaged", description="Type of code running in this process"
+ )
+ implementsCommunicationProtocol: bool = Field(
+ default=False,
+ description="Does this process implement a communication protocol?",
+ )
+ tracksExecutionFlow: bool = Field(
+ default=False, description="Does this process track execution flow?"
+ )
+ implementsAPI: bool = Field(
+ default=False, description="Does this process implement an API?"
+ )
+ environment: str = Field(default="", description="Environment for this process")
+ allowsClientSideScripting: bool = Field(
+ default=False, description="Does this process allow client-side scripting?"
+ )
+
+ def __init__(self, name: str = None, **data):
+ """Initialize a Process.
+
+ Args:
+ name (str): Name of the process.
+ **data: Optional process properties:
+ - port (int): Default TCP port for incoming data flows
+ - protocol (str): Default network protocol for incoming data flows
+ - data (DataSet): pytm.Data object(s) in incoming data flows
+ - inputs (List[Dataflow]): Incoming Dataflows
+ - outputs (List[Dataflow]): Outgoing Dataflows
+ - onAWS (bool): Is this asset on AWS?
+ - handlesResources (bool): Does this asset handle resources?
+ - usesEnvironmentVariables (bool): Does this asset use environment variables?
+ - OS (str): Operating system
+ - codeType (str): Type of code running in this process
+ - implementsCommunicationProtocol (bool): Does this process implement a communication protocol?
+ - tracksExecutionFlow (bool): Does this process track execution flow?
+ - implementsAPI (bool): Does this process implement an API?
+ - environment (str): Environment for this process
+ - allowsClientSideScripting (bool): Does this process allow client-side scripting?
+ """
+ super().__init__(name, **data)
+
+ def _shape(self) -> str:
+ """Get shape for DFD representation."""
+ return "circle"
+
+
+class SetOfProcesses(Process):
+ """A set of processes grouped together.
+
+ Attributes:
+ port (int): Default TCP port for incoming data flows
+ protocol (str): Default network protocol for incoming data flows
+ data (DataSet): pytm.Data object(s) in incoming data flows
+ inputs (List[Dataflow]): Incoming Dataflows
+ outputs (List[Dataflow]): Outgoing Dataflows
+ onAWS (bool): Is this asset on AWS?
+ handlesResources (bool): Does this asset handle resources?
+ usesEnvironmentVariables (bool): Does this asset use environment variables?
+ OS (str): Operating system
+ codeType (str): Type of code running in this process
+ implementsCommunicationProtocol (bool): Does this process implement a communication protocol?
+ tracksExecutionFlow (bool): Does this process track execution flow?
+ implementsAPI (bool): Does this process implement an API?
+ environment (str): Environment for this process
+ allowsClientSideScripting (bool): Does this process allow client-side scripting?
+ """
+
+ def __init__(self, name: str = None, **data):
+ """Initialize a SetOfProcesses.
+
+ Args:
+ name (str): Name of the set of processes.
+ **data: Optional properties:
+ - port (int): Default TCP port for incoming data flows
+ - protocol (str): Default network protocol for incoming data flows
+ - data (DataSet): pytm.Data object(s) in incoming data flows
+ - inputs (List[Dataflow]): Incoming Dataflows
+ - outputs (List[Dataflow]): Outgoing Dataflows
+ - onAWS (bool): Is this asset on AWS?
+ - handlesResources (bool): Does this asset handle resources?
+ - usesEnvironmentVariables (bool): Does this asset use environment variables?
+ - OS (str): Operating system
+ - codeType (str): Type of code running in this process
+ - implementsCommunicationProtocol (bool): Does this process implement a communication protocol?
+ - tracksExecutionFlow (bool): Does this process track execution flow?
+ - implementsAPI (bool): Does this process implement an API?
+ - environment (str): Environment for this process
+ - allowsClientSideScripting (bool): Does this process allow client-side scripting?
+ """
+ super().__init__(name, **data)
+
+ def _shape(self) -> str:
+ """Get shape for DFD representation."""
+ return "doublecircle"
diff --git a/pytm/pytm.py b/pytm/pytm.py
index 82b56aa8..d710c8c6 100644
--- a/pytm/pytm.py
+++ b/pytm/pytm.py
@@ -1,376 +1,68 @@
import argparse
-import errno
-import inspect
-import json
-import logging
-import os
-import random
-import sys
-import uuid
import html
import copy
+import logging
+import re
+import sys
-from collections import Counter, defaultdict
-from collections.abc import Iterable
-from enum import Enum
-from functools import lru_cache, singledispatch
-from hashlib import sha224
-from itertools import combinations
-from shutil import rmtree
-from textwrap import indent, wrap
-from weakref import WeakKeyDictionary
-from datetime import datetime
-
-from .template_engine import SuperFormatter
-
-""" Helper functions """
-
-""" The base for this (descriptors instead of properties) has been
- shamelessly lifted from
- https://nbviewer.jupyter.org/urls/gist.github.com/ChrisBeaumont/5758381/raw/descriptor_writeup.ipynb
- By Chris Beaumont
-"""
-
-
-def sev_to_color(sev):
- # calculate the color depending on the severity
- if sev == 5:
- return 'firebrick3; fillcolor="#b2222222"; style=filled '
- elif sev <= 4 and sev >= 2:
- return 'gold; fillcolor="#ffd80022"; style=filled'
- elif sev < 2 and sev >= 0:
- return 'darkgreen; fillcolor="#00630022"; style=filled'
-
- return "black"
-
-
-class UIError(Exception):
- def __init__(self, e, context):
- self.error = e
- self.context = context
-
+from dataclasses import dataclass, field
+from typing import ClassVar
+
+from pydantic import ValidationError
+from pydantic.fields import PydanticUndefined
+
+from collections import defaultdict
+from collections.abc import Iterable, Mapping
+from functools import singledispatch
+
+# Import all the new Pydantic models
+from .enums import (
+ Action,
+ Classification,
+ DatastoreType,
+ Lifetime,
+ TLSVersion,
+ OrderedEnum,
+)
+from .base import Assumption, Controls
+from .element import Element
+from .data import Data
+from .threat import Threat
+from .finding import Finding
+from .asset import Asset, Lambda, LLM, Server, ExternalEntity
+from .datastore import Datastore
+from .actor import Actor
+from .process import Process, SetOfProcesses
+from .dataflow import Dataflow
+from .boundary import Boundary
+from .tm import TM, UIError
logger = logging.getLogger(__name__)
-
-class var(object):
- """A descriptor that allows setting a value only once"""
-
- def __init__(self, default, required=False, doc="", onSet=None):
- self.default = default
- self.required = required
- self.doc = doc
- self.data = WeakKeyDictionary()
- self.onSet = onSet
-
- def __get__(self, instance, owner):
- # when x.d is called we get here
- # instance = x
- # owner = type(x)
- if instance is None:
- return self
- return self.data.get(instance, self.default)
-
- def __set__(self, instance, value):
- # called when x.d = val
- # instance = x
- # value = val
- if instance in self.data:
- raise ValueError(
- "cannot overwrite {}.{} value with {}, already set to {}".format(
- instance, self.__class__.__name__, value, self.data[instance]
- )
- )
- self.data[instance] = value
- if self.onSet is not None:
- self.onSet(instance, value)
-
-
-class varString(var):
- def __set__(self, instance, value):
- if not isinstance(value, str):
- raise ValueError("expecting a String value, got a {}".format(type(value)))
- super().__set__(instance, value)
-
-
-class varStrings(var):
- def __set__(self, instance, value):
- if not isinstance(value, Iterable) or isinstance(value, str):
- value = [value]
- for i, e in enumerate(value):
- if not isinstance(e, str):
- raise ValueError(
- f"expecting a list of str, item number {i} is a {type(e)}"
- )
- super().__set__(instance, set(value))
-
-
-class varBoundary(var):
- def __set__(self, instance, value):
- if not isinstance(value, Boundary):
- raise ValueError("expecting a Boundary value, got a {}".format(type(value)))
- super().__set__(instance, value)
-
-
-class varBool(var):
- def __set__(self, instance, value):
- if not isinstance(value, bool):
- raise ValueError("expecting a boolean value, got a {}".format(type(value)))
- super().__set__(instance, value)
-
-
-class varInt(var):
- def __set__(self, instance, value):
- if not isinstance(value, int):
- raise ValueError("expecting an integer value, got a {}".format(type(value)))
- super().__set__(instance, value)
-
-
-class varInts(var):
- def __set__(self, instance, value):
- if not isinstance(value, Iterable):
- value = [value]
- for i, e in enumerate(value):
- if not isinstance(e, int):
- raise ValueError(
- f"expecting a list of int, item number {i} is a {type(e)}"
- )
- super().__set__(instance, set(value))
-
-
-class varElement(var):
- def __set__(self, instance, value):
- if not isinstance(value, Element):
- raise ValueError(
- "expecting an Element (or inherited) "
- "value, got a {}".format(type(value))
- )
- super().__set__(instance, value)
-
-
-class varElements(var):
- def __set__(self, instance, value):
- for i, e in enumerate(value):
- if not isinstance(e, Element):
- raise ValueError(
- "expecting a list of Elements, item number {} is a {}".format(
- i, type(e)
- )
- )
- super().__set__(instance, list(value))
-
-
-class varFindings(var):
- def __set__(self, instance, value):
- for i, e in enumerate(value):
- if not isinstance(e, Finding):
- raise ValueError(
- "expecting a list of Findings, item number {} is a {}".format(
- i, type(e)
- )
- )
- super().__set__(instance, list(value))
-
-
-class varAction(var):
- def __set__(self, instance, value):
- if not isinstance(value, Action):
- raise ValueError("expecting an Action, got a {}".format(type(value)))
- super().__set__(instance, value)
-
-
-class varClassification(var):
- def __set__(self, instance, value):
- if not isinstance(value, Classification):
- raise ValueError("expecting a Classification, got a {}".format(type(value)))
- super().__set__(instance, value)
-
-
-class varLifetime(var):
- def __set__(self, instance, value):
- if not isinstance(value, Lifetime):
- raise ValueError("expecting a Lifetime, got a {}".format(type(value)))
- super().__set__(instance, value)
-
-
-class varDatastoreType(var):
- def __set__(self, instance, value):
- if not isinstance(value, DatastoreType):
- raise ValueError("expecting a DatastoreType, got a {}".format(type(value)))
- super().__set__(instance, value)
-
-
-class varTLSVersion(var):
- def __set__(self, instance, value):
- if not isinstance(value, TLSVersion):
- raise ValueError("expecting a TLSVersion, got a {}".format(type(value)))
- super().__set__(instance, value)
-
-
-class varData(var):
- def __set__(self, instance, value):
- if isinstance(value, str):
- value = [
- Data(
- name="undefined",
- description=value,
- classification=Classification.UNKNOWN,
- )
- ]
- sys.stderr.write(
- "FIXME: a dataflow is using a string as the Data attribute. This has been deprecated and Data objects should be created instead.\n"
- )
-
- if not isinstance(value, Iterable):
- value = [value]
- for i, e in enumerate(value):
- if not isinstance(e, Data):
- raise ValueError(
- "expecting a list of pytm.Data, item number {} is a {}".format(
- i, type(e)
- )
- )
- super().__set__(instance, DataSet(value))
-
-
-class DataSet(set):
- def __contains__(self, item):
- if isinstance(item, str):
- return item in [d.name for d in self]
- if isinstance(item, Data):
- return super().__contains__(item)
- return NotImplemented
-
- def __eq__(self, other):
- if isinstance(other, set):
- return super().__eq__(other)
- if isinstance(other, str):
- return other in self
- return NotImplemented
-
- def __ne__(self, other):
- if isinstance(other, set):
- return super().__ne__(other)
- if isinstance(other, str):
- return other not in self
- return NotImplemented
-
- def __str__(self):
- return ", ".join(sorted(set(d.name for d in self)))
-
-
-class varControls(var):
- def __set__(self, instance, value):
- if not isinstance(value, Controls):
- raise ValueError(
- f"expecting an Controls value, got a {type(value)}"
- )
- super().__set__(instance, value)
-
-
-class varAssumptions(var):
- def __set__(self, instance, value):
- for i, e in enumerate(value):
- if isinstance(e, str):
- e = value[i] = Assumption(e)
- if not isinstance(e, Assumption):
- raise ValueError(
- f"expecting a list of Assumptions, item number {i} is a {type(e)}"
- )
- super().__set__(instance, list(value))
-
-
-class varAssumption(var):
- def __set__(self, instance, value):
- if not isinstance(value, Assumption):
- raise ValueError(
- f"expecting an Assumption value, got a {type(value)}"
- )
- super().__set__(instance, value)
-
-
-class Action(Enum):
- """Action taken when validating a threat model."""
-
- NO_ACTION = "NO_ACTION"
- RESTRICT = "RESTRICT"
- IGNORE = "IGNORE"
-
-
-class OrderedEnum(Enum):
- def __ge__(self, other):
- if self.__class__ is other.__class__:
- return self.value >= other.value
- return NotImplemented
-
- def __gt__(self, other):
- if self.__class__ is other.__class__:
- return self.value > other.value
- return NotImplemented
-
- def __le__(self, other):
- if self.__class__ is other.__class__:
- return self.value <= other.value
- return NotImplemented
-
- def __lt__(self, other):
- if self.__class__ is other.__class__:
- return self.value < other.value
- return NotImplemented
-
-
-class Classification(OrderedEnum):
- UNKNOWN = 0
- PUBLIC = 1
- RESTRICTED = 2
- SENSITIVE = 3
- SECRET = 4
- TOP_SECRET = 5
-
-
-class Lifetime(Enum):
- # not applicable
- NONE = "NONE"
- # unknown lifetime
- UNKNOWN = "UNKNOWN"
- # relatively short expiration date (time to live)
- SHORT = "SHORT_LIVED"
- # long or no expiration date
- LONG = "LONG_LIVED"
- # no expiration date but revoked/invalidated automatically in some conditions
- AUTO = "AUTO_REVOKABLE"
- # no expiration date but can be invalidated manually
- MANUAL = "MANUALLY_REVOKABLE"
- # cannot be invalidated at all
- HARDCODED = "HARDCODED"
-
- def label(self):
- return self.value.lower().replace("_", " ")
-
-
-class DatastoreType(Enum):
- UNKNOWN = "UNKNOWN"
- FILE_SYSTEM = "FILE_SYSTEM"
- SQL = "SQL"
- LDAP = "LDAP"
- AWS_S3 = "AWS_S3"
-
- def label(self):
- return self.value.lower().replace("_", " ")
-
-
-class TLSVersion(OrderedEnum):
- NONE = 0
- SSLv1 = 1
- SSLv2 = 2
- SSLv3 = 3
- TLSv10 = 4
- TLSv11 = 5
- TLSv12 = 6
- TLSv13 = 7
-
-
+# Legacy aliases for backward compatibility
+varString = str
+varStrings = list
+varBoundary = object
+varBool = bool
+varInt = int
+varInts = set
+varElement = object
+varElements = list
+varFindings = list
+varAction = Action
+varClassification = Classification
+varLifetime = Lifetime
+varDatastoreType = DatastoreType
+varTLSVersion = TLSVersion
+varData = list
+varControls = Controls
+varAssumptions = list
+varAssumption = Assumption
+
+
+# Essential helper functions preserved from original
def _sort(flows, addOrder=False):
+ """Sort flows by order."""
ordered = sorted(flows, key=lambda flow: flow.order)
if not addOrder:
return ordered
@@ -382,6 +74,7 @@ def _sort(flows, addOrder=False):
def _sort_elem(elements):
+ """Sort elements."""
if len(elements) == 0:
return elements
orders = {}
@@ -404,8 +97,148 @@ def _sort_elem(elements):
)
+def _iter_subclasses(cls):
+ """Yield all subclasses of *cls*, recursively."""
+ seen = set()
+ stack = [cls]
+
+ while stack:
+ current = stack.pop()
+ for subclass in getattr(current, "__subclasses__", lambda: [])():
+ if subclass in seen:
+ continue
+ seen.add(subclass)
+ yield subclass
+ stack.append(subclass)
+
+
+def _list_elements():
+ """List all elements usable in a threat model along with their descriptions."""
+
+ def _print_components(classes):
+ entries = sorted(classes, key=lambda cls: cls.__name__)
+ if not entries:
+ return
+
+ name_width = max(len(entry.__name__) for entry in entries)
+ for entry in entries:
+ doc = entry.__doc__ or ""
+ print(f"{entry.__name__:<{name_width}} -- {doc}")
+
+ print("Elements:")
+ _print_components(list(_iter_subclasses(Element)))
+
+ print("\nAtributes:")
+ enumerated = set(_iter_subclasses(OrderedEnum))
+ enumerated.update({Data, Action, Lifetime})
+ _print_components(list(enumerated))
+
+
+_CLASS_REGISTRY = {
+ "Action": Action,
+ "Actor": Actor,
+ "Asset": Asset,
+ "Boundary": Boundary,
+ "Classification": Classification,
+ "Data": Data,
+ "Dataflow": Dataflow,
+ "Datastore": Datastore,
+ "DatastoreType": DatastoreType,
+ "ExternalEntity": ExternalEntity,
+ "Finding": Finding,
+ "Lambda": Lambda,
+ "Lifetime": Lifetime,
+ "LLM": LLM,
+ "Process": Process,
+ "Server": Server,
+ "SetOfProcesses": SetOfProcesses,
+ "Threat": Threat,
+ "TLSVersion": TLSVersion,
+ "TM": TM,
+ "UIError": UIError,
+}
+
+
+def _describe_classes(class_names):
+ """Describe available classes and their attributes for CLI users."""
+
+ registry = dict(_CLASS_REGISTRY)
+
+ for cls in _iter_subclasses(Element):
+ registry.setdefault(cls.__name__, cls)
+
+ for name in class_names:
+ klass = registry.get(name)
+ if klass is None:
+ logger.error("No such class to describe: %s", name)
+ sys.exit(1)
+
+ print(f"{name} class attributes:")
+
+ model_fields = getattr(klass, "model_fields", None)
+ if model_fields:
+ field_names = sorted(model_fields.keys())
+ if not field_names:
+ print(" (no attributes)")
+ else:
+ longest = len(max(field_names, key=len)) + 2
+ lpadding = f'\n{" ":<{longest+2}}'
+ for field_name in field_names:
+ field_info = model_fields[field_name]
+ docs: list[str] = []
+ description = field_info.description or ""
+ if description:
+ docs.extend(description.split("\n"))
+ if field_info.is_required():
+ docs.append("required")
+ default = field_info.default
+ if default is not PydanticUndefined:
+ docs.append(f"default: {default!r}")
+ elif field_info.default_factory is not None:
+ factory = field_info.default_factory
+ factory_name = getattr(factory, "__name__", repr(factory))
+ docs.append(f"default factory: {factory_name}")
+
+ if docs:
+ print(f" {field_name:<{longest}}{lpadding.join(docs)}")
+ else:
+ print(f" {field_name}")
+ elif hasattr(klass, "__members__"):
+ members = getattr(klass, "__members__", {})
+ if not members:
+ print(" (no members)")
+ else:
+ for member in members:
+ print(f" {member}")
+ else:
+ attrs = [
+ attr
+ for attr in dir(klass)
+ if not attr.startswith("_") and not callable(getattr(klass, attr))
+ ]
+ if not attrs:
+ print(" (no attributes)")
+ else:
+ longest = len(max(attrs, key=len)) + 2
+ lpadding = f'\n{" ":<{longest+2}}'
+ for attr in sorted(attrs):
+ value = getattr(klass, attr)
+ docs = []
+ doc_attr = getattr(value, "__doc__", None)
+ if isinstance(doc_attr, str):
+ stripped = doc_attr.strip()
+ if stripped:
+ docs.append(stripped)
+ if docs:
+ print(f" {attr:<{longest}}{lpadding.join(docs)}")
+ else:
+ print(f" {attr}")
+
+ print()
+
+
def _match_responses(flows):
- """Ensure that responses are pointing to requests"""
+ """Ensure that responses are pointing to requests."""
index = defaultdict(list)
for e in flows:
key = (e.source, e.sink)
@@ -433,1618 +266,257 @@ def _match_responses(flows):
return flows
-def _apply_defaults(flows, data):
- inputs = defaultdict(list)
- outputs = defaultdict(list)
- carriers = defaultdict(set)
- processors = defaultdict(set)
-
- for d in data:
- for e in d.carriedBy:
- try:
- setattr(e, "data", d)
- except ValueError:
- e.data.add(d)
-
- for e in flows:
- if e.source.data:
- try:
- setattr(e, "data", e.source.data.copy())
- except ValueError:
- e.data.update(e.source.data)
-
- for d in e.data:
- carriers[d].add(e)
- processors[d].add(e.source)
- processors[d].add(e.sink)
-
- e._safeset("levels", e.source.levels & e.sink.levels)
+def _add_data(container, value):
+ """Attach Data objects to a container supporting add/append semantics."""
+ if container is None or value is None:
+ return
- try:
- e.overrides = e.sink.overrides
- e.overrides.extend(
- f
- for f in e.source.overrides
- if f.threat_id not in (f.threat_id for f in e.overrides)
- )
- except ValueError:
- pass
+ if isinstance(value, Data):
+ items = [value]
+ elif hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
+ items = list(value)
+ else:
+ items = [value]
- if e.isResponse:
- e._safeset("protocol", e.source.protocol)
- e._safeset("srcPort", e.source.port)
- e.controls._safeset("isEncrypted", e.source.controls.isEncrypted)
+ for item in items:
+ if item is None:
continue
+ if hasattr(container, "add"):
+ container.add(item)
+ elif hasattr(container, "append"):
+ container.append(item)
- e._safeset("protocol", e.sink.protocol)
- e._safeset("dstPort", e.sink.port)
- if hasattr(e.sink.controls, "isEncrypted"):
- e.controls._safeset("isEncrypted", e.sink.controls.isEncrypted)
- e.controls._safeset(
- "authenticatesDestination", e.source.controls.authenticatesDestination
- )
- e.controls._safeset(
- "checksDestinationRevocation", e.source.controls.checksDestinationRevocation
- )
-
- for d in e.data:
- if d.isStored:
- if hasattr(e.sink.controls, "isEncryptedAtRest"):
- for d in e.data:
- d._safeset(
- "isDestEncryptedAtRest", e.sink.controls.isEncryptedAtRest
- )
- if hasattr(e.source, "isEncryptedAtRest"):
- for d in e.data:
- d._safeset(
- "isSourceEncryptedAtRest",
- e.source.controls.isEncryptedAtRest,
- )
- if d.credentialsLife != Lifetime.NONE and not d.isCredentials:
- d._safeset("isCredentials", True)
- if d.isCredentials and d.credentialsLife == Lifetime.NONE:
- d._safeset("credentialsLife", Lifetime.UNKNOWN)
-
- outputs[e.source].append(e)
- inputs[e.sink].append(e)
-
- for e, flows in inputs.items():
- try:
- e.inputs = flows
- except (AttributeError, ValueError):
- pass
- for e, flows in outputs.items():
- try:
- e.outputs = flows
- except (AttributeError, ValueError):
- pass
-
- for d, flows in carriers.items():
- flows = sorted(flows, key=lambda f: f.name)
- try:
- setattr(d, "carriedBy", list(flows))
- except ValueError:
- for e in flows:
- if e not in d.carriedBy:
- d.carriedBy.append(e)
- for d, elements in processors.items():
- elements = sorted(elements, key=lambda e: e.name)
- try:
- setattr(d, "processedBy", elements)
- except ValueError:
- for e in elements:
- if e not in d.processedBy:
- d.processedBy.append(e)
-
-
-def _describe_classes(classes):
- for name in classes:
- klass = getattr(sys.modules[__name__], name, None)
- if klass is None:
- logger.error("No such class to describe: %s\n", name)
- sys.exit(1)
- print("{} class attributes:".format(name))
- attrs = []
- for i in dir(klass):
- if i.startswith("_") or callable(getattr(klass, i)):
- continue
- attrs.append(i)
- longest = len(max(attrs, key=len)) + 2
- for i in attrs:
- attr = getattr(klass, i, {})
- docs = []
- if isinstance(attr, var):
- if attr.doc:
- docs.extend(attr.doc.split("\n"))
- if attr.required:
- docs.append("required")
- if attr.default or isinstance(attr.default, bool):
- docs.append("default: {}".format(attr.default))
- lpadding = f'\n{" ":<{longest+2}}'
- print(f" {i:<{longest}}{lpadding.join(docs)}")
- print()
-
-
-def _list_elements():
- """List all elements which can be used in a threat model with the corresponding description"""
-
- def all_subclasses(cls):
- """Get all sub classes of a class"""
- subclasses = set(cls.__subclasses__())
- return subclasses.union((s for c in subclasses for s in all_subclasses(c)))
- def print_components(cls_list):
- elements = sorted(cls_list, key=lambda c: c.__name__)
- max_len = max((len(e.__name__) for e in elements))
- for sc in elements:
- doc = sc.__doc__ if sc.__doc__ is not None else ""
- print(f"{sc.__name__:<{max_len}} -- {doc}")
+@dataclass
+class _FlowDefaultsBuilder:
+ """Collect and apply default relationships for data flows."""
- # print all elements
- print("Elements:")
- print_components(all_subclasses(Element))
-
- # Print Attributes
- print("\nAtributes:")
- print_components(all_subclasses(OrderedEnum) | {Data, Action, Lifetime})
-
-
-def _get_elements_and_boundaries(flows):
- """filter out elements and boundaries not used in this TM"""
- elements = set()
- boundaries = set()
- for e in flows:
- elements.add(e)
- elements.add(e.source)
- elements.add(e.sink)
- if e.source.inBoundary is not None:
- elements.add(e.source.inBoundary)
- boundaries.add(e.source.inBoundary)
- for b in e.source.inBoundary.parents():
- elements.add(b)
- boundaries.add(b)
- if e.sink.inBoundary is not None:
- elements.add(e.sink.inBoundary)
- boundaries.add(e.sink.inBoundary)
- for b in e.sink.inBoundary.parents():
- elements.add(b)
- boundaries.add(b)
- return (list(elements), list(boundaries))
-
-
-""" End of help functions """
-
-
-class Threat:
- """Represents a possible threat"""
-
- id = varString("", required=True)
- description = varString("")
- condition = varString(
- "",
- doc="""a Python expression that should evaluate
-to a boolean True or False""",
- )
- details = varString("")
- likelihood = varString("")
- severity = varString("")
- mitigations = varString("")
- prerequisites = varString("")
- example = varString("")
- references = varString("")
- target = ()
-
- def __init__(self, **kwargs):
- self.id = kwargs["SID"]
- self.description = kwargs.get("description", "")
- self.likelihood = kwargs.get("Likelihood Of Attack", "")
- self.condition = kwargs.get("condition", "True")
- target = kwargs.get("target", "Element")
- if not isinstance(target, str) and isinstance(target, Iterable):
- target = tuple(target)
- else:
- target = (target,)
- self.target = tuple(getattr(sys.modules[__name__], x) for x in target)
- self.details = kwargs.get("details", "")
- self.severity = kwargs.get("severity", "")
- self.mitigations = kwargs.get("mitigations", "")
- self.prerequisites = kwargs.get("prerequisites", "")
- self.example = kwargs.get("example", "")
- self.references = kwargs.get("references", "")
-
- def _safeset(self, attr, value):
- try:
- setattr(self, attr, value)
- except ValueError:
- pass
-
- def __repr__(self):
- return "<{0}.{1}({2}) at {3}>".format(
- self.__module__, type(self).__name__, self.id, hex(id(self))
- )
-
- def __str__(self):
- return "{0}({1})".format(type(self).__name__, self.id)
-
- def apply(self, target):
- if not isinstance(target, self.target):
- return None
- return eval(self.condition)
-
-
-class Finding:
- """Represents a Finding - the element in question
- and a description of the finding"""
-
- element = varElement(None, required=True, doc="Element this finding applies to")
- target = varString("", doc="Name of the element this finding applies to")
- description = varString("", required=True, doc="Threat description")
- details = varString("", required=True, doc="Threat details")
- severity = varString("", required=True, doc="Threat severity")
- mitigations = varString("", required=True, doc="Threat mitigations")
- example = varString("", required=True, doc="Threat example")
- id = varString("", required=True, doc="Finding ID")
- threat_id = varString("", required=True, doc="Threat ID")
- references = varString("", required=True, doc="Threat references")
- condition = varString("", required=True, doc="Threat condition")
- assumption = varAssumption(None, required=False, doc="The assumption, that caused this finding to be excluded")
- response = varString(
- "",
- required=False,
- doc="""Describes how this threat matching this particular asset or dataflow is being handled.
-Can be one of:
-* mitigated - there were changes made in the modeled system to reduce the probability of this threat occurring or the impact when it does,
-* transferred - users of the system are required to mitigate this threat,
-* avoided - this asset or dataflow is removed from the system,
-* accepted - no action is taken as the probability and/or impact is very low
-""",
- )
- cvss = varString("", required=False, doc="The CVSS score and/or vector")
- likelihood = varString("", required=False, doc="Likelihood of the threat")
-
- def __init__(
- self,
- *args,
- **kwargs,
- ):
- if args:
- element = args[0]
- else:
- element = kwargs.pop("element", Element("invalid"))
-
- self.target = element.name
- self.element = element
- attrs = [
- "description",
- "details",
- "severity",
- "mitigations",
- "example",
- "references",
- "condition",
- "likelihood",
- ]
- threat = kwargs.pop("threat", None)
- if threat:
- kwargs["threat_id"] = getattr(threat, "id")
- for a in attrs:
- # copy threat attrs into kwargs to allow to override them in next step
- kwargs[a] = getattr(threat, a)
-
- threat_id = kwargs.get("threat_id", None)
- for f in element.overrides:
- if f.threat_id != threat_id:
- continue
- for i in dir(f.__class__):
- attr = getattr(f.__class__, i)
- if (
- i in ("element", "target")
- or i.startswith("_")
- or callable(attr)
- or not isinstance(attr, var)
- ):
- continue
- if f in attr.data:
- kwargs[i] = attr.data[f]
- break
-
- for k, v in kwargs.items():
- setattr(self, k, v)
-
- def _safeset(self, attr, value):
- try:
- setattr(self, attr, value)
- except ValueError:
- pass
-
- def __repr__(self):
- return "<{0}.{1}({2}) at {3}>".format(
- self.__module__, type(self).__name__, self.id, hex(id(self))
- )
-
- def __str__(self):
- return f"'{self.target}': {self.description}\n{self.details}\n{self.severity}"
-
-
-class TM:
- """Describes the threat model administratively,
- and holds all details during a run"""
-
- _flows = []
- _elements = []
- _actors = []
- _assets = []
- _threats = []
- _boundaries = []
- _data = []
- _threatsExcluded = []
- _sf = None
- _duplicate_ignored_attrs = (
- "name",
- "note",
- "order",
- "response",
- "responseTo",
- "controls",
- )
- name = varString("", required=True, doc="Model name")
- description = varString("", required=True, doc="Model description")
- threatsFile = varString(
- os.path.dirname(__file__) + "/threatlib/threats.json",
- onSet=lambda i, v: i._init_threats(),
- doc="JSON file with custom threats",
+ inputs: defaultdict[Element, list[Dataflow]] = field(
+ default_factory=lambda: defaultdict(list)
)
- isOrdered = varBool(False, doc="Automatically order all Dataflows")
- mergeResponses = varBool(False, doc="Merge response edges in DFDs")
- ignoreUnused = varBool(False, doc="Ignore elements not used in any Dataflow")
- findings = varFindings([], doc="Threats found for elements of this model")
- excluded_findings = varFindings(
- [],
- doc="Threats found for elements of this model, "
- "that were excluded on a per-element basis, using the Assumptions class"
+ outputs: defaultdict[Element, list[Dataflow]] = field(
+ default_factory=lambda: defaultdict(list)
)
- onDuplicates = varAction(
- Action.NO_ACTION,
- doc="""How to handle duplicate Dataflow
-with same properties, except name and notes""",
+ carriers: defaultdict[Data, set[Dataflow]] = field(
+ default_factory=lambda: defaultdict(set)
)
- assumptions = varAssumptions(
- [],
- required=False,
- doc="A list of assumptions about the design/model.",
+ processors: defaultdict[Data, set[Element]] = field(
+ default_factory=lambda: defaultdict(set)
)
- _colormap = False
-
- def __init__(self, name, **kwargs):
- for key, value in kwargs.items():
- setattr(self, key, value)
- self.name = name
- self._sf = SuperFormatter()
- self._add_threats()
- # make sure generated diagrams do not change, makes sense if they're commited
- random.seed(0)
-
- @classmethod
- def reset(cls):
- cls._flows = []
- cls._elements = []
- cls._actors = []
- cls._assets = []
- cls._threats = []
- cls._boundaries = []
- cls._data = []
- cls._threatsExcluded = []
-
- def _init_threats(self):
- TM._threats = []
- self._add_threats()
-
- def _add_threats(self):
- try:
- with open(self.threatsFile, "r", encoding="utf8") as threat_file:
- threats_json = json.load(threat_file)
- except (FileNotFoundError, PermissionError, IsADirectoryError) as e:
- raise UIError(
- e, f"while trying to open the the threat file ({self.threatsFile})."
- )
- active_threats = (threat for threat in threats_json if "DEPRECATED" not in threat)
- for threat in active_threats:
- TM._threats.append(Threat(**threat))
-
- def resolve(self):
- finding_count = 0
- excluded_finding_count = 0
- findings = []
- excluded_findings = []
- # We just need the assumptions with SIDs to exclude
- global_assumptions = [a for a in self.assumptions if len(a.exclude) > 0]
- elements = defaultdict(list)
- for e in TM._elements:
- if not e.inScope:
- e.findings = findings
- continue
-
- override_ids = set(f.threat_id for f in e.overrides)
- # if element is a dataflow filter out overrides from source and sink
- # because they will be always applied there anyway
- try:
- override_ids -= set(
- f.threat_id for f in e.source.overrides + e.sink.overrides
- )
- except AttributeError:
- pass
-
- for t in TM._threats:
- if not t.apply(e) and t.id not in override_ids:
- continue
-
- if t.id in TM._threatsExcluded:
- continue
-
- _continue = False
- for assumption in e.assumptions + global_assumptions: # type: Assumption
- if t.id in assumption.exclude:
- excluded_finding_count += 1
- f = Finding(e, id=str(excluded_finding_count), threat=t, assumption=assumption)
- excluded_findings.append(f)
- _continue = True
- break
- if _continue:
- continue
-
- finding_count += 1
- f = Finding(e, id=str(finding_count), threat=t)
- findings.append(f)
- elements[e].append(f)
- e._set_severity(f.severity)
- self.findings = findings
- self.excluded_findings = excluded_findings
- for e, findings in elements.items():
- e.findings = findings
-
- def check(self):
- if self.description is None:
- raise ValueError(
- """Every threat model should have at least
-a brief description of the system being modeled."""
- )
- TM._flows = _match_responses(_sort(TM._flows, self.isOrdered))
-
- self._check_duplicates(TM._flows)
-
- _apply_defaults(TM._flows, TM._data)
-
- for e in TM._elements:
- top = Counter(f.threat_id for f in e.overrides).most_common(1)
- if not top:
- continue
- threat_id, count = top[0]
- if count != 1:
- raise ValueError(
- f"Finding {threat_id} have more than one override in {e}"
- )
-
- if self.ignoreUnused:
- TM._elements, TM._boundaries = _get_elements_and_boundaries(TM._flows)
-
- result = True
- for e in TM._elements:
- if not e.check():
- result = False
+ assignment_errors: ClassVar[tuple[type[Exception], ...]] = (
+ ValueError,
+ AttributeError,
+ TypeError,
+ ValidationError,
+ )
- if self.ignoreUnused:
- # cannot rely on user defined order if assets are re-used in multiple models
- TM._elements = _sort_elem(TM._elements)
+ def seed_data_relationships(self, data_items: Iterable[Data]) -> None:
+ """Ensure data instances are referenced by existing carriers."""
+ for datum in data_items:
+ for flow in getattr(datum, "carriedBy", []):
+ _add_data(getattr(flow, "data", None), datum)
- return result
+ def process_flow(self, flow: Dataflow) -> None:
+ """Apply defaults and collect relationships for a single flow."""
+ self._inherit_source_data(flow)
+ self._index_flow_relationships(flow)
+ self._sync_levels(flow)
+ self._merge_overrides(flow)
- def _check_duplicates(self, flows):
- if self.onDuplicates == Action.NO_ACTION:
+ if getattr(flow, "isResponse", False):
+ self._apply_response_defaults(flow)
return
- index = defaultdict(list)
- for e in flows:
- key = (e.source, e.sink)
- index[key].append(e)
-
- for flows in index.values():
- for left, right in combinations(flows, 2):
- left_attrs = left._attr_values()
- right_attrs = right._attr_values()
- for a in self._duplicate_ignored_attrs:
- del left_attrs[a], right_attrs[a]
- if left_attrs != right_attrs:
- continue
- if self.onDuplicates == Action.IGNORE:
- right._is_drawn = True
- continue
-
- left_controls_attrs = left.controls._attr_values()
- right_controls_attrs = right.controls._attr_values()
- # for a in self._duplicate_ignored_attrs:
- # del left_controls_attrs[a], right_controls_attrs[a]
- if left_controls_attrs != right_controls_attrs:
- continue
- if self.onDuplicates == Action.IGNORE:
- right._is_drawn = True
- continue
-
- raise ValueError(
- "Duplicate Dataflow found between {} and {}: "
- "{} is same as {}".format(
- left.source,
- left.sink,
- left,
- right,
- )
- )
-
- def _dfd_template(self):
- return """digraph tm {{
- graph [
- fontname = Arial;
- fontsize = 14;
- ]
- node [
- fontname = Arial;
- fontsize = 14;
- rankdir = lr;
- ]
- edge [
- shape = none;
- arrowtail = onormal;
- fontname = Arial;
- fontsize = 12;
- ]
- labelloc = "t";
- fontsize = 20;
- nodesep = 1;
-
-{edges}
-}}"""
-
- def dfd(self, **kwargs):
- if "levels" in kwargs:
- levels = kwargs["levels"]
- if not isinstance(kwargs["levels"], Iterable):
- kwargs["levels"] = [levels]
- kwargs["levels"] = set(levels)
-
- edges = []
- # since boundaries can be nested sort them by level and start from top
- parents = set(b.inBoundary for b in TM._boundaries if b.inBoundary)
-
- # TODO boundaries should not be drawn if they don't contain elements matching requested levels
- # or contain only empty boundaries
- boundary_levels = defaultdict(set)
- max_level = 0
- for b in TM._boundaries:
- if b in parents:
- continue
- boundary_levels[0].add(b)
- for i, p in enumerate(b.parents()):
- i = i + 1
- boundary_levels[i].add(p)
- if i > max_level:
- max_level = i
-
- for i in range(max_level, -1, -1):
- for b in sorted(boundary_levels[i], key=lambda b: b.name):
- edges.append(b.dfd(**kwargs))
-
- if self.mergeResponses:
- for e in TM._flows:
- if e.response is not None:
- e.response._is_drawn = True
- kwargs["mergeResponses"] = self.mergeResponses
- for e in TM._elements:
- if not e._is_drawn and not isinstance(e, Boundary) and e.inBoundary is None:
- edges.append(e.dfd(**kwargs))
-
- return self._dfd_template().format(
- edges=indent("\n".join(filter(len, edges)), " ")
- )
-
- def _seq_template(self):
- return """@startuml
-{participants}
-
-{messages}
-@enduml"""
-
- def seq(self):
- participants = []
- for e in TM._elements:
- if isinstance(e, Actor):
- participants.append(
- 'actor {0} as "{1}"'.format(e._uniq_name(), e.display_name())
- )
- elif isinstance(e, Datastore):
- participants.append(
- 'database {0} as "{1}"'.format(e._uniq_name(), e.display_name())
- )
- elif not isinstance(e, Dataflow) and not isinstance(e, Boundary):
- participants.append(
- 'entity {0} as "{1}"'.format(e._uniq_name(), e.display_name())
- )
-
- messages = []
- for e in TM._flows:
- message = "{0} -> {1}: {2}".format(
- e.source._uniq_name(), e.sink._uniq_name(), e.display_name()
- )
- note = ""
- if e.note != "":
- note = "\nnote left\n{}\nend note".format(e.note)
- messages.append("{}{}".format(message, note))
+ self._apply_forward_defaults(flow)
+ self._enrich_data_attributes(flow)
+ self.inputs[flow.sink].append(flow)
+ self.outputs[flow.source].append(flow)
- return self._seq_template().format(
- participants="\n".join(participants), messages="\n".join(messages)
- )
+ def finalize_assets(self) -> None:
+ """Populate inputs/outputs on elements once all flows are processed."""
+ for asset, flow_list in self.inputs.items():
+ self._set_sequence(asset, "inputs", flow_list)
- def report(self, template_path):
- try:
- with open(template_path) as file:
- template = file.read()
- except (FileNotFoundError, PermissionError, IsADirectoryError) as e:
- raise UIError(
- e, f"while trying to open the report template file ({template_path})."
- )
+ for asset, flow_list in self.outputs.items():
+ self._set_sequence(asset, "outputs", flow_list)
- threats = encode_threat_data(TM._threats)
- findings = encode_threat_data(self.findings)
-
- elements = encode_element_threat_data(TM._elements)
- assets = encode_element_threat_data(TM._assets)
- actors = encode_element_threat_data(TM._actors)
- boundaries = encode_element_threat_data(TM._boundaries)
- flows = encode_element_threat_data(TM._flows)
-
- data = {
- "tm": self,
- "dataflows": flows,
- "threats": threats,
- "findings": findings,
- "elements": elements,
- "assets": assets,
- "actors": actors,
- "boundaries": boundaries,
- "data": TM._data,
- }
-
- return self._sf.format(template, **data)
-
- def process(self):
- try:
- self._process()
- except UIError as e:
- erromsg = f"""Failed to excecute
- {e.context}
- {e.error}
-"""
- sys.stderr.write(erromsg)
- sys.exit(127)
-
- def _process(self):
- self.check()
- result = get_args()
- logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
-
- if result.debug:
- logger.setLevel(logging.DEBUG)
-
- if result.exclude is not None:
- TM._threatsExcluded = result.exclude.split(",")
-
- if result.seq is True:
- print(self.seq())
-
- if result.dfd is True:
- if result.colormap is True:
- self.resolve()
- print(self.dfd(colormap=result.colormap, levels=(result.levels or set())))
-
- if (
- result.report is not None
- or result.json is not None
- or result.sqldump is not None
- or result.stale_days is not None
- ):
- self.resolve()
-
- if result.sqldump is not None:
- self.sqlDump(result.sqldump)
-
- if result.json:
+ def finalize_data_relationships(self) -> None:
+ """Attach carrier and processor metadata to data objects."""
+ for datum, flow_list in self.carriers.items():
+ ordered = sorted(flow_list, key=lambda f: f.name)
try:
- with open(result.json, "w", encoding="utf8") as f:
- json.dump(self, f, default=to_serializable)
- except (FileExistsError, PermissionError, IsADirectoryError) as e:
- raise UIError(
- e, f"while trying to write to the result file ({result.json})"
- )
-
- if result.report is not None:
- print(self.report(result.report))
-
- if result.describe is not None:
- _describe_classes(result.describe.split())
-
- if result.list_elements:
- _list_elements()
-
- if result.list is True:
- [print("{} - {}".format(t.id, t.description)) for t in TM._threats]
-
- if result.stale_days is not None:
- print(self._stale(result.stale_days))
-
- def _stale(self, days):
- try:
- base_path = os.path.dirname(sys.argv[0])
- tm_mtime = datetime.fromtimestamp(
- os.stat(base_path + f"/{sys.argv[0]}").st_mtime
- )
- except os.error as err:
- sys.stderr.write(f"{sys.argv[0]} - {err}\n")
- sys.stderr.flush()
- return "[ERROR]"
-
- print(f"Checking for code {days} days older than this model.")
-
- for e in TM._elements:
- for src in e.sourceFiles:
- try:
- src_mtime = datetime.fromtimestamp(
- os.stat(base_path + f"/{src}").st_mtime
- )
- except os.error as err:
- sys.stderr.write(f"{sys.argv[0]} - {err}\n")
- sys.stderr.flush()
- continue
-
- age = (src_mtime - tm_mtime).days
-
- # source code is older than model by more than the speficied delta
- if (age) >= days:
- print(f"This model is {age} days older than {base_path}/{src}.")
- elif age <= -days:
- print(
- f"Model script {sys.argv[0]}"
- + " is only "
- + str(-1 * age)
- + " days newer than source code file "
- + f"{base_path}/{src}"
- )
-
- return ""
-
- def sqlDump(self, filename):
- try:
- from pydal import DAL, Field
- except ImportError as e:
- raise UIError(
- e, """This feature requires the pyDAL package,
- Please install the package via pip or your packagemanger of choice.
- """
- )
-
- @lru_cache(maxsize=None)
- def get_table(db, klass):
- name = klass.__name__
- fields = [
- Field("SID" if i == "id" else i)
- for i in dir(klass)
- if not i.startswith("_") and not callable(getattr(klass, i))
- ]
- return db.define_table(name, fields)
-
+ setattr(datum, "carriedBy", list(ordered))
+ except self.assignment_errors:
+ for flow in ordered:
+ existing = getattr(datum, "carriedBy", [])
+ if flow not in existing:
+ existing.append(flow)
+
+ for datum, elements in self.processors.items():
+ ordered = sorted(elements, key=lambda el: el.name)
+ try:
+ setattr(datum, "processedBy", list(ordered))
+ except self.assignment_errors:
+ for element in ordered:
+ existing = getattr(datum, "processedBy", [])
+ if element not in existing:
+ existing.append(element)
+
+ def _inherit_source_data(self, flow: Dataflow) -> None:
+ source_data = getattr(flow.source, "data", None)
+ if source_data:
+ _add_data(getattr(flow, "data", None), source_data)
+
+ def _index_flow_relationships(self, flow: Dataflow) -> None:
+ for datum in list(getattr(flow, "data", [])):
+ self.carriers[datum].add(flow)
+ self.processors[datum].add(flow.source)
+ self.processors[datum].add(flow.sink)
+
+ @staticmethod
+ def _sync_levels(flow: Dataflow) -> None:
try:
- rmtree("./sqldump")
- os.mkdir("./sqldump")
- except OSError as e:
- if e.errno != errno.ENOENT:
- raise
- else:
- os.mkdir("./sqldump")
-
- db = DAL("sqlite://" + filename, folder="sqldump")
-
- for klass in (
- Server,
- ExternalEntity,
- Dataflow,
- Datastore,
- Actor,
- Process,
- SetOfProcesses,
- Boundary,
- TM,
- Threat,
- Lambda,
- Data,
- Finding,
- ):
- get_table(db, klass)
-
- for e in TM._threats + TM._data + TM._elements + self.findings + [self]:
- table = get_table(db, e.__class__)
- row = {}
- for k, v in serialize(e).items():
- if k == "id":
- k = "SID"
- row[k] = ", ".join(str(i) for i in v) if isinstance(v, list) else v
- db[table].bulk_insert([row])
-
- db.close()
-
-
-
-class Controls:
- """Controls implemented by/on and Element"""
-
- authenticatesDestination = varBool(
- False,
- doc="""Verifies the identity of the destination,
-for example by verifying the authenticity of a digital certificate.""",
- )
- authenticatesSource = varBool(False)
- authenticationScheme = varString("")
- authorizesSource = varBool(False)
- checksDestinationRevocation = varBool(
- False,
- doc="""Correctly checks the revocation status
-of credentials used to authenticate the destination""",
- )
- checksInputBounds = varBool(False)
- definesConnectionTimeout = varBool(False)
- disablesDTD = varBool(False)
- disablesiFrames = varBool(False)
- encodesHeaders = varBool(False)
- encodesOutput = varBool(False)
- encryptsCookies = varBool(False)
- encryptsSessionData = varBool(False)
- handlesCrashes = varBool(False)
- handlesInterruptions = varBool(False)
- handlesResourceConsumption = varBool(False)
- hasAccessControl = varBool(False)
- implementsAuthenticationScheme = varBool(False)
- implementsCSRFToken = varBool(False)
- implementsNonce = varBool(
- False,
- doc="""Nonce is an arbitrary number
-that can be used just once in a cryptographic communication.
-It is often a random or pseudo-random number issued in an authentication protocol
-to ensure that old communications cannot be reused in replay attacks.
-They can also be useful as initialization vectors and in cryptographic
-hash functions.""",
- )
- implementsPOLP = varBool(
- False,
- doc="""The principle of least privilege (PoLP),
-also known as the principle of minimal privilege or the principle of least authority,
-requires that in a particular abstraction layer of a computing environment,
-every module (such as a process, a user, or a program, depending on the subject)
-must be able to access only the information and resources
-that are necessary for its legitimate purpose.""",
- )
- implementsServerSideValidation = varBool(False)
- implementsStrictHTTPValidation = varBool(False)
- invokesScriptFilters = varBool(False)
- isEncrypted = varBool(False, doc="Requires incoming data flow to be encrypted")
- isEncryptedAtRest = varBool(False, doc="Stored data is encrypted at rest")
- isHardened = varBool(False)
- isResilient = varBool(False)
- providesConfidentiality = varBool(False)
- providesIntegrity = varBool(False)
- sanitizesInput = varBool(False)
- tracksExecutionFlow = varBool(False)
- usesCodeSigning = varBool(False)
- usesEncryptionAlgorithm = varString("")
- usesMFA = varBool(
- False,
- doc="""Multi-factor authentication is an authentication method
-in which a computer user is granted access only after successfully presenting two
-or more pieces of evidence (or factors) to an authentication mechanism: knowledge
-(something the user and only the user knows), possession (something the user
-and only the user has), and inherence (something the user and only the user is).""",
- )
- usesParameterizedInput = varBool(False)
- usesSecureFunctions = varBool(False)
- usesStrongSessionIdentifiers = varBool(False)
- usesVPN = varBool(False)
- validatesContentType = varBool(False)
- validatesHeaders = varBool(False)
- validatesInput = varBool(False)
- verifySessionIdentifiers = varBool(False)
-
- def _attr_values(self):
- klass = self.__class__
- result = {}
- for i in dir(klass):
- if i.startswith("_") or callable(getattr(klass, i)):
- continue
- attr = getattr(klass, i, {})
- if isinstance(attr, var):
- value = attr.data.get(self, attr.default)
- else:
- value = getattr(self, i)
- result[i] = value
- return result
+ level_intersection = flow.source.levels & flow.sink.levels
+ except TypeError:
+ level_intersection = set()
+ if level_intersection:
+ flow._safeset("levels", level_intersection)
- def _safeset(self, attr, value):
+ def _merge_overrides(self, flow: Dataflow) -> None:
try:
- setattr(self, attr, value)
- except ValueError:
+ sink_overrides = list(getattr(flow.sink, "overrides", []))
+ source_overrides = list(getattr(flow.source, "overrides", []))
+ combined = list(sink_overrides)
+ existing_ids = {getattr(finding, "threat_id", None) for finding in combined}
+ for finding in source_overrides:
+ sid = getattr(finding, "threat_id", None)
+ if sid not in existing_ids:
+ combined.append(finding)
+ existing_ids.add(sid)
+ flow.overrides = combined
+ except self.assignment_errors:
pass
-
-class Assumption:
- """
- Assumption used by an Element.
- Used to exclude threats on a per-element basis.
- """
- name = varString("", required=True)
- exclude = varStrings([], doc="A list of threat SIDs to exclude for this assumption. For example: INP01")
- description = varString("", doc="An additional description of the assumption")
-
- def __init__(self, name, **kwargs):
- for key, value in kwargs.items():
- setattr(self, key, value)
- self.name = name
-
- def __str__(self):
- return self.name
-
-
-class Element:
- """A generic element"""
-
- name = varString("", required=True)
- description = varString("")
- inBoundary = varBoundary(None, doc="Trust boundary this element exists in")
- inScope = varBool(True, doc="Is the element in scope of the threat model")
- maxClassification = varClassification(
- Classification.UNKNOWN,
- required=False,
- doc="Maximum data classification this element can handle.",
- )
- minTLSVersion = varTLSVersion(
- TLSVersion.NONE,
- required=False,
- doc="""Minimum TLS version required.""",
- )
- findings = varFindings([], doc="Threats that apply to this element")
- overrides = varFindings(
- [],
- doc="""Overrides to findings, allowing to set
-a custom response, CVSS score or override other attributes.""",
- )
- assumptions = varAssumptions(
- [],
- doc="Assumptions about the element. These optionally allow to exclude threats with the given SIDs.",
- )
- levels = varInts({0}, doc="List of levels (0, 1, 2, ...) to be drawn in the model.")
- sourceFiles = varStrings(
- [],
- required=False,
- doc="Location of the source code that describes this element relative to the directory of the model script.",
- )
- controls = varControls(None)
- severity = 0
-
- def __init__(self, name, **kwargs):
- for key, value in kwargs.items():
- setattr(self, key, value)
- self.name = name
- self.controls = Controls()
- self.uuid = uuid.UUID(int=random.getrandbits(128))
- self._is_drawn = False
- TM._elements.append(self)
-
- def __repr__(self):
- return "<{0}.{1}({2}) at {3}>".format(
- self.__module__, type(self).__name__, self.name, hex(id(self))
+ def _apply_response_defaults(self, flow: Dataflow) -> None:
+ flow._safeset(
+ "protocol", getattr(flow.source, "protocol", getattr(flow, "protocol", ""))
)
-
- def __str__(self):
- return "{0}({1})".format(type(self).__name__, self.name)
-
- def _uniq_name(self):
- """transform name and uuid into a unique string"""
- h = sha224(str(self.uuid).encode("utf-8")).hexdigest()
- name = "".join(x for x in self.name if x.isalpha())
- return "{0}_{1}_{2}".format(type(self).__name__.lower(), name, h[:10])
-
- def check(self):
- return True
-
- def _dfd_template(self):
- return """{uniq_name} [
- shape = {shape};
- color = {color};
- fontcolor = black;
- label = "{label}";
- margin = 0.02;
-]
-"""
-
- def dfd(self, **kwargs):
- self._is_drawn = True
-
- levels = kwargs.get("levels", None)
- if levels and not levels & self.levels:
- return ""
-
- color = self._color()
- if kwargs.get("colormap", False):
- color = sev_to_color(self.severity)
-
- return self._dfd_template().format(
- uniq_name=self._uniq_name(),
- label=self._label(),
- color=color,
- shape=self._shape(),
+ flow._safeset(
+ "srcPort", getattr(flow.source, "port", getattr(flow, "srcPort", -1))
)
+ if hasattr(flow.source, "controls"):
+ flow.controls._safeset(
+ "isEncrypted", getattr(flow.source.controls, "isEncrypted", False)
+ )
- def _color(self):
- if self.inScope is True:
- return "black"
- else:
- return "grey69"
-
- def display_name(self):
- return self.name
-
- def _label(self):
- return "\\n".join(wrap(self.display_name(), 18))
-
- def _shape(self):
- return "square"
-
- def _safeset(self, attr, value):
- try:
- setattr(self, attr, value)
- except ValueError:
- pass
+ def _apply_forward_defaults(self, flow: Dataflow) -> None:
+ flow._safeset(
+ "protocol", getattr(flow.sink, "protocol", getattr(flow, "protocol", ""))
+ )
+ flow._safeset(
+ "dstPort", getattr(flow.sink, "port", getattr(flow, "dstPort", -1))
+ )
+ if hasattr(flow.sink, "controls"):
+ flow.controls._safeset(
+ "isEncrypted", getattr(flow.sink.controls, "isEncrypted", False)
+ )
+ if hasattr(flow.source, "controls"):
+ flow.controls._safeset(
+ "authenticatesDestination",
+ getattr(flow.source.controls, "authenticatesDestination", False),
+ )
+ flow.controls._safeset(
+ "checksDestinationRevocation",
+ getattr(flow.source.controls, "checksDestinationRevocation", False),
+ )
- def oneOf(self, *elements):
- """Is self one of a list of Elements"""
- for element in elements:
- if inspect.isclass(element):
- if isinstance(self, element):
- return True
- elif self is element:
- return True
- return False
-
- def crosses(self, *boundaries):
- """Does self (dataflow) cross any of the list of boundaries"""
- if self.source.inBoundary is self.sink.inBoundary:
- return False
- for boundary in boundaries:
- if inspect.isclass(boundary):
- if (
- (
- isinstance(self.source.inBoundary, boundary)
- and not isinstance(self.sink.inBoundary, boundary)
- )
- or (
- not isinstance(self.source.inBoundary, boundary)
- and isinstance(self.sink.inBoundary, boundary)
+ def _enrich_data_attributes(self, flow: Dataflow) -> None:
+ for datum in list(getattr(flow, "data", [])):
+ if getattr(datum, "isStored", False):
+ if hasattr(flow.sink, "controls") and hasattr(
+ flow.sink.controls, "isEncryptedAtRest"
+ ):
+ datum._safeset(
+ "isDestEncryptedAtRest", flow.sink.controls.isEncryptedAtRest
)
- or self.source.inBoundary is not self.sink.inBoundary
+ if hasattr(flow.source, "controls") and hasattr(
+ flow.source.controls, "isEncryptedAtRest"
):
- return True
- elif (self.source.inside(boundary) and not self.sink.inside(boundary)) or (
- not self.source.inside(boundary) and self.sink.inside(boundary)
- ):
- return True
- return False
-
- def enters(self, *boundaries):
- """does self (dataflow) enter into one of the list of boundaries"""
- return self.source.inBoundary is None and self.sink.inside(*boundaries)
-
- def exits(self, *boundaries):
- """does self (dataflow) exit one of the list of boundaries"""
- return self.source.inside(*boundaries) and self.sink.inBoundary is None
-
- def inside(self, *boundaries):
- """is self inside of one of the list of boundaries"""
- for boundary in boundaries:
- if inspect.isclass(boundary):
- if isinstance(self.inBoundary, boundary):
- return True
- elif self.inBoundary is boundary:
- return True
- return False
-
- def _attr_values(self):
- klass = self.__class__
- result = {}
- for i in dir(klass):
- if i.startswith("_") or callable(getattr(klass, i)):
- continue
- attr = getattr(klass, i, {})
- if isinstance(attr, var):
- value = attr.data.get(self, attr.default)
- else:
- value = getattr(self, i)
- result[i] = value
- return result
-
- def checkTLSVersion(self, flows):
- return any(f.tlsVersion < self.minTLSVersion for f in flows)
-
- def _set_severity(self, sev):
- sevs = {
- "very high": 5,
- "high": 4,
- "medium": 3,
- "low": 2,
- "very low": 1,
- "info": 0,
- }
-
- if sev.lower() not in sevs.keys():
- return
-
- if self.severity < sevs[sev.lower()]:
- self.severity = sevs[sev.lower()]
- return
-
-
-class Data:
- """Represents a single piece of data that traverses the system"""
-
- name = varString("", required=True)
- description = varString("")
- format = varString("")
- classification = varClassification(
- Classification.UNKNOWN,
- required=True,
- doc="Level of classification for this piece of data",
- )
- isPII = varBool(
- False,
- doc="""Does the data contain personally identifyable information.
-Should always be encrypted both in transmission and at rest.""",
- )
- isCredentials = varBool(
- False,
- doc="""Does the data contain authentication information,
-like passwords or cryptographic keys, with or without expiration date.
-Should always be encrypted in transmission. If stored, they should be hashed
-using a cryptographic hash function.""",
- )
- credentialsLife = varLifetime(
- Lifetime.NONE,
- doc="""Credentials lifetime, describing if and how
-credentials can be revoked. One of:
-* NONE - not applicable
-* UNKNOWN - unknown lifetime
-* SHORT - relatively short expiration date, with an allowed maximum
-* LONG - long or no expiration date
-* AUTO - no expiration date but can be revoked/invalidated automatically
- in some conditions
-* MANUAL - no expiration date but can be revoked/invalidated manually
-* HARDCODED - cannot be invalidated at all""",
- )
- isStored = varBool(
- False,
- doc="""Is the data going to be stored by the target or only processed.
-If only derivative data is stored (a hash) it can be set to False.""",
- )
- isDestEncryptedAtRest = varBool(False, doc="Is data encrypted at rest at dest")
- isSourceEncryptedAtRest = varBool(False, doc="Is data encrypted at rest at source")
- carriedBy = varElements([], doc="Dataflows that carries this piece of data")
- processedBy = varElements([], doc="Elements that store/process this piece of data")
-
- def __init__(self, name, **kwargs):
- for key, value in kwargs.items():
- setattr(self, key, value)
- self.name = name
- TM._data.append(self)
-
- def __repr__(self):
- return "<{0}.{1}({2}) at {3}>".format(
- self.__module__, type(self).__name__, self.name, hex(id(self))
- )
+ datum._safeset(
+ "isSourceEncryptedAtRest",
+ flow.source.controls.isEncryptedAtRest,
+ )
- def __str__(self):
- return "{0}({1})".format(type(self).__name__, self.name)
+ if getattr(
+ datum, "credentialsLife", Lifetime.NONE
+ ) != Lifetime.NONE and not getattr(datum, "isCredentials", False):
+ datum._safeset("isCredentials", True)
+ if (
+ getattr(datum, "isCredentials", False)
+ and getattr(datum, "credentialsLife", Lifetime.NONE) == Lifetime.NONE
+ ):
+ datum._safeset("credentialsLife", Lifetime.UNKNOWN)
- def _safeset(self, attr, value):
+ def _set_sequence(
+ self, obj: Element, attr: str, values: Iterable[Dataflow]
+ ) -> None:
+ if not hasattr(obj, attr):
+ return
+ ordered = list(values)
try:
- setattr(self, attr, value)
- except ValueError:
- pass
-
-
-class Asset(Element):
- """An asset with outgoing or incoming dataflows"""
-
- port = varInt(-1, doc="Default TCP port for incoming data flows")
- protocol = varString("", doc="Default network protocol for incoming data flows")
- data = varData([], doc="pytm.Data object(s) in incoming data flows")
- inputs = varElements([], doc="incoming Dataflows")
- outputs = varElements([], doc="outgoing Dataflows")
- onAWS = varBool(False)
- handlesResources = varBool(False)
- usesEnvironmentVariables = varBool(False)
- OS = varString("")
-
- def __init__(self, name, **kwargs):
- super().__init__(name, **kwargs)
- TM._assets.append(self)
-
-
-class Lambda(Asset):
- """A lambda function running in a Function-as-a-Service (FaaS) environment"""
-
- onAWS = varBool(True)
- environment = varString("")
- implementsAPI = varBool(False)
-
- def __init__(self, name, **kwargs):
- super().__init__(name, **kwargs)
-
- def _dfd_template(self):
- return """{uniq_name} [
- shape = {shape};
-
- color = {color};
- fontcolor = "black";
- label = <
-
- >;
-]
-"""
-
- def dfd(self, **kwargs):
- self._is_drawn = True
-
- levels = kwargs.get("levels", None)
- if levels and not levels & self.levels:
- return ""
-
- color = self._color()
-
- if kwargs.get("colormap", False):
- color = sev_to_color(self.severity)
-
- return self._dfd_template().format(
- uniq_name=self._uniq_name(),
- label=self._label(),
- color=color,
- shape=self._shape(),
- )
-
- def _shape(self):
- return "rectangle; style=rounded"
-
-
-class LLM(Asset):
- """A Large Language Model element, either third-party or self-hosted"""
-
- isThirdParty = varBool(True)
- isSelfHosted = varBool(False)
- processesPersonalData = varBool(False)
- retainsUserData = varBool(False)
- hasAgentCapabilities = varBool(False)
- hasAccessToSensitiveSystems = varBool(False)
- executesCode = varBool(False)
- hasContentFiltering = varBool(False)
- hasSystemPrompt = varBool(True)
- processesUntrustedInput = varBool(True)
- hasRAG = varBool(False)
- hasFineTuning = varBool(False)
-
- def __init__(self, name, **kwargs):
- super().__init__(name, **kwargs)
-
- def _shape(self):
- return "hexagon"
-
-
-class Server(Asset):
- """An entity processing data"""
-
- usesSessionTokens = varBool(False)
- usesCache = varBool(False)
- usesVPN = varBool(False)
- usesXMLParser = varBool(False)
-
- def __init__(self, name, **kwargs):
- super().__init__(name, **kwargs)
-
- def _shape(self):
- return "circle"
-
-
-class ExternalEntity(Asset):
- hasPhysicalAccess = varBool(False)
-
- def __init__(self, name, **kwargs):
- super().__init__(name, **kwargs)
-
-
-class Datastore(Asset):
- """An entity storing data"""
-
- onRDS = varBool(False)
- storesLogData = varBool(False)
- storesPII = varBool(
- False,
- doc="""Personally Identifiable Information
-is any information relating to an identifiable person.""",
- )
- storesSensitiveData = varBool(False)
- isSQL = varBool(True)
- isShared = varBool(False)
- hasWriteAccess = varBool(False)
- type = varDatastoreType(
- DatastoreType.UNKNOWN,
- doc="""The type of Datastore, values may be one of:
-* UNKNOWN - unknown applicable
-* FILE_SYSTEM - files on a file system
-* SQL - A SQL Database
-* LDAP - An LDAP Server
-* AWS_S3 - An S3 Bucket within AWS""",
- )
-
- def __init__(self, name, **kwargs):
- super().__init__(name, **kwargs)
-
- def _dfd_template(self):
- return """{uniq_name} [
- shape = {shape};
- fixedsize = shape;
- image = "{image}";
- imagescale = true;
- color = {color};
- fontcolor = black;
- xlabel = "{label}";
- label = "";
-]
-"""
-
- def _shape(self):
- return "none"
-
- def dfd(self, **kwargs):
- self._is_drawn = True
-
- levels = kwargs.get("levels", None)
- if levels and not levels & self.levels:
- return ""
-
- color = self._color()
- color_file = "black"
-
- if kwargs.get("colormap", False):
- color = sev_to_color(self.severity)
- color_file = color.split(";")[0]
-
- return self._dfd_template().format(
- uniq_name=self._uniq_name(),
- label=self._label(),
- color=color,
- shape=self._shape(),
- image=os.path.join(
- os.path.dirname(__file__), "images", f"datastore_{color_file}.png"
- ),
- )
-
-
-class Actor(Element):
- """An entity usually initiating actions.
-
- Actors represent users or external systems that initiate
- interactions with the system being modeled.
-
- Attributes:
- port (int): Default TCP port for outgoing data flows.
- protocol (str): Default network protocol for outgoing data flows.
- data (list): pytm.Data objects carried in outgoing data flows.
- inputs (list): Incoming Dataflows.
- outputs (list): Outgoing Dataflows.
- isAdmin (bool): Indicates whether the actor has administrative privileges.
- """
-
- port = varInt(-1, doc="Default TCP port for outgoing data flows")
- protocol = varString("", doc="Default network protocol for outgoing data flows")
- data = varData([], doc="pytm.Data object(s) in outgoing data flows")
- inputs = varElements([], doc="incoming Dataflows")
- outputs = varElements([], doc="outgoing Dataflows")
- isAdmin = varBool(False)
-
- def __init__(self, name, **kwargs):
- """
- Initialize an Actor.
-
- Args:
- name (str): Name of the actor.
- **kwargs: Optional actor properties.
- port (int): Default TCP port for outgoing data flows.
- protocol (str): Default network protocol for outgoing data flows.
- data (list): pytm.Data objects in outgoing data flows.
- inputs (list): Incoming Dataflows.
- outputs (list): Outgoing Dataflows.
- isAdmin (bool): Indicates administrative privileges.
- """
- super().__init__(name, **kwargs)
- TM._actors.append(self)
-
-
-
-class Process(Asset):
- """An entity processing data"""
-
- codeType = varString("Unmanaged")
- implementsCommunicationProtocol = varBool(False)
- tracksExecutionFlow = varBool(False)
- implementsAPI = varBool(False)
- environment = varString("")
- allowsClientSideScripting = varBool(False)
-
- def __init__(self, name, **kwargs):
- super().__init__(name, **kwargs)
-
- def _shape(self):
- return "circle"
-
-
-class SetOfProcesses(Process):
- def __init__(self, name, **kwargs):
- super().__init__(name, **kwargs)
-
- def _shape(self):
- return "doublecircle"
-
-
-class Dataflow(Element):
- """A data flow from a source to a sink"""
-
- source = varElement(None, required=True)
- sink = varElement(None, required=True)
- isResponse = varBool(False, doc="Is a response to another data flow")
- response = varElement(None, doc="Another data flow that is a response to this one")
- responseTo = varElement(None, doc="Is a response to this data flow")
- srcPort = varInt(-1, doc="Source TCP port")
- dstPort = varInt(-1, doc="Destination TCP port")
- tlsVersion = varTLSVersion(
- TLSVersion.NONE,
- required=True,
- doc="TLS version used.",
- )
- protocol = varString("", doc="Protocol used in this data flow")
- data = varData([], doc="pytm.Data object(s) in incoming data flows")
- order = varInt(-1, doc="Number of this data flow in the threat model")
- implementsCommunicationProtocol = varBool(False)
- note = varString("")
- usesVPN = varBool(False)
- usesSessionTokens = varBool(False)
- severity = 0
-
- def __init__(self, source, sink, name, **kwargs):
- self.source = source
- self.sink = sink
- super().__init__(name, **kwargs)
- TM._flows.append(self)
-
- def display_name(self):
- if self.order == -1:
- return self.name
- return "({}) {}".format(self.order, self.name)
-
- def _dfd_template(self):
- return """{source} -> {sink} [
- color = {color};
- fontcolor = {color};
- dir = {direction};
- label = "{label}";
-]
-"""
-
- def dfd(self, mergeResponses=False, **kwargs):
- self._is_drawn = True
-
- levels = kwargs.get("levels", None)
- if (
- levels
- and not levels & self.levels
- and not (levels & self.source.levels and levels & self.sink.levels)
- ):
- return ""
-
- color = self._color()
-
- if kwargs.get("colormap", False):
- color = sev_to_color(self.severity)
-
- direction = "forward"
- label = self._label()
- if mergeResponses and self.response is not None:
- direction = "both"
- label += "\n" + self.response._label()
-
- return self._dfd_template().format(
- source=self.source._uniq_name(),
- sink=self.sink._uniq_name(),
- direction=direction,
- label=label,
- color=color,
- )
-
- def hasDataLeaks(self):
- return any(
- d.classification > self.source.maxClassification
- or d.classification > self.sink.maxClassification
- or d.classification > self.maxClassification
- for d in self.data
- )
+ setattr(obj, attr, ordered)
+ except self.assignment_errors:
+ existing = getattr(obj, attr)
+ try:
+ existing[:] = ordered
+ except TypeError:
+ if hasattr(existing, "clear") and hasattr(existing, "extend"):
+ existing.clear()
+ existing.extend(ordered)
+ else:
+ for item in ordered:
+ if item not in existing:
+ existing.append(item)
-class Boundary(Element):
- """Trust boundary groups elements and data with the same trust level."""
+def _apply_defaults(flows, data):
+ """Apply default values to flows and data."""
+ builder = _FlowDefaultsBuilder()
+ builder.seed_data_relationships(data)
- def __init__(self, name, **kwargs):
- super().__init__(name, **kwargs)
- if name not in TM._boundaries:
- TM._boundaries.append(self)
+ for flow in flows:
+ builder.process_flow(flow)
- def _dfd_template(self):
- return """subgraph cluster_{uniq_name} {{
- graph [
- fontsize = 10;
- fontcolor = black;
- style = dashed;
- color = {color};
- label = <{label}>;
- ]
+ builder.finalize_assets()
+ builder.finalize_data_relationships()
-{edges}
-}}
-"""
- def dfd(self, **kwargs):
- if self._is_drawn:
- return ""
+def _get_elements_and_boundaries(flows):
+ """Get elements and boundaries used in flows."""
+ elements = set()
+ boundaries = set()
- self._is_drawn = True
+ for flow in flows:
+ elements.add(flow)
+ elements.add(flow.source)
+ elements.add(flow.sink)
- edges = []
- for e in TM._elements:
- if e.inBoundary != self or e._is_drawn:
- continue
- # The content to draw can include Boundary objects
- edges.append(e.dfd(**kwargs))
-
- return self._dfd_template().format(
- uniq_name=self._uniq_name(),
- label=self._label(),
- color=self._color(**kwargs),
- edges=indent("\n".join(edges), " "),
- )
+ source_boundary = getattr(flow.source, "inBoundary", None)
+ if source_boundary is not None:
+ boundaries.add(source_boundary)
+ for parent in getattr(source_boundary, "parents", lambda: [])():
+ elements.add(parent)
+ boundaries.add(parent)
- def _color(self, **kwargs):
- if kwargs.get("colormap", False):
- return "black"
- else:
- return "firebrick2"
+ sink_boundary = getattr(flow.sink, "inBoundary", None)
+ if sink_boundary is not None:
+ boundaries.add(sink_boundary)
+ for parent in getattr(sink_boundary, "parents", lambda: [])():
+ elements.add(parent)
+ boundaries.add(parent)
- def parents(self):
- result = []
- parent = self.inBoundary
- while parent is not None:
- result.append(parent)
- parent = parent.inBoundary
- return result
+ return list(elements), list(boundaries)
@singledispatch
@@ -2055,6 +527,7 @@ def to_serializable(val):
@to_serializable.register(TM)
def ts_tm(obj):
+ """Serialize TM object."""
return serialize(obj, nested=True)
@@ -2064,73 +537,154 @@ def ts_tm(obj):
@to_serializable.register(Element)
@to_serializable.register(Finding)
def ts_element(obj):
+ """Serialize element objects."""
return serialize(obj, nested=False)
def serialize(obj, nested=False):
"""Used if *obj* is an instance of TM, Element, Threat or Finding."""
- klass = obj.__class__
+
result = {}
+ klass = obj.__class__
+
if isinstance(obj, (Actor, Asset)):
result["__class__"] = klass.__name__
- for i in dir(obj):
- if (
- i.startswith("__")
- or callable(getattr(klass, i, {}))
- or (
- isinstance(obj, TM)
- and i in ("_sf", "_duplicate_ignored_attrs", "_threats")
- )
- or (isinstance(obj, Element) and i in ("_is_drawn", "uuid"))
- or (isinstance(obj, Finding) and i == "element")
- ):
+
+ attribute_names = set()
+
+ if hasattr(obj, "__dict__"):
+ attribute_names.update(
+ name for name in obj.__dict__.keys() if not name.startswith("__")
+ )
+
+ model_fields = getattr(klass, "model_fields", {})
+ attribute_names.update(model_fields.keys())
+
+ computed_fields = getattr(klass, "model_computed_fields", {})
+ attribute_names.update(computed_fields.keys())
+
+ if isinstance(obj, TM):
+ attribute_names.update(
+ {
+ "_actors",
+ "_assets",
+ "_elements",
+ "_flows",
+ "_data",
+ "_boundaries",
+ "assumptions",
+ "findings",
+ "excluded_findings",
+ "_threatsExcluded",
+ }
+ )
+
+ skip_attrs = {
+ "_sf",
+ "_duplicate_ignored_attrs",
+ "_threats",
+ "model_fields",
+ "model_computed_fields",
+ "model_config",
+ "model_post_init",
+ "model_extra",
+ "model_json_schema",
+ "schema",
+ "copy",
+ "dict",
+ "json",
+ "parse_file",
+ "parse_obj",
+ "parse_raw",
+ "construct",
+ "model_copy",
+ "validate",
+ "abc_impl",
+ "register",
+ }
+
+ for attr_name in sorted(attribute_names):
+ if attr_name in skip_attrs:
+ continue
+
+ if isinstance(obj, Element) and attr_name in {"uuid", "_is_drawn", "is_drawn"}:
continue
- value = getattr(obj, i)
- if isinstance(obj, TM) and i == "_elements":
+ if isinstance(obj, Finding) and attr_name == "element":
+ continue
+
+ try:
+ value = getattr(obj, attr_name)
+ except AttributeError:
+ continue
+
+ key = attr_name.lstrip("_")
+
+ if isinstance(obj, TM) and attr_name == "_elements":
value = [e for e in value if isinstance(e, (Actor, Asset))]
- if value is not None:
- if isinstance(value, (Element, Data)):
- value = value.name
- elif isinstance(obj, Threat) and i == "target":
- value = [v.__name__ for v in value]
- elif i in ("levels", "sourceFiles", "assumptions"):
- value = list(value)
- elif (
- not nested
- and not isinstance(value, str)
- and isinstance(value, Iterable)
- ):
- value = [v.id if isinstance(v, Finding) else v.name for v in value]
- result[i.lstrip("_")] = value
+
+ if value is None:
+ result[key] = None
+ continue
+
+ if isinstance(value, (Element, Data)):
+ value = value.name
+ elif hasattr(value, "model_dump") and not isinstance(value, TM):
+ value = value.model_dump()
+ elif isinstance(obj, Threat) and attr_name == "target":
+ coerced_targets = []
+ for target in value:
+ if hasattr(target, "__name__"):
+ coerced_targets.append(target.__name__)
+ else:
+ coerced_targets.append(str(target))
+ value = coerced_targets
+ elif attr_name in {"levels", "sourceFiles", "assumptions"}:
+ value = list(value)
+ elif (
+ not nested
+ and not isinstance(value, (str, bytes))
+ and isinstance(value, Iterable)
+ and not isinstance(value, Mapping)
+ ):
+ coerced = []
+ for item in value:
+ if isinstance(item, Finding):
+ coerced.append(item.id)
+ elif isinstance(item, (Element, Data)):
+ coerced.append(item.name)
+ else:
+ coerced.append(item)
+ value = coerced
+
+ result[key] = value
+
return result
def encode_element_threat_data(obj):
- """Used to html encode threat data from a list of Elements"""
- encoded_elements = []
- if type(obj) is not list:
- raise ValueError("expecting a list value, got a {}".format(type(obj)))
-
- for o in obj:
- c = copy.deepcopy(o)
- for a in o._attr_values():
- if a == "findings":
- encoded_findings = encode_threat_data(o.findings)
- c._safeset("findings", encoded_findings)
+ """Encode element threat data."""
+ result = []
+ if hasattr(obj, "__iter__"):
+ for item in obj:
+ if hasattr(item, "model_dump"):
+ result.append(item.model_dump())
else:
- v = getattr(o, a)
- if type(v) is not list or (type(v) is list and len(v) != 0):
- c._safeset(a, v)
-
- encoded_elements.append(c)
-
- return encoded_elements
+ result.append(serialize(item))
+ return result
def encode_threat_data(obj):
- """Used to html encode threat data from a list of threats or findings"""
+ """HTML-encode threat data while preserving attribute access."""
encoded_threat_data = []
+ if obj is None:
+ return encoded_threat_data
+
+ if isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)):
+ candidates = list(obj)
+ else:
+ candidates = [obj]
+
attrs = [
"description",
"details",
@@ -2143,73 +697,109 @@ def encode_threat_data(obj):
"condition",
"cvss",
"response",
- "likelihood",
]
- if type(obj) is Finding or (len(obj) != 0 and type(obj[0]) is Finding):
+ from .finding import Finding
+
+ if candidates and isinstance(candidates[0], Finding):
attrs.append("target")
- for e in obj:
- t = copy.deepcopy(e)
+ def _escape_markdown(text: str) -> str:
+ return re.sub(r"(? str:
+ """Return the parent boundary name for *element* or an empty string."""
+ from pytm import Boundary # Local import to avoid circular dependency
+
+ if not isinstance(element, Boundary):
+ return (
+ f"ERROR: getParentName method is not valid for {type(element).__name__}"
+ )
+ parent = element.inBoundary
+ return parent.name if parent is not None else ""
@staticmethod
- def getNamesOfParents(element):
+ def getNamesOfParents(element: Any) -> List[str] | str:
+ """Return a list of parent boundary names for *element*."""
from pytm import Boundary
- if (isinstance(element, Boundary)):
- parents = [p.name for p in element.parents()]
- return parents
- else:
- return "ERROR: getNamesOfParents method is not valid for " + element.__class__.__name__
-
+
+ if not isinstance(element, Boundary):
+ return f"ERROR: getNamesOfParents method is not valid for {type(element).__name__}"
+
+ return [parent.name for parent in element.parents()]
@staticmethod
- def getInScopeFindings(element):
- """
- Return only findings that:
- 1. Belong to an in-scope element
- 2. Target an in-scope element
- """
+ def getInScopeFindings(element: Any) -> list:
+ """Return only findings that belong to an in-scope element and target an in-scope element."""
from pytm import Element
if not isinstance(element, Element):
@@ -39,7 +43,6 @@ def getInScopeFindings(element):
return []
in_scope_findings = []
-
for finding in element.findings:
target = getattr(finding, "target", None)
if target is not None and getattr(target, "inScope", False):
@@ -47,69 +50,78 @@ def getInScopeFindings(element):
return in_scope_findings
-
@staticmethod
- def getFindingCount(element):
+ def getFindingCount(element: Any) -> str:
+ """Return the count of findings for *element* as a string."""
from pytm import Element
+
if not isinstance(element, Element):
- return "ERROR: getFindingCount method is not valid for " + element.__class__.__name__
- return str(len(ReportUtils.getInScopeFindings(element)))
+ return f"ERROR: getFindingCount method is not valid for {type(element).__name__}"
+ return str(len(list(element.findings)))
@staticmethod
- def getElementType(element):
+ def getElementType(element: Any) -> str:
+ """Return the class name for *element*."""
from pytm import Element
- if (isinstance(element, Element)):
- return str(element.__class__.__name__)
- else:
- return "ERROR: getElementType method is not valid for " + element.__class__.__name__
+ if not isinstance(element, Element):
+ return f"ERROR: getElementType method is not valid for {type(element).__name__}"
+
+ return element.__class__.__name__
@staticmethod
- def getThreatId(obj):
+ def getThreatId(obj: Any) -> str:
from pytm import Finding
+
if isinstance(obj, Finding):
return obj.threat_id
return ""
@staticmethod
- def getFindingDescription(obj):
+ def getFindingDescription(obj: Any) -> str:
from pytm import Finding
+
if isinstance(obj, Finding):
return obj.description
return ""
@staticmethod
- def getFindingTarget(obj):
+ def getFindingTarget(obj: Any) -> Any:
from pytm import Finding
+
if isinstance(obj, Finding):
return obj.target
return ""
@staticmethod
- def getFindingSeverity(obj):
+ def getFindingSeverity(obj: Any) -> str:
from pytm import Finding
+
if isinstance(obj, Finding):
return obj.severity
return ""
@staticmethod
- def getFindingMitigations(obj):
+ def getFindingMitigations(obj: Any) -> str:
from pytm import Finding
+
if isinstance(obj, Finding):
return obj.mitigations
return ""
@staticmethod
- def getFindingReferences(obj):
+ def getFindingReferences(obj: Any) -> str:
from pytm import Finding
+
if isinstance(obj, Finding):
return obj.references
return ""
-
+
@staticmethod
- def getFindingExample(obj):
+ def getFindingExample(obj: Any) -> str:
from pytm import Finding
+
if isinstance(obj, Finding):
return obj.example
return ""
diff --git a/pytm/template_engine.py b/pytm/template_engine.py
index 9118a5d4..172d1521 100644
--- a/pytm/template_engine.py
+++ b/pytm/template_engine.py
@@ -2,84 +2,76 @@
# but modified to include support to call methods which return lists, to call external utility methods, use
# if operator with methods and added a not operator.
+from __future__ import annotations
+
import string
+from collections.abc import Iterable
+from functools import lru_cache
+from typing import Any, Callable
class SuperFormatter(string.Formatter):
- """World's simplest Template engine."""
+ """Lightweight formatter with helpers for reports and templates."""
- def format_field(self, value, spec):
+ def format_field(
+ self, value: Any, spec: str
+ ) -> Any: # noqa: D401 - same semantics as base
+ if not spec:
+ return super().format_field(value, spec)
- spec_parts = spec.split(":")
if spec.startswith("repeat"):
- # Example usage, format, count of spec_parts, exampple format
- # object:repeat:template 2 {item.findings:repeat:{{item.id}}, }
-
- template = spec.partition(":")[-1]
- if type(value) is dict:
- value = value.items()
- return "".join([self.format(template, item=item) for item in value])
-
- elif spec.startswith("call:") and hasattr(value, "__call__"):
- # Example usage, format, exampple format
- # methood:call {item.display_name:call:}
- # methood:call:template {item.parents:call:{{item.name}}, }
- result = value()
-
- if type(result) is list:
- template = spec.partition(":")[-1]
- return "".join([self.format(template, item=item) for item in result])
+ return self._format_repeat(value, spec)
- return result
+ if spec.startswith("call:"):
+ return self._format_call(value, spec)
- elif spec.startswith("call:"):
- # Example usage, format, exampple format
- # object:call:method_name {item:call:getFindingCount}
- # object:call:method_name:template {item:call:getNamesOfParents:
- # {{item}}
- # }
+ if spec.startswith("if") or spec.startswith("not"):
+ return self._format_conditional(value, spec)
- method_name = spec_parts[1]
+ return super().format_field(value, spec)
- result = self.call_util_method(method_name, value)
+ def _format_repeat(self, value: Any, spec: str) -> str:
+ """Handle the custom repeat operator."""
+ template = spec.partition(":")[2]
+ if isinstance(value, dict):
+ iterable: Iterable[Any] = value.items()
+ elif isinstance(value, Iterable) and not isinstance(value, (str, bytes)):
+ iterable = value
+ else:
+ iterable = []
+ return "".join(self.format(template, item=item) for item in iterable)
- if type(result) is list:
- template = spec.partition(":")[-1]
- template = template.partition(":")[-1]
- return "".join([self.format(template, item=item) for item in result])
-
- return result
-
- elif (spec.startswith("if") or spec.startswith("not")):
- # Example usage, format, exampple format
- # object.bool:if:template {item.inScope:if:Is in scope}
- # object:if:template {item.findings:if:Has Findings}
- # object.method:if:template {item.parents:if:Has Parents}
- #
- # object.bool:not:template {item.inScope:not:Is not in scope}
- # object:not:template {item.findings:not:Has No Findings}
- # object.method:not:template {item.parents:not:Has No Parents}
-
- template = spec.partition(":")[-1]
- if (hasattr(value, "__call__")):
- result = value()
- else:
- result = value
-
- if (spec.startswith("if")):
- return (result and template or "")
- else:
- return (not result and template or "")
+ def _format_call(self, value: Any, spec: str) -> Any:
+ """Evaluate callable values or report utility helpers."""
+ _, _, remainder = spec.partition(":")
+ if callable(value):
+ result = value()
+ template = remainder
else:
- return super(SuperFormatter, self).format_field(value, spec)
-
- def call_util_method(self, method_name, object):
- module_name = "pytm.report_util"
- klass_name = "ReportUtils"
- module = __import__(module_name, fromlist=['ReportUtils'])
- klass = getattr(module, klass_name)
- method = getattr(klass, method_name)
+ method_name, _, template = remainder.partition(":")
+ result = self.call_util_method(method_name, value)
- result = method(object)
+ if isinstance(result, list) and template:
+ return "".join(self.format(template, item=item) for item in result)
return result
+
+ def _format_conditional(self, value: Any, spec: str) -> str:
+ """Render content conditionally based on truthiness of *value*."""
+ _, _, template = spec.partition(":")
+ result = value() if callable(value) else value
+ if spec.startswith("if"):
+ return template if result else ""
+ return template if not result else ""
+
+ def call_util_method(self, method_name: str, obj: Any) -> Any:
+ """Invoke a helper method from :mod:`pytm.report_util`."""
+ method = self._resolve_report_method(method_name)
+ return method(obj)
+
+ @staticmethod
+ @lru_cache(maxsize=None)
+ def _resolve_report_method(method_name: str) -> Callable[[Any], Any]:
+ from pytm.report_util import ReportUtils
+
+ return getattr(ReportUtils, method_name)
diff --git a/pytm/threat.py b/pytm/threat.py
new file mode 100644
index 00000000..e972b8ea
--- /dev/null
+++ b/pytm/threat.py
@@ -0,0 +1,331 @@
+"""Threat model - represents possible threats in the system."""
+
+from __future__ import annotations
+
+import ast
+import sys
+from types import CodeType
+from typing import Any, ClassVar, Tuple, List
+from collections.abc import Iterable
+
+import builtins
+
+from pydantic import (
+ BaseModel,
+ Field,
+ ConfigDict,
+ model_validator,
+ PrivateAttr,
+)
+
+
+class _ConditionValidator(ast.NodeVisitor):
+ """Validate threat conditions to ensure they only use safe constructs."""
+
+ SAFE_CALL_NAMES: ClassVar[set[str]] = {"any", "all", "len", "min", "max", "sum"}
+ ALLOWED_TARGET_METHODS: ClassVar[set[str]] = {
+ "oneOf",
+ "crosses",
+ "enters",
+ "exits",
+ "inside",
+ "checkTLSVersion",
+ "hasDataLeaks",
+ }
+ _ALLOWED_NODES: ClassVar[tuple[type[ast.AST], ...]] = (
+ ast.Expression,
+ ast.BoolOp,
+ ast.BinOp,
+ ast.UnaryOp,
+ ast.Compare,
+ ast.Name,
+ ast.Load,
+ ast.Constant,
+ ast.Attribute,
+ ast.Call,
+ ast.Subscript,
+ ast.List,
+ ast.Tuple,
+ ast.Set,
+ ast.Dict,
+ ast.ListComp,
+ ast.GeneratorExp,
+ ast.comprehension,
+ ast.IfExp,
+ ast.And,
+ ast.Or,
+ ast.Not,
+ ast.Eq,
+ ast.NotEq,
+ ast.Lt,
+ ast.LtE,
+ ast.Gt,
+ ast.GtE,
+ ast.Is,
+ ast.IsNot,
+ ast.In,
+ ast.NotIn,
+ ast.Add,
+ ast.Sub,
+ ast.Mult,
+ ast.Div,
+ ast.Mod,
+ ast.Pow,
+ ast.USub,
+ ast.UAdd,
+ ast.BitAnd,
+ ast.BitOr,
+ ast.BitXor,
+ ast.FloorDiv,
+ ast.Slice,
+ )
+
+ def __init__(self, allowed_names: set[str]) -> None:
+ super().__init__()
+ self.allowed_names = allowed_names | {"target", "True", "False", "None"}
+
+ def visit(self, node: ast.AST) -> Any: # type: ignore[override]
+ if not isinstance(node, self._ALLOWED_NODES):
+ raise ValueError(
+ f"Unsupported syntax in threat condition: {type(node).__name__}"
+ )
+ return super().visit(node)
+
+ def visit_Attribute(self, node: ast.Attribute) -> Any: # noqa: D401
+ if isinstance(node.attr, str) and node.attr.startswith("__"):
+ raise ValueError(
+ "Access to dunder attributes is not permitted in threat conditions"
+ )
+ return self.generic_visit(node)
+
+ def visit_Call(self, node: ast.Call) -> Any: # noqa: D401
+ if node.keywords:
+ raise ValueError("Keyword arguments are not permitted in threat conditions")
+
+ func = node.func
+ if isinstance(func, ast.Name):
+ if func.id not in self.SAFE_CALL_NAMES:
+ raise ValueError(
+ f"Call to '{func.id}' is not permitted in threat conditions"
+ )
+ elif isinstance(func, ast.Attribute):
+ chain = self._attribute_chain(func)
+ if chain[-1] not in self.ALLOWED_TARGET_METHODS:
+ raise ValueError(
+ f"Call to target method '{chain[-1]}' is not permitted"
+ )
+ else:
+ raise ValueError("Unsupported call target in threat condition")
+
+ return self.generic_visit(node)
+
+ def visit_Name(self, node: ast.Name) -> Any: # noqa: D401
+ if (
+ isinstance(node.ctx, ast.Load)
+ and node.id not in self.allowed_names
+ and node.id not in self.SAFE_CALL_NAMES
+ ):
+ # Allow names introduced by comprehensions; they will fail at runtime if undefined.
+ return
+ return None
+
+ @staticmethod
+ def _attribute_chain(node: ast.Attribute) -> List[str]:
+ chain: List[str] = [node.attr]
+ current = node.value
+ while isinstance(current, ast.Attribute):
+ if isinstance(current.attr, str) and current.attr.startswith("__"):
+ raise ValueError(
+ "Access to dunder attributes is not permitted in threat conditions"
+ )
+ chain.append(current.attr)
+ current = current.value
+ if isinstance(current, ast.Name):
+ chain.append(current.id)
+ else:
+ raise ValueError(
+ "Only attribute access on names is permitted in threat conditions"
+ )
+ chain.reverse()
+ return chain
+
+
+class Threat(BaseModel):
+ """Represents a possible threat.
+
+ Attributes:
+ id (str): Threat identifier (SID)
+ description (str): Description of the threat
+ condition (str): A Python expression that should evaluate to a boolean True or False
+ details (str): Detailed information about the threat
+ likelihood (str): Likelihood of the threat occurring
+ severity (str): Severity level of the threat
+ mitigations (str): Possible mitigations for the threat
+ prerequisites (str): Prerequisites for the threat
+ example (str): Example of the threat
+ references (str): References for the threat
+ target (Tuple): Target classes for this threat
+ """
+
+ model_config = ConfigDict(
+ extra="allow", validate_assignment=True, arbitrary_types_allowed=True
+ )
+
+ id: str = Field(description="Threat identifier (SID)")
+ description: str = Field(default="", description="Description of the threat")
+ condition: str = Field(
+ default="True",
+ description="A Python expression that should evaluate to a boolean True or False",
+ )
+ details: str = Field(
+ default="", description="Detailed information about the threat"
+ )
+ likelihood: str = Field(
+ default="", description="Likelihood of the threat occurring"
+ )
+ severity: str = Field(default="", description="Severity level of the threat")
+ mitigations: str = Field(
+ default="", description="Possible mitigations for the threat"
+ )
+ prerequisites: str = Field(default="", description="Prerequisites for the threat")
+ example: str = Field(default="", description="Example of the threat")
+ references: str = Field(default="", description="References for the threat")
+ target: Tuple = Field(default=(), description="Target classes for this threat")
+
+ _compiled_condition: CodeType | None = PrivateAttr(default=None)
+ _eval_globals: ClassVar[dict[str, Any] | None] = None
+ _SAFE_BUILTINS: ClassVar[dict[str, Any]] = {
+ name: getattr(builtins, name) for name in _ConditionValidator.SAFE_CALL_NAMES
+ }
+
+ @model_validator(mode="before")
+ @classmethod
+ def _normalize_input(cls, data: Any) -> Any:
+ if not isinstance(data, dict):
+ return data
+
+ # Map legacy field names from the threats.json format
+ if "SID" in data:
+ data.setdefault("id", data.pop("SID"))
+ if "Likelihood Of Attack" in data:
+ data.setdefault("likelihood", data.pop("Likelihood Of Attack"))
+
+ # Normalise target to a tuple
+ target = data.get("target", "Element")
+ if isinstance(target, str) or not isinstance(target, Iterable):
+ target = (target,)
+ else:
+ target = tuple(target)
+
+ # Resolve target name strings to actual Python classes
+ resolved = []
+ for name in target:
+ if isinstance(name, type):
+ resolved.append(name)
+ else:
+ klass = getattr(sys.modules.get("pytm"), name, None)
+ resolved.append(klass if klass is not None else name)
+ data["target"] = tuple(resolved)
+
+ return data
+
+ def model_post_init(self, __context: Any) -> None: # noqa: D401
+ if not self.condition:
+ self._compiled_condition = None
+ return
+
+ try:
+ tree = ast.parse(self.condition, mode="eval")
+ validator = _ConditionValidator(self._allowed_global_names())
+ validator.visit(tree)
+ self._compiled_condition = compile(
+ tree, filename=f"", mode="eval"
+ )
+ except ValueError as exc: # pragma: no cover - defensive, surfaced via tests
+ raise ValueError(f"Invalid condition for threat {self.id}: {exc}") from exc
+ except SyntaxError as exc: # noqa: D401
+ raise ValueError(
+ f"Invalid syntax in condition for threat {self.id}: {exc}"
+ ) from exc
+
+ def _safeset(self, attr: str, value) -> None:
+ """Safely set an attribute value."""
+ try:
+ setattr(self, attr, value)
+ except (ValueError, TypeError):
+ pass
+
+ def __repr__(self):
+ return (
+ f"<{self.__module__}.{type(self).__name__}({self.id}) at {hex(id(self))}>"
+ )
+
+ def __str__(self):
+ return f"{type(self).__name__}({self.id})"
+
+ @classmethod
+ def _build_eval_globals(cls) -> dict[str, Any]:
+ if cls._eval_globals is None:
+ import pytm
+
+ globals_dict: dict[str, Any] = {
+ "__builtins__": cls._SAFE_BUILTINS,
+ "Actor": pytm.Actor,
+ "Asset": pytm.Asset,
+ "Boundary": pytm.Boundary,
+ "Dataflow": pytm.Dataflow,
+ "Datastore": pytm.Datastore,
+ "DatastoreType": pytm.DatastoreType,
+ "Element": pytm.Element,
+ "ExternalEntity": pytm.ExternalEntity,
+ "Lambda": pytm.Lambda,
+ "Process": pytm.Process,
+ "Server": pytm.Server,
+ "SetOfProcesses": pytm.SetOfProcesses,
+ "TM": pytm.TM,
+ "TLSVersion": pytm.TLSVersion,
+ "Classification": pytm.Classification,
+ "Action": pytm.Action,
+ "Lifetime": pytm.Lifetime,
+ }
+
+ # Expose safe builtins as globals as well for convenience
+ globals_dict.update(cls._SAFE_BUILTINS)
+ cls._eval_globals = globals_dict
+
+ return cls._eval_globals
+
+ @classmethod
+ def _allowed_global_names(cls) -> set[str]:
+ globals_dict = cls._build_eval_globals()
+ return {key for key in globals_dict.keys() if key != "__builtins__"}
+
+ def apply(self, target):
+ """Apply the threat condition to a target."""
+ # Check if target matches any of the target types
+ if self.target:
+ target_matches = False
+ for target_type in self.target:
+ if isinstance(target_type, str):
+ # String comparison for backward compatibility
+ if target_type == type(target).__name__:
+ target_matches = True
+ break
+ elif isinstance(target_type, type):
+ # Class type comparison
+ if isinstance(target, target_type):
+ target_matches = True
+ break
+
+ if not target_matches:
+ return False
+
+ if self._compiled_condition is None:
+ return False
+
+ try:
+ globals_dict = dict(self._build_eval_globals())
+ locals_dict = {"target": target}
+ return bool(eval(self._compiled_condition, globals_dict, locals_dict))
+ except Exception:
+ return False
diff --git a/pytm/tm.py b/pytm/tm.py
new file mode 100644
index 00000000..fb75e0ca
--- /dev/null
+++ b/pytm/tm.py
@@ -0,0 +1,813 @@
+"""TM (Threat Model) - the main container for all threat model elements."""
+
+from __future__ import annotations
+
+import copy
+import json
+import logging
+import os
+import random
+import re
+import sys
+from collections import defaultdict, Counter
+from dataclasses import dataclass, field
+from datetime import datetime
+from itertools import combinations
+from textwrap import indent
+from typing import ClassVar, Dict, Iterable, List, TYPE_CHECKING
+from html import escape as html_escape
+
+from pydantic import BaseModel, Field, ConfigDict, field_validator
+
+from .enums import Action
+from .base import Assumption
+from .template_engine import SuperFormatter
+
+logger = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+ from .element import Element
+ from .asset import Asset
+ from .actor import Actor
+ from .dataflow import Dataflow
+ from .boundary import Boundary
+ from .data import Data
+ from .threat import Threat
+ from .finding import Finding
+
+
+class UIError(Exception):
+ """Exception for UI-related errors."""
+
+ def __init__(self, e, context):
+ self.error = e
+ self.context = context
+
+
+@dataclass
+class TMState:
+ """Mutable registry for TM-owned collections."""
+
+ flows: List["Dataflow"] = field(default_factory=list)
+ elements: List["Element"] = field(default_factory=list)
+ actors: List["Actor"] = field(default_factory=list)
+ assets: List["Asset"] = field(default_factory=list)
+ threats: List["Threat"] = field(default_factory=list)
+ boundaries: List["Boundary"] = field(default_factory=list)
+ data: List["Data"] = field(default_factory=list)
+ threats_excluded: List[str] = field(default_factory=list)
+
+
+class _StateAttribute:
+ """Descriptor that proxies attribute access to the shared TM state."""
+
+ def __init__(self, field_name: str):
+ self.field_name = field_name
+ self.owner: type["TM"] | None = None
+
+ def __set_name__(self, owner, name):
+ self.owner = owner
+ register = getattr(owner, "_register_state_attribute", None)
+ if callable(register):
+ register(name, self)
+
+ def __get__(self, instance, owner=None):
+ owner = owner or self.owner
+ if owner is None:
+ raise AttributeError("State attribute descriptor is unbound")
+ return getattr(owner._state, self.field_name)
+
+ def __set__(self, instance, value):
+ owner = self.owner if instance is None else type(instance)
+ if owner is None:
+ raise AttributeError("State attribute descriptor is unbound")
+ setattr(owner._state, self.field_name, value)
+
+
+class TMModelMetaclass(type(BaseModel)):
+ """Metaclass that keeps TM state descriptors intact on class assignment."""
+
+ def __setattr__(cls, name, value):
+ state_attrs = getattr(cls, "_state_attributes", None)
+ if state_attrs and name in state_attrs:
+ descriptor = state_attrs[name]
+ descriptor.__set__(None, value)
+ return
+ super().__setattr__(name, value)
+
+
+class TM(BaseModel, metaclass=TMModelMetaclass):
+ """Describes the threat model administratively, and holds all details during a run."""
+
+ model_config = ConfigDict(
+ extra="allow", validate_assignment=True, arbitrary_types_allowed=True
+ )
+
+ _state: ClassVar[TMState] = TMState()
+ _state_attributes: ClassVar[Dict[str, _StateAttribute]] = {}
+
+ @classmethod
+ def _register_state_attribute(cls, name: str, descriptor: _StateAttribute) -> None:
+ cls._state_attributes[name] = descriptor
+
+ _flows: ClassVar[_StateAttribute] = _StateAttribute("flows")
+ _elements: ClassVar[_StateAttribute] = _StateAttribute("elements")
+ _actors: ClassVar[_StateAttribute] = _StateAttribute("actors")
+ _assets: ClassVar[_StateAttribute] = _StateAttribute("assets")
+ _threats: ClassVar[_StateAttribute] = _StateAttribute("threats")
+ _boundaries: ClassVar[_StateAttribute] = _StateAttribute("boundaries")
+ _data: ClassVar[_StateAttribute] = _StateAttribute("data")
+ _threatsExcluded: ClassVar[_StateAttribute] = _StateAttribute("threats_excluded")
+
+ @classmethod
+ def _get_state(cls) -> TMState:
+ """Return the mutable shared state for this TM class."""
+ return cls._state
+
+ name: str = Field(description="Model name")
+ description: str = Field(description="Model description")
+ threatsFile: str = Field(
+ default_factory=lambda: os.path.dirname(__file__) + "/threatlib/threats.json",
+ description="JSON file with custom threats",
+ )
+ isOrdered: bool = Field(
+ default=False, description="Automatically order all Dataflows"
+ )
+ mergeResponses: bool = Field(
+ default=False, description="Merge response edges in DFDs"
+ )
+ ignoreUnused: bool = Field(
+ default=False, description="Ignore elements not used in any Dataflow"
+ )
+ findings: List["Finding"] = Field(
+ default_factory=list, description="Threats found for elements of this model"
+ )
+ excluded_findings: List["Finding"] = Field(
+ default_factory=list,
+ description="Threats found for elements of this model, that were excluded on a per-element basis, using the Assumptions class",
+ )
+ onDuplicates: Action = Field(
+ default=Action.NO_ACTION,
+ description="How to handle duplicate Dataflow with same properties, except name and notes",
+ )
+ assumptions: List[Assumption] = Field(
+ default_factory=list, description="A list of assumptions about the design/model"
+ )
+ colormap: bool = Field(default=False, exclude=True)
+
+ @field_validator("assumptions", mode="before")
+ @classmethod
+ def _normalize_assumptions(cls, value):
+ """Allow string assumptions to be converted into Assumption models."""
+ if value is None or value == []:
+ return []
+
+ def convert(item):
+ if isinstance(item, Assumption):
+ return item
+ if isinstance(item, dict):
+ return item
+ if isinstance(item, str):
+ return {"name": item}
+ raise TypeError(
+ "assumptions must be strings, dicts, or Assumption instances"
+ )
+
+ if isinstance(value, (str, dict, Assumption)):
+ return [convert(value)]
+
+ if isinstance(value, Iterable) and not isinstance(value, (bytes, str)):
+ return [convert(item) for item in value]
+
+ raise TypeError(
+ "assumptions must be provided as an iterable of supported types"
+ )
+
+ def __init__(self, name: str, description: str = "", **data):
+ """Initialize the threat model."""
+ data.update({"name": name, "description": description})
+
+ object.__setattr__(self, "_initializing_tm", True)
+ super().__init__(**data)
+ object.__setattr__(self, "_initializing_tm", False)
+
+ self._sf = SuperFormatter()
+ random.seed(0)
+
+ try:
+ self._init_threats()
+ except UIError as e:
+ raise e
+ finally:
+ if hasattr(self, "_initializing_tm"):
+ object.__delattr__(self, "_initializing_tm")
+
+ def __setattr__(self, name, value):
+ if name == "threatsFile" and not getattr(self, "_initializing_tm", False):
+ current_value = getattr(self, "threatsFile", None)
+ if current_value == value:
+ return super().__setattr__(name, value)
+
+ super().__setattr__(name, value)
+ try:
+ self._init_threats()
+ except UIError as e:
+ object.__setattr__(self, "_initializing_tm", True)
+ try:
+ super().__setattr__(name, current_value)
+ finally:
+ object.__setattr__(self, "_initializing_tm", False)
+
+ if current_value is not None:
+ try:
+ self._init_threats()
+ except UIError:
+ TM._get_state().threats.clear()
+ raise e
+ return
+
+ super().__setattr__(name, value)
+
+ @classmethod
+ def reset(cls):
+ """Reset all class variables."""
+ cls._state = TMState()
+
+ def _init_threats(self):
+ """Initialize threats from file."""
+ TM._get_state().threats.clear()
+ self._add_threats()
+
+ def _add_threats(self):
+ """Add threats from the threats file."""
+ try:
+ with open(self.threatsFile, "r", encoding="utf8") as threat_file:
+ threats_json = json.load(threat_file)
+ except (FileNotFoundError, PermissionError, IsADirectoryError) as e:
+ raise UIError(
+ e, f"while trying to open the threat file ({self.threatsFile})."
+ )
+
+ from .threat import Threat
+
+ active_threats = (
+ threat for threat in threats_json if "DEPRECATED" not in threat
+ )
+ for threat in active_threats:
+ TM._threats.append(Threat(**threat))
+
+ def check(self):
+ """Check the threat model for consistency and completeness."""
+ if self.description is None:
+ raise ValueError(
+ """Every threat model should have at least
+a brief description of the system being modeled."""
+ )
+
+ from . import pytm as pytm_module
+
+ state = TM._get_state()
+ state.flows = pytm_module._match_responses(
+ pytm_module._sort(state.flows, getattr(self, "isOrdered", False))
+ )
+
+ self._check_duplicates(state.flows)
+
+ pytm_module._apply_defaults(state.flows, state.data)
+
+ for element in state.elements:
+ top = Counter(
+ getattr(f, "threat_id", None) for f in getattr(element, "overrides", [])
+ ).most_common(1)
+ if not top:
+ continue
+ threat_id, count = top[0]
+ if count != 1:
+ raise ValueError(
+ f"Finding {threat_id} have more than one override in {element}"
+ )
+
+ if getattr(self, "ignoreUnused", False):
+ elements, boundaries = pytm_module._get_elements_and_boundaries(state.flows)
+ state.elements = elements
+ state.boundaries = boundaries
+
+ result = True
+ for element in state.elements:
+ if not element.check():
+ result = False
+
+ if getattr(self, "ignoreUnused", False):
+ state.elements = pytm_module._sort_elem(state.elements)
+
+ return result
+
+ def resolve(self):
+ """Resolve threats and generate findings."""
+ from .finding import Finding
+ from collections import defaultdict
+
+ finding_count = 0
+ excluded_finding_count = 0
+ findings = []
+ excluded_findings = []
+
+ # Get global assumptions with exclusions
+ global_assumptions = [a for a in self.assumptions if len(a.exclude) > 0]
+ elements = defaultdict(list)
+
+ for e in TM._elements:
+ if not getattr(e, "inScope", True):
+ e.findings = findings
+ continue
+
+ override_ids = set(f.threat_id for f in getattr(e, "overrides", []))
+
+ # Filter out overrides from source and sink for dataflows
+ try:
+ source_overrides = set(
+ f.threat_id for f in getattr(e.source, "overrides", [])
+ )
+ sink_overrides = set(
+ f.threat_id for f in getattr(e.sink, "overrides", [])
+ )
+ override_ids -= source_overrides | sink_overrides
+ except AttributeError:
+ pass
+
+ for t in TM._threats:
+ if not t.apply(e) and t.id not in override_ids:
+ continue
+
+ if t.id in TM._threatsExcluded:
+ continue
+
+ _continue = False
+ element_assumptions = getattr(e, "assumptions", [])
+ for assumption in element_assumptions + global_assumptions:
+ if hasattr(assumption, "exclude") and t.id in assumption.exclude:
+ excluded_finding_count += 1
+ f = Finding(
+ e,
+ id=str(excluded_finding_count),
+ threat=t,
+ assumption=assumption,
+ )
+ excluded_findings.append(f)
+ _continue = True
+ break
+ if _continue:
+ continue
+
+ finding_count += 1
+ f = Finding(e, id=str(finding_count), threat=t)
+ findings.append(f)
+ elements[e].append(f)
+
+ # Set severity on element
+ if hasattr(e, "_set_severity"):
+ e._set_severity(getattr(f, "severity", 0))
+
+ self.findings = findings
+ self.excluded_findings = excluded_findings
+
+ for e, findings in elements.items():
+ e.findings = findings
+
+ def process(self):
+ """Entry point mirroring the legacy CLI workflow."""
+ try:
+ self._process()
+ except UIError as e: # pragma: no cover - mirrors historical behaviour
+ message = "Failed to execute\n" f" {e.context}\n" f" {e.error}\n"
+ sys.stderr.write(message)
+ raise SystemExit(127) from e
+
+ def _process(self):
+ """Execute the CLI workflow (check, resolve, render outputs)."""
+ from . import pytm as pytm_module
+
+ self.check()
+
+ result = pytm_module.get_args()
+
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
+ if getattr(result, "debug", False):
+ logger.setLevel(logging.DEBUG)
+
+ exclude_raw = getattr(result, "exclude", None)
+ if exclude_raw:
+ if isinstance(exclude_raw, str):
+ tokens = [exclude_raw]
+ elif isinstance(exclude_raw, list):
+ tokens = [str(item) for item in exclude_raw]
+ else:
+ tokens = [str(exclude_raw)]
+
+ exclusions: list[str] = []
+ for token in tokens:
+ exclusions.extend(
+ sid.strip() for sid in token.split(",") if sid and sid.strip()
+ )
+ TM._threatsExcluded = exclusions
+
+ if getattr(result, "seq", False):
+ print(self.seq())
+
+ if getattr(result, "dfd", False):
+ if getattr(result, "colormap", False):
+ self.resolve()
+ levels = set(getattr(result, "levels", []) or [])
+ print(self.dfd(colormap=getattr(result, "colormap", False), levels=levels))
+
+ needs_resolution = any(
+ getattr(result, attr, None) for attr in ("report", "json", "stale_days")
+ )
+
+ if needs_resolution:
+ self.resolve()
+
+ if getattr(result, "json", None):
+ try:
+ with open(result.json, "w", encoding="utf8") as f:
+ json.dump(self, f, default=pytm_module.to_serializable)
+ except (FileExistsError, PermissionError, IsADirectoryError) as exc:
+ raise UIError(
+ exc, f"while trying to write to the result file ({result.json})"
+ )
+
+ if getattr(result, "report", None):
+ print(self.report(result.report))
+
+ describe_targets = getattr(result, "describe", None)
+ if describe_targets:
+ names = describe_targets.split()
+ pytm_module._describe_classes(names)
+
+ if getattr(result, "list_elements", False):
+ pytm_module._list_elements()
+
+ if getattr(result, "list", False):
+ for threat in TM._threats:
+ print(
+ f"{getattr(threat, 'id', '')} - {getattr(threat, 'description', '')}"
+ )
+
+ if getattr(result, "stale_days", None) is not None:
+ print(self._stale(result.stale_days))
+
+ def _stale(self, days: int) -> str:
+ """Report source files whose age diverges from the model script."""
+ base_path = os.path.dirname(sys.argv[0])
+ try:
+ tm_path = os.path.join(base_path, sys.argv[0])
+ tm_mtime = datetime.fromtimestamp(os.stat(tm_path).st_mtime)
+ except OSError as err:
+ sys.stderr.write(f"{sys.argv[0]} - {err}\n")
+ sys.stderr.flush()
+ return "[ERROR]"
+
+ print(f"Checking for code {days} days older than this model.")
+
+ for element in TM._elements:
+ source_files = getattr(element, "sourceFiles", [])
+ for src in source_files:
+ try:
+ src_path = os.path.join(base_path, src)
+ src_mtime = datetime.fromtimestamp(os.stat(src_path).st_mtime)
+ except OSError as err:
+ sys.stderr.write(f"{sys.argv[0]} - {err}\n")
+ sys.stderr.flush()
+ continue
+
+ age = (src_mtime - tm_mtime).days
+ if age >= days:
+ print(f"This model is {age} days older than {src_path}.")
+ elif age <= -days:
+ print(
+ f"Model script {sys.argv[0]} is only {-age} days newer than source code file {src_path}"
+ )
+
+ return ""
+
+ def _dfd_template(self):
+ """Template for DFD generation."""
+ return (
+ "digraph tm {{\n"
+ " graph [\n"
+ " fontname = Arial;\n"
+ " fontsize = 14;\n"
+ " ]\n"
+ " node [\n"
+ " fontname = Arial;\n"
+ " fontsize = 14;\n"
+ " rankdir = lr;\n"
+ " ]\n"
+ " edge [\n"
+ " shape = none;\n"
+ " arrowtail = onormal;\n"
+ " fontname = Arial;\n"
+ " fontsize = 12;\n"
+ " ]\n"
+ ' labelloc = "t";\n'
+ " fontsize = 20;\n"
+ " nodesep = 1;\n"
+ "\n"
+ "{edges}\n"
+ "\n"
+ "}}"
+ )
+
+ def dfd(self, **kwargs):
+ """Generate Data Flow Diagram."""
+ from collections import defaultdict
+ from .boundary import Boundary
+
+ if "levels" in kwargs:
+ levels = kwargs["levels"]
+ if not hasattr(levels, "__iter__") or isinstance(levels, str):
+ kwargs["levels"] = [levels]
+ kwargs["levels"] = set(kwargs["levels"])
+
+ edges = []
+ # Since boundaries can be nested sort them by level and start from top
+ parents = set(b.inBoundary for b in TM._boundaries if b.inBoundary)
+
+ # Collect boundary levels
+ boundary_levels = defaultdict(set)
+ max_level = 0
+ for b in TM._boundaries:
+ if b in parents:
+ continue
+ boundary_levels[0].add(b)
+ for i, p in enumerate(getattr(b, "parents", lambda: [])(), 1):
+ boundary_levels[i].add(p)
+ if i > max_level:
+ max_level = i
+
+ # Draw boundaries from highest level to lowest
+ for i in range(max_level, -1, -1):
+ for b in sorted(boundary_levels[i], key=lambda b: b.name):
+ edges.append(b.dfd(**kwargs))
+
+ # Handle response merging
+ if getattr(self, "mergeResponses", False):
+ for e in TM._flows:
+ if getattr(e, "response", None) is not None:
+ e.response.is_drawn = True
+ kwargs["mergeResponses"] = getattr(self, "mergeResponses", False)
+
+ # Draw elements that are not boundaries and not inside boundaries
+ for e in TM._elements:
+ if (
+ not getattr(e, "is_drawn", False)
+ and not isinstance(e, Boundary)
+ and getattr(e, "inBoundary", None) is None
+ ):
+ edges.append(e.dfd(**kwargs))
+
+ return self._dfd_template().format(
+ edges=indent("\n".join(filter(len, edges)), " ").rstrip("\n")
+ )
+
+ def _seq_template(self):
+ """Template for sequence diagram generation."""
+ return """@startuml
+{participants}
+
+{messages}
+@enduml"""
+
+ def seq(self):
+ """Generate sequence diagram."""
+ from .actor import Actor
+ from .datastore import Datastore
+ from .boundary import Boundary
+ from .dataflow import Dataflow
+
+ participants = []
+ for e in TM._elements:
+ if isinstance(e, Actor):
+ participants.append(
+ 'actor {0} as "{1}"'.format(
+ e._uniq_name(), getattr(e, "display_name", lambda: e.name)()
+ )
+ )
+ elif isinstance(e, Datastore):
+ participants.append(
+ 'database {0} as "{1}"'.format(
+ e._uniq_name(), getattr(e, "display_name", lambda: e.name)()
+ )
+ )
+ elif not isinstance(e, (Dataflow, Boundary)):
+ participants.append(
+ 'entity {0} as "{1}"'.format(
+ e._uniq_name(), getattr(e, "display_name", lambda: e.name)()
+ )
+ )
+
+ messages = []
+ for e in TM._flows:
+ message = "{0} -> {1}: {2}".format(
+ e.source._uniq_name(),
+ e.sink._uniq_name(),
+ getattr(e, "display_name", lambda: e.name)(),
+ )
+ note = ""
+ if getattr(e, "note", "") != "":
+ note = "\nnote left\n{}\nend note".format(e.note)
+ messages.append("{}{}".format(message, note))
+
+ return self._seq_template().format(
+ participants="\n".join(participants), messages="\n".join(messages)
+ )
+
+ def report(self, template_path):
+ """Generate report from template."""
+
+ try:
+ with open(template_path) as file:
+ template = file.read()
+ except (FileNotFoundError, PermissionError, IsADirectoryError) as e:
+ from .pytm import UIError
+
+ raise UIError(
+ e, f"while trying to open the report template file ({template_path})."
+ )
+
+ def _clone(obj):
+ copy_method = getattr(obj, "model_copy", None)
+ if callable(copy_method):
+ return copy_method(deep=True)
+ return copy.deepcopy(obj)
+
+ def encode_threat_data(obj):
+ """Encode threat data for HTML output."""
+ encoded_threat_data = []
+ attrs = [
+ "description",
+ "details",
+ "severity",
+ "mitigations",
+ "example",
+ "id",
+ "threat_id",
+ "references",
+ "condition",
+ "cvss",
+ "response",
+ ]
+
+ items = obj if isinstance(obj, list) else [obj]
+
+ def _escape_markdown(text: str) -> str:
+ return re.sub(r"(? sqlite3.Connection:
- db_path = tmp_path / "sqldump" / "test.db"
- return sqlite3.connect(db_path)
-
-
-def test_sql_dump_creates_serialized_columns(sample_tm, tmp_path, monkeypatch):
- monkeypatch.chdir(tmp_path)
-
- sample_tm.sqlDump("test.db")
-
- with _open_connection(tmp_path) as conn:
- column_names = {
- column_info[1].lower()
- for column_info in conn.execute("PRAGMA table_info(Boundary)")
- }
-
- assert {"name", "inscope", "inboundary"}.issubset(column_names)
-
-
-def test_sql_dump_persists_element_and_finding_data(sample_tm, tmp_path, monkeypatch):
- monkeypatch.chdir(tmp_path)
-
- sample_tm.sqlDump("test.db")
-
- with _open_connection(tmp_path) as conn:
- boundary_rows = conn.execute(
- "SELECT name, inBoundary FROM Boundary ORDER BY id"
- ).fetchall()
- server_rows = conn.execute(
- "SELECT name, inBoundary FROM Server ORDER BY id"
- ).fetchall()
- finding_rows = conn.execute(
- "SELECT threat_id FROM Finding ORDER BY id"
- ).fetchall()
-
- assert ("Internet", None) in boundary_rows
- assert ("Server/DB", "Internet") in boundary_rows
- assert ("Web Server", "Server/DB") in server_rows
- assert [row[0] for row in finding_rows] == ["SRV001"]
\ No newline at end of file
diff --git a/tm.py b/tm.py
index 9f0cb464..cd6fdc4c 100755
--- a/tm.py
+++ b/tm.py
@@ -134,4 +134,3 @@
if __name__ == "__main__":
tm.process()
-