from collections.abc import Iterator, Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from opentelemetry import context as otel_context
from opentelemetry.propagators.textmap import TextMapPropagator
from typing import Any

__all__ = ['get_context', 'attach_context', 'ContextCarrier']

ContextCarrier = Mapping[str, Any]

def get_context() -> ContextCarrier:
    """Create a new empty carrier dict and inject context into it.

    Returns:
        A new dict with the context injected into it.

    Usage:

    ```py
    import logfire

    logfire_context = logfire.get_context()

    ...

    # later on in another thread, process or service
    with logfire.attach_context(logfire_context):
        ...
    ```

    You could also inject context into an existing mapping like headers with:

    ```py
    import logfire

    existing_headers = {'X-Foobar': 'baz'}
    existing_headers.update(logfire.get_context())
    ...
    ```
    """
@contextmanager
def attach_context(carrier: ContextCarrier, *, third_party: bool = False, propagator: TextMapPropagator | None = None) -> Iterator[None]:
    """Attach a context as generated by [`get_context`][logfire.propagate.get_context] to the current execution context.

    Since `attach_context` is a context manager, it restores the previous context when exiting.

    Set `third_party` to `True` if using this inside a library intended to be used by others.
    This will respect the [`distributed_tracing` argument of `logfire.configure()`][logfire.configure(distributed_tracing)],
    so users will be warned about unintentional distributed tracing by default and they can suppress it.
    See [Unintentional Distributed Tracing](https://logfire.pydantic.dev/docs/how-to-guides/distributed-tracing/#unintentional-distributed-tracing) for more information.
    """

@dataclass
class WrapperPropagator(TextMapPropagator):
    """Helper base class to wrap another propagator."""
    wrapped: TextMapPropagator
    def extract(self, *args: Any, **kwargs: Any) -> otel_context.Context: ...
    def inject(self, *args: Any, **kwargs: Any): ...
    @property
    def fields(self): ...

class NoExtractTraceContextPropagator(WrapperPropagator):
    """A propagator that ignores any trace context that was extracted by the wrapped propagator.

    Used when `logfire.configure(distributed_tracing=False)` is called.
    """
    def extract(self, carrier: Any, context: otel_context.Context | None = None, *args: Any, **kwargs: Any) -> otel_context.Context: ...

@dataclass
class WarnOnExtractTraceContextPropagator(WrapperPropagator):
    """A propagator that warns the first time that trace context is extracted by the wrapped propagator.

    Used when `logfire.configure(distributed_tracing=None)` is called. This is the default behavior.
    """
    warned: bool = ...
    def extract(self, carrier: Any, context: otel_context.Context | None = None, *args: Any, **kwargs: Any) -> otel_context.Context: ...
