import os
import sys
import inspect
import traceback
from argparse import ArgumentParser, Action
from celery.signals import worker_init
from celery.utils.log import get_task_logger
from firexapp.common import delimit2list
from firexkit.task import REPLACEMENT_TASK_NAME_POSTFIX
import importlib.util
from celery import current_app
logger = get_task_logger(__name__)
PLUGINS_ENV_NAME = "firex_plugins"
[docs]
class PluginLoadError(Exception):
pass
[docs]
def get_short_name(long_name: str) -> str:
return long_name.split('.')[-1]
[docs]
def find_plugin_file(file_path):
# is it a full path?
if os.path.isabs(file_path):
plugin_file = file_path
else:
# Maybe it's relative?
plugin_file = os.path.abspath(file_path)
if os.path.isfile(plugin_file):
return plugin_file
raise FileNotFoundError(file_path)
[docs]
def cdl2list(plugin_files):
if not plugin_files:
return []
if not isinstance(plugin_files, list):
plugin_files = [file.strip() for file in plugin_files.split(",")]
plugin_files = [find_plugin_file(file) for file in plugin_files if file]
return plugin_files
[docs]
def get_plugin_module_name(plugin_file):
return os.path.splitext(os.path.basename(plugin_file))[0]
[docs]
def get_plugin_module_names(plugin_files):
files = cdl2list(plugin_files)
if not files:
return []
return [get_plugin_module_name(file) for file in files]
[docs]
def get_plugin_module_names_from_env():
plugin_files = get_active_plugins()
return get_plugin_module_names(plugin_files)
# noinspection PyUnusedLocal
@worker_init.connect()
def _worker_init_signal(*args, **kwargs):
try:
load_plugin_modules_from_env()
except PluginLoadError:
traceback.print_exc()
# We will just exit, the pid file won't be written,
# and the bringup of the worker will eventually timeout
exit(-2)
# there is no way of copying the signals without coupling with the internals of celery signals
# noinspection PyProtectedMember
def _get_signals_with_connections():
from celery.utils.dispatch.signal import Signal, NONE_ID
import celery.signals as sigs
# get all official signals
signals = [s for s in sigs.__dict__.values() if type(s) is Signal]
# only use the ones registered to specific microservices (as opposed to sender=None)
signals = [s for s in signals if len(s.receivers) > len(s._live_receivers(None))]
# now get the task specific registrations
def from_sender_only(sig):
return [k for k in sig.receivers if k[0][1] != NONE_ID]
signals = {s: from_sender_only(s) for s in signals}
signals = {s: k for s, k in signals.items() if k}
return signals
[docs]
def create_replacement_task(original, name_postfix, sigs):
new_name = original.name + name_postfix
bound = inspect.ismethod(original.undecorated)
func = original.run if not bound else original.run.__func__
options = {key: getattr(original, key) for key in ["acks_late",
"default_retry_delay",
"expires",
"ignore_result",
"max_retries",
"reject_on_worker_lost",
"resultrepr_maxsize",
"soft_time_limit",
"store_errors_even_if_ignored",
"soft_time_limit",
"time_limit",
"track_started",
"trail",
"typing",
"returns",
"flame",
"use_cache",
"pending_child_strategy",
"from_plugin"] if key in dir(original)}
new_task = current_app.task(name=new_name,
bind=bound,
base=inspect.getmro(original.__class__)[1],
check_name_for_override_posfix=False,
**options)(fun=func)
if hasattr(original, "orig"):
new_task.orig = original.orig
if hasattr(original, "report_meta"):
new_task.report_meta = original.report_meta
try:
# there is no way of copying the signals without coupling with the internals of celery signals
# noinspection PyProtectedMember
from celery.utils.dispatch.signal import _make_id
orig_task_id = _make_id(original)
for s, receivers in sigs.items():
for r in receivers:
# format is ((id(receiver), id(sender)), ref(receiver))
# locate any registered signal against the original microservice
if r[0][1] == orig_task_id:
# new entry only replaces the
entry = ((r[0][0], _make_id(new_task)), r[1])
s.receivers.append(entry)
except Exception as e:
logger.error("Unable to copy signals while overriding %s:\n%s" % (original.name, str(e)))
return new_task
def _unregister_duplicate_tasks():
sigs = _get_signals_with_connections()
becomes = identify_duplicate_tasks(current_app.tasks, get_plugin_module_names_from_env())
for substitutions in becomes:
prime_overrider = substitutions[-1] # the last item in the list is the last override
for index in range(0, len(substitutions)-1):
original_name = substitutions[index]
original = current_app.tasks[original_name]
current_app.tasks[original_name] = current_app.tasks[prime_overrider]
postfix = REPLACEMENT_TASK_NAME_POSTFIX * (len(substitutions) - index - 1)
new_task = create_replacement_task(original, postfix, sigs)
overrider = substitutions[index+1]
current_app.tasks[overrider].orig = new_task
[docs]
def identify_duplicate_tasks(all_tasks, priority_modules: list) -> [[]]:
"""
Returns a list of substitution. Each substitution is a list of microservices. The last will be the 'dominant' one.
It will be the one used.
"""
unique_names = set([get_short_name(long_name) for long_name in all_tasks])
unique_names = {name: list() for name in unique_names}
for long_name in all_tasks:
unique_names[get_short_name(long_name)].append(long_name)
def priority_index(micro_name):
try:
return priority_modules.index(os.path.splitext(micro_name)[0])
except ValueError:
return -1
overrides = [sorted(long_names, key=priority_index) for long_names in unique_names.values() if len(long_names) > 1]
return overrides
[docs]
def import_plugin_file(plugin_file):
def _import_plugin(module_name, plugin_file):
spec = importlib.util.spec_from_file_location(module_name, plugin_file)
module = importlib.util.module_from_spec(spec)
module_directory = os.path.dirname(os.path.realpath(plugin_file))
if module_directory not in sys.path:
sys.path.append(module_directory)
mod = sys.modules[module_name] = module
try:
spec.loader.exec_module(module)
except BaseException:
raise PluginLoadError(f'Fatal Error loading plugin {plugin_file!r}')
print(f"Plugins module {module_name!r} imported from {plugin_file!r}")
return mod
plugin_file = find_plugin_file(plugin_file)
module_name = get_plugin_module_name(plugin_file)
if module_name in sys.modules:
module_source = sys.modules[module_name].__file__
if module_source != plugin_file:
logger.error(f'Plugin module {module_name!r} was NOT imported from {plugin_file!r}. '
f'A module with the same name was already imported from {module_source!r}')
else:
logger.warning(f'Plugin module {module_name!r} was already imported from {module_source!r}.')
else:
return _import_plugin(module_name, plugin_file)
[docs]
def import_plugin_files(plugin_files) -> set[str]:
plugin_files = cdl2list(plugin_files)
if not plugin_files:
return set()
original_tasks = set(current_app.tasks)
for plugin_file in plugin_files:
import_plugin_file(plugin_file)
new_tasks = set(current_app.tasks) - original_tasks
new_tasks_modules = {t.rsplit('.', 1)[0] for t in new_tasks}
if new_tasks_modules:
print(f'{len(new_tasks)} new service{"s" if len(new_tasks)>1 else ""} imported '
f'from {len(new_tasks_modules)} plugin module{"s" if len(new_tasks_modules)>1 else ""} '
f'[{", ".join(new_tasks_modules)}]')
else:
print(f'No new tasks/services imported from {plugin_files}!')
return new_tasks
[docs]
def set_plugins_env(plugin_files):
plugin_files = cdl2list(plugin_files)
os.environ[PLUGINS_ENV_NAME] = ",".join(plugin_files)
[docs]
def get_active_plugins():
return os.environ.get(PLUGINS_ENV_NAME, "")
[docs]
def load_plugin_modules(plugin_files):
set_plugins_env(plugin_files)
new_tasks_imported = import_plugin_files(plugin_files)
# Mark the newly imported tasks with "from_plugin"
for t in new_tasks_imported:
current_app.tasks[t].from_plugin = True
if plugin_files:
_unregister_duplicate_tasks()
[docs]
def load_plugin_modules_from_env():
plugin_files = get_active_plugins()
if plugin_files:
load_plugin_modules(plugin_files)
[docs]
def merge_plugins(*plugin_lists) -> []:
"""Merge comma delimited lists of plugins into a single list. Right-handed most significant plugin"""
combined_list = []
for plugin_list in plugin_lists:
combined_list += delimit2list(plugin_list)
new_list = []
for next_idx, plugin in enumerate(combined_list, start=1):
if plugin not in combined_list[next_idx:]:
new_list.append(plugin)
return new_list
[docs]
class CommaDelimitedListAction(Action):
def __init__(self, option_strings, dest, nargs=None, **kwargs):
self.is_default = True
if nargs is not None:
raise ValueError("nargs not allowed")
super(CommaDelimitedListAction, self).__init__(option_strings, dest, **kwargs)
def __call__(self, parser, namespace, values, option_string=None):
old_value = getattr(namespace, self.dest) if hasattr(namespace, self.dest) and not self.is_default else ""
self.is_default = False
if old_value:
old_value += ","
new_value = ",".join(merge_plugins(old_value, values))
setattr(namespace, self.dest, new_value)
plugin_support_parser = ArgumentParser(add_help=False)
plugin_support_parser.add_argument("--external", "--plugins", '-external', '-plugins', "--plugin",
help="Comma delimited list of plugins files to load",
default="",
dest='plugins',
action=CommaDelimitedListAction)