Skip to content

Commit

Permalink
Support env variables in config
Browse files Browse the repository at this point in the history
  • Loading branch information
4Kaylum committed Jan 14, 2024
1 parent f8f9b9c commit f6fa03b
Showing 1 changed file with 70 additions and 5 deletions.
75 changes: 70 additions & 5 deletions novus/ext/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

import glob
import logging
from typing import TYPE_CHECKING, Any, NoReturn, TypeAlias
import os
from typing import TYPE_CHECKING, Any, NoReturn, TypeAlias, Union, overload

import dotenv
from typing_extensions import Self

import novus
Expand All @@ -31,12 +33,16 @@

Extended: TypeAlias = dict[str, dict[str, Any]]

DictValue: TypeAlias = Union[str, int, bool]
NestedDict: TypeAlias = dict[str, Union[DictValue, "NestedDict"]]

__all__ = (
'Config',
)


log = logging.getLogger("novus.ext.client.config")
dotenv.load_dotenv(dotenv.find_dotenv(usecwd=True))


class Config:
Expand Down Expand Up @@ -65,14 +71,18 @@ def __init__(
self.shard_count: int = shard_count
self.intents: novus.Intents = intents or novus.Intents()
self.plugins: list[str] = plugins or []
self.extended = {}
for k, v in kwargs.items():
setattr(self, k.replace("-", "_"), v)
if self.plugins:
from .client import Client
client_logger = logging.getLogger("novus.ext.client")
current_level = client_logger.level
client_logger.setLevel(logging.ERROR)
bot = Client(self)
for p in bot.plugins:
self.extended[p.__name__] = p.CONFIG.copy()
for k, v in kwargs.items():
setattr(self, k.replace("-", "_"), v)
self.extended = {}
client_logger.setLevel(current_level)

if TYPE_CHECKING:

Expand Down Expand Up @@ -198,7 +208,62 @@ def from_file(cls, filename: str | None) -> Self:
def from_dict(cls, data: dict[str, Any]) -> Self:
intent_dict = data.pop("intents")
intents = novus.Intents(**intent_dict)
return cls(**data, intents=intents)
new_data = cls.exchange_dict_for_env(data)
return cls(**new_data, intents=intents) # type: ignore

@classmethod
def exchange_str_for_env(cls, value: str) -> str:
"""
Get an environment variable if a string starts with "$". Otherwise,
return the given value.
... -> ...
$... -> ENV[...]
Any leading backslashes will be exchanged for the same number of
backslashes minus one.
"""

if value.startswith("$"):
name = value[1:]
return os.getenv(name, "")
elif value.startswith("\\"):
backslash_count = len(value) - len(value.lstrip("\\"))
return ("\\" * (backslash_count - 1)) + value.lstrip("\\")
else:
return value

@overload
@classmethod
def exchange_dict_for_env(cls, value: DictValue | list[DictValue]) -> DictValue | list[DictValue]:
...

@overload
@classmethod
def exchange_dict_for_env(cls, value: NestedDict) -> NestedDict:
...

@classmethod
def exchange_dict_for_env(
cls,
value: DictValue | NestedDict | list[DictValue]) -> DictValue | NestedDict | list[DictValue]:
"""
Deep iterate through a given dictionary and work out the environment
vars from it.
"""

if isinstance(value, str):
return cls.exchange_str_for_env(value)
elif isinstance(value, list):
return [
cls.exchange_str_for_env(i) if isinstance(i, str) else i
for i in value
]
elif isinstance(value, (bool, int)) or value is None:
return value
new = {}
for k, v in value.items():
new[k] = cls.exchange_dict_for_env(v)
return new # type: ignore

def to_dict(self) -> dict[str, Any]:
v = {
Expand Down

0 comments on commit f6fa03b

Please sign in to comment.