Skip to content
Merged
2 changes: 2 additions & 0 deletions fmpose3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
HRNetConfig,
InferenceConfig,
ModelConfig,
SupportedModel,
PipelineConfig,
)

Expand Down Expand Up @@ -57,6 +58,7 @@
"HRNetConfig",
"InferenceConfig",
"ModelConfig",
"SupportedModel",
"PipelineConfig",
# Aggregation methods
"average_aggregation",
Expand Down
2 changes: 2 additions & 0 deletions fmpose3d/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .config import (
PipelineConfig,
ModelConfig,
SupportedModel,
FMPose3DConfig,
HRNetConfig,
Pose2DConfig,
Expand Down Expand Up @@ -48,6 +49,7 @@
"HRNetConfig",
"Pose2DConfig",
"ModelConfig",
"SupportedModel",
"DatasetConfig",
"TrainingConfig",
"InferenceConfig",
Expand Down
21 changes: 19 additions & 2 deletions fmpose3d/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,28 @@

import math
from dataclasses import dataclass, field, fields, asdict
from enum import Enum
from typing import Dict, List

# ---------------------------------------------------------------------------
# Dataclass configuration groups
# ---------------------------------------------------------------------------


class SupportedModel(str, Enum):
"""Supported FMPose3D pose-estimation model types."""
FMPOSE3D_HUMANS = "fmpose3d_humans"
FMPOSE3D_ANIMALS = "fmpose3d_animals"

@classmethod
def _missing_(cls, value: str) -> "SupportedModel":
valid = ", ".join(repr(m.value) for m in cls)
raise ValueError(
f"{value!r} is not a valid {cls.__name__}. "
f"Valid values are: {valid}"
)


@dataclass
class ModelConfig:
"""Model architecture configuration."""
Expand Down Expand Up @@ -51,7 +66,7 @@ class ModelConfig:

@dataclass
class FMPose3DConfig(ModelConfig):
model_type: str = "fmpose3d_humans"
model_type: SupportedModel = SupportedModel.FMPOSE3D_HUMANS
model: str = ""
layers: int = 5
channel: int = 512
Expand All @@ -67,6 +82,8 @@ class FMPose3DConfig(ModelConfig):
frames: int = 1

def __post_init__(self):
if not isinstance(self.model_type, SupportedModel):
self.model_type = SupportedModel(self.model_type)
defaults = _FMPOSE3D_DEFAULTS.get(self.model_type)
if defaults is None:
supported = ", ".join(sorted(_FMPOSE3D_DEFAULTS))
Expand Down Expand Up @@ -321,7 +338,7 @@ def _pick(dc_class, src: dict):

kwargs = {}
for group_name, dc_class in _SUB_CONFIG_CLASSES.items():
if group_name == "model_cfg" and raw.get("model_type", 'fmpose3d_humans') in _FMPOSE3D_DEFAULTS:
if group_name == "model_cfg" and raw.get("model_type", "fmpose3d_humans") in _FMPOSE3D_DEFAULTS:
dc_class = FMPose3DConfig
elif group_name == "pose2d_cfg":
p2d = raw.get("pose2d_model", "hrnet")
Expand Down
Loading
Loading