Source code for nesta.core.luigihacks.parameter

"""
parameter
=========

Heavily based on :py:class:`luigi.parameter`. This package
extends the :py:class:`luigi.DictParameter` to allow dict values
t include :py:class:`luigi.Task`.
"""

import luigi
from luigi.parameter import _DictParamEncoder
import json
from datetime import datetime, date

class _DictParamEncoderPlus(_DictParamEncoder):
    """
    JSON encoder for :py:class:`~DictParameterPlus`, which makes :py:class:`Task` JSON serializable.
    """
    def default(self, obj):
        try:
            return super().default(obj)
        except TypeError:
            if isinstance(obj, luigi.Task):
                return obj.get_task_family()
            elif isinstance(obj, (datetime, date)):
                return obj.isoformat()

class DictParameterPlus(luigi.DictParameter):
    """
    Parameter whose value is a ``dict` and whose values may include
    a :py:class:`Task`.
    """
    def __init__(self, encoder=_DictParamEncoderPlus, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.encoder = encoder

    def serialize(self, x):
        return json.dumps(x, cls=self.encoder)