from firexapp.submit.console import setup_console_logging
from firexapp.submit.uid import Uid
from logging import INFO, DEBUG, WARNING
import os
import re
import subprocess
import psutil
from firexapp.broker_manager.broker_factory import BrokerFactory
from socket import gethostname
from firexapp.common import poll_until_file_not_empty, poll_until_dir_empty, find_procs
from firexapp.plugins import PLUGINS_ENV_NAME, cdl2list
from firexapp.fileregistry import FileRegistry
from collections.abc import Iterable
from firexapp.common import qualify_firex_bin
logger = setup_console_logging(__name__)
CELERY_LOGS_REGISTRY_KEY = 'celery_logs'
FileRegistry().register_file(CELERY_LOGS_REGISTRY_KEY, os.path.join(Uid.debug_dirname, 'celery'))
CELERY_PIDS_REGISTRY_KEY = 'celery_pids'
FileRegistry().register_file(CELERY_PIDS_REGISTRY_KEY,
os.path.join(FileRegistry().get_relative_path(CELERY_LOGS_REGISTRY_KEY), 'pids'))
MICROSERVICE_LOGS_REGISTRY_KEY = 'microservice_logs'
FileRegistry().register_file(MICROSERVICE_LOGS_REGISTRY_KEY, 'microservice_logs')
[docs]
class CeleryWorkerStartFailed(Exception):
pass
[docs]
class CeleryManager(object):
celery_bin_name = 'celery'
def __init__(self, plugins=None, logs_dir=None, worker_log_level='debug', cap_concurrency=None,
app='firexapp.engine', env=None, broker=None):
if not broker:
self.broker = BrokerFactory.get_broker_url(assert_if_not_set=True)
self.hostname = gethostname()
self.plugins = plugins
self.logs_dir = logs_dir
self.worker_log_level = worker_log_level
self.cap_concurrency = cap_concurrency
self.app = app
self.env = os.environ.copy()
self.env['CELERY_RDBSIG'] = '1'
self.update_env(self.get_plugins_env(plugins))
if env:
self.update_env(env)
self.pid_files = dict()
self._celery_logs_dir = None
self._celery_pids_dir = None
self._workers_logs_dir = None
@property
def celery_bin(self):
return qualify_firex_bin(self.celery_bin_name)
[docs]
@classmethod
def log(cls, msg, header=None, level=DEBUG):
if header is None:
header = cls.__name__
if header:
msg = '[%s] %s' % (header, msg)
logger.log(level, msg)
[docs]
def update_env(self, env):
assert isinstance(env, dict), 'env needs to be a dictionary'
self.env.update({k: str(v) for k, v in env.items()})
[docs]
@staticmethod
def get_plugins_env(plugins):
plugin_files = cdl2list(plugins)
return {PLUGINS_ENV_NAME: ",".join(plugin_files)}
[docs]
@staticmethod
def get_celery_logs_dir(logs_dir):
return FileRegistry().get_file(CELERY_LOGS_REGISTRY_KEY, logs_dir)
[docs]
@staticmethod
def get_celery_pids_dir(logs_dir):
return FileRegistry().get_file(CELERY_PIDS_REGISTRY_KEY, logs_dir)
[docs]
@staticmethod
def get_worker_logs_dir(logs_dir):
return FileRegistry().get_file(MICROSERVICE_LOGS_REGISTRY_KEY, logs_dir)
@property
def celery_logs_dir(self):
if not self._celery_logs_dir:
_celery_logs_dir = self.get_celery_logs_dir(self.logs_dir)
os.makedirs(_celery_logs_dir, exist_ok=True)
self._celery_logs_dir = _celery_logs_dir
return self._celery_logs_dir
@property
def celery_pids_dir(self):
if not self._celery_pids_dir:
_celery_pids_dir = self.get_celery_pids_dir(self.logs_dir)
os.makedirs(_celery_pids_dir, exist_ok=True)
self._celery_pids_dir = _celery_pids_dir
return self._celery_pids_dir
@property
def workers_logs_dir(self):
if not self._workers_logs_dir:
_workers_logs_dir = self.get_worker_logs_dir(self.logs_dir)
os.makedirs(_workers_logs_dir, exist_ok=True)
self._workers_logs_dir = _workers_logs_dir
return self._workers_logs_dir
[docs]
@classmethod
def get_worker_log_file(cls, logs_dir, worker_and_host):
return cls.__get_worker_log_file(cls.get_worker_logs_dir(logs_dir), worker_and_host)
def _get_worker_log_file(self, workername):
queue_and_worker = self.get_worker_and_host(workername, self.hostname)
return self.__get_worker_log_file(self.workers_logs_dir, queue_and_worker)
@staticmethod
def __get_worker_log_file(worker_logs_dir, worker_and_host):
return os.path.join(worker_logs_dir, '%s.html' % worker_and_host)
[docs]
@classmethod
def get_pid_file(cls, logs_dir, workername, hostname=gethostname()):
worker_and_host = cls.get_worker_and_host(workername, hostname)
return cls.__get_pid_file(cls.get_celery_pids_dir(logs_dir), worker_and_host)
def _get_pid_file(self, workername):
worker_and_host = self.get_worker_and_host(workername, self.hostname)
return self.__get_pid_file(self.celery_pids_dir, worker_and_host)
@staticmethod
def __get_pid_file(pids_logs_dir, worker_and_host):
return os.path.join(pids_logs_dir, '%s.pid' % worker_and_host)
def _get_stdout_file(self, workername):
return os.path.join(self.celery_logs_dir, '%s@%s.stdout.txt' % (workername, self.hostname))
[docs]
@staticmethod
def get_worker_and_host(workername, hostname):
return '%s@%s' % (workername, hostname)
[docs]
@classmethod
def get_pid_from_file(cls, pid_file):
try:
with open(pid_file) as f:
pid = f.read().strip()
except FileNotFoundError:
cls.log('No pid file found in %s' % pid_file, level=WARNING)
raise
else:
if pid:
return int(pid)
else:
raise AssertionError('no pid')
[docs]
@classmethod
def get_pid(cls, logs_dir, workername, hostname=gethostname()):
pid_file = cls.get_pid_file(logs_dir, workername, hostname)
return cls.get_pid_from_file(pid_file)
[docs]
@classmethod
def get_worker_pids(cls, logs_dir, hostname, workernames):
hostname = gethostname() if hostname == 'localhost' else hostname
pids = []
for workername in workernames:
try:
pid = cls.get_pid(logs_dir, workername, hostname)
except Exception as e:
cls.log(e)
else:
pids.append(pid)
return pids
[docs]
@staticmethod
def cap_cpu_count(count, cap_concurrency):
return min(count, cap_concurrency) if cap_concurrency else count
[docs]
def wait_until_active(self, pid_file, stdout_file, workername, timeout=15*60):
extra_err_info = ''
try:
poll_until_file_not_empty(pid_file, timeout=timeout)
except AssertionError:
err_list = self.extract_errors_from_celery_logs(stdout_file)
if err_list:
extra_err_info += '\nFound the following errors:\n' + '\n'.join(err_list)
extra_err_info += '\nAttempting to delete the invocation pids'
deleted_pids = subprocess.run(['/bin/pkill', '-e', '-f', pid_file],
capture_output=True,
text=True)
if deleted_pids.stdout:
extra_err_info += f'\nstdout: {deleted_pids.stdout}'
if deleted_pids.stderr:
extra_err_info += f'\nstderr: {deleted_pids.stderr}'
raise CeleryWorkerStartFailed(f'The worker{workername}@{self.hostname} did not come up after'
f' {timeout} seconds.\n'
f'Please look into {stdout_file!r} for details.'
f'{extra_err_info}')
pid = self.get_pid_from_file(pid_file)
self.log('pid %d became active' % pid)
[docs]
def start(self, workername, queues=None, wait=True, timeout=15*60, concurrency=None, worker_log_level=None,
app=None, cap_concurrency=None, cwd=None, soft_time_limit=None, autoscale: tuple = None,
detach: bool = True):
# Override defaults if applicable
worker_log_level = worker_log_level if worker_log_level else self.worker_log_level
app = app if app else self.app
cap_concurrency = cap_concurrency if cap_concurrency else self.cap_concurrency
stdout_file = self._get_stdout_file(workername)
log_file = self._get_worker_log_file(workername)
pid_file = self._get_pid_file(workername)
self.pid_files[workername] = pid_file
cmd = f'{self.celery_bin} ' \
f'--app={app} worker ' \
f'--hostname={workername}@%h ' \
f'--loglevel={worker_log_level} ' \
f'--logfile={log_file} ' \
f'--pidfile={pid_file} ' \
f'--events ' \
f'--without-gossip ' \
f'--without-heartbeat ' \
f'--without-mingle ' \
f'-Ofair'
if queues:
cmd += ' --queues=%s' % queues
if concurrency and autoscale:
raise AssertionError('You can either provide a value of concurrency or autoscale, but not both')
if concurrency:
cmd += ' --concurrency=%d' % self.cap_cpu_count(concurrency, cap_concurrency)
elif autoscale:
assert isinstance(autoscale, Iterable), 'autoscale should be a tuple of (min, max)'
assert len(autoscale) == 2, 'autoscale should be a tuple of two elements (min, max)'
autoscale_v1, autoscale_v2 = autoscale
autoscale_min = min(autoscale_v1, autoscale_v2)
autoscale_max = max(autoscale_v1, autoscale_v2)
autoscale_min = self.cap_cpu_count(autoscale_min, cap_concurrency)
autoscale_max = self.cap_cpu_count(autoscale_max, cap_concurrency)
cmd += f' --autoscale={autoscale_max},{autoscale_min}'
if soft_time_limit:
cmd += f' --soft-time-limit={soft_time_limit}'
# piping to ts is helpful for debugging if available
try:
subprocess.check_call(["which", "ts"], stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
except subprocess.CalledProcessError:
pass
else:
cmd += " | ts '[%Y-%m-%d %H:%M:%S]'"
if detach:
cmd += ' &'
self.log('Starting %s on %s...' % (workername, self.hostname))
self.log(cmd)
if cwd:
self.log('cwd=%s' % cwd)
with open(stdout_file, 'ab') as fp:
subprocess.check_call(cmd, shell=True, stdout=fp, stderr=subprocess.STDOUT, env=self.env,
cwd=cwd)
if detach and wait:
self.wait_until_active(pid_file=pid_file, timeout=timeout, stdout_file=stdout_file, workername=workername)
[docs]
@staticmethod
def find_procs(pid_file):
return find_procs('celery', cmdline_contains='--pidfile=%s' % pid_file)
[docs]
def find_all_procs(self):
procs = []
for pid_file in os.listdir(self.celery_pids_dir):
procs += self.find_procs(os.path.join(self.celery_pids_dir, pid_file))
return procs
[docs]
def kill_all_forked(self, pid_file):
for proc in self.find_procs(pid_file):
self.log('Killing pid %d' % proc.pid, level=INFO)
try:
proc.kill()
except Exception:
self.log('Failed to kill pid %d' % proc.pid, level=WARNING)
[docs]
@classmethod
def terminate(cls, pid, timeout=60):
cls.log('Terminating pid %d' % pid, level=INFO)
p = psutil.Process(pid)
p.terminate()
p.wait(timeout=timeout)
[docs]
def shutdown(self, timeout=60):
if self.pid_files:
name_to_pid_file = self.pid_files
else:
# self.pid_files is only populated when starting celery, so if this manager didn't start the celery
# instance being operated on, fallback to the pid directory.
name_to_pid_file = {pf: os.path.join(self.celery_pids_dir, pf) for pf in os.listdir(self.celery_pids_dir)}
for name, pid_file in name_to_pid_file.items():
self.log('Attempting shutdown of %s' % name)
try:
pid = self.get_pid_from_file(pid_file)
except Exception as e:
self.log(e)
else:
try:
self.terminate(pid, timeout=timeout)
except (psutil.TimeoutExpired, psutil.NoSuchProcess):
self.kill_all_forked(pid_file)
except Exception as e:
self.log(e)
[docs]
def wait_for_shutdown(self, timeout=15):
return poll_until_dir_empty(self.celery_pids_dir, timeout=timeout)