# Copyright 2017 datawire. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import eventlet, functools, sys, os
from contextlib import contextmanager
from eventlet.corolocal import local
from eventlet.green import time
from .sentinel import Sentinel
logging = eventlet.import_patched('logging')
traceback = eventlet.import_patched('traceback')
datetime = eventlet.import_patched('datetime')
output = eventlet.import_patched('forge.output')
emod = eventlet.import_patched('forge.executor')
executor = emod.executor
Result = emod.Result
# XXX: need better default for logfile
[docs]def setup(logfile=None):
    """
    Setup the task system. This will perform eventlet monkey patching as well as set up logging.
    """
    if not logfile:
        logfile = "/tmp/forge-{}-{}.log".format(os.environ.get("USER", os.getuid()),
                                                datetime.date.today().isoformat())
    logging.getLogger("tasks").addFilter(TaskFilter())
    logging.basicConfig(filename=logfile,
                        level=logging.INFO,
                        format='%(levelname)s %(task_id)s: %(message)s')
    executor.setup()
    executor.resize(5) 
[docs]class TaskError(Exception):
    report_traceback = False
    """
    Used to signal anticipated errors has occured. A task error will
    be rendered without it's stack trace, so it should include enough
    information in the error message to diagnose the issue.
    """
    pass 
ChildError = emod.ChildError
PENDING = emod.PENDING
ERROR = emod.ERROR
[docs]def elapsed(delta):
    """
    Return a pretty representation of an elapsed time.
    """
    minutes, seconds = divmod(delta, 60)
    hours, minutes = divmod(minutes, 60)
    return "%d:%02d:%02d" % (hours, minutes, seconds) 
[docs]class TaskFilter(logging.Filter):
    """
    This logging filter augments log records with useful context when
    log statements are made within a task. It also captures the log
    messages made within a task and records them in the execution
    object for a task invocation.
    """
[docs]    def filter(self, record):
        exe = execution.current()
        if exe:
            record.task_id = exe.id
        else:
            record.task_id = "(none)"
        return True  
[docs]class task(object):
    """A decorator used to mark a given function or method as a task.
    A task can really be any python code, however it is expected that
    tasks will perform scripting, coordination, integration, and
    general glue-like activities that are used to automate tasks on
    behalf of humans.
    This kind of code generally suffers from a number of problems:
     - There is rarely good user feedback for what is happening at any
       given moment.
     - When integration assumptions are violated (e.g. remote system
       barfs) the errors are often swallowed/opaque.
     - Because of the way it is incrementally built via growing
       convenience scripts it is often opaque and difficult to debug.
     - When parallel workflows are needed, they are difficult to code
       in a way that preserves clear user feedback on progress and
       errors.
    Using the task decorator provides a number of conveniences useful
    for this kind of code.
     - Task arguments/results are automatically captured for easy
       debugging.
     - Convenience APIs for executing tasks in parallel.
     - Convenience for safely executing shell and http requests with
       good error reporting and user feedback.
    Any python function can be marked as a task and invoked in the
    normal way you would invoke any function, e.g.::
        @task()
        def normpath(path):
            parts = [p for p in path.split("/") if p]
            normalized = "/".join(parts)
            if path.startswith("/"):
              return "/" + normalized
            else:
              return normalized
        print normpath("/foo//bar/baz") -> "/foo/bar/baz"
    The decorator however provides several other convenient ways you
    can invoke a task::
        # using normpath.go, I can launch subtasks in parallel
        normalized = normpath.go("asdf"), normpath.go("fdsa"), normpath.go("bleh")
        # now I can fetch the result of an individual subtask:
        result = normalized[0].get()
        # or sync on any outstanding sub tasks:
        task.sync()
    You can also run a task. This will render progress indicators,
    status, and errors to the screen as the task and any subtasks
    proceed::
        normpath.run("/foo//bar/baz")
    """
    def __init__(self, name = None, context = None):
        self.name = name
        self.context_template = context
        self.logger = logging.getLogger("tasks")
        self.count = 0
