Add EnumAction

This commit is contained in:
Maks Snegov 2022-02-09 09:41:31 +03:00
parent 0f009edd8e
commit 2802de909d
6 changed files with 142 additions and 93 deletions

1
requirements-dev.txt Normal file
View File

@ -0,0 +1 @@
requests

View File

@ -1,6 +1,4 @@
from .helpers import (
nested_dataclass,
ExtraDataclass,
singleton,
ConstEnum
)
from .dataclasses import ExtraDataclass, nested_dataclass
from .enum import ConstEnum, EnumAction
from helpers import singleton
from httplog import logRoundtrip, HttpFormatter

View File

@ -0,0 +1,46 @@
from typing import Sequence
from dataclasses import dataclass, fields, asdict, is_dataclass
@dataclass
class ExtraDataclass:
@classmethod
def from_dict(cls, data: dict):
""" Create dataclass from dict ignoring extra arguments. """
cl_fields = set([f.name for f in fields(cls)])
filtered_data = dict()
for k, v in data.items():
if k in cl_fields:
filtered_data[k] = v
return cls(**filtered_data)
@classmethod
def from_list(cls, ent_list: Sequence[dict]) -> list:
""" Create list of dataclass instances from list of dicts. """
return [cls.from_dict(ent) for ent in ent_list]
def to_dict(self) -> dict:
""" Returns dict with not None values """
return asdict(
self, dict_factory=lambda d: {k: v for k, v in d if v is not None}
)
def nested_dataclass(*args, **kwargs):
def wrapper(cls):
cls = dataclass(cls, **kwargs)
original_init = cls.__init__
def __init__(self, *args, **kwargs):
for name, value in kwargs.items():
field_type = cls.__annotations__.get(name, None)
if is_dataclass(field_type) and isinstance(value, dict):
new_obj = field_type(**value)
kwargs[name] = new_obj
original_init(self, *args, **kwargs)
cls.__init__ = __init__
return cls
return wrapper(args[0]) if args else wrapper

75
spqr/kieran/enum.py Normal file
View File

@ -0,0 +1,75 @@
import argparse
import enum
class ConstEnum(str, enum.Enum):
"""
Class is used for gathering string constants in one place.
Attribute values are generated from attribute names.
All attributes values will be considered as strings.
Class name is omitted from attribute value.
>>> from spqr.kieran.enum import ConstEnum
>>> Booze = ConstEnum("Booze", "Whiskey Beer Vodka")
>>> Booze.Beer
Beer
>>> print(Booze.Beer)
Beer
>>> type(Booze.Vodka)
<enum 'Booze'>
>>> repr(Booze.Whiskey)
'Whiskey'
Also could be used with usual class creation:
>>> from enum import auto
>>> class NotBooze(ConstEnum):
>>> Juice = auto()
>>> Tea = auto()
>>> Coffee = auto()
>>> repr(NotBooze.Tea)
'Tea'
"""
def _generate_next_value_(name, start, count, last_values):
return name
def __repr__(self):
return self.name
def __str__(self):
return self.name
class EnumAction(argparse.Action):
"""
Argparse action for handling Enums
Usage:
>>> import enum
>>> class Do(enum.Enum):
>>> Foo = "foo"
>>> Bar = "bar"
>>> parser = argparse.ArgumentParser()
>>> parser.add_argument('do', type=Do, action=EnumAction)
"""
def __init__(self, **kwargs):
# Pop off the type value
enum_type = kwargs.pop("type", None)
# Ensure an Enum subclass is provided
if enum_type is None:
raise ValueError("type must be assigned an Enum when using EnumAction")
if not issubclass(enum_type, enum.Enum):
raise TypeError("type must be an Enum when using EnumAction")
# Generate choices from the Enum
kwargs.setdefault("choices", tuple(e.value for e in enum_type))
super(EnumAction, self).__init__(**kwargs)
self._enum = enum_type
def __call__(self, parser, namespace, values, option_string=None):
# Convert value back into an Enum
value = self._enum(values)
setattr(namespace, self.dest, value)

View File

@ -1,52 +1,18 @@
from dataclasses import dataclass, is_dataclass, fields, asdict
from enum import Enum
from typing import Sequence
def nested_dataclass(*args, **kwargs):
def wrapper(cls):
cls = dataclass(cls, **kwargs)
original_init = cls.__init__
def __init__(self, *args, **kwargs):
for name, value in kwargs.items():
field_type = cls.__annotations__.get(name, None)
if is_dataclass(field_type) and isinstance(value, dict):
new_obj = field_type(**value)
kwargs[name] = new_obj
original_init(self, *args, **kwargs)
cls.__init__ = __init__
return cls
return wrapper(args[0]) if args else wrapper
@dataclass
class ExtraDataclass:
@classmethod
def from_dict(cls, data: dict):
""" Create dataclass from dict ignoring extra arguments. """
cl_fields = set([f.name for f in fields(cls)])
filtered_data = dict()
for k, v in data.items():
if k in cl_fields:
filtered_data[k] = v
return cls(**filtered_data)
@classmethod
def from_list(cls, ent_list: Sequence[dict]) -> list:
""" Create list of dataclass instances from list of dicts. """
return [cls.from_dict(ent) for ent in ent_list]
def to_dict(self) -> dict:
""" Returns dict with not None values """
return asdict(
self, dict_factory=lambda d: {k: v for k, v in d if v is not None}
)
def singleton(class_):
"""
Singleton wrapper.
Usage:
>>> from spqr.kieran.helpers import singleton
>>> @singleton
>>> class DBConnection:
>>> def __init__(self, db_uri):
>>> self.uri = db_uri
>>> db1 = DBConnection("example.com")
>>> db2 = DBConnection("example.net")
>>> assert db1.uri == db2.uri
"""
instances = {}
def get_instance(*args, **kwargs):
@ -55,41 +21,3 @@ def singleton(class_):
return instances[class_]
return get_instance
class ConstEnum(str, Enum):
"""
Class is used for gathering string constants in one place.
Attribute values are generated from attribute names.
All attributes values will be considered as strings.
Class name is omitted from attribute value.
>>> Booze = ConstEnum("Booze", "Whiskey Beer Vodka")
>>> Booze.Beer
Beer
>>> print(Booze.Beer)
Beer
>>> type(Booze.Vodka)
<enum 'Booze'>
>>> repr(Booze.Whiskey)
'Whiskey'
Also could be used with usual class creation:
>>> from enum import auto
>>> class NotBooze(ConstEnum):
>>> Juice = auto()
>>> Tea = auto()
>>> Coffee = auto()
>>> repr(NotBooze.Tea)
'Tea'
"""
def _generate_next_value_(name, start, count, last_values):
return name
def __repr__(self):
return self.name
def __str__(self):
return self.name

View File

@ -4,12 +4,13 @@ Enable HTTP requests/responses logging.
Usage:
>>> import logging
>>> import requests
>>> from spqr.kieran.httplog import logRoundtrip, HttpFormatter
>>> formatter = HttpFormatter("{asctime}|{levelname}|{threadName}|{message}", style="{")
>>> handler = logging.StreamHandler()
>>> handler.setFormatter(formatter)
>>> logging.basicConfig(level=loglevel, handlers=[handler])
>>> logging.basicConfig(level="INFO", handlers=[handler])
>>> session = requests.Session()
>>> session.hooks['response'].append(logRoundtrip)