diff --git a/hyperglass/models/directive.py b/hyperglass/models/directive.py index b5f0049..d44e22d 100644 --- a/hyperglass/models/directive.py +++ b/hyperglass/models/directive.py @@ -315,10 +315,10 @@ class NativeDirective(Directive): DirectiveT = t.Union[NativeDirective, Directive] -class Directives(HyperglassMultiModel[DirectiveT]): +class Directives(HyperglassMultiModel[Directive]): """Collection of directives.""" - def __init__(self, *items: t.Dict[str, t.Any]) -> None: + def __init__(self, *items: t.Union[DirectiveT, t.Dict[str, t.Any]]) -> None: """Initialize base class and validate objects.""" super().__init__(*items, model=Directive, accessor="id") diff --git a/hyperglass/models/main.py b/hyperglass/models/main.py index c71d458..d5915d4 100644 --- a/hyperglass/models/main.py +++ b/hyperglass/models/main.py @@ -128,13 +128,15 @@ class HyperglassMultiModel(GenericModel, t.Generic[MultiModelT]): extra = "forbid" validate_assignment = True - def __init__(self, *items: t.Dict[str, t.Any], model: MultiModelT, accessor: str) -> None: + def __init__( + self, *items: t.Union[MultiModelT, t.Dict[str, t.Any]], model: MultiModelT, accessor: str + ) -> None: """Validate items.""" - items = self._valid_items(*items, model=model, accessor=accessor) - super().__init__(__root__=items) - self._count = len(items) self._accessor = accessor self._model = model + valid = self._valid_items(*items) + super().__init__(__root__=valid) + self._count = len(self.__root__) def __iter__(self) -> t.Iterator[MultiModelT]: """Iterate items.""" @@ -179,31 +181,27 @@ class HyperglassMultiModel(GenericModel, t.Generic[MultiModelT]): """Access item count.""" return self._count - @staticmethod def _valid_items( - *to_validate: t.List[t.Union[MultiModelT, t.Dict[str, t.Any]]], - model: MultiModelT, - accessor: str + self, *to_validate: t.List[t.Union[MultiModelT, t.Dict[str, t.Any]]] ) -> t.List[MultiModelT]: items = [ item for item in to_validate if any( ( - (isinstance(item, dict) and accessor in item), - (isinstance(item, model) and hasattr(item, accessor)), + (isinstance(item, self.model) and hasattr(item, self.accessor)), + (isinstance(item, t.Dict) and self.accessor in item), ), ) ] - for index, item in enumerate(items): - if isinstance(item, dict): - items[index] = model(**item) + if isinstance(item, t.Dict): + items[index] = self.model(**item) return items def add(self, *items, unique_by: t.Optional[str] = None) -> None: """Add an item to the model.""" - to_add = self._valid_items(*items, model=self.model, accessor=self.accessor) + to_add = self._valid_items(*items) if unique_by is not None: unique_by_values = { getattr(obj, unique_by) for obj in (*self, *to_add) if hasattr(obj, unique_by)