[docs]    @staticmethod
    @contextmanager
    def verbose(value):
        exe = executor.current()
        saved = exe.verbose
        exe.verbose = value
        yield
        exe.verbose = value 
[docs]    @staticmethod
    @contextmanager
    def context(name):
        exe = executor.current()
        saved = getattr(exe, "_default_name", None)
        exe._default_name = name
        yield
        exe._default_name = saved 
    def _context(self, args, kwargs):
        exe = executor.current()
        if exe and getattr(exe, "_default_name", None) is not None:
            return exe._default_name
        if self.context_template is None:
            return None
        return self.context_template.format(*args, **kwargs)
[docs]    def generate_id(self):
        self.count += 1
        return self.count 
    def __call__(self, function):
        self.function = function
        if self.name is None:
            self.name = self.function.__name__
        result = decorator(self)
        functools.update_wrapper(result, function)
        return result
[docs]    @staticmethod
    def sync():
        """
        Wait until all child tasks have terminated.
        """
        r = executor.current_result()
        r.wait()
        if r.value is ERROR:
            r.get() 
[docs]    @staticmethod
    def terminal():
        return executor.MUXER.terminal 
[docs]    @staticmethod
    def echo(*args, **kwargs):
        executor.current().echo(*args, **kwargs) 
[docs]    @staticmethod
    def info(*args, **kwargs):
        executor.current().info(*args, **kwargs) 
[docs]    @staticmethod
    def warn(*args, **kwargs):
        executor.current().warn(*args, **kwargs) 
[docs]    @staticmethod
    def error(*args, **kwargs):
        executor.current().error(*args, **kwargs)  
_UNBOUND = Sentinel("_UNBOUND")
[docs]class decorator(object):
    def __init__(self, task, object = _UNBOUND):
        self.task = task
        self.object = object
        self.__name__ = getattr(self.task.function, "__name__", "<unknown>")
    def __get__(self, object, clazz):
        return decorator(self.task, object)
    def _munge(self, args):
        if self.object is _UNBOUND:
            return args
        else:
            return (self.object,) + args
    def __call__(self, *args, **kwargs):
        exe = executor(self.task._context(args, kwargs))
        result = exe.run(self.task.function, *self._munge(args), **kwargs)
        return result.get()
[docs]    def go(self, *args, **kwargs):
        exe = executor(self.task._context(args, kwargs), async=True)
        result = exe.run(self.task.function, *self._munge(args), **kwargs)
        return result 
[docs]    def run(self, *args, **kwargs):
        result = self.go(*args, **kwargs)
        result.wait()
        result.executor.echo(result.report())
        return result  
[docs]def elide(t):
    if isinstance(t, Secret):
        return "<ELIDED>"
    elif isinstance(t, Elidable):
        return t.elide()
    else:
        return t 
[docs]class Elidable(object):
    def __init__(self, *parts):
        self.parts = parts
[docs]    def elide(self):
        return "".join(elide(p) for p in self.parts) 
    def __str__(self):
        return "".join(str(p) for p in self.parts) 
[docs]class execution(object):
[docs]    def log(self, *args, **kwargs):
        self.task.logger.log(*args, **kwargs) 
[docs]    def info(self, *args, **kwargs):
        self.task.logger.info(*args, **kwargs)  
[docs]def gather(sequence):
    """
    Resolve a sequence of asynchronously executed tasks.
    """
    for obj in sequence:
        if isinstance(obj, Result):
            yield obj.get()
        else:
            yield obj 
OMIT = Sentinel("OMIT")
def _taskify(obj):
    if isinstance(obj, decorator):
        return obj
    else:
        @task()
        def applicator(*args, **kwargs):
            return obj(*args, **kwargs)
        return applicator
