import os
import sys
import logging
from collections import OrderedDict, namedtuple
from typing import Dict, List, Tuple
from entrypoints import EntryPoint
TASKS_DIRECTORY = "firex_tasks_directory"
logger = logging.getLogger(__name__)
_loaded_firex_bundles = {}
[docs]
class PkgVersionInfo(namedtuple('PkgVersionInfo', ('pkg', 'version', 'commit'), defaults=(None, None, None))):
def __str__(self):
return f'{self.pkg}: {self.version or self.commit}'
def _get_paths_without_cwd() -> [str]:
# This is needed because Celery temporarily adds the cwd into the sys.path via a context switcher,
# and our discovery takes place inside that context.
# Having cwd in the sys.path can slow down the discovery significantly without any benefit.
paths = list(sys.path)
try:
paths.remove(os.getcwd())
except ValueError: # pragma: no cover
pass
return paths
#
# In case there are duplicate modules found, only keep one for each
# (name, module_name, object_name) tuple. This prevents duplicate
# arg registration failures when the sys.path causes the same service
# to be found twice.
#
[docs]
def prune_duplicate_module_entry_points(entry_points) -> [EntryPoint]:
id_to_entry_points = OrderedDict()
for e in entry_points:
key = (e.name, e.module_name, e.object_name)
if key not in id_to_entry_points:
id_to_entry_points[key] = e
# Replace the currently stored entry point for this key if the distro is None.
elif id_to_entry_points[key].distro is None and e.distro is not None:
id_to_entry_points[key] = e
return list(id_to_entry_points.values())
def _get_entrypoints(name, prune_duplicates=True, path=None) -> [EntryPoint]:
import entrypoints
if path is not None and not isinstance(path, list):
path = [path]
eps = [ep for ep in entrypoints.get_group_all(name, path=path)]
if prune_duplicates:
eps = prune_duplicate_module_entry_points(eps)
return eps
[docs]
def loaded_firex_core_entry_points(path=None) -> Dict[EntryPoint, object]:
return _load_firex_entry_points('firex.core', path=path)
[docs]
def loaded_firex_bundles_entry_points(path=None) -> Dict[EntryPoint, object]:
return _load_firex_entry_points('firex.bundles', path=path)
[docs]
def loaded_firex_entry_points(path=None):
cores = loaded_firex_core_entry_points(path=path)
bundles = loaded_firex_bundles_entry_points(path=path)
return {**cores, **bundles}
def _load_firex_entry_points(entrypoint_name, path=None) -> Dict[EntryPoint, object]:
global _loaded_firex_bundles
key = str(path)
try:
return _loaded_firex_bundles[key][entrypoint_name]
except KeyError:
eps = _get_entrypoints(entrypoint_name, path=path)
loaded_eps = {ep: ep.load() for ep in eps}
try:
_loaded_firex_bundles[key][entrypoint_name] = loaded_eps
except KeyError:
_loaded_firex_bundles[key] = dict(entrypoint_name=loaded_eps)
return loaded_eps
[docs]
def get_firex_tracking_services_entry_points() -> [EntryPoint]:
return _get_entrypoints('firex_tracking_service')
[docs]
def get_firex_dependant_package_versions() -> [PkgVersionInfo]:
versions = list()
for ep, loaded_pkg in loaded_firex_entry_points().items():
try:
version = loaded_pkg.__version__
except AttributeError:
version = None
try:
commit = loaded_pkg._version.get_versions()['full-revisionid']
except AttributeError:
commit = None
versions.append(PkgVersionInfo(pkg=ep.name, version=version, commit=commit))
return versions
[docs]
def get_all_pkg_versions() -> [PkgVersionInfo]:
from firexapp.submit.tracking_service import get_tracking_services_versions
return get_tracking_services_versions() + get_firex_dependant_package_versions()
[docs]
def get_all_pkg_versions_as_dict() -> dict:
return {pkg_info.pkg: pkg_info for pkg_info in get_all_pkg_versions()}
[docs]
def get_all_pkg_versions_str() -> str:
pkg_version_info_str = [f'\t - {p_info}' for p_info in get_all_pkg_versions()]
return 'FireX Package Versions:\n' + '\n'.join(pkg_version_info_str) + '\n'
def _find_bundle_pkg_root(path, namespace):
while True:
head, tail = os.path.split(path)
if tail == namespace:
return head
else:
path = os.path.dirname(path)
# Return a list of two-element tuples
# Where the 1st element is the path of package
# and the 2nd element is the path of package's root
def _get_firex_bundle_package_locations(path=None) -> List[Tuple[str, str]]:
locations = []
loaded_entry_points = loaded_firex_bundles_entry_points(path=path)
for p in loaded_entry_points.values():
namespace = p.__package__.split('.')[0]
pkg_paths = p.__path__
for pkg_path in pkg_paths:
root = _find_bundle_pkg_root(pkg_path, namespace)
locations.append((pkg_path, root))
return locations
[docs]
def discover_package_modules(current_path, root_path=None) -> [str]:
if root_path is None:
root_path = os.path.dirname(current_path)
services = []
if os.path.isfile(current_path):
basename, ext = os.path.splitext(current_path)
if ext.lower() == ".py" and not os.path.basename(current_path).startswith('_'):
basename = basename.replace(root_path, "")
return [basename.replace(os.path.sep, ".").strip(".")]
else:
return []
elif os.path.isdir(current_path):
base = os.path.basename(current_path)
if "__pycache__" in base or base.startswith("."):
return []
for child_name in os.listdir(current_path):
full_child = os.path.join(current_path, child_name)
services += discover_package_modules(full_child, root_path)
return services
else:
# either a symlink or a path that doesn't exist
return []
[docs]
def find_firex_task_bundles() -> [str]:
# look for task modules in dependant packages
bundles = []
locations = _get_firex_bundle_package_locations()
for path, root_path in locations:
bundles += discover_package_modules(path, root_path)
# look for task modules in env defined location
if TASKS_DIRECTORY in os.environ:
include_location = os.environ[TASKS_DIRECTORY]
if os.path.isdir(include_location):
if include_location not in sys.path:
sys.path.append(include_location)
include_tasks = discover_package_modules(include_location, root_path=include_location)
bundles += include_tasks
return bundles