[docs]def project(task, sequence):
    task = _taskify(task)
    execs = []
    for obj in sequence:
        execs.append(task.go(obj))
    for e in execs:
        obj = e.get()
        if obj is not OMIT:
            yield obj 
[docs]def cull(task, sequence):
    task = _taskify(task)
    execs = []
    for obj in sequence:
        execs.append((task.go(obj), obj))
    for e, obj in execs:
        if e.get():
            yield obj 
## common tasks
from eventlet.green.subprocess import Popen, STDOUT, PIPE
[docs]class SHResult(object):
    def __init__(self, command, code, output):
        self.command = command
        self.code = code
        self.output = output
    def __str__(self):
        if self.code != 0:
            code = "[exit %s]" % self.code
            if self.output:
                return "%s: %s" % (code, self.output)
            else:
                return code
        else:
            return self.output 
[docs]@task("CMD")
def sh(*args, **kwargs):
    output_transform = kwargs.pop("output_transform", lambda l: l)
    expected = kwargs.pop("expected", (0,))
    output_buffer = kwargs.pop("output_buffer", 10)
    cmd = tuple(str(a) for a in args)
    kwcopy = kwargs.copy()
    parts = []
    cwd = kwcopy.pop("cwd", None)
    if cwd is not None and not os.path.samefile(cwd, os.getcwd()):
        relcwd = os.path.relpath(cwd)
        abscwd = os.path.abspath(cwd)
        mincwd = relcwd if len(relcwd) < len(abscwd) else abscwd
        parts.append("[%s]" % mincwd)
    env = kwcopy.pop("env", None)
    if env is not None:
        for k, v in env.items():
            if v != os.environ.get(k, None):
                parts.append("%s=%s" % (k, v))
    parts.extend(str(elide(a)) for a in args)
    command = " ".join(parts)
    try:
        p = Popen(cmd, stderr=STDOUT, stdout=PIPE, **kwargs)
        output = ""
        line_buffer = [command]
        start = time.time()
        for line in p.stdout:
            output += line
            line_buffer.append(output_transform(line[:-1]))
            elapsed = time.time() - start
            if (len(line_buffer) > output_buffer) or (elapsed > 1.0):
                while line_buffer:
                    task.info(line_buffer.pop(0))
            start = time.time()
        while line_buffer:
            task.info(line_buffer.pop(0))
        p.wait()
        result = SHResult(command, p.returncode, output)
    except OSError, e:
        raise TaskError("error executing command '%s': %s" % (command, e))
    if p.returncode in expected:
        return result
    else:
        raise TaskError("command '%s' failed[%s]: %s" % (command, result.code, result.output)) 
requests = eventlet.import_patched('requests.__init__') # the .__init__ is a workaround for: https://github.com/eventlet/eventlet/issues/208
[docs]def json_patch(response, parser):
    def patched():
        try:
            return parser()
        except ValueError, e:
            task.echo("== response could not be parsed as JSON ==")
            task.echo(response.content)
            raise
    return patched 
[docs]@task("GET")
def get(url, **kwargs):
    task.info("GET %s" % url)
    try:
        response = requests.get(str(url), **kwargs)
        response.json = json_patch(response, response.json)
        return response
    except requests.RequestException, e:
        raise TaskError(e) 
import watchdog, watchdog.events
class _Wrapper(watchdog.events.FileSystemEventHandler):
    def __init__(self, action):
        for attr in "on_any_event", "on_created", "on_deleted", "on_modified", "on_moved":
            meth = getattr(action, attr, None)
            if meth:
                setattr(self, attr, meth)
    @task()
    def dispatch(self, event):
        watchdog.events.FileSystemEventHandler.dispatch(self, event)
[docs]@task()
def watch(paths, action):
    handler = _Wrapper(action)
    obs = watchdog.observers.Observer()
    for path in paths:
        obs.schedule(handler, path, recursive=True)
    obs.